MaxView

‹ —Case: 01_sft_smoke07_distill_smoke ›

Metrics: Linen vs NNX  ·  feat/nnx-post-train-fixes

MetricLinen  28b1e4a16NNX  28b1e4a16Diff (NNX − Linen)
Parameters0.000 billion0.000 billion
Final loss5.7660
TFLOP/s0.015
Tok/s67432.1
Avg s/step3.925
Memory %0.03
JAX0.8.30.9.2

Diff = NNX value − Linen value. Green = NNX improved. Red = NNX regressed.

Linen  ·  28b1e4a16  ·  feat_nnx_post_train_fixes_20260424_120707  ·  full log
XPK Start: Fri Apr 24 12:07:35 UTC 2026
Unrecognized keys in `rope_scaling` for 'rope_type'='yarn': {'rope_theta'}
`rope_scaling`'s factor field must be a float >= 1, got 40
`rope_scaling`'s beta_fast field must be a float, got 32
`rope_scaling`'s beta_slow field must be a float, got 1
Unrecognized keys in `rope_scaling` for 'rope_type'='yarn': {'rope_theta'}
Unrecognized keys in `rope_scaling` for 'rope_type'='yarn': {'rope_theta'}
Unrecognized keys in `rope_scaling` for 'rope_type'='yarn': {'rope_theta'}
2026-04-24 12:08:04.165333: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)
I0424 12:08:04.411194 134645151012672 max_utils.py:273] Attempting to initialize the jax distributed system...
INFO:2026-04-24 12:08:13,452:jax._src.distributed:149: Starting JAX distributed service on [::]:8482
I0424 12:08:13.452770 134645151012672 distributed.py:149] Starting JAX distributed service on [::]:8482
INFO:2026-04-24 12:08:13,455:jax._src.distributed:166: Connecting to JAX distributed service on mt-01-sft-smoke-3abzd-slice-job-0-0.mt-01-sft-smoke-3abzd:8482
I0424 12:08:13.455368 134645151012672 distributed.py:166] Connecting to JAX distributed service on mt-01-sft-smoke-3abzd-slice-job-0-0.mt-01-sft-smoke-3abzd:8482
I0424 12:08:14.573010 134645151012672 max_utils.py:284] Jax distributed system initialized!
I0424 12:08:21.115375 134645151012672 max_utils.py:800] System Information: Jax Version: 0.8.3
I0424 12:08:21.115482 134645151012672 max_utils.py:801] System Information: Jaxlib Version: 0.8.3
I0424 12:08:21.115526 134645151012672 max_utils.py:802] System Information: Jax Backend: PJRT C API
TFRT TPU v6 lite
Built on Dec 15 2025 14:03:46 (1765836226) cl/844590465
I0424 12:08:21.118988 134645151012672 maxtext_utils.py:1771] Num_devices: 32, shape (1, 4, 1, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1)
I0424 12:08:21.308490 134645151012672 maxtext_utils.py:1771] Num_devices: 32, shape (1, 4, 1, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1)
I0424 12:08:22.505997 134645151012672 config.py:112] TensorFlow version 2.20.0 available.
I0424 12:08:22.506505 134645151012672 config.py:125] JAX version 0.8.3 available.
/deps/src/maxtext/input_pipeline/input_pipeline_utils.py:467: UserWarning: WARNING: Inefficient dataloading. Your train or eval dataset contains 3 shards, smaller than number of host loading data. This is known to lead to inefficient dataloading. Seegithub.com/google/maxtext/blob/main/getting_started/Data_Input_Pipeline.md#multihost-dataloading-best-practice
  warnings.warn(
E0424 12:08:27.978212 134645151012672 packing.py:209] PackAndBatchOperation is deprecated. Please use lazy_dataset.FirstFitPackIterDataset instead.
I0424 12:08:27.978440 134645151012672 data_loader.py:408] Adding CopyNumPyArrayToSharedMemory MapTransform.
I0424 12:08:28.363730 134645151012672 pytree_checkpoint_handler.py:577] save_device_host_concurrent_bytes=None
I0424 12:08:28.364178 134645151012672 base_pytree_checkpoint_handler.py:411] Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=False, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x7a74ce952ed0>, enable_pinned_host_transfer=False, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB)
I0424 12:08:28.364226 134645151012672 pytree_checkpoint_handler.py:577] save_device_host_concurrent_bytes=None
I0424 12:08:28.364269 134645151012672 base_pytree_checkpoint_handler.py:411] Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=False, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x7a74ce952ed0>, enable_pinned_host_transfer=False, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB)
I0424 12:08:28.364315 134645151012672 checkpoint_manager.py:702] [process=4][thread=MainThread] CheckpointManager init: checkpointers=None, item_names=None, item_handlers={'model_params': <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7a57580a1fa0>, 'optimizer_state': <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7a57580a1160>, 'custom_metadata': <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7a57580a0b90>}, handler_registry=None
I0424 12:08:28.364528 134645151012672 composite_checkpoint_handler.py:237] Deferred registration for item: "model_params". Adding handler `<orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7a57580a1fa0>` for item "model_params" and save args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>` to `_handler_registry`.
I0424 12:08:28.364573 134645151012672 composite_checkpoint_handler.py:237] Deferred registration for item: "optimizer_state". Adding handler `<orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7a57580a1160>` for item "optimizer_state" and save args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>` to `_handler_registry`.
I0424 12:08:28.364603 134645151012672 composite_checkpoint_handler.py:237] Deferred registration for item: "custom_metadata". Adding handler `<orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7a57580a0b90>` for item "custom_metadata" and save args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>` to `_handler_registry`.
I0424 12:08:28.364629 134645151012672 composite_checkpoint_handler.py:237] Deferred registration for item: "metrics". Adding handler `<orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7a57580a1fd0>` for item "metrics" and save args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>` to `_handler_registry`.
I0424 12:08:28.364657 134645151012672 composite_checkpoint_handler.py:505] Initialized registry DefaultCheckpointHandlerRegistry({('model_params', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7a57580a1fa0>, ('model_params', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7a57580a1fa0>, ('optimizer_state', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7a57580a1160>, ('optimizer_state', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7a57580a1160>, ('custom_metadata', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7a57580a0b90>, ('custom_metadata', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7a57580a0b90>, ('metrics', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7a57580a1fd0>, ('metrics', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7a57580a1fd0>}).
I0424 12:08:28.364917 134645151012672 abstract_checkpointer.py:35] orbax-checkpoint version: 0.11.28
I0424 12:08:28.364969 134645151012672 async_checkpointer.py:177] [process=4][thread=MainThread] Using barrier_sync_fn: <function get_barrier_sync_fn.<locals>._fn at 0x7a577812b060> timeout: 600 secs and primary_host=0 for async checkpoint writes
I0424 12:08:31.220307 134645151012672 checkpoint_manager.py:1788] Found 0 checkpoint steps in gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_post_train_fixes_20260424_120707/pt_sft_linen_xpk_feat_nnx_post_train_fixes_20260424_120707_01_sft_smoke/checkpoints
I0424 12:08:31.241176 134645151012672 checkpoint_manager.py:921] [process=4][thread=MainThread] CheckpointManager created,  primary_host=0, CheckpointManagerOptions=CheckpointManagerOptions(save_interval_steps=10000, max_to_keep=None, keep_time_interval=None, keep_period=None, should_keep_fn=None, best_fn=None, best_mode='max', keep_checkpoints_without_metrics=True, step_prefix=None, step_format_fixed_length=None, step_name_format=None, create=True, cleanup_tmp_directories=False, save_on_steps=frozenset(), single_host_load_and_broadcast=False, todelete_subdir=None, todelete_full_path=None, enable_hns=False, enable_background_delete=False, read_only=False, enable_async_checkpointing=True, async_options=None, multiprocessing_options=MultiprocessingOptions(primary_host=0, active_processes=None, barrier_sync_key_prefix=None), should_save_fn=None, file_options=FileOptions(path_permission_mode=None), save_root_metadata=True, temporary_path_class=None, save_decision_policy=None, preservation_policy=None, prevent_write_metrics=False, enable_should_save_is_saving_in_progress_check=True, enable_per_process_directory_creation=False), root_directory=gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_post_train_fixes_20260424_120707/pt_sft_linen_xpk_feat_nnx_post_train_fixes_20260424_120707_01_sft_smoke/checkpoints: <orbax.checkpoint.checkpoint_manager.CheckpointManager object at 0x7a57580a1190>
I0424 12:08:31.241458 134645151012672 peft_trainer.py:584] Training with mesh: Mesh('diloco': 1, 'data': 4, 'stage': 1, 'fsdp': 8, 'fsdp_transpose': 1, 'sequence': 1, 'context': 1, 'context_autoregressive': 1, 'tensor': 1, 'tensor_transpose': 1, 'tensor_sequence': 1, 'expert': 1, 'autoregressive': 1, axis_types=(Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto))
I0424 12:08:31.652833 134645151012672 peft_trainer.py:594] Compiled train_step cache size: 0
I0424 12:08:31.654793 134645151012672 metric_logger.py:301] number parameters: 0.000 billion
I0424 12:08:31.657593 134492094179072 grain_pool.py:367] Grain pool will use 1 processes.
I0424 12:08:31.684176 134492094179072 grain_pool.py:440] Grain pool will start child processes.
Per train step:
 Total TFLOPs: 0.00 
 split as 54.29% learnable weight flops and 45.71% attention flops
