feat/nnx-post-train-fixes| Metric | Linen 7f06c99ac | NNX 7f06c99ac | Diff (NNX − Linen) |
|---|---|---|---|
| Parameters | 0.000 billion | 0.000 billion | — |
| Final loss | 5.7660 | 5.6700 | -0.096 |
| TFLOP/s | 0.014 | 0.011 | -0.003 |
| Tok/s | 67135.0 | 52847.1 | -14287.857 |
| Avg s/step | 3.974 | 3.921 | -0.053 |
| Memory % | 0.03 | 0.03 | 0 |
| JAX | 0.8.3 | 0.8.3 | — |
Diff = NNX value − Linen value. Green = NNX improved. Red = NNX regressed.
XPK Start: Thu Apr 23 12:50:13 UTC 2026
Unrecognized keys in `rope_scaling` for 'rope_type'='yarn': {'rope_theta'}
`rope_scaling`'s factor field must be a float >= 1, got 40
`rope_scaling`'s beta_fast field must be a float, got 32
`rope_scaling`'s beta_slow field must be a float, got 1
Unrecognized keys in `rope_scaling` for 'rope_type'='yarn': {'rope_theta'}
Unrecognized keys in `rope_scaling` for 'rope_type'='yarn': {'rope_theta'}
Unrecognized keys in `rope_scaling` for 'rope_type'='yarn': {'rope_theta'}
2026-04-23 12:51:09.477411: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)
I0423 12:51:09.772345 137877435512640 max_utils.py:273] Attempting to initialize the jax distributed system...
INFO:2026-04-23 12:51:18,813:jax._src.distributed:149: Starting JAX distributed service on [::]:8482
I0423 12:51:18.813870 137877435512640 distributed.py:149] Starting JAX distributed service on [::]:8482
INFO:2026-04-23 12:51:18,816:jax._src.distributed:166: Connecting to JAX distributed service on mt-01-sft-smoke-pdd8m-slice-job-0-0.mt-01-sft-smoke-pdd8m:8482
I0423 12:51:18.816180 137877435512640 distributed.py:166] Connecting to JAX distributed service on mt-01-sft-smoke-pdd8m-slice-job-0-0.mt-01-sft-smoke-pdd8m:8482
I0423 12:51:39.635349 137877435512640 max_utils.py:284] Jax distributed system initialized!
I0423 12:51:45.625696 137877435512640 max_utils.py:800] System Information: Jax Version: 0.8.3
I0423 12:51:45.625801 137877435512640 max_utils.py:801] System Information: Jaxlib Version: 0.8.3
I0423 12:51:45.625841 137877435512640 max_utils.py:802] System Information: Jax Backend: PJRT C API
TFRT TPU v6 lite
Built on Dec 15 2025 14:03:46 (1765836226) cl/844590465
I0423 12:51:45.629241 137877435512640 maxtext_utils.py:1771] Num_devices: 32, shape (1, 4, 1, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1)
I0423 12:51:45.818412 137877435512640 maxtext_utils.py:1771] Num_devices: 32, shape (1, 4, 1, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1)
I0423 12:51:46.974288 137877435512640 config.py:112] TensorFlow version 2.20.0 available.
I0423 12:51:46.974801 137877435512640 config.py:125] JAX version 0.8.3 available.
/deps/src/maxtext/input_pipeline/input_pipeline_utils.py:467: UserWarning: WARNING: Inefficient dataloading. Your train or eval dataset contains 3 shards, smaller than number of host loading data. This is known to lead to inefficient dataloading. Seegithub.com/google/maxtext/blob/main/getting_started/Data_Input_Pipeline.md#multihost-dataloading-best-practice
warnings.warn(
E0423 12:51:52.328268 137877435512640 packing.py:209] PackAndBatchOperation is deprecated. Please use lazy_dataset.FirstFitPackIterDataset instead.
I0423 12:51:52.328494 137877435512640 data_loader.py:408] Adding CopyNumPyArrayToSharedMemory MapTransform.
I0423 12:51:52.707172 137877435512640 pytree_checkpoint_handler.py:577] save_device_host_concurrent_bytes=None
I0423 12:51:52.707632 137877435512640 base_pytree_checkpoint_handler.py:411] Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=False, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x7d6561bdb500>, enable_pinned_host_transfer=False, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB)
I0423 12:51:52.707680 137877435512640 pytree_checkpoint_handler.py:577] save_device_host_concurrent_bytes=None
I0423 12:51:52.707721 137877435512640 base_pytree_checkpoint_handler.py:411] Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=False, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x7d6561bdb500>, enable_pinned_host_transfer=False, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB)
I0423 12:51:52.707770 137877435512640 checkpoint_manager.py:702] [process=3][thread=MainThread] CheckpointManager init: checkpointers=None, item_names=None, item_handlers={'model_params': <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7d489406fce0>, 'optimizer_state': <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7d47d8167a70>, 'custom_metadata': <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7d47d81674d0>}, handler_registry=None
I0423 12:51:52.707998 137877435512640 composite_checkpoint_handler.py:237] Deferred registration for item: "model_params". Adding handler `<orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7d489406fce0>` for item "model_params" and save args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>` to `_handler_registry`.
I0423 12:51:52.708044 137877435512640 composite_checkpoint_handler.py:237] Deferred registration for item: "optimizer_state". Adding handler `<orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7d47d8167a70>` for item "optimizer_state" and save args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>` to `_handler_registry`.
I0423 12:51:52.708073 137877435512640 composite_checkpoint_handler.py:237] Deferred registration for item: "custom_metadata". Adding handler `<orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7d47d81674d0>` for item "custom_metadata" and save args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>` to `_handler_registry`.
I0423 12:51:52.708098 137877435512640 composite_checkpoint_handler.py:237] Deferred registration for item: "metrics". Adding handler `<orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7d47f820d640>` for item "metrics" and save args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>` to `_handler_registry`.
I0423 12:51:52.708127 137877435512640 composite_checkpoint_handler.py:505] Initialized registry DefaultCheckpointHandlerRegistry({('model_params', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7d489406fce0>, ('model_params', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7d489406fce0>, ('optimizer_state', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7d47d8167a70>, ('optimizer_state', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7d47d8167a70>, ('custom_metadata', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7d47d81674d0>, ('custom_metadata', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7d47d81674d0>, ('metrics', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7d47f820d640>, ('metrics', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7d47f820d640>}).
I0423 12:51:52.708372 137877435512640 abstract_checkpointer.py:35] orbax-checkpoint version: 0.11.28
I0423 12:51:52.708425 137877435512640 async_checkpointer.py:177] [process=3][thread=MainThread] Using barrier_sync_fn: <function get_barrier_sync_fn.<locals>._fn at 0x7d4c6b523060> timeout: 600 secs and primary_host=0 for async checkpoint writes
I0423 12:51:56.295186 137877435512640 checkpoint_manager.py:1788] Found 0 checkpoint steps in gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_post_train_fixes_20260423_124550/pt_sft_linen_xpk_feat_nnx_post_train_fixes_20260423_124550_01_sft_smoke/checkpoints
I0423 12:51:56.713753 137877435512640 checkpoint_manager.py:921] [process=3][thread=MainThread] CheckpointManager created, primary_host=0, CheckpointManagerOptions=CheckpointManagerOptions(save_interval_steps=10000, max_to_keep=None, keep_time_interval=None, keep_period=None, should_keep_fn=None, best_fn=None, best_mode='max', keep_checkpoints_without_metrics=True, step_prefix=None, step_format_fixed_length=None, step_name_format=None, create=True, cleanup_tmp_directories=False, save_on_steps=frozenset(), single_host_load_and_broadcast=False, todelete_subdir=None, todelete_full_path=None, enable_hns=False, enable_background_delete=False, read_only=False, enable_async_checkpointing=True, async_options=None, multiprocessing_options=MultiprocessingOptions(primary_host=0, active_processes=None, barrier_sync_key_prefix=None), should_save_fn=None, file_options=FileOptions(path_permission_mode=None), save_root_metadata=True, temporary_path_class=None, save_decision_policy=None, preservation_policy=None, prevent_write_metrics=False, enable_should_save_is_saving_in_progress_check=True, enable_per_process_directory_creation=False), root_directory=gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_post_train_fixes_20260423_124550/pt_sft_linen_xpk_feat_nnx_post_train_fixes_20260423_124550_01_sft_smoke/checkpoints: <orbax.checkpoint.checkpoint_manager.CheckpointManager object at 0x7d47d8165fa0>
I0423 12:51:56.714137 137877435512640 peft_trainer.py:584] Training with mesh: Mesh('diloco': 1, 'data': 4, 'stage': 1, 'fsdp': 8, 'fsdp_transpose': 1, 'sequence': 1, 'context': 1, 'context_autoregressive': 1, 'tensor': 1, 'tensor_transpose': 1, 'tensor_sequence': 1, 'expert': 1, 'autoregressive': 1, axis_types=(Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto))
I0423 12:51:57.123275 137877435512640 peft_trainer.py:594] Compiled train_step cache size: 0
I0423 12:51:57.125353 137877435512640 metric_logger.py:301] number parameters: 0.000 billion
I0423 12:51:57.127718 137724593940224 grain_pool.py:367] Grain pool will use 1 processes.
I0423 12:51:57.157135 137724593940224 grain_pool.py:440] Grain pool will start child processes.
Per train step:
Total TFLOPs: 0.00
split as 54.29% learnable weight flops and 45.71% attention flops
I0423 12:51:57.162610 137724593940224 grain_pool.py:448] Grain pool started all child processes.
2026-04-23 12:52:01.197405: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-04-23 12:52:01.243029: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-04-23 12:52:02.410366: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
Unrecognized keys in `rope_scaling` for 'rope_type'='yarn': {'rope_theta'}
`rope_scaling`'s factor field must be a float >= 1, got 40
`rope_scaling`'s beta_fast field must be a float, got 32
`rope_scaling`'s beta_slow field must be a float, got 1
Unrecognized keys in `rope_scaling` for 'rope_type'='yarn': {'rope_theta'}
Unrecognized keys in `rope_scaling` for 'rope_type'='yarn': {'rope_theta'}
Unrecognized keys in `rope_scaling` for 'rope_type'='yarn': {'rope_theta'}
2026-04-23 12:52:06.651257: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)
I0423 12:52:15.300107 137877435512640 checkpoint_manager.py:1983] [process=3][thread=MainThread][wait_until_finished] No Save Finalize thread to wait for. Returning.
I0423 12:52:15.302035 137877435512640 checkpoint_manager.py:1501] [process=3] Saving checkpoint at step 1
I0423 12:52:15.305342 137877435512640 async_checkpointer.py:452] [process=3] Started async saving checkpoint to gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_post_train_fixes_20260423_124550/pt_sft_linen_xpk_feat_nnx_post_train_fixes_20260423_124550_01_sft_smoke/checkpoints/1.
I0423 12:52:15.847815 137877435512640 signaling_client.py:364] Using JaxDistributedSignalingClient
I0423 12:52:15.849104 137877435512640 jax_array_handlers.py:347] Scheduling D2H of 52 prioritized jax.Array.
I0423 12:52:15.849175 137877435512640 replica_slices.py:410] Transferring arrays to host memory with options: use_replica_parallel=True, min_slice_bytes_for_replica_parallel=None, max_replicas_for_replica_parallel=None, enable_pinned_host_transfer=False
I0423 12:52:16.556372 137877435512640 base_pytree_checkpoint_handler.py:153] [process=3][thread=MainThread] Initiated "orbax.checkpoint._src.serialization.jax_array_handlers.ArrayHandler".serialize. Time taken: 0.708692s
I0423 12:52:16.557897 137877435512640 base_pytree_checkpoint_handler.py:128] [process=3] /jax/checkpoint/write/blocking_gbytes_per_sec: 80.076 KiB/s (total gbytes: 76.6 KiB) (time elapsed: 956 milliseconds) (per-host)
I0423 12:52:16.557980 137877435512640 base_pytree_checkpoint_handler.py:732] [process=3][thread=MainThread] Initiated Pytree async_save. Time taken: 0.956638s (batch_requests_ready=0.241058s, total_serialization_initiated=0.715452s, others=0.000129s)
I0423 12:52:16.559186 137877435512640 jax_array_handlers.py:347] Scheduling D2H of 22 prioritized jax.Array.
I0423 12:52:16.559255 137877435512640 replica_slices.py:410] Transferring arrays to host memory with options: use_replica_parallel=True, min_slice_bytes_for_replica_parallel=None, max_replicas_for_replica_parallel=None, enable_pinned_host_transfer=False
I0423 12:52:16.564334 137877435512640 base_pytree_checkpoint_handler.py:153] [process=3][thread=MainThread] Initiated "orbax.checkpoint._src.serialization.jax_array_handlers.ArrayHandler".serialize. Time taken: 0.006195s
I0423 12:52:16.564464 137877435512640 base_pytree_checkpoint_handler.py:128] [process=3] /jax/checkpoint/write/blocking_gbytes_per_sec: 26.453 KiB/s (total gbytes: 25.5 KiB) (time elapsed: 965 milliseconds) (per-host)
I0423 12:52:16.564523 137877435512640 base_pytree_checkpoint_handler.py:732] [process=3][thread=MainThread] Initiated Pytree async_save. Time taken: 0.965239s (batch_requests_ready=0.957292s, total_serialization_initiated=0.007858s, others=0.000089s)
I0423 12:52:16.564634 137877435512640 composite_checkpoint_handler.py:715] [process=3][thread=MainThread] Initiated CompositeCheckpointHandler.async_save. Time taken: 0.969350s (all_items=0.000021s, per_item={'model_params': '0.00001740', 'optimizer_state': '0.00000381'}, temp_paths=0.969329)
I0423 12:52:16.565697 137721934763776 async_checkpointer.py:79] [process=3][thread=async_save] Background save thread started.
I0423 12:52:16.565876 137877435512640 async_checkpointer.py:561] Finished blocking save. Time taken: 1.263766s. Continuing background save to gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_post_train_fixes_20260423_124550/pt_sft_linen_xpk_feat_nnx_post_train_fixes_20260423_124550_01_sft_smoke/checkpoints/1.
I0423 12:52:16.595017 137877435512640 checkpoint_manager.py:1549] [process=3][thread=MainThread][step=1] Starting CheckpointManager Save Finalize thread=save_finalize
I0423 12:52:16.595308 137722463241984 async_checkpointer.py:265] [process=3][thread=save_finalize] Waiting for background save thread=async_save.
I0423 12:52:16.595468 137877435512640 standard_logger.py:34] {'step': 1, 'event_type': 'save', 'directory': 'gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_post_train_fixes_20260423_124550/pt_sft_linen_xpk_feat_nnx_post_train_fixes_20260423_124550_01_sft_smoke/checkpoints', 'reached_preemption': False, 'preemption_received_at': None, 'synchronous': False, 'wait_for_prev_start_time': 1776948735.300077, 'wait_for_prev_duration_secs': 0.00012302398681640625, 'time_between_consecutive_saves_sec': None, 'checkpointer_blocking_start_time': 1776948735.3020787, 'checkpointer_blocking_duration_secs': 1.2639145851135254, 'get_old_steps_start_time': 1776948736.5660164, 'get_old_steps_duration_secs': 7.867813110351562e-05, 'checkpoint_manager_blocking_start_time': 1776948735.2495546, 'checkpoint_manager_blocking_duration_secs': 1.3458783626556396}
I0423 12:52:16.734617 137877435512640 peft_trainer.py:474] Train step 1 training loss: 5.894745 - training perplexity: 363.124207
I0423 12:52:16.734892 137877435512640 max_utils.py:750]
Memstats: After params initialized:
I0423 12:52:16.734972 137877435512640 max_utils.py:756] Using (GB) 0.01 / 31.25 (0.032000%) on TPU_10(process=3,(2,2,0,0))
I0423 12:52:16.735011 137877435512640 max_utils.py:756] Using (GB) 0.01 / 31.25 (0.032000%) on TPU_11(process=3,(3,2,0,0))
I0423 12:52:16.735043 137877435512640 max_utils.py:756] Using (GB) 0.01 / 31.25 (0.032000%) on TPU_14(process=3,(2,3,0,0))
I0423 12:52:16.735071 137877435512640 max_utils.py:756] Using (GB) 0.01 / 31.25 (0.032000%) on TPU_15(process=3,(3,3,0,0))
I0423 12:52:16.864545 137877435512640 metric_logger.py:196] completed step: 1, seconds: 19.610, TFLOP/s/device: 0.000, Tokens/s/device: 52.219, total_weights: 21054, loss: 5.895, lm_loss: 0.000, perplexity: 0.000
I0423 12:52:16.873330 137877435512640 peft_trainer.py:474] Train step 2 training loss: 5.603691 - training perplexity: 271.426270
I0423 12:52:16.874102 137877435512640 metric_logger.py:196] completed step: 2, seconds: 0.138, TFLOP/s/device: 0.002, Tokens/s/device: 7397.917, total_weights: 21455, loss: 5.604, lm_loss: 0.000, perplexity: 0.000
I0423 12:52:16.961464 137877435512640 peft_trainer.py:474] Train step 3 training loss: 5.511911 - training perplexity: 247.623871
I0423 12:52:16.962353 137877435512640 metric_logger.py:196] completed step: 3, seconds: 0.088, TFLOP/s/device: 0.002, Tokens/s/device: 11610.749, total_weights: 22025, loss: 5.512, lm_loss: 0.000, perplexity: 0.000
I0423 12:52:16.980087 137877435512640 peft_trainer.py:474] Train step 4 training loss: 5.686435 - training perplexity: 294.840698
I0423 12:52:16.980765 137877435512640 metric_logger.py:196] completed step: 4, seconds: 0.019, TFLOP/s/device: 0.012, Tokens/s/device: 55211.800, total_weights: 23787, loss: 5.686, lm_loss: 0.000, perplexity: 0.000
I0423 12:52:16.994415 137877435512640 peft_trainer.py:733] Train loop finished in: 19.8679 seconds
I0423 12:52:16.995357 137877435512640 peft_trainer.py:474] Train step 5 training loss: 5.766005 - training perplexity: 319.259766
I0423 12:52:16.996008 137877435512640 metric_logger.py:196] completed step: 5, seconds: 0.015, TFLOP/s/device: 0.014, Tokens/s/device: 67134.994, total_weights: 20141, loss: 5.766, lm_loss: 0.000, perplexity: 0.000
I0423 12:52:18.896984 2994 google_auth_provider.cc:181] Running on GCE, using service account 562977990677-compute@developer.gserviceaccount.com
I0423 12:52:20.735967 137722438063872 array_metadata_store.py:203] [process=3][thread=array_type_handler] Wrote 52 array_metadata.ArrayMetadata to gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_post_train_fixes_20260423_124550/pt_sft_linen_xpk_feat_nnx_post_train_fixes_20260423_124550_01_sft_smoke/checkpoints/1/optimizer_state/array_metadatas/process_3
I0423 12:52:21.247812 137721943156480 array_metadata_store.py:203] [process=3][thread=array_type_handler] Wrote 22 array_metadata.ArrayMetadata to gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_post_train_fixes_20260423_124550/pt_sft_linen_xpk_feat_nnx_post_train_fixes_20260423_124550_01_sft_smoke/checkpoints/1/model_params/array_metadatas/process_3
I0423 12:52:21.248944 137721934763776 base_pytree_checkpoint_handler.py:128] [process=3] /jax/checkpoint/write/gbytes_per_sec: 4.519 KiB/s (total gbytes: 25.5 KiB) (time elapsed: 5 seconds) (per-host)
I0423 12:52:21.249095 137721934763776 base_pytree_checkpoint_handler.py:128] [process=3] /jax/checkpoint/write/gbytes_per_sec: 13.562 KiB/s (total gbytes: 76.6 KiB) (time elapsed: 5 seconds) (per-host)
I0423 12:52:21.249134 137721934763776 async_checkpointer.py:90] [process=3][thread=async_save] 4 Handler Commit operations completed. Time taken: 4.683319s.
I0423 12:52:26.143213 137877435512640 checkpoint_manager.py:1994] [process=3][thread=MainThread][step=1][wait_until_finished] Waiting for Save Finalize thread (save_finalize) to complete.
I0423 12:52:32.271582 137721934763776 async_checkpointer.py:144] [process=3][thread=async_save] Background save thread done. Time taken: 15.705750s.
I0423 12:52:32.271867 137722463241984 async_checkpointer.py:273] [process=3][thread=save_finalize] Done with waiting for background save thread=async_save.
I0423 12:52:32.272135 137722463241984 async_checkpointer.py:283] [process=3][thread=save_finalize] No errors found in background save thread=async_save.
I0423 12:52:32.272259 137722463241984 checkpoint_manager.py:2103] [process=3][thread=save_finalize][step=1] CheckpointManager Save Finalize is syncing with other hosts...
I0423 12:52:32.273701 137722463241984 checkpoint_manager.py:2112] [process=3][thread=save_finalize][step=1] CheckpointManager Save Finalize is done on all hosts.
I0423 12:52:32.273879 137877435512640 checkpoint_manager.py:2006] [process=3][thread=MainThread][step=1][wait_until_finished] Done waiting for Save Finalize thread (save_finalize) running at step=1.
W0423 12:52:32.274023 137877435512640 checkpoint_manager.py:1441] Waiting for previous save to complete took 6.130828 seconds. If this number is high, consider checkpointing less frequently.
I0423 12:52:32.275604 137877435512640 checkpoint_manager.py:1501] [process=3] Saving checkpoint at step 5
I0423 12:52:32.278993 137877435512640 async_checkpointer.py:452] [process=3] Started async saving checkpoint to gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_post_train_fixes_20260423_124550/pt_sft_linen_xpk_feat_nnx_post_train_fixes_20260423_124550_01_sft_smoke/checkpoints/5.
I0423 12:52:32.823299 137877435512640 jax_array_handlers.py:347] Scheduling D2H of 22 prioritized jax.Array.
I0423 12:52:32.823391 137877435512640 replica_slices.py:410] Transferring arrays to host memory with options: use_replica_parallel=True, min_slice_bytes_for_replica_parallel=None, max_replicas_for_replica_parallel=None, enable_pinned_host_transfer=False
I0423 12:52:32.828223 137877435512640 base_pytree_checkpoint_handler.py:153] [process=3][thread=MainThread] Initiated "orbax.checkpoint._src.serialization.jax_array_handlers.ArrayHandler".serialize. Time taken: 0.006021s
I0423 12:52:32.830003 137877435512640 base_pytree_checkpoint_handler.py:128] [process=3] /jax/checkpoint/write/blocking_gbytes_per_sec: 96.566 KiB/s (total gbytes: 25.5 KiB) (time elapsed: 264 milliseconds) (per-host)
I0423 12:52:32.830060 137877435512640 base_pytree_checkpoint_handler.py:732] [process=3][thread=MainThread] Initiated Pytree async_save. Time taken: 0.264483s (batch_requests_ready=0.255038s, total_serialization_initiated=0.009351s, others=0.000095s)
I0423 12:52:32.830904 137877435512640 jax_array_handlers.py:347] Scheduling D2H of 52 prioritized jax.Array.
I0423 12:52:32.830970 137877435512640 replica_slices.py:410] Transferring arrays to host memory with options: use_replica_parallel=True, min_slice_bytes_for_replica_parallel=None, max_replicas_for_replica_parallel=None, enable_pinned_host_transfer=False
I0423 12:52:32.842113 137877435512640 base_pytree_checkpoint_handler.py:153] [process=3][thread=MainThread] Initiated "orbax.checkpoint._src.serialization.jax_array_handlers.ArrayHandler".serialize. Time taken: 0.011967s
I0423 12:52:32.842226 137877435512640 base_pytree_checkpoint_handler.py:128] [process=3] /jax/checkpoint/write/blocking_gbytes_per_sec: 278.314 KiB/s (total gbytes: 76.6 KiB) (time elapsed: 275 milliseconds) (per-host)
I0423 12:52:32.842269 137877435512640 base_pytree_checkpoint_handler.py:732] [process=3][thread=MainThread] Initiated Pytree async_save. Time taken: 0.275268s (batch_requests_ready=0.261331s, total_serialization_initiated=0.013870s, others=0.000066s)
I0423 12:52:32.842391 137877435512640 composite_checkpoint_handler.py:715] [process=3][thread=MainThread] Initiated CompositeCheckpointHandler.async_save. Time taken: 0.280574s (all_items=0.000012s, per_item={'model_params': '0.00000954', 'optimizer_state': '0.00000262'}, temp_paths=0.280562)
I0423 12:52:32.843281 137721934763776 async_checkpointer.py:79] [process=3][thread=async_save] Background save thread started.
I0423 12:52:32.843435 137877435512640 async_checkpointer.py:561] Finished blocking save. Time taken: 0.567760s. Continuing background save to gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_post_train_fixes_20260423_124550/pt_sft_linen_xpk_feat_nnx_post_train_fixes_20260423_124550_01_sft_smoke/checkpoints/5.
I0423 12:52:33.262471 137877435512640 checkpoint_manager.py:1549] [process=3][thread=MainThread][step=5] Starting CheckpointManager Save Finalize thread=save_finalize
I0423 12:52:33.262832 137722463241984 async_checkpointer.py:265] [process=3][thread=save_finalize] Waiting for background save thread=async_save.
I0423 12:52:33.262989 137877435512640 standard_logger.py:34] {'step': 5, 'event_type': 'save', 'directory': 'gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_post_train_fixes_20260423_124550/pt_sft_linen_xpk_feat_nnx_post_train_fixes_20260423_124550_01_sft_smoke/checkpoints', 'reached_preemption': False, 'preemption_received_at': None, 'synchronous': False, 'wait_for_prev_start_time': 1776948746.1431715, 'wait_for_prev_duration_secs': 6.130828142166138, 'time_between_consecutive_saves_sec': None, 'checkpointer_blocking_start_time': 1776948752.2756457, 'checkpointer_blocking_duration_secs': 0.5678951740264893, 'get_old_steps_start_time': 1776948752.8435628, 'get_old_steps_duration_secs': 7.891654968261719e-05, 'checkpoint_manager_blocking_start_time': 1776948736.9995844, 'checkpoint_manager_blocking_duration_secs': 16.263370990753174}
I0423 12:52:33.263232 137877435512640 checkpoint_manager.py:1994] [process=3][thread=MainThread][step=5][wait_until_finished] Waiting for Save Finalize thread (save_finalize) to complete.
I0423 12:52:37.933129 137721901192960 array_metadata_store.py:203] [process=3][thread=array_type_handler] Wrote 52 array_metadata.ArrayMetadata to gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_post_train_fixes_20260423_124550/pt_sft_linen_xpk_feat_nnx_post_train_fixes_20260423_124550_01_sft_smoke/checkpoints/5/optimizer_state/array_metadatas/process_3
I0423 12:52:37.947113 137722438063872 array_metadata_store.py:203] [process=3][thread=array_type_handler] Wrote 22 array_metadata.ArrayMetadata to gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_post_train_fixes_20260423_124550/pt_sft_linen_xpk_feat_nnx_post_train_fixes_20260423_124550_01_sft_smoke/checkpoints/5/model_params/array_metadatas/process_3
I0423 12:52:37.948221 137721934763776 base_pytree_checkpoint_handler.py:128] [process=3] /jax/checkpoint/write/gbytes_per_sec: 4.743 KiB/s (total gbytes: 25.5 KiB) (time elapsed: 5 seconds) (per-host)
I0423 12:52:37.948351 137721934763776 base_pytree_checkpoint_handler.py:128] [process=3] /jax/checkpoint/write/gbytes_per_sec: 14.233 KiB/s (total gbytes: 76.6 KiB) (time elapsed: 5 seconds) (per-host)
I0423 12:52:37.948385 137721934763776 async_checkpointer.py:90] [process=3][thread=async_save] 4 Handler Commit operations completed. Time taken: 5.104994s.
I0423 12:52:48.095706 137721934763776 async_checkpointer.py:144] [process=3][thread=async_save] Background save thread done. Time taken: 15.252298s.
I0423 12:52:48.096037 137722463241984 async_checkpointer.py:273] [process=3][thread=save_finalize] Done with waiting for background save thread=async_save.
I0423 12:52:48.096152 137722463241984 async_checkpointer.py:283] [process=3][thread=save_finalize] No errors found in background save thread=async_save.
I0423 12:52:48.096201 137722463241984 checkpoint_manager.py:2103] [process=3][thread=save_finalize][step=5] CheckpointManager Save Finalize is syncing with other hosts...
I0423 12:52:48.097750 137722463241984 checkpoint_manager.py:2112] [process=3][thread=save_finalize][step=5] CheckpointManager Save Finalize is done on all hosts.
I0423 12:52:48.097949 137877435512640 checkpoint_manager.py:2006] [process=3][thread=MainThread][step=5][wait_until_finished] Done waiting for Save Finalize thread (save_finalize) running at step=5.
I0423 12:52:48.098099 137877435512640 checkpoint.py:459] Closing _NonBlockingMetadataStore(enable_write=True, _write_lock=<locked _thread.RLock object owner=137877435512640 count=1 at 0x7d4b9c341340>, _store_impl=<orbax.checkpoint._src.metadata.checkpoint._MetadataStoreImpl object at 0x7d47d8171670>, _single_thread_executor=<concurrent.futures.thread.ThreadPoolExecutor object at 0x7d47d8173620>, _write_futures=[])
I0423 12:52:48.098536 137877435512640 checkpoint.py:459] Closing _NonBlockingMetadataStore(enable_write=True, _write_lock=<locked _thread.RLock object owner=137877435512640 count=1 at 0x7d4b9c341340>, _store_impl=<orbax.checkpoint._src.metadata.checkpoint._MetadataStoreImpl object at 0x7d47d8171670>, _single_thread_executor=<concurrent.futures.thread.ThreadPoolExecutor object at 0x7d47d8173620>, _write_futures=[])
I0423 12:52:48.098570 137877435512640 checkpoint.py:459] Closing _NonBlockingMetadataStore(enable_write=True, _write_lock=<locked _thread.RLock object owner=137877435512640 count=1 at 0x7d4b9c341340>, _store_impl=<orbax.checkpoint._src.metadata.checkpoint._MetadataStoreImpl object at 0x7d47d8171670>, _single_thread_executor=<concurrent.futures.thread.ThreadPoolExecutor object at 0x7d47d8173620>, _write_futures=[])
I0423 12:52:48.489856 137724593940224 grain_pool.py:542] Grain pool is exiting.
I0423 12:52:48.489983 137724593940224 grain_pool.py:547] Shutting down multiprocessing system.
I0423 12:52:50.612835 137724593940224 grain_pool.py:547] Shutting down multiprocessing system.
/usr/local/lib/python3.12/multiprocessing/resource_tracker.py:279: UserWarning: resource_tracker: There appear to be 15 leaked shared_memory objects to clean up at shutdown
warnings.warn('resource_tracker: There appear to be %d '
XPK End: Thu Apr 23 12:53:03 UTC 2026
EXIT_CODE=0
XPK Start: Thu Apr 23 13:04:29 UTC 2026
Unrecognized keys in `rope_scaling` for 'rope_type'='yarn': {'rope_theta'}
`rope_scaling`'s factor field must be a float >= 1, got 40
`rope_scaling`'s beta_fast field must be a float, got 32
`rope_scaling`'s beta_slow field must be a float, got 1
Unrecognized keys in `rope_scaling` for 'rope_type'='yarn': {'rope_theta'}
Unrecognized keys in `rope_scaling` for 'rope_type'='yarn': {'rope_theta'}
Unrecognized keys in `rope_scaling` for 'rope_type'='yarn': {'rope_theta'}
2026-04-23 13:04:58.440627: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)
I0423 13:04:58.685689 133087290353472 max_utils.py:273] Attempting to initialize the jax distributed system...
INFO:2026-04-23 13:05:07,725:jax._src.distributed:149: Starting JAX distributed service on [::]:8482
I0423 13:05:07.725977 133087290353472 distributed.py:149] Starting JAX distributed service on [::]:8482
INFO:2026-04-23 13:05:07,728:jax._src.distributed:166: Connecting to JAX distributed service on mt-01-sft-smoke-vzjhu-slice-job-0-0.mt-01-sft-smoke-vzjhu:8482
I0423 13:05:07.728435 133087290353472 distributed.py:166] Connecting to JAX distributed service on mt-01-sft-smoke-vzjhu-slice-job-0-0.mt-01-sft-smoke-vzjhu:8482
I0423 13:05:09.019104 133087290353472 max_utils.py:284] Jax distributed system initialized!
I0423 13:05:14.283369 133087290353472 max_utils.py:800] System Information: Jax Version: 0.8.3
I0423 13:05:14.283476 133087290353472 max_utils.py:801] System Information: Jaxlib Version: 0.8.3
I0423 13:05:14.283516 133087290353472 max_utils.py:802] System Information: Jax Backend: PJRT C API
TFRT TPU v6 lite
Built on Dec 15 2025 14:03:46 (1765836226) cl/844590465
I0423 13:05:14.286912 133087290353472 maxtext_utils.py:1771] Num_devices: 32, shape (1, 4, 1, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1)
I0423 13:05:14.381633 133087290353472 maxtext_utils.py:1771] Num_devices: 32, shape (1, 4, 1, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1)
I0423 13:05:14.482711 133087290353472 maxtext_utils.py:1771] Num_devices: 32, shape (1, 4, 1, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1)
I0423 13:05:15.596841 133087290353472 config.py:112] TensorFlow version 2.20.0 available.
I0423 13:05:15.597347 133087290353472 config.py:125] JAX version 0.8.3 available.
/deps/src/maxtext/input_pipeline/input_pipeline_utils.py:467: UserWarning: WARNING: Inefficient dataloading. Your train or eval dataset contains 3 shards, smaller than number of host loading data. This is known to lead to inefficient dataloading. Seegithub.com/google/maxtext/blob/main/getting_started/Data_Input_Pipeline.md#multihost-dataloading-best-practice
warnings.warn(
E0423 13:05:20.863260 133087290353472 packing.py:209] PackAndBatchOperation is deprecated. Please use lazy_dataset.FirstFitPackIterDataset instead.
I0423 13:05:20.863575 133087290353472 data_loader.py:408] Adding CopyNumPyArrayToSharedMemory MapTransform.
I0423 13:05:21.256447 133087290353472 pytree_checkpoint_handler.py:577] save_device_host_concurrent_bytes=None
I0423 13:05:21.256964 133087290353472 base_pytree_checkpoint_handler.py:411] Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=False, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x790a16dc7aa0>, enable_pinned_host_transfer=False, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB)
I0423 13:05:21.257013 133087290353472 pytree_checkpoint_handler.py:577] save_device_host_concurrent_bytes=None
I0423 13:05:21.257052 133087290353472 base_pytree_checkpoint_handler.py:411] Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=False, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x790a16dc7aa0>, enable_pinned_host_transfer=False, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB)
I0423 13:05:21.257098 133087290353472 checkpoint_manager.py:702] [process=5][thread=MainThread] CheckpointManager init: checkpointers=None, item_names=None, item_handlers={'model_params': <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x78ed7c13f6b0>, 'optimizer_state': <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7903b4f029f0>, 'custom_metadata': <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x78ed281978f0>}, handler_registry=None
I0423 13:05:21.257310 133087290353472 composite_checkpoint_handler.py:237] Deferred registration for item: "model_params". Adding handler `<orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x78ed7c13f6b0>` for item "model_params" and save args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>` to `_handler_registry`.
I0423 13:05:21.257357 133087290353472 composite_checkpoint_handler.py:237] Deferred registration for item: "optimizer_state". Adding handler `<orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7903b4f029f0>` for item "optimizer_state" and save args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>` to `_handler_registry`.
I0423 13:05:21.257386 133087290353472 composite_checkpoint_handler.py:237] Deferred registration for item: "custom_metadata". Adding handler `<orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x78ed281978f0>` for item "custom_metadata" and save args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>` to `_handler_registry`.
I0423 13:05:21.257412 133087290353472 composite_checkpoint_handler.py:237] Deferred registration for item: "metrics". Adding handler `<orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x78ed28196390>` for item "metrics" and save args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>` to `_handler_registry`.
I0423 13:05:21.257439 133087290353472 composite_checkpoint_handler.py:505] Initialized registry DefaultCheckpointHandlerRegistry({('model_params', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x78ed7c13f6b0>, ('model_params', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x78ed7c13f6b0>, ('optimizer_state', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7903b4f029f0>, ('optimizer_state', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7903b4f029f0>, ('custom_metadata', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x78ed281978f0>, ('custom_metadata', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x78ed281978f0>, ('metrics', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x78ed28196390>, ('metrics', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x78ed28196390>}).
I0423 13:05:21.257642 133087290353472 abstract_checkpointer.py:35] orbax-checkpoint version: 0.11.28
I0423 13:05:21.257704 133087290353472 async_checkpointer.py:177] [process=5][thread=MainThread] Using barrier_sync_fn: <function get_barrier_sync_fn.<locals>._fn at 0x78ed76f07420> timeout: 600 secs and primary_host=0 for async checkpoint writes
I0423 13:05:24.073213 133087290353472 checkpoint_manager.py:1788] Found 0 checkpoint steps in gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_post_train_fixes_20260423_124550/pt_sft_nnx_xpk_feat_nnx_post_train_fixes_20260423_124550_01_sft_smoke/checkpoints
I0423 13:05:24.491952 133087290353472 checkpoint_manager.py:921] [process=5][thread=MainThread] CheckpointManager created, primary_host=0, CheckpointManagerOptions=CheckpointManagerOptions(save_interval_steps=10000, max_to_keep=None, keep_time_interval=None, keep_period=None, should_keep_fn=None, best_fn=None, best_mode='max', keep_checkpoints_without_metrics=True, step_prefix=None, step_format_fixed_length=None, step_name_format=None, create=True, cleanup_tmp_directories=False, save_on_steps=frozenset(), single_host_load_and_broadcast=False, todelete_subdir=None, todelete_full_path=None, enable_hns=False, enable_background_delete=False, read_only=False, enable_async_checkpointing=True, async_options=None, multiprocessing_options=MultiprocessingOptions(primary_host=0, active_processes=None, barrier_sync_key_prefix=None), should_save_fn=None, file_options=FileOptions(path_permission_mode=None), save_root_metadata=True, temporary_path_class=None, save_decision_policy=None, preservation_policy=None, prevent_write_metrics=False, enable_should_save_is_saving_in_progress_check=True, enable_per_process_directory_creation=False), root_directory=gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_post_train_fixes_20260423_124550/pt_sft_nnx_xpk_feat_nnx_post_train_fixes_20260423_124550_01_sft_smoke/checkpoints: <orbax.checkpoint.checkpoint_manager.CheckpointManager object at 0x78ed28196840>
I0423 13:05:24.492314 133087290353472 peft_trainer.py:584] Training with mesh: Mesh('diloco': 1, 'data': 4, 'stage': 1, 'fsdp': 8, 'fsdp_transpose': 1, 'sequence': 1, 'context': 1, 'context_autoregressive': 1, 'tensor': 1, 'tensor_transpose': 1, 'tensor_sequence': 1, 'expert': 1, 'autoregressive': 1, axis_types=(Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto))
I0423 13:05:24.914184 133087290353472 peft_trainer.py:594] Compiled train_step cache size: 0
I0423 13:05:24.918195 133087290353472 metric_logger.py:301] number parameters: 0.000 billion
I0423 13:05:24.920594 132938047907584 grain_pool.py:367] Grain pool will use 1 processes.
I0423 13:05:24.946539 132938047907584 grain_pool.py:440] Grain pool will start child processes.
Per train step:
Total TFLOPs: 0.00
split as 54.29% learnable weight flops and 45.71% attention flops
I0423 13:05:24.951696 132938047907584 grain_pool.py:448] Grain pool started all child processes.
2026-04-23 13:05:28.944532: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-04-23 13:05:28.990292: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-04-23 13:05:30.163483: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
Unrecognized keys in `rope_scaling` for 'rope_type'='yarn': {'rope_theta'}
`rope_scaling`'s factor field must be a float >= 1, got 40
`rope_scaling`'s beta_fast field must be a float, got 32
`rope_scaling`'s beta_slow field must be a float, got 1
Unrecognized keys in `rope_scaling` for 'rope_type'='yarn': {'rope_theta'}
Unrecognized keys in `rope_scaling` for 'rope_type'='yarn': {'rope_theta'}
Unrecognized keys in `rope_scaling` for 'rope_type'='yarn': {'rope_theta'}
2026-04-23 13:05:34.355178: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)
I0423 13:05:42.836500 133087290353472 checkpoint_manager.py:1983] [process=5][thread=MainThread][wait_until_finished] No Save Finalize thread to wait for. Returning.
I0423 13:05:42.839325 133087290353472 checkpoint_manager.py:1501] [process=5] Saving checkpoint at step 1
I0423 13:05:42.842432 133087290353472 async_checkpointer.py:452] [process=5] Started async saving checkpoint to gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_post_train_fixes_20260423_124550/pt_sft_nnx_xpk_feat_nnx_post_train_fixes_20260423_124550_01_sft_smoke/checkpoints/1.
I0423 13:05:43.417838 133087290353472 signaling_client.py:364] Using JaxDistributedSignalingClient
I0423 13:05:43.418862 133087290353472 jax_array_handlers.py:347] Scheduling D2H of 46 prioritized jax.Array.
I0423 13:05:43.418921 133087290353472 replica_slices.py:410] Transferring arrays to host memory with options: use_replica_parallel=True, min_slice_bytes_for_replica_parallel=None, max_replicas_for_replica_parallel=None, enable_pinned_host_transfer=False
I0423 13:05:44.083886 133087290353472 base_pytree_checkpoint_handler.py:153] [process=5][thread=MainThread] Initiated "orbax.checkpoint._src.serialization.jax_array_handlers.ArrayHandler".serialize. Time taken: 0.667130s
I0423 13:05:44.085983 133087290353472 base_pytree_checkpoint_handler.py:128] [process=5] /jax/checkpoint/write/blocking_gbytes_per_sec: 27.203 KiB/s (total gbytes: 25.5 KiB) (time elapsed: 938 milliseconds) (per-host)
I0423 13:05:44.086064 133087290353472 base_pytree_checkpoint_handler.py:732] [process=5][thread=MainThread] Initiated Pytree async_save. Time taken: 0.938677s (batch_requests_ready=0.263607s, total_serialization_initiated=0.674946s, others=0.000124s)
I0423 13:05:44.086999 133087290353472 jax_array_handlers.py:347] Scheduling D2H of 52 prioritized jax.Array.
I0423 13:05:44.087068 133087290353472 replica_slices.py:410] Transferring arrays to host memory with options: use_replica_parallel=True, min_slice_bytes_for_replica_parallel=None, max_replicas_for_replica_parallel=None, enable_pinned_host_transfer=False
I0423 13:05:44.098926 133087290353472 base_pytree_checkpoint_handler.py:153] [process=5][thread=MainThread] Initiated "orbax.checkpoint._src.serialization.jax_array_handlers.ArrayHandler".serialize. Time taken: 0.012719s
I0423 13:05:44.099076 133087290353472 base_pytree_checkpoint_handler.py:128] [process=5] /jax/checkpoint/write/blocking_gbytes_per_sec: 80.765 KiB/s (total gbytes: 76.6 KiB) (time elapsed: 948 milliseconds) (per-host)
I0423 13:05:44.099142 133087290353472 base_pytree_checkpoint_handler.py:732] [process=5][thread=MainThread] Initiated Pytree async_save. Time taken: 0.948450s (batch_requests_ready=0.933410s, total_serialization_initiated=0.014934s, others=0.000107s)
I0423 13:05:44.099294 133087290353472 composite_checkpoint_handler.py:715] [process=5][thread=MainThread] Initiated CompositeCheckpointHandler.async_save. Time taken: 0.956047s (all_items=0.000024s, per_item={'model_params': '0.00001884', 'optimizer_state': '0.00000477'}, temp_paths=0.956023)
I0423 13:05:44.100374 132934874953472 async_checkpointer.py:79] [process=5][thread=async_save] Background save thread started.
I0423 13:05:44.100572 133087290353472 async_checkpointer.py:561] Finished blocking save. Time taken: 1.261176s. Continuing background save to gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_post_train_fixes_20260423_124550/pt_sft_nnx_xpk_feat_nnx_post_train_fixes_20260423_124550_01_sft_smoke/checkpoints/1.
I0423 13:05:44.124311 133087290353472 checkpoint_manager.py:1549] [process=5][thread=MainThread][step=1] Starting CheckpointManager Save Finalize thread=save_finalize
I0423 13:05:44.124597 132935395038976 async_checkpointer.py:265] [process=5][thread=save_finalize] Waiting for background save thread=async_save.
I0423 13:05:44.124780 133087290353472 standard_logger.py:34] {'step': 1, 'event_type': 'save', 'directory': 'gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_post_train_fixes_20260423_124550/pt_sft_nnx_xpk_feat_nnx_post_train_fixes_20260423_124550_01_sft_smoke/checkpoints', 'reached_preemption': False, 'preemption_received_at': None, 'synchronous': False, 'wait_for_prev_start_time': 1776949542.836466, 'wait_for_prev_duration_secs': 0.0001430511474609375, 'time_between_consecutive_saves_sec': None, 'checkpointer_blocking_start_time': 1776949542.8393655, 'checkpointer_blocking_duration_secs': 1.261326551437378, 'get_old_steps_start_time': 1776949544.1007133, 'get_old_steps_duration_secs': 7.843971252441406e-05, 'checkpoint_manager_blocking_start_time': 1776949542.6500492, 'checkpoint_manager_blocking_duration_secs': 1.4746966361999512}
I0423 13:05:44.268045 133087290353472 peft_trainer.py:474] Train step 1 training loss: 5.871947 - training perplexity: 354.939331
I0423 13:05:44.268313 133087290353472 max_utils.py:750]
Memstats: After params initialized:
I0423 13:05:44.268376 133087290353472 max_utils.py:756] Using (GB) 0.01 / 31.25 (0.032000%) on TPU_18(process=5,(2,4,0,0))
I0423 13:05:44.268412 133087290353472 max_utils.py:756] Using (GB) 0.01 / 31.25 (0.032000%) on TPU_19(process=5,(3,4,0,0))
I0423 13:05:44.268442 133087290353472 max_utils.py:756] Using (GB) 0.01 / 31.25 (0.032000%) on TPU_22(process=5,(2,5,0,0))
I0423 13:05:44.268467 133087290353472 max_utils.py:756] Using (GB) 0.01 / 31.25 (0.032000%) on TPU_23(process=5,(3,5,0,0))
I0423 13:05:44.399544 133087290353472 metric_logger.py:196] completed step: 1, seconds: 19.350, TFLOP/s/device: 0.000, Tokens/s/device: 52.919, total_weights: 21054, loss: 5.872, lm_loss: 0.000, perplexity: 0.000
I0423 13:05:44.408441 133087290353472 peft_trainer.py:474] Train step 2 training loss: 5.539500 - training perplexity: 254.550751
I0423 13:05:44.409253 133087290353472 metric_logger.py:196] completed step: 2, seconds: 0.140, TFLOP/s/device: 0.002, Tokens/s/device: 7307.116, total_weights: 21455, loss: 5.540, lm_loss: 0.000, perplexity: 0.000
I0423 13:05:44.481837 133087290353472 peft_trainer.py:474] Train step 3 training loss: 5.440901 - training perplexity: 230.649963
I0423 13:05:44.482746 133087290353472 metric_logger.py:196] completed step: 3, seconds: 0.073, TFLOP/s/device: 0.003, Tokens/s/device: 13940.505, total_weights: 22025, loss: 5.441, lm_loss: 0.000, perplexity: 0.000
I0423 13:05:44.507392 133087290353472 peft_trainer.py:474] Train step 4 training loss: 5.596125 - training perplexity: 269.380554
I0423 13:05:44.508203 133087290353472 metric_logger.py:196] completed step: 4, seconds: 0.025, TFLOP/s/device: 0.009, Tokens/s/device: 40179.871, total_weights: 23787, loss: 5.596, lm_loss: 0.000, perplexity: 0.000
I0423 13:05:44.525371 133087290353472 peft_trainer.py:733] Train loop finished in: 19.6061 seconds
I0423 13:05:44.526796 133087290353472 peft_trainer.py:474] Train step 5 training loss: 5.670155 - training perplexity: 290.079498
I0423 13:05:44.527517 133087290353472 metric_logger.py:196] completed step: 5, seconds: 0.019, TFLOP/s/device: 0.011, Tokens/s/device: 52847.137, total_weights: 20141, loss: 5.670, lm_loss: 0.000, perplexity: 0.000
I0423 13:05:46.661324 2920 google_auth_provider.cc:181] Running on GCE, using service account 562977990677-compute@developer.gserviceaccount.com
I0423 13:05:49.024945 132935378253568 array_metadata_store.py:203] [process=5][thread=array_type_handler] Wrote 46 array_metadata.ArrayMetadata to gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_post_train_fixes_20260423_124550/pt_sft_nnx_xpk_feat_nnx_post_train_fixes_20260423_124550_01_sft_smoke/checkpoints/1/model_params/array_metadatas/process_5
I0423 13:05:49.026191 132934874953472 base_pytree_checkpoint_handler.py:128] [process=5] /jax/checkpoint/write/gbytes_per_sec: 4.343 KiB/s (total gbytes: 25.5 KiB) (time elapsed: 5 seconds) (per-host)
I0423 13:05:49.112101 132935361468160 array_metadata_store.py:203] [process=5][thread=array_type_handler] Wrote 52 array_metadata.ArrayMetadata to gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_post_train_fixes_20260423_124550/pt_sft_nnx_xpk_feat_nnx_post_train_fixes_20260423_124550_01_sft_smoke/checkpoints/1/optimizer_state/array_metadatas/process_5
I0423 13:05:49.113256 132934874953472 base_pytree_checkpoint_handler.py:128] [process=5] /jax/checkpoint/write/gbytes_per_sec: 12.846 KiB/s (total gbytes: 76.6 KiB) (time elapsed: 5 seconds) (per-host)
I0423 13:05:49.113366 132934874953472 async_checkpointer.py:90] [process=5][thread=async_save] 4 Handler Commit operations completed. Time taken: 5.012860s.
I0423 13:05:53.523924 133087290353472 checkpoint_manager.py:1994] [process=5][thread=MainThread][step=1][wait_until_finished] Waiting for Save Finalize thread (save_finalize) to complete.
I0423 13:05:58.321393 132934874953472 async_checkpointer.py:144] [process=5][thread=async_save] Background save thread done. Time taken: 14.220871s.
I0423 13:05:58.321775 132935395038976 async_checkpointer.py:273] [process=5][thread=save_finalize] Done with waiting for background save thread=async_save.
I0423 13:05:58.321894 132935395038976 async_checkpointer.py:283] [process=5][thread=save_finalize] No errors found in background save thread=async_save.
I0423 13:05:58.321949 132935395038976 checkpoint_manager.py:2103] [process=5][thread=save_finalize][step=1] CheckpointManager Save Finalize is syncing with other hosts...
I0423 13:05:58.323485 132935395038976 checkpoint_manager.py:2112] [process=5][thread=save_finalize][step=1] CheckpointManager Save Finalize is done on all hosts.
I0423 13:05:58.323677 133087290353472 checkpoint_manager.py:2006] [process=5][thread=MainThread][step=1][wait_until_finished] Done waiting for Save Finalize thread (save_finalize) running at step=1.
W0423 13:05:58.323807 133087290353472 checkpoint_manager.py:1441] Waiting for previous save to complete took 4.799901 seconds. If this number is high, consider checkpointing less frequently.
I0423 13:05:58.325374 133087290353472 checkpoint_manager.py:1501] [process=5] Saving checkpoint at step 5
I0423 13:05:58.329092 133087290353472 async_checkpointer.py:452] [process=5] Started async saving checkpoint to gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_post_train_fixes_20260423_124550/pt_sft_nnx_xpk_feat_nnx_post_train_fixes_20260423_124550_01_sft_smoke/checkpoints/5.
I0423 13:05:58.878398 133087290353472 jax_array_handlers.py:347] Scheduling D2H of 46 prioritized jax.Array.
I0423 13:05:58.878502 133087290353472 replica_slices.py:410] Transferring arrays to host memory with options: use_replica_parallel=True, min_slice_bytes_for_replica_parallel=None, max_replicas_for_replica_parallel=None, enable_pinned_host_transfer=False
I0423 13:05:58.884139 133087290353472 base_pytree_checkpoint_handler.py:153] [process=5][thread=MainThread] Initiated "orbax.checkpoint._src.serialization.jax_array_handlers.ArrayHandler".serialize. Time taken: 0.007582s
I0423 13:05:58.885909 133087290353472 base_pytree_checkpoint_handler.py:128] [process=5] /jax/checkpoint/write/blocking_gbytes_per_sec: 92.078 KiB/s (total gbytes: 25.5 KiB) (time elapsed: 277 milliseconds) (per-host)
I0423 13:05:58.885970 133087290353472 base_pytree_checkpoint_handler.py:732] [process=5][thread=MainThread] Initiated Pytree async_save. Time taken: 0.277378s (batch_requests_ready=0.264320s, total_serialization_initiated=0.012955s, others=0.000103s)
I0423 13:05:58.886772 133087290353472 jax_array_handlers.py:347] Scheduling D2H of 52 prioritized jax.Array.
I0423 13:05:58.886826 133087290353472 replica_slices.py:410] Transferring arrays to host memory with options: use_replica_parallel=True, min_slice_bytes_for_replica_parallel=None, max_replicas_for_replica_parallel=None, enable_pinned_host_transfer=False
I0423 13:05:58.898743 133087290353472 base_pytree_checkpoint_handler.py:153] [process=5][thread=MainThread] Initiated "orbax.checkpoint._src.serialization.jax_array_handlers.ArrayHandler".serialize. Time taken: 0.012657s
I0423 13:05:58.898853 133087290353472 base_pytree_checkpoint_handler.py:128] [process=5] /jax/checkpoint/write/blocking_gbytes_per_sec: 266.266 KiB/s (total gbytes: 76.6 KiB) (time elapsed: 287 milliseconds) (per-host)
I0423 13:05:58.898894 133087290353472 base_pytree_checkpoint_handler.py:732] [process=5][thread=MainThread] Initiated Pytree async_save. Time taken: 0.287718s (batch_requests_ready=0.273083s, total_serialization_initiated=0.014571s, others=0.000063s)
I0423 13:05:58.899006 133087290353472 composite_checkpoint_handler.py:715] [process=5][thread=MainThread] Initiated CompositeCheckpointHandler.async_save. Time taken: 0.294258s (all_items=0.000014s, per_item={'model_params': '0.00001049', 'optimizer_state': '0.00000310'}, temp_paths=0.294245)
I0423 13:05:58.899968 132934874953472 async_checkpointer.py:79] [process=5][thread=async_save] Background save thread started.
I0423 13:05:58.900125 133087290353472 async_checkpointer.py:561] Finished blocking save. Time taken: 0.574681s. Continuing background save to gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_post_train_fixes_20260423_124550/pt_sft_nnx_xpk_feat_nnx_post_train_fixes_20260423_124550_01_sft_smoke/checkpoints/5.
I0423 13:05:59.309911 133087290353472 checkpoint_manager.py:1549] [process=5][thread=MainThread][step=5] Starting CheckpointManager Save Finalize thread=save_finalize
I0423 13:05:59.310286 132935395038976 async_checkpointer.py:265] [process=5][thread=save_finalize] Waiting for background save thread=async_save.
I0423 13:05:59.310461 133087290353472 standard_logger.py:34] {'step': 5, 'event_type': 'save', 'directory': 'gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_post_train_fixes_20260423_124550/pt_sft_nnx_xpk_feat_nnx_post_train_fixes_20260423_124550_01_sft_smoke/checkpoints', 'reached_preemption': False, 'preemption_received_at': None, 'synchronous': False, 'wait_for_prev_start_time': 1776949553.5238833, 'wait_for_prev_duration_secs': 4.799901247024536, 'time_between_consecutive_saves_sec': None, 'checkpointer_blocking_start_time': 1776949558.3254123, 'checkpointer_blocking_duration_secs': 0.5748212337493896, 'get_old_steps_start_time': 1776949558.9002562, 'get_old_steps_duration_secs': 8.296966552734375e-05, 'checkpoint_manager_blocking_start_time': 1776949544.5317876, 'checkpoint_manager_blocking_duration_secs': 14.778639793395996}
I0423 13:05:59.310695 133087290353472 checkpoint_manager.py:1994] [process=5][thread=MainThread][step=5][wait_until_finished] Waiting for Save Finalize thread (save_finalize) to complete.
I0423 13:06:03.720292 132935378253568 array_metadata_store.py:203] [process=5][thread=array_type_handler] Wrote 46 array_metadata.ArrayMetadata to gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_post_train_fixes_20260423_124550/pt_sft_nnx_xpk_feat_nnx_post_train_fixes_20260423_124550_01_sft_smoke/checkpoints/5/model_params/array_metadatas/process_5
I0423 13:06:03.721545 132934874953472 base_pytree_checkpoint_handler.py:128] [process=5] /jax/checkpoint/write/gbytes_per_sec: 4.994 KiB/s (total gbytes: 25.5 KiB) (time elapsed: 5 seconds) (per-host)
I0423 13:06:04.029778 132934329689856 array_metadata_store.py:203] [process=5][thread=array_type_handler] Wrote 52 array_metadata.ArrayMetadata to gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_post_train_fixes_20260423_124550/pt_sft_nnx_xpk_feat_nnx_post_train_fixes_20260423_124550_01_sft_smoke/checkpoints/5/optimizer_state/array_metadatas/process_5
I0423 13:06:04.031026 132934874953472 base_pytree_checkpoint_handler.py:128] [process=5] /jax/checkpoint/write/gbytes_per_sec: 14.132 KiB/s (total gbytes: 76.6 KiB) (time elapsed: 5 seconds) (per-host)
I0423 13:06:04.031141 132934874953472 async_checkpointer.py:90] [process=5][thread=async_save] 4 Handler Commit operations completed. Time taken: 5.131061s.
I0423 13:06:12.781555 132934874953472 async_checkpointer.py:144] [process=5][thread=async_save] Background save thread done. Time taken: 13.881461s.
I0423 13:06:12.781901 132935395038976 async_checkpointer.py:273] [process=5][thread=save_finalize] Done with waiting for background save thread=async_save.
I0423 13:06:12.782092 132935395038976 async_checkpointer.py:283] [process=5][thread=save_finalize] No errors found in background save thread=async_save.
I0423 13:06:12.782148 132935395038976 checkpoint_manager.py:2103] [process=5][thread=save_finalize][step=5] CheckpointManager Save Finalize is syncing with other hosts...
I0423 13:06:12.783565 132935395038976 checkpoint_manager.py:2112] [process=5][thread=save_finalize][step=5] CheckpointManager Save Finalize is done on all hosts.
I0423 13:06:12.783739 133087290353472 checkpoint_manager.py:2006] [process=5][thread=MainThread][step=5][wait_until_finished] Done waiting for Save Finalize thread (save_finalize) running at step=5.
I0423 13:06:12.783888 133087290353472 checkpoint.py:459] Closing _NonBlockingMetadataStore(enable_write=True, _write_lock=<locked _thread.RLock object owner=133087290353472 count=1 at 0x78eec4151580>, _store_impl=<orbax.checkpoint._src.metadata.checkpoint._MetadataStoreImpl object at 0x78ee241e0bc0>, _single_thread_executor=<concurrent.futures.thread.ThreadPoolExecutor object at 0x78ed2819c680>, _write_futures=[])
I0423 13:06:12.784342 133087290353472 checkpoint.py:459] Closing _NonBlockingMetadataStore(enable_write=True, _write_lock=<locked _thread.RLock object owner=133087290353472 count=1 at 0x78eec4151580>, _store_impl=<orbax.checkpoint._src.metadata.checkpoint._MetadataStoreImpl object at 0x78ee241e0bc0>, _single_thread_executor=<concurrent.futures.thread.ThreadPoolExecutor object at 0x78ed2819c680>, _write_futures=[])
I0423 13:06:12.784373 133087290353472 checkpoint.py:459] Closing _NonBlockingMetadataStore(enable_write=True, _write_lock=<locked _thread.RLock object owner=133087290353472 count=1 at 0x78eec4151580>, _store_impl=<orbax.checkpoint._src.metadata.checkpoint._MetadataStoreImpl object at 0x78ee241e0bc0>, _single_thread_executor=<concurrent.futures.thread.ThreadPoolExecutor object at 0x78ed2819c680>, _write_futures=[])
I0423 13:06:13.016732 132938047907584 grain_pool.py:542] Grain pool is exiting.
I0423 13:06:13.016831 132938047907584 grain_pool.py:547] Shutting down multiprocessing system.
I0423 13:06:15.120122 132938047907584 grain_pool.py:547] Shutting down multiprocessing system.
/usr/local/lib/python3.12/multiprocessing/resource_tracker.py:279: UserWarning: resource_tracker: There appear to be 15 leaked shared_memory objects to clean up at shutdown
warnings.warn('resource_tracker: There appear to be %d '
XPK End: Thu Apr 23 13:06:25 UTC 2026
EXIT_CODE=0