I0424 12:08:31.689425 134492094179072 grain_pool.py:448] Grain pool started all child processes.
2026-04-24 12:08:35.672451: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-04-24 12:08:35.716756: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-04-24 12:08:36.882041: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
Unrecognized keys in `rope_scaling` for 'rope_type'='yarn': {'rope_theta'}
`rope_scaling`'s factor field must be a float >= 1, got 40
`rope_scaling`'s beta_fast field must be a float, got 32
`rope_scaling`'s beta_slow field must be a float, got 1
Unrecognized keys in `rope_scaling` for 'rope_type'='yarn': {'rope_theta'}
Unrecognized keys in `rope_scaling` for 'rope_type'='yarn': {'rope_theta'}
Unrecognized keys in `rope_scaling` for 'rope_type'='yarn': {'rope_theta'}
2026-04-24 12:08:41.102886: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)
I0424 12:08:49.642750 134645151012672 checkpoint_manager.py:1983] [process=4][thread=MainThread][wait_until_finished] No Save Finalize thread to wait for. Returning.
I0424 12:08:49.644853 134645151012672 checkpoint_manager.py:1501] [process=4] Saving checkpoint at step 1
I0424 12:08:49.647957 134645151012672 async_checkpointer.py:452] [process=4] Started async saving checkpoint to gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_post_train_fixes_20260424_120707/pt_sft_linen_xpk_feat_nnx_post_train_fixes_20260424_120707_01_sft_smoke/checkpoints/1.
I0424 12:08:50.185297 134645151012672 signaling_client.py:364] Using JaxDistributedSignalingClient
I0424 12:08:50.186305 134645151012672 jax_array_handlers.py:347] Scheduling D2H of 52 prioritized jax.Array.
I0424 12:08:50.186364 134645151012672 replica_slices.py:410] Transferring arrays to host memory with options: use_replica_parallel=True, min_slice_bytes_for_replica_parallel=None, max_replicas_for_replica_parallel=None, enable_pinned_host_transfer=False
I0424 12:08:50.890940 134645151012672 base_pytree_checkpoint_handler.py:153] [process=4][thread=MainThread] Initiated "orbax.checkpoint._src.serialization.jax_array_handlers.ArrayHandler".serialize. Time taken: 0.705742s
I0424 12:08:50.892462 134645151012672 base_pytree_checkpoint_handler.py:128] [process=4] /jax/checkpoint/write/blocking_gbytes_per_sec: 79.518 KiB/s (total gbytes: 76.7 KiB) (time elapsed: 964 milliseconds) (per-host)
I0424 12:08:50.892529 134645151012672 base_pytree_checkpoint_handler.py:732] [process=4][thread=MainThread] Initiated Pytree async_save. Time taken: 0.964516s (batch_requests_ready=0.252476s, total_serialization_initiated=0.711924s, others=0.000115s)
I0424 12:08:50.893594 134645151012672 jax_array_handlers.py:347] Scheduling D2H of 22 prioritized jax.Array.
I0424 12:08:50.893649 134645151012672 replica_slices.py:410] Transferring arrays to host memory with options: use_replica_parallel=True, min_slice_bytes_for_replica_parallel=None, max_replicas_for_replica_parallel=None, enable_pinned_host_transfer=False
I0424 12:08:50.898445 134645151012672 base_pytree_checkpoint_handler.py:153] [process=4][thread=MainThread] Initiated "orbax.checkpoint._src.serialization.jax_array_handlers.ArrayHandler".serialize. Time taken: 0.005781s
I0424 12:08:50.898564 134645151012672 base_pytree_checkpoint_handler.py:128] [process=4] /jax/checkpoint/write/blocking_gbytes_per_sec: 26.283 KiB/s (total gbytes: 25.6 KiB) (time elapsed: 972 milliseconds) (per-host)
I0424 12:08:50.898612 134645151012672 base_pytree_checkpoint_handler.py:732] [process=4][thread=MainThread] Initiated Pytree async_save. Time taken: 0.972659s (batch_requests_ready=0.965175s, total_serialization_initiated=0.007407s, others=0.000077s)
I0424 12:08:50.898703 134645151012672 composite_checkpoint_handler.py:715] [process=4][thread=MainThread] Initiated CompositeCheckpointHandler.async_save. Time taken: 0.977118s (all_items=0.000023s, per_item={'model_params': '0.00001860', 'optimizer_state': '0.00000429'}, temp_paths=0.977096)
I0424 12:08:50.899671 134488914917120 async_checkpointer.py:79] [process=4][thread=async_save] Background save thread started.
I0424 12:08:50.899839 134645151012672 async_checkpointer.py:561] Finished blocking save. Time taken: 1.254900s. Continuing background save to gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_post_train_fixes_20260424_120707/pt_sft_linen_xpk_feat_nnx_post_train_fixes_20260424_120707_01_sft_smoke/checkpoints/1.
I0424 12:08:50.901654 134645151012672 checkpoint_manager.py:1549] [process=4][thread=MainThread][step=1] Starting CheckpointManager Save Finalize thread=save_finalize
I0424 12:08:50.901921 134489443395328 async_checkpointer.py:265] [process=4][thread=save_finalize] Waiting for background save thread=async_save.
I0424 12:08:50.902050 134645151012672 standard_logger.py:34] {'step': 1, 'event_type': 'save', 'directory': 'gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_post_train_fixes_20260424_120707/pt_sft_linen_xpk_feat_nnx_post_train_fixes_20260424_120707_01_sft_smoke/checkpoints', 'reached_preemption': False, 'preemption_received_at': None, 'synchronous': False, 'wait_for_prev_start_time': 1777032529.6427128, 'wait_for_prev_duration_secs': 0.000148773193359375, 'time_between_consecutive_saves_sec': None, 'checkpointer_blocking_start_time': 1777032529.6449094, 'checkpointer_blocking_duration_secs': 1.255051612854004, 'get_old_steps_start_time': 1777032530.899983, 'get_old_steps_duration_secs': 8.034706115722656e-05, 'checkpoint_manager_blocking_start_time': 1777032529.5077105, 'checkpoint_manager_blocking_duration_secs': 1.3943052291870117}
I0424 12:08:51.037631 134645151012672 peft_trainer.py:474] Train step 1 training loss: 5.894745  - training perplexity: 363.124207
I0424 12:08:51.037911 134645151012672 max_utils.py:750] 
Memstats: After params initialized:
I0424 12:08:51.037976 134645151012672 max_utils.py:756] 	Using (GB) 0.01 / 31.25 (0.032000%) on TPU_16(process=4,(0,4,0,0))
I0424 12:08:51.038013 134645151012672 max_utils.py:756] 	Using (GB) 0.01 / 31.25 (0.032000%) on TPU_17(process=4,(1,4,0,0))
I0424 12:08:51.038043 134645151012672 max_utils.py:756] 	Using (GB) 0.01 / 31.25 (0.032000%) on TPU_20(process=4,(0,5,0,0))
I0424 12:08:51.038070 134645151012672 max_utils.py:756] 	Using (GB) 0.01 / 31.25 (0.032000%) on TPU_21(process=4,(1,5,0,0))
I0424 12:08:51.168150 134645151012672 metric_logger.py:196] completed step: 1, seconds: 19.383, TFLOP/s/device: 0.000, Tokens/s/device: 52.829, total_weights: 21054, loss: 5.895, lm_loss: 0.000, perplexity: 0.000
I0424 12:08:51.176715 134645151012672 peft_trainer.py:474] Train step 2 training loss: 5.603691  - training perplexity: 271.426270
I0424 12:08:51.177435 134645151012672 metric_logger.py:196] completed step: 2, seconds: 0.139, TFLOP/s/device: 0.002, Tokens/s/device: 7376.894, total_weights: 21455, loss: 5.604, lm_loss: 0.000, perplexity: 0.000
I0424 12:08:51.247349 134645151012672 peft_trainer.py:474] Train step 3 training loss: 5.511911  - training perplexity: 247.623871
I0424 12:08:51.248254 134645151012672 metric_logger.py:196] completed step: 3, seconds: 0.071, TFLOP/s/device: 0.003, Tokens/s/device: 14489.947, total_weights: 22025, loss: 5.512, lm_loss: 0.000, perplexity: 0.000
I0424 12:08:51.266252 134645151012672 peft_trainer.py:474] Train step 4 training loss: 5.686435  - training perplexity: 294.840698
I0424 12:08:51.266981 134645151012672 metric_logger.py:196] completed step: 4, seconds: 0.019, TFLOP/s/device: 0.012, Tokens/s/device: 54360.752, total_weights: 23787, loss: 5.686, lm_loss: 0.000, perplexity: 0.000
I0424 12:08:51.280470 134645151012672 peft_trainer.py:733] Train loop finished in: 19.6242 seconds
I0424 12:08:51.281443 134645151012672 peft_trainer.py:474] Train step 5 training loss: 5.766005  - training perplexity: 319.259766
I0424 12:08:51.282139 134645151012672 metric_logger.py:196] completed step: 5, seconds: 0.015, TFLOP/s/device: 0.015, Tokens/s/device: 67432.087, total_weights: 20141, loss: 5.766, lm_loss: 0.000, perplexity: 0.000
I0424 12:08:53.091215    2992 google_auth_provider.cc:181] Running on GCE, using service account 562977990677-compute@developer.gserviceaccount.com
I0424 12:08:55.003154 134489418217216 array_metadata_store.py:203] [process=4][thread=array_type_handler] Wrote 52 array_metadata.ArrayMetadata to gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_post_train_fixes_20260424_120707/pt_sft_linen_xpk_feat_nnx_post_train_fixes_20260424_120707_01_sft_smoke/checkpoints/1/optimizer_state/array_metadatas/process_4
I0424 12:08:55.026992 134489401431808 array_metadata_store.py:203] [process=4][thread=array_type_handler] Wrote 22 array_metadata.ArrayMetadata to gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_post_train_fixes_20260424_120707/pt_sft_linen_xpk_feat_nnx_post_train_fixes_20260424_120707_01_sft_smoke/checkpoints/1/model_params/array_metadatas/process_4
I0424 12:08:55.028167 134488914917120 base_pytree_checkpoint_handler.py:128] [process=4] /jax/checkpoint/write/gbytes_per_sec: 5.010 KiB/s (total gbytes: 25.6 KiB) (time elapsed: 5 seconds) (per-host)
I0424 12:08:55.028325 134488914917120 base_pytree_checkpoint_handler.py:128] [process=4] /jax/checkpoint/write/gbytes_per_sec: 15.036 KiB/s (total gbytes: 76.7 KiB) (time elapsed: 5 seconds) (per-host)
I0424 12:08:55.028363 134488914917120 async_checkpointer.py:90] [process=4][thread=async_save] 4 Handler Commit operations completed. Time taken: 4.128578s.
I0424 12:09:00.439439 134645151012672 checkpoint_manager.py:1994] [process=4][thread=MainThread][step=1][wait_until_finished] Waiting for Save Finalize thread (save_finalize) to complete.
I0424 12:09:06.716241 134488914917120 async_checkpointer.py:144] [process=4][thread=async_save] Background save thread done. Time taken: 15.816438s.
I0424 12:09:06.716572 134489443395328 async_checkpointer.py:273] [process=4][thread=save_finalize] Done with waiting for background save thread=async_save.
I0424 12:09:06.716691 134489443395328 async_checkpointer.py:283] [process=4][thread=save_finalize] No errors found in background save thread=async_save.
I0424 12:09:06.716741 134489443395328 checkpoint_manager.py:2103] [process=4][thread=save_finalize][step=1] CheckpointManager Save Finalize is syncing with other hosts...
I0424 12:09:06.718056 134489443395328 checkpoint_manager.py:2112] [process=4][thread=save_finalize][step=1] CheckpointManager Save Finalize is done on all hosts.
I0424 12:09:06.718240 134645151012672 checkpoint_manager.py:2006] [process=4][thread=MainThread][step=1][wait_until_finished] Done waiting for Save Finalize thread (save_finalize) running at step=1.
W0424 12:09:06.718374 134645151012672 checkpoint_manager.py:1441] Waiting for previous save to complete took 6.278954 seconds. If this number is high, consider checkpointing less frequently.
I0424 12:09:06.720014 134645151012672 checkpoint_manager.py:1501] [process=4] Saving checkpoint at step 5
I0424 12:09:06.723476 134645151012672 async_checkpointer.py:452] [process=4] Started async saving checkpoint to gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_post_train_fixes_20260424_120707/pt_sft_linen_xpk_feat_nnx_post_train_fixes_20260424_120707_01_sft_smoke/checkpoints/5.
I0424 12:09:07.278192 134645151012672 jax_array_handlers.py:347] Scheduling D2H of 52 prioritized jax.Array.
I0424 12:09:07.278287 134645151012672 replica_slices.py:410] Transferring arrays to host memory with options: use_replica_parallel=True, min_slice_bytes_for_replica_parallel=None, max_replicas_for_replica_parallel=None, enable_pinned_host_transfer=False
I0424 12:09:07.290795 134645151012672 base_pytree_checkpoint_handler.py:153] [process=4][thread=MainThread] Initiated "orbax.checkpoint._src.serialization.jax_array_handlers.ArrayHandler".serialize. Time taken: 0.013440s
I0424 12:09:07.292155 134645151012672 base_pytree_checkpoint_handler.py:128] [process=4] /jax/checkpoint/write/blocking_gbytes_per_sec: 285.114 KiB/s (total gbytes: 76.7 KiB) (time elapsed: 268 milliseconds) (per-host)
I0424 12:09:07.292213 134645151012672 base_pytree_checkpoint_handler.py:732] [process=4][thread=MainThread] Initiated Pytree async_save. Time taken: 0.269065s (batch_requests_ready=0.252213s, total_serialization_initiated=0.016756s, others=0.000097s)
I0424 12:09:07.293319 134645151012672 jax_array_handlers.py:347] Scheduling D2H of 22 prioritized jax.Array.
I0424 12:09:07.293370 134645151012672 replica_slices.py:410] Transferring arrays to host memory with options: use_replica_parallel=True, min_slice_bytes_for_replica_parallel=None, max_replicas_for_replica_parallel=None, enable_pinned_host_transfer=False
I0424 12:09:07.298066 134645151012672 base_pytree_checkpoint_handler.py:153] [process=4][thread=MainThread] Initiated "orbax.checkpoint._src.serialization.jax_array_handlers.ArrayHandler".serialize. Time taken: 0.005728s
I0424 12:09:07.298166 134645151012672 base_pytree_checkpoint_handler.py:128] [process=4] /jax/checkpoint/write/blocking_gbytes_per_sec: 92.476 KiB/s (total gbytes: 25.6 KiB) (time elapsed: 276 milliseconds) (per-host)
I0424 12:09:07.298208 134645151012672 base_pytree_checkpoint_handler.py:732] [process=4][thread=MainThread] Initiated Pytree async_save. Time taken: 0.276481s (batch_requests_ready=0.269201s, total_serialization_initiated=0.007218s, others=0.000062s)
I0424 12:09:07.298295 134645151012672 composite_checkpoint_handler.py:715] [process=4][thread=MainThread] Initiated CompositeCheckpointHandler.async_save. Time taken: 0.280517s (all_items=0.000012s, per_item={'model_params': '0.00000978', 'optimizer_state': '0.00000238'}, temp_paths=0.280504)
I0424 12:09:07.299169 134489443395328 async_checkpointer.py:79] [process=4][thread=async_save] Background save thread started.
I0424 12:09:07.299266 134645151012672 async_checkpointer.py:561] Finished blocking save. Time taken: 0.579181s. Continuing background save to gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_post_train_fixes_20260424_120707/pt_sft_linen_xpk_feat_nnx_post_train_fixes_20260424_120707_01_sft_smoke/checkpoints/5.
I0424 12:09:07.685739 134645151012672 checkpoint_manager.py:1549] [process=4][thread=MainThread][step=5] Starting CheckpointManager Save Finalize thread=save_finalize
I0424 12:09:07.686113 134488914917120 async_checkpointer.py:265] [process=4][thread=save_finalize] Waiting for background save thread=async_save.
I0424 12:09:07.686282 134645151012672 standard_logger.py:34] {'step': 5, 'event_type': 'save', 'directory': 'gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_post_train_fixes_20260424_120707/pt_sft_linen_xpk_feat_nnx_post_train_fixes_20260424_120707_01_sft_smoke/checkpoints', 'reached_preemption': False, 'preemption_received_at': None, 'synchronous': False, 'wait_for_prev_start_time': 1777032540.4393964, 'wait_for_prev_duration_secs': 6.278954267501831, 'time_between_consecutive_saves_sec': None, 'checkpointer_blocking_start_time': 1777032546.7200549, 'checkpointer_blocking_duration_secs': 0.579322338104248, 'get_old_steps_start_time': 1777032547.2993982, 'get_old_steps_duration_secs': 8.0108642578125e-05, 'checkpoint_manager_blocking_start_time': 1777032531.2855985, 'checkpoint_manager_blocking_duration_secs': 16.400646924972534}
I0424 12:09:07.686480 134645151012672 checkpoint_manager.py:1994] [process=4][thread=MainThread][step=5][wait_until_finished] Waiting for Save Finalize thread (save_finalize) to complete.
I0424 12:09:12.063593 134487841175296 array_metadata_store.py:203] [process=4][thread=array_type_handler] Wrote 22 array_metadata.ArrayMetadata to gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_post_train_fixes_20260424_120707/pt_sft_linen_xpk_feat_nnx_post_train_fixes_20260424_120707_01_sft_smoke/checkpoints/5/model_params/array_metadatas/process_4
I0424 12:09:12.064770 134489443395328 base_pytree_checkpoint_handler.py:128] [process=4] /jax/checkpoint/write/gbytes_per_sec: 5.069 KiB/s (total gbytes: 25.6 KiB) (time elapsed: 5 seconds) (per-host)
I0424 12:09:12.118028 134489418217216 array_metadata_store.py:203] [process=4][thread=array_type_handler] Wrote 52 array_metadata.ArrayMetadata to gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_post_train_fixes_20260424_120707/pt_sft_linen_xpk_feat_nnx_post_train_fixes_20260424_120707_01_sft_smoke/checkpoints/5/optimizer_state/array_metadatas/process_4
I0424 12:09:12.119237 134489443395328 base_pytree_checkpoint_handler.py:128] [process=4] /jax/checkpoint/write/gbytes_per_sec: 15.048 KiB/s (total gbytes: 76.7 KiB) (time elapsed: 5 seconds) (per-host)
I0424 12:09:12.119340 134489443395328 async_checkpointer.py:90] [process=4][thread=async_save] 4 Handler Commit operations completed. Time taken: 4.820112s.
I0424 12:09:22.458468 134489443395328 async_checkpointer.py:144] [process=4][thread=async_save] Background save thread done. Time taken: 15.159226s.
I0424 12:09:22.458696 134488914917120 async_checkpointer.py:273] [process=4][thread=save_finalize] Done with waiting for background save thread=async_save.
I0424 12:09:22.458758 134488914917120 async_checkpointer.py:283] [process=4][thread=save_finalize] No errors found in background save thread=async_save.
I0424 12:09:22.458805 134488914917120 checkpoint_manager.py:2103] [process=4][thread=save_finalize][step=5] CheckpointManager Save Finalize is syncing with other hosts...
I0424 12:09:22.460398 134488914917120 checkpoint_manager.py:2112] [process=4][thread=save_finalize][step=5] CheckpointManager Save Finalize is done on all hosts.
I0424 12:09:22.460592 134645151012672 checkpoint_manager.py:2006] [process=4][thread=MainThread][step=5][wait_until_finished] Done waiting for Save Finalize thread (save_finalize) running at step=5.
I0424 12:09:22.460748 134645151012672 checkpoint.py:459] Closing _NonBlockingMetadataStore(enable_write=True, _write_lock=<locked _thread.RLock object owner=134645151012672 count=1 at 0x7a58b8102480>, _store_impl=<orbax.checkpoint._src.metadata.checkpoint._MetadataStoreImpl object at 0x7a57580a5550>, _single_thread_executor=<concurrent.futures.thread.ThreadPoolExecutor object at 0x7a57580a5520>, _write_futures=[])
I0424 12:09:22.461216 134645151012672 checkpoint.py:459] Closing _NonBlockingMetadataStore(enable_write=True, _write_lock=<locked _thread.RLock object owner=134645151012672 count=1 at 0x7a58b8102480>, _store_impl=<orbax.checkpoint._src.metadata.checkpoint._MetadataStoreImpl object at 0x7a57580a5550>, _single_thread_executor=<concurrent.futures.thread.ThreadPoolExecutor object at 0x7a57580a5520>, _write_futures=[])
I0424 12:09:22.461249 134645151012672 checkpoint.py:459] Closing _NonBlockingMetadataStore(enable_write=True, _write_lock=<locked _thread.RLock object owner=134645151012672 count=1 at 0x7a58b8102480>, _store_impl=<orbax.checkpoint._src.metadata.checkpoint._MetadataStoreImpl object at 0x7a57580a5550>, _single_thread_executor=<concurrent.futures.thread.ThreadPoolExecutor object at 0x7a57580a5520>, _write_futures=[])
I0424 12:09:22.776576 134492094179072 grain_pool.py:542] Grain pool is exiting.
I0424 12:09:22.776677 134492094179072 grain_pool.py:547] Shutting down multiprocessing system.
I0424 12:09:24.864463 134492094179072 grain_pool.py:547] Shutting down multiprocessing system.
/usr/local/lib/python3.12/multiprocessing/resource_tracker.py:279: UserWarning: resource_tracker: There appear to be 15 leaked shared_memory objects to clean up at shutdown
  warnings.warn('resource_tracker: There appear to be %d '
XPK End: Fri Apr 24 12:09:35 UTC 2026
EXIT_CODE=0
NNX  ·  28b1e4a16  ·  feat_nnx_post_train_fixes_20260424_120707  ·  full log
XPK Start: Fri Apr 24 12:18:46 UTC 2026
`rope_parameters`'s factor field must be a float >= 1, got 40
`rope_parameters`'s beta_fast field must be a float, got 32
`rope_parameters`'s beta_slow field must be a float, got 1
DeepseekV32Config got `key=rope_scaling` in kwargs but hasn't set it as attribute. For RoPE standardization you need to set `self.rope_parameters` in model's config. 
2026-04-24 12:19:17.617646: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)
I0424 12:19:17.818554 135183711901504 max_utils.py:273] Attempting to initialize the jax distributed system...
I0424 12:19:26.858551 135183711901504 distributed.py:149] Starting JAX distributed service on [::]:8482
I0424 12:19:26.860944 135183711901504 distributed.py:172] Connecting to JAX distributed service on mt-01-sft-smoke-rcz2y-slice-job-0-0.mt-01-sft-smoke-rcz2y:8482
I0424 12:19:28.426906 135183711901504 max_utils.py:284] Jax distributed system initialized!
I0424 12:19:33.649544 135183711901504 max_utils.py:800] System Information: Jax Version: 0.9.2
I0424 12:19:33.649652 135183711901504 max_utils.py:801] System Information: Jaxlib Version: 0.9.2
I0424 12:19:33.649693 135183711901504 max_utils.py:802] System Information: Jax Backend: PJRT C API
TFRT TPU v6 lite
Built on Apr 6 2026 20:48:10 (1775533690) cl/895581894
I0424 12:19:33.653441 135183711901504 maxtext_utils.py:1771] Num_devices: 32, shape (1, 4, 1, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1)
I0424 12:19:34.233051 135183711901504 maxtext_utils.py:1771] Num_devices: 32, shape (1, 4, 1, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1)
I0424 12:19:34.340657 135183711901504 maxtext_utils.py:1771] Num_devices: 32, shape (1, 4, 1, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1)
I0424 12:19:35.472669 135183711901504 config.py:112] TensorFlow version 2.20.0 available.
I0424 12:19:35.473235 135183711901504 config.py:125] JAX version 0.9.2 available.
I0424 12:19:35.915923 135183711901504 _client.py:1025] HTTP Request: HEAD https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k/resolve/main/README.md "HTTP/1.1 307 Temporary Redirect"
I0424 12:19:35.923945 135183711901504 _client.py:1025] HTTP Request: HEAD https://huggingface.co/api/resolve-cache/datasets/HuggingFaceH4/ultrachat_200k/8049631c405ae6576f93f445c6b8166f76f5505a/README.md "HTTP/1.1 200 OK"
I0424 12:19:35.932128 135183711901504 _client.py:1025] HTTP Request: GET https://huggingface.co/api/resolve-cache/datasets/HuggingFaceH4/ultrachat_200k/8049631c405ae6576f93f445c6b8166f76f5505a/README.md "HTTP/1.1 200 OK"
I0424 12:19:36.040184 135183711901504 _client.py:1025] HTTP Request: HEAD https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k/resolve/8049631c405ae6576f93f445c6b8166f76f5505a/ultrachat_200k.py "HTTP/1.1 404 Not Found"
I0424 12:19:36.344403 135183711901504 _client.py:1025] HTTP Request: HEAD https://s3.amazonaws.com/datasets.huggingface.co/datasets/datasets/HuggingFaceH4/ultrachat_200k/HuggingFaceH4/ultrachat_200k.py "HTTP/1.1 404 Not Found"
I0424 12:19:36.460436 135183711901504 _client.py:1025] HTTP Request: GET https://huggingface.co/api/datasets/HuggingFaceH4/ultrachat_200k/revision/8049631c405ae6576f93f445c6b8166f76f5505a "HTTP/1.1 200 OK"
I0424 12:19:36.581073 135183711901504 _client.py:1025] HTTP Request: HEAD https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k/resolve/8049631c405ae6576f93f445c6b8166f76f5505a/.huggingface.yaml "HTTP/1.1 404 Not Found"
I0424 12:19:36.736821 135183711901504 _client.py:1025] HTTP Request: GET https://datasets-server.huggingface.co/info?dataset=HuggingFaceH4/ultrachat_200k "HTTP/1.1 200 OK"
I0424 12:19:36.851757 135183711901504 _client.py:1025] HTTP Request: GET https://huggingface.co/api/datasets/HuggingFaceH4/ultrachat_200k/tree/8049631c405ae6576f93f445c6b8166f76f5505a/data?recursive=true&expand=false "HTTP/1.1 200 OK"
I0424 12:19:36.954523 135183711901504 _client.py:1025] HTTP Request: GET https://huggingface.co/api/datasets/HuggingFaceH4/ultrachat_200k/tree/8049631c405ae6576f93f445c6b8166f76f5505a?recursive=false&expand=false "HTTP/1.1 200 OK"
I0424 12:19:37.062849 135183711901504 _client.py:1025] HTTP Request: HEAD https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k/resolve/8049631c405ae6576f93f445c6b8166f76f5505a/dataset_infos.json "HTTP/1.1 404 Not Found"
I0424 12:19:37.242192 135183711901504 _client.py:1025] HTTP Request: HEAD https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/resolve/main/config.json "HTTP/1.1 200 OK"
I0424 12:19:37.352548 135183711901504 _client.py:1025] HTTP Request: GET https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/resolve/main/config.json "HTTP/1.1 200 OK"
I0424 12:19:37.466740 135183711901504 _client.py:1025] HTTP Request: HEAD https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/resolve/main/tokenizer_config.json "HTTP/1.1 200 OK"
I0424 12:19:37.576438 135183711901504 _client.py:1025] HTTP Request: GET https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/resolve/main/tokenizer_config.json "HTTP/1.1 200 OK"
I0424 12:19:37.754712 135183711901504 _client.py:1025] HTTP Request: GET https://huggingface.co/api/models/meta-llama/Llama-2-7b-chat-hf/tree/main/additional_chat_templates?recursive=false&expand=false "HTTP/1.1 404 Not Found"
I0424 12:19:37.862047 135183711901504 _client.py:1025] HTTP Request: GET https://huggingface.co/api/models/meta-llama/Llama-2-7b-chat-hf/tree/main?recursive=true&expand=false "HTTP/1.1 200 OK"
I0424 12:19:37.976055 135183711901504 _client.py:1025] HTTP Request: HEAD https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/resolve/main/tokenizer.model "HTTP/1.1 302 Found"
I0424 12:19:38.092259 135183711901504 _client.py:1025] HTTP Request: GET https://huggingface.co/api/models/meta-llama/Llama-2-7b-chat-hf/xet-read-token/f5db02db724555f92da89c216ac04704f23d4590 "HTTP/1.1 200 OK"
I0424 12:19:38.733916 135183711901504 _client.py:1025] HTTP Request: HEAD https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/resolve/main/tokenizer.json "HTTP/1.1 200 OK"
I0424 12:19:38.871599 135183711901504 _client.py:1025] HTTP Request: GET https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/resolve/main/tokenizer.json "HTTP/1.1 200 OK"
I0424 12:19:39.254543 135183711901504 _client.py:1025] HTTP Request: HEAD https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/resolve/main/added_tokens.json "HTTP/1.1 404 Not Found"
I0424 12:19:39.360423 135183711901504 _client.py:1025] HTTP Request: HEAD https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/resolve/main/special_tokens_map.json "HTTP/1.1 200 OK"
I0424 12:19:39.469512 135183711901504 _client.py:1025] HTTP Request: GET https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/resolve/main/special_tokens_map.json "HTTP/1.1 200 OK"
I0424 12:19:39.606384 135183711901504 _client.py:1025] HTTP Request: HEAD https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/resolve/main/chat_template.jinja "HTTP/1.1 404 Not Found"
/deps/src/maxtext/input_pipeline/input_pipeline_utils.py:467: UserWarning: WARNING: Inefficient dataloading. Your train or eval dataset contains 3 shards, smaller than number of host loading data. This is known to lead to inefficient dataloading. Seegithub.com/google/maxtext/blob/main/getting_started/Data_Input_Pipeline.md#multihost-dataloading-best-practice
  warnings.warn(
E0424 12:19:39.713494 135183711901504 packing.py:209] PackAndBatchOperation is deprecated. Please use lazy_dataset.FirstFitPackIterDataset instead.
I0424 12:19:39.713714 135183711901504 data_loader.py:408] Adding CopyNumPyArrayToSharedMemory MapTransform.
I0424 12:19:40.140875 135183711901504 pytree_checkpoint_handler.py:577] save_device_host_concurrent_bytes=None
I0424 12:19:40.141023 135183711901504 base_pytree_checkpoint_handler.py:411] Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=False, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x7af232ac7fb0>, enable_pinned_host_transfer=False, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB)
I0424 12:19:40.141072 135183711901504 pytree_checkpoint_handler.py:577] save_device_host_concurrent_bytes=None
I0424 12:19:40.141132 135183711901504 base_pytree_checkpoint_handler.py:411] Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=False, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x7af232ac7fb0>, enable_pinned_host_transfer=False, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB)
I0424 12:19:40.141194 135183711901504 checkpoint_manager.py:702] [process=6][thread=MainThread] CheckpointManager init: checkpointers=None, item_names=None, item_handlers={'model_params': <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7ad4ec076ae0>, 'optimizer_state': <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7af0c34f4b00>, 'custom_metadata': <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7ad4ec07be60>}, handler_registry=None
I0424 12:19:40.141422 135183711901504 composite_checkpoint_handler.py:237] Deferred registration for item: "model_params". Adding handler `<orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7ad4ec076ae0>` for item "model_params" and save args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>` to `_handler_registry`.
I0424 12:19:40.141468 135183711901504 composite_checkpoint_handler.py:237] Deferred registration for item: "optimizer_state". Adding handler `<orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7af0c34f4b00>` for item "optimizer_state" and save args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>` to `_handler_registry`.
I0424 12:19:40.141496 135183711901504 composite_checkpoint_handler.py:237] Deferred registration for item: "custom_metadata". Adding handler `<orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7ad4ec07be60>` for item "custom_metadata" and save args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>` to `_handler_registry`.
I0424 12:19:40.141522 135183711901504 composite_checkpoint_handler.py:237] Deferred registration for item: "metrics". Adding handler `<orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7ad4ec077c80>` for item "metrics" and save args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>` to `_handler_registry`.
I0424 12:19:40.141551 135183711901504 composite_checkpoint_handler.py:505] Initialized registry DefaultCheckpointHandlerRegistry({('model_params', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7ad4ec076ae0>, ('model_params', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7ad4ec076ae0>, ('optimizer_state', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7af0c34f4b00>, ('optimizer_state', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7af0c34f4b00>, ('custom_metadata', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7ad4ec07be60>, ('custom_metadata', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7ad4ec07be60>, ('metrics', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7ad4ec077c80>, ('metrics', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7ad4ec077c80>}).
I0424 12:19:40.141937 135183711901504 abstract_checkpointer.py:35] orbax-checkpoint version: 0.11.28
I0424 12:19:40.141989 135183711901504 async_checkpointer.py:177] [process=6][thread=MainThread] Using barrier_sync_fn: <function get_barrier_sync_fn.<locals>._fn at 0x7ad3681ed260> timeout: 600 secs and primary_host=0 for async checkpoint writes
I0424 12:19:42.515621 135183711901504 checkpoint_manager.py:1788] Found 0 checkpoint steps in gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_post_train_fixes_20260424_120707/pt_sft_nnx_xpk_feat_nnx_post_train_fixes_20260424_120707_01_sft_smoke/checkpoints
I0424 12:19:42.553857 135183711901504 checkpoint_manager.py:921] [process=6][thread=MainThread] CheckpointManager created,  primary_host=0, CheckpointManagerOptions=CheckpointManagerOptions(save_interval_steps=10000, max_to_keep=None, keep_time_interval=None, keep_period=None, should_keep_fn=None, best_fn=None, best_mode='max', keep_checkpoints_without_metrics=True, step_prefix=None, step_format_fixed_length=None, step_name_format=None, create=True, cleanup_tmp_directories=False, save_on_steps=frozenset(), single_host_load_and_broadcast=False, todelete_subdir=None, todelete_full_path=None, enable_hns=False, enable_background_delete=False, read_only=False, enable_async_checkpointing=True, async_options=None, multiprocessing_options=MultiprocessingOptions(primary_host=0, active_processes=None, barrier_sync_key_prefix=None), should_save_fn=None, file_options=FileOptions(path_permission_mode=None), save_root_metadata=True, temporary_path_class=None, save_decision_policy=None, preservation_policy=None, prevent_write_metrics=False, enable_should_save_is_saving_in_progress_check=True, enable_per_process_directory_creation=False), root_directory=gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_post_train_fixes_20260424_120707/pt_sft_nnx_xpk_feat_nnx_post_train_fixes_20260424_120707_01_sft_smoke/checkpoints: <orbax.checkpoint.checkpoint_manager.CheckpointManager object at 0x7ad4ec079a60>
I0424 12:19:42.554177 135183711901504 peft_trainer.py:584] Training with mesh: Mesh('diloco': 1, 'data': 4, 'stage': 1, 'fsdp': 8, 'fsdp_transpose': 1, 'sequence': 1, 'context': 1, 'context_autoregressive': 1, 'tensor': 1, 'tensor_transpose': 1, 'tensor_sequence': 1, 'expert': 1, 'autoregressive': 1, axis_types=(Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto))
I0424 12:19:43.033161 135183711901504 peft_trainer.py:594] Compiled train_step cache size: 0
I0424 12:19:43.037165 135183711901504 metric_logger.py:301] number parameters: 0.000 billion
I0424 12:19:43.039406 135012640945920 grain_pool.py:367] Grain pool will use 1 processes.
I0424 12:19:43.089987 135012640945920 grain_pool.py:440] Grain pool will start child processes.
Per train step:
 Total TFLOPs: 0.00 
 split as 54.29% learnable weight flops and 45.71% attention flops
I0424 12:19:43.096185 135012640945920 grain_pool.py:448] Grain pool started all child processes.
2026-04-24 12:19:47.277356: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-04-24 12:19:47.323070: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-04-24 12:19:48.498411: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
`rope_parameters`'s factor field must be a float >= 1, got 40
`rope_parameters`'s beta_fast field must be a float, got 32
`rope_parameters`'s beta_slow field must be a float, got 1
DeepseekV32Config got `key=rope_scaling` in kwargs but hasn't set it as attribute. For RoPE standardization you need to set `self.rope_parameters` in model's config. 
2026-04-24 12:19:54.108790: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)
Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/deps/src/maxtext/trainers/post_train/sft/train_sft.py", line 283, in <module>
    app.run(main)
  File "/usr/local/lib/python3.12/site-packages/absl/app.py", line 316, in run
    _run_main(main, args)
  File "/usr/local/lib/python3.12/site-packages/absl/app.py", line 261, in _run_main
    sys.exit(main(argv))
             ^^^^^^^^^^
  File "/deps/src/maxtext/trainers/post_train/sft/train_sft.py", line 279, in main
    train(mt_config, goodput_recorder)
  File "/deps/src/maxtext/trainers/post_train/sft/train_sft.py", line 256, in train
    trainer = train_model(mt_config, trainer, mesh)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/deps/src/maxtext/trainers/post_train/sft/train_sft.py", line 242, in train_model
    trainer.train(trainer.data_hooks.train_data_iterator, trainer.data_hooks.eval_data_iterator)
  File "/usr/local/lib/python3.12/site-packages/tunix/sft/peft_trainer.py", line 652, in train
    train_example = sharding_utils.shard_input(
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/tunix/sft/sharding_utils.py", line 58, in shard_input
    return jax.tree.map(
           ^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/tree.py", line 156, in map
    return tree_util.tree_map(f, tree, *rest, is_leaf=is_leaf)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/tree_util.py", line 373, in tree_map
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/tree_util.py", line 373, in <genexpr>
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
                             ^^^^^^
  File "/usr/local/lib/python3.12/site-packages/tunix/sft/sharding_utils.py", line 59, in <lambda>
    lambda x: jax.make_array_from_process_local_data(
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/array.py", line 985, in make_array_from_process_local_data
    out = [_array_from_process_local_data(data, s, shape)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/array.py", line 1047, in _array_from_process_local_data
    return make_array_from_callback(global_shape, sharding, cb)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/array.py", line 844, in make_array_from_callback
    per_device_values = api.device_put(per_device_values, devices)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/api.py", line 2732, in device_put
    out_flat = dispatch._batched_device_put_impl(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/dispatch.py", line 602, in _batched_device_put_impl
    y = _device_put_impl(x, device=device, src=src, copy=cp, aval=aval)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/dispatch.py", line 582, in _device_put_impl
    return _device_put_sharding_impl(x, aval, device, copy)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/dispatch.py", line 512, in _device_put_sharding_impl
    raise ValueError(
ValueError: When the second argument to `device_put` is a Device, the first argument must be a fully addressable array or a non-addressable array with a single device sharding. Got value with devices {TpuDevice(id=12, process_index=2, coords=(0,3,0), core_on_chip=0), TpuDevice(id=28, process_index=6, coords=(0,7,0), core_on_chip=0), TpuDevice(id=31, process_index=7, coords=(3,7,0), core_on_chip=0), TpuDevice(id=17, process_index=4, coords=(1,4,0), core_on_chip=0), TpuDevice(id=2, process_index=1, coords=(2,0,0), core_on_chip=0), TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=6, process_index=1, coords=(2,1,0), core_on_chip=0), TpuDevice(id=26, process_index=7, coords=(2,6,0), core_on_chip=0), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=23, process_index=5, coords=(3,5,0), core_on_chip=0), TpuDevice(id=8, process_index=2, coords=(0,2,0), core_on_chip=0), TpuDevice(id=29, process_index=6, coords=(1,7,0), core_on_chip=0), TpuDevice(id=20, process_index=4, coords=(0,5,0), core_on_chip=0), TpuDevice(id=27, process_index=7, coords=(3,6,0), core_on_chip=0), TpuDevice(id=16, process_index=4, coords=(0,4,0), core_on_chip=0), TpuDevice(id=22, process_index=5, coords=(2,5,0), core_on_chip=0), TpuDevice(id=24, process_index=6, coords=(0,6,0), core_on_chip=0), TpuDevice(id=13, process_index=2, coords=(1,3,0), core_on_chip=0), TpuDevice(id=30, process_index=7, coords=(2,7,0), core_on_chip=0), TpuDevice(id=11, process_index=3, coords=(3,2,0), core_on_chip=0), TpuDevice(id=7, process_index=1, coords=(3,1,0), core_on_chip=0), TpuDevice(id=21, process_index=4, coords=(1,5,0), core_on_chip=0), TpuDevice(id=9, process_index=2, coords=(1,2,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=3, process_index=1, coords=(3,0,0), core_on_chip=0), TpuDevice(id=25, process_index=6, coords=(1,6,0), core_on_chip=0), TpuDevice(id=15, process_index=3, coords=(3,3,0), core_on_chip=0), TpuDevice(id=10, process_index=3, coords=(2,2,0), core_on_chip=0), TpuDevice(id=19, process_index=5, coords=(3,4,0), core_on_chip=0), TpuDevice(id=14, process_index=3, coords=(2,3,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=18, process_index=5, coords=(2,4,0), core_on_chip=0)}
I0424 12:19:59.434149 135012640945920 grain_pool.py:542] Grain pool is exiting.
I0424 12:19:59.434252 135012640945920 grain_pool.py:547] Shutting down multiprocessing system.
I0424 12:20:05.246124 135012640945920 grain_pool.py:547] Shutting down multiprocessing system.
/usr/local/lib/python3.12/multiprocessing/resource_tracker.py:279: UserWarning: resource_tracker: There appear to be 20 leaked shared_memory objects to clean up at shutdown
  warnings.warn('resource_tracker: There appear to be %d '
XPK End: Fri Apr 24 12:20:16 UTC 2026
EXIT_CODE=1