MaxView

‹ —Case: 01_sft_smoke07_distill_smoke ›

Metrics: Linen vs NNX  ·  feat/nnx-post-train-fixes

MetricLinen  7f06c99acNNX  7f06c99acDiff (NNX − Linen)
Parameters0.000 billion0.000 billion
Final loss5.7660
TFLOP/s0.014
Tok/s66992.8
Avg s/step3.931
Memory %0.03
JAX0.8.30.8.3

Diff = NNX value − Linen value. Green = NNX improved. Red = NNX regressed.

Linen  ·  7f06c99ac  ·  feat_nnx_post_train_fixes_20260422_123915  ·  full log
XPK Start: Wed Apr 22 12:43:19 UTC 2026
Unrecognized keys in `rope_scaling` for 'rope_type'='yarn': {'rope_theta'}
`rope_scaling`'s factor field must be a float >= 1, got 40
`rope_scaling`'s beta_fast field must be a float, got 32
`rope_scaling`'s beta_slow field must be a float, got 1
Unrecognized keys in `rope_scaling` for 'rope_type'='yarn': {'rope_theta'}
Unrecognized keys in `rope_scaling` for 'rope_type'='yarn': {'rope_theta'}
Unrecognized keys in `rope_scaling` for 'rope_type'='yarn': {'rope_theta'}
2026-04-22 12:44:17.098003: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)
I0422 12:44:17.344536 134531539752768 max_utils.py:273] Attempting to initialize the jax distributed system...
INFO:2026-04-22 12:44:26,387:jax._src.distributed:166: Connecting to JAX distributed service on mt-01-sft-smoke-v69ex-slice-job-0-0.mt-01-sft-smoke-v69ex:8482
I0422 12:44:26.387798 134531539752768 distributed.py:166] Connecting to JAX distributed service on mt-01-sft-smoke-v69ex-slice-job-0-0.mt-01-sft-smoke-v69ex:8482
I0422 12:44:49.340689 134531539752768 max_utils.py:284] Jax distributed system initialized!
I0422 12:44:55.275766 134531539752768 max_utils.py:800] System Information: Jax Version: 0.8.3
I0422 12:44:55.275880 134531539752768 max_utils.py:801] System Information: Jaxlib Version: 0.8.3
I0422 12:44:55.275920 134531539752768 max_utils.py:802] System Information: Jax Backend: PJRT C API
TFRT TPU v6 lite
Built on Dec 15 2025 14:03:46 (1765836226) cl/844590465
I0422 12:44:55.279273 134531539752768 maxtext_utils.py:1771] Num_devices: 32, shape (1, 4, 1, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1)
I0422 12:44:55.474796 134531539752768 maxtext_utils.py:1771] Num_devices: 32, shape (1, 4, 1, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1)
I0422 12:44:56.634816 134531539752768 config.py:112] TensorFlow version 2.20.0 available.
I0422 12:44:56.635317 134531539752768 config.py:125] JAX version 0.8.3 available.
/deps/src/maxtext/input_pipeline/input_pipeline_utils.py:467: UserWarning: WARNING: Inefficient dataloading. Your train or eval dataset contains 3 shards, smaller than number of host loading data. This is known to lead to inefficient dataloading. Seegithub.com/google/maxtext/blob/main/getting_started/Data_Input_Pipeline.md#multihost-dataloading-best-practice
  warnings.warn(
E0422 12:45:02.057409 134531539752768 packing.py:209] PackAndBatchOperation is deprecated. Please use lazy_dataset.FirstFitPackIterDataset instead.
I0422 12:45:02.057625 134531539752768 data_loader.py:408] Adding CopyNumPyArrayToSharedMemory MapTransform.
I0422 12:45:02.445387 134531539752768 pytree_checkpoint_handler.py:577] save_device_host_concurrent_bytes=None
I0422 12:45:02.445851 134531539752768 base_pytree_checkpoint_handler.py:411] Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=False, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x7a5a5ad280e0>, enable_pinned_host_transfer=False, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB)
I0422 12:45:02.445901 134531539752768 pytree_checkpoint_handler.py:577] save_device_host_concurrent_bytes=None
I0422 12:45:02.445940 134531539752768 base_pytree_checkpoint_handler.py:411] Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=False, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x7a5a5ad280e0>, enable_pinned_host_transfer=False, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB)
I0422 12:45:02.445988 134531539752768 checkpoint_manager.py:702] [process=1][thread=MainThread] CheckpointManager init: checkpointers=None, item_names=None, item_handlers={'model_params': <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7a3d8811ee40>, 'optimizer_state': <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7a3d880808c0>, 'custom_metadata': <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7a3ca8051b80>}, handler_registry=None
I0422 12:45:02.446202 134531539752768 composite_checkpoint_handler.py:237] Deferred registration for item: "model_params". Adding handler `<orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7a3d8811ee40>` for item "model_params" and save args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>` to `_handler_registry`.
I0422 12:45:02.446245 134531539752768 composite_checkpoint_handler.py:237] Deferred registration for item: "optimizer_state". Adding handler `<orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7a3d880808c0>` for item "optimizer_state" and save args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>` to `_handler_registry`.
I0422 12:45:02.446273 134531539752768 composite_checkpoint_handler.py:237] Deferred registration for item: "custom_metadata". Adding handler `<orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7a3ca8051b80>` for item "custom_metadata" and save args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>` to `_handler_registry`.
I0422 12:45:02.446300 134531539752768 composite_checkpoint_handler.py:237] Deferred registration for item: "metrics". Adding handler `<orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7a3dec128140>` for item "metrics" and save args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>` to `_handler_registry`.
I0422 12:45:02.446330 134531539752768 composite_checkpoint_handler.py:505] Initialized registry DefaultCheckpointHandlerRegistry({('model_params', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7a3d8811ee40>, ('model_params', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7a3d8811ee40>, ('optimizer_state', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7a3d880808c0>, ('optimizer_state', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7a3d880808c0>, ('custom_metadata', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7a3ca8051b80>, ('custom_metadata', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7a3ca8051b80>, ('metrics', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7a3dec128140>, ('metrics', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7a3dec128140>}).
I0422 12:45:02.446610 134531539752768 abstract_checkpointer.py:35] orbax-checkpoint version: 0.11.28
I0422 12:45:02.446666 134531539752768 async_checkpointer.py:177] [process=1][thread=MainThread] Using barrier_sync_fn: <function get_barrier_sync_fn.<locals>._fn at 0x7a3ca80cf060> timeout: 600 secs and primary_host=0 for async checkpoint writes
I0422 12:45:09.342662 134531539752768 checkpoint_manager.py:1788] Found 0 checkpoint steps in gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_post_train_fixes_20260422_123915/pt_sft_linen_xpk_feat_nnx_post_train_fixes_20260422_123915_01_sft_smoke/checkpoints
I0422 12:45:09.772027 134531539752768 checkpoint_manager.py:921] [process=1][thread=MainThread] CheckpointManager created,  primary_host=0, CheckpointManagerOptions=CheckpointManagerOptions(save_interval_steps=10000, max_to_keep=None, keep_time_interval=None, keep_period=None, should_keep_fn=None, best_fn=None, best_mode='max', keep_checkpoints_without_metrics=True, step_prefix=None, step_format_fixed_length=None, step_name_format=None, create=True, cleanup_tmp_directories=False, save_on_steps=frozenset(), single_host_load_and_broadcast=False, todelete_subdir=None, todelete_full_path=None, enable_hns=False, enable_background_delete=False, read_only=False, enable_async_checkpointing=True, async_options=None, multiprocessing_options=MultiprocessingOptions(primary_host=0, active_processes=None, barrier_sync_key_prefix=None), should_save_fn=None, file_options=FileOptions(path_permission_mode=None), save_root_metadata=True, temporary_path_class=None, save_decision_policy=None, preservation_policy=None, prevent_write_metrics=False, enable_should_save_is_saving_in_progress_check=True, enable_per_process_directory_creation=False), root_directory=gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_post_train_fixes_20260422_123915/pt_sft_linen_xpk_feat_nnx_post_train_fixes_20260422_123915_01_sft_smoke/checkpoints: <orbax.checkpoint.checkpoint_manager.CheckpointManager object at 0x7a3ca8052d80>
I0422 12:45:09.772422 134531539752768 peft_trainer.py:584] Training with mesh: Mesh('diloco': 1, 'data': 4, 'stage': 1, 'fsdp': 8, 'fsdp_transpose': 1, 'sequence': 1, 'context': 1, 'context_autoregressive': 1, 'tensor': 1, 'tensor_transpose': 1, 'tensor_sequence': 1, 'expert': 1, 'autoregressive': 1, axis_types=(Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto))
I0422 12:45:10.185117 134531539752768 peft_trainer.py:594] Compiled train_step cache size: 0
I0422 12:45:10.187134 134531539752768 metric_logger.py:301] number parameters: 0.000 billion
I0422 12:45:10.189594 134366441240320 grain_pool.py:367] Grain pool will use 1 processes.
I0422 12:45:10.215684 134366441240320 grain_pool.py:440] Grain pool will start child processes.
Per train step:
 Total TFLOPs: 0.00 
 split as 54.29% learnable weight flops and 45.71% attention flops
I0422 12:45:10.220927 134366441240320 grain_pool.py:448] Grain pool started all child processes.
2026-04-22 12:45:14.226535: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-04-22 12:45:14.271161: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-04-22 12:45:15.447251: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
Unrecognized keys in `rope_scaling` for 'rope_type'='yarn': {'rope_theta'}
`rope_scaling`'s factor field must be a float >= 1, got 40
`rope_scaling`'s beta_fast field must be a float, got 32
`rope_scaling`'s beta_slow field must be a float, got 1
Unrecognized keys in `rope_scaling` for 'rope_type'='yarn': {'rope_theta'}
Unrecognized keys in `rope_scaling` for 'rope_type'='yarn': {'rope_theta'}
Unrecognized keys in `rope_scaling` for 'rope_type'='yarn': {'rope_theta'}
2026-04-22 12:45:19.622577: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)
I0422 12:45:28.158396 134531539752768 checkpoint_manager.py:1983] [process=1][thread=MainThread][wait_until_finished] No Save Finalize thread to wait for. Returning.
I0422 12:45:28.160267 134531539752768 checkpoint_manager.py:1501] [process=1] Saving checkpoint at step 1
I0422 12:45:28.163488 134531539752768 async_checkpointer.py:452] [process=1] Started async saving checkpoint to gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_post_train_fixes_20260422_123915/pt_sft_linen_xpk_feat_nnx_post_train_fixes_20260422_123915_01_sft_smoke/checkpoints/1.
I0422 12:45:28.698688 134531539752768 signaling_client.py:364] Using JaxDistributedSignalingClient
I0422 12:45:28.699661 134531539752768 jax_array_handlers.py:347] Scheduling D2H of 52 prioritized jax.Array.
I0422 12:45:28.699728 134531539752768 replica_slices.py:410] Transferring arrays to host memory with options: use_replica_parallel=True, min_slice_bytes_for_replica_parallel=None, max_replicas_for_replica_parallel=None, enable_pinned_host_transfer=False
I0422 12:45:29.364965 134531539752768 base_pytree_checkpoint_handler.py:153] [process=1][thread=MainThread] Initiated "orbax.checkpoint._src.serialization.jax_array_handlers.ArrayHandler".serialize. Time taken: 0.666374s
I0422 12:45:29.366486 134531539752768 base_pytree_checkpoint_handler.py:128] [process=1] /jax/checkpoint/write/blocking_gbytes_per_sec: 82.942 KiB/s (total gbytes: 76.6 KiB) (time elapsed: 923 milliseconds) (per-host)
I0422 12:45:29.366550 134531539752768 base_pytree_checkpoint_handler.py:732] [process=1][thread=MainThread] Initiated Pytree async_save. Time taken: 0.923566s (batch_requests_ready=0.250985s, total_serialization_initiated=0.672470s, others=0.000112s)
I0422 12:45:29.367607 134531539752768 jax_array_handlers.py:347] Scheduling D2H of 22 prioritized jax.Array.
I0422 12:45:29.367660 134531539752768 replica_slices.py:410] Transferring arrays to host memory with options: use_replica_parallel=True, min_slice_bytes_for_replica_parallel=None, max_replicas_for_replica_parallel=None, enable_pinned_host_transfer=False
I0422 12:45:29.372174 134531539752768 base_pytree_checkpoint_handler.py:153] [process=1][thread=MainThread] Initiated "orbax.checkpoint._src.serialization.jax_array_handlers.ArrayHandler".serialize. Time taken: 0.005488s
I0422 12:45:29.372283 134531539752768 base_pytree_checkpoint_handler.py:128] [process=1] /jax/checkpoint/write/blocking_gbytes_per_sec: 27.413 KiB/s (total gbytes: 25.5 KiB) (time elapsed: 931 milliseconds) (per-host)
I0422 12:45:29.372329 134531539752768 base_pytree_checkpoint_handler.py:732] [process=1][thread=MainThread] Initiated Pytree async_save. Time taken: 0.931432s (batch_requests_ready=0.924258s, total_serialization_initiated=0.007104s, others=0.000071s)
I0422 12:45:29.372433 134531539752768 composite_checkpoint_handler.py:715] [process=1][thread=MainThread] Initiated CompositeCheckpointHandler.async_save. Time taken: 0.935319s (all_items=0.000023s, per_item={'model_params': '0.00001884', 'optimizer_state': '0.00000453'}, temp_paths=0.935296)
I0422 12:45:29.373450 134360367871744 async_checkpointer.py:79] [process=1][thread=async_save] Background save thread started.
I0422 12:45:29.373589 134531539752768 async_checkpointer.py:561] Finished blocking save. Time taken: 1.213248s. Continuing background save to gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_post_train_fixes_20260422_123915/pt_sft_linen_xpk_feat_nnx_post_train_fixes_20260422_123915_01_sft_smoke/checkpoints/1.
I0422 12:45:29.460603 134531539752768 checkpoint_manager.py:1549] [process=1][thread=MainThread][step=1] Starting CheckpointManager Save Finalize thread=save_finalize
I0422 12:45:29.460999 134360493696768 async_checkpointer.py:265] [process=1][thread=save_finalize] Waiting for background save thread=async_save.
I0422 12:45:29.461169 134531539752768 standard_logger.py:34] {'step': 1, 'event_type': 'save', 'directory': 'gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_post_train_fixes_20260422_123915/pt_sft_linen_xpk_feat_nnx_post_train_fixes_20260422_123915_01_sft_smoke/checkpoints', 'reached_preemption': False, 'preemption_received_at': None, 'synchronous': False, 'wait_for_prev_start_time': 1776861928.1583424, 'wait_for_prev_duration_secs': 0.00015807151794433594, 'time_between_consecutive_saves_sec': None, 'checkpointer_blocking_start_time': 1776861928.1603081, 'checkpointer_blocking_duration_secs': 1.2134618759155273, 'get_old_steps_start_time': 1776861929.3737996, 'get_old_steps_duration_secs': 8.20159912109375e-05, 'checkpoint_manager_blocking_start_time': 1776861927.957667, 'checkpoint_manager_blocking_duration_secs': 1.503464698791504}
I0422 12:45:29.606145 134531539752768 peft_trainer.py:474] Train step 1 training loss: 5.894745  - training perplexity: 363.124207
I0422 12:45:29.606425 134531539752768 max_utils.py:750] 
Memstats: After params initialized:
I0422 12:45:29.606489 134531539752768 max_utils.py:756] 	Using (GB) 0.01 / 31.25 (0.032000%) on TPU_2(process=1,(2,0,0,0))
I0422 12:45:29.606525 134531539752768 max_utils.py:756] 	Using (GB) 0.01 / 31.25 (0.032000%) on TPU_3(process=1,(3,0,0,0))
I0422 12:45:29.606554 134531539752768 max_utils.py:756] 	Using (GB) 0.01 / 31.25 (0.032000%) on TPU_6(process=1,(2,1,0,0))
I0422 12:45:29.606579 134531539752768 max_utils.py:756] 	Using (GB) 0.01 / 31.25 (0.032000%) on TPU_7(process=1,(3,1,0,0))
I0422 12:45:29.737132 134531539752768 metric_logger.py:196] completed step: 1, seconds: 19.419, TFLOP/s/device: 0.000, Tokens/s/device: 52.731, total_weights: 21054, loss: 5.895, lm_loss: 0.000, perplexity: 0.000
I0422 12:45:29.746229 134531539752768 peft_trainer.py:474] Train step 2 training loss: 5.603691  - training perplexity: 271.426270
I0422 12:45:29.747006 134531539752768 metric_logger.py:196] completed step: 2, seconds: 0.140, TFLOP/s/device: 0.002, Tokens/s/device: 7322.443, total_weights: 21455, loss: 5.604, lm_loss: 0.000, perplexity: 0.000
I0422 12:45:29.807019 134531539752768 peft_trainer.py:474] Train step 3 training loss: 5.511911  - training perplexity: 247.623871
I0422 12:45:29.807858 134531539752768 metric_logger.py:196] completed step: 3, seconds: 0.061, TFLOP/s/device: 0.004, Tokens/s/device: 16845.442, total_weights: 22025, loss: 5.512, lm_loss: 0.000, perplexity: 0.000
I0422 12:45:29.826100 134531539752768 peft_trainer.py:474] Train step 4 training loss: 5.686435  - training perplexity: 294.840698
I0422 12:45:29.826838 134531539752768 metric_logger.py:196] completed step: 4, seconds: 0.019, TFLOP/s/device: 0.012, Tokens/s/device: 53798.667, total_weights: 23787, loss: 5.686, lm_loss: 0.000, perplexity: 0.000
I0422 12:45:29.840449 134531539752768 peft_trainer.py:733] Train loop finished in: 19.6522 seconds
I0422 12:45:29.841397 134531539752768 peft_trainer.py:474] Train step 5 training loss: 5.766005  - training perplexity: 319.259766
I0422 12:45:29.842134 134531539752768 metric_logger.py:196] completed step: 5, seconds: 0.015, TFLOP/s/device: 0.014, Tokens/s/device: 66992.812, total_weights: 20141, loss: 5.766, lm_loss: 0.000, perplexity: 0.000
I0422 12:45:32.862983    2727 google_auth_provider.cc:181] Running on GCE, using service account 562977990677-compute@developer.gserviceaccount.com
I0422 12:45:35.149612 134360393049856 array_metadata_store.py:203] [process=1][thread=array_type_handler] Wrote 52 array_metadata.ArrayMetadata to gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_post_train_fixes_20260422_123915/pt_sft_linen_xpk_feat_nnx_post_train_fixes_20260422_123915_01_sft_smoke/checkpoints/1/optimizer_state/array_metadatas/process_1
I0422 12:45:35.299094 134360376264448 array_metadata_store.py:203] [process=1][thread=array_type_handler] Wrote 22 array_metadata.ArrayMetadata to gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_post_train_fixes_20260422_123915/pt_sft_linen_xpk_feat_nnx_post_train_fixes_20260422_123915_01_sft_smoke/checkpoints/1/model_params/array_metadatas/process_1
I0422 12:45:35.300283 134360367871744 base_pytree_checkpoint_handler.py:128] [process=1] /jax/checkpoint/write/gbytes_per_sec: 3.722 KiB/s (total gbytes: 25.5 KiB) (time elapsed: 6 seconds) (per-host)
I0422 12:45:35.300463 134360367871744 base_pytree_checkpoint_handler.py:128] [process=1] /jax/checkpoint/write/gbytes_per_sec: 11.169 KiB/s (total gbytes: 76.6 KiB) (time elapsed: 6 seconds) (per-host)
I0422 12:45:35.300501 134360367871744 async_checkpointer.py:90] [process=1][thread=async_save] 4 Handler Commit operations completed. Time taken: 5.926830s.
I0422 12:45:38.522752 134531539752768 checkpoint_manager.py:1994] [process=1][thread=MainThread][step=1][wait_until_finished] Waiting for Save Finalize thread (save_finalize) to complete.
I0422 12:45:46.600028 134360367871744 async_checkpointer.py:144] [process=1][thread=async_save] Background save thread done. Time taken: 17.226341s.
I0422 12:45:46.600268 134360493696768 async_checkpointer.py:273] [process=1][thread=save_finalize] Done with waiting for background save thread=async_save.
I0422 12:45:46.600334 134360493696768 async_checkpointer.py:283] [process=1][thread=save_finalize] No errors found in background save thread=async_save.
I0422 12:45:46.600434 134360493696768 checkpoint_manager.py:2103] [process=1][thread=save_finalize][step=1] CheckpointManager Save Finalize is syncing with other hosts...
I0422 12:45:46.602261 134360493696768 checkpoint_manager.py:2112] [process=1][thread=save_finalize][step=1] CheckpointManager Save Finalize is done on all hosts.
I0422 12:45:46.602442 134531539752768 checkpoint_manager.py:2006] [process=1][thread=MainThread][step=1][wait_until_finished] Done waiting for Save Finalize thread (save_finalize) running at step=1.
W0422 12:45:46.602575 134531539752768 checkpoint_manager.py:1441] Waiting for previous save to complete took 8.079841 seconds. If this number is high, consider checkpointing less frequently.
I0422 12:45:46.604175 134531539752768 checkpoint_manager.py:1501] [process=1] Saving checkpoint at step 5
I0422 12:45:46.607546 134531539752768 async_checkpointer.py:452] [process=1] Started async saving checkpoint to gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_post_train_fixes_20260422_123915/pt_sft_linen_xpk_feat_nnx_post_train_fixes_20260422_123915_01_sft_smoke/checkpoints/5.
I0422 12:45:47.567579 134531539752768 jax_array_handlers.py:347] Scheduling D2H of 52 prioritized jax.Array.
I0422 12:45:47.567680 134531539752768 replica_slices.py:410] Transferring arrays to host memory with options: use_replica_parallel=True, min_slice_bytes_for_replica_parallel=None, max_replicas_for_replica_parallel=None, enable_pinned_host_transfer=False
I0422 12:45:47.581964 134531539752768 base_pytree_checkpoint_handler.py:153] [process=1][thread=MainThread] Initiated "orbax.checkpoint._src.serialization.jax_array_handlers.ArrayHandler".serialize. Time taken: 0.015485s
I0422 12:45:47.583323 134531539752768 base_pytree_checkpoint_handler.py:128] [process=1] /jax/checkpoint/write/blocking_gbytes_per_sec: 288.094 KiB/s (total gbytes: 76.6 KiB) (time elapsed: 265 milliseconds) (per-host)
I0422 12:45:47.583394 134531539752768 base_pytree_checkpoint_handler.py:732] [process=1][thread=MainThread] Initiated Pytree async_save. Time taken: 0.265968s (batch_requests_ready=0.247061s, total_serialization_initiated=0.018800s, others=0.000107s)
I0422 12:45:47.584493 134531539752768 jax_array_handlers.py:347] Scheduling D2H of 22 prioritized jax.Array.
I0422 12:45:47.584545 134531539752768 replica_slices.py:410] Transferring arrays to host memory with options: use_replica_parallel=True, min_slice_bytes_for_replica_parallel=None, max_replicas_for_replica_parallel=None, enable_pinned_host_transfer=False
I0422 12:45:47.589677 134531539752768 base_pytree_checkpoint_handler.py:153] [process=1][thread=MainThread] Initiated "orbax.checkpoint._src.serialization.jax_array_handlers.ArrayHandler".serialize. Time taken: 0.006160s
I0422 12:45:47.589783 134531539752768 base_pytree_checkpoint_handler.py:128] [process=1] /jax/checkpoint/write/blocking_gbytes_per_sec: 93.252 KiB/s (total gbytes: 25.5 KiB) (time elapsed: 273 milliseconds) (per-host)
I0422 12:45:47.589826 134531539752768 base_pytree_checkpoint_handler.py:732] [process=1][thread=MainThread] Initiated Pytree async_save. Time taken: 0.273849s (batch_requests_ready=0.266111s, total_serialization_initiated=0.007673s, others=0.000065s)
I0422 12:45:47.589915 134531539752768 composite_checkpoint_handler.py:715] [process=1][thread=MainThread] Initiated CompositeCheckpointHandler.async_save. Time taken: 0.277980s (all_items=0.000014s, per_item={'model_params': '0.00001144', 'optimizer_state': '0.00000238'}, temp_paths=0.277966)
I0422 12:45:47.590856 134360493696768 async_checkpointer.py:79] [process=1][thread=async_save] Background save thread started.
I0422 12:45:47.591012 134531539752768 async_checkpointer.py:561] Finished blocking save. Time taken: 0.986768s. Continuing background save to gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_post_train_fixes_20260422_123915/pt_sft_linen_xpk_feat_nnx_post_train_fixes_20260422_123915_01_sft_smoke/checkpoints/5.
I0422 12:45:47.931273 134531539752768 checkpoint_manager.py:1549] [process=1][thread=MainThread][step=5] Starting CheckpointManager Save Finalize thread=save_finalize
I0422 12:45:47.931725 134360367871744 async_checkpointer.py:265] [process=1][thread=save_finalize] Waiting for background save thread=async_save.
I0422 12:45:47.931907 134531539752768 standard_logger.py:34] {'step': 5, 'event_type': 'save', 'directory': 'gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_post_train_fixes_20260422_123915/pt_sft_linen_xpk_feat_nnx_post_train_fixes_20260422_123915_01_sft_smoke/checkpoints', 'reached_preemption': False, 'preemption_received_at': None, 'synchronous': False, 'wait_for_prev_start_time': 1776861938.5227103, 'wait_for_prev_duration_secs': 8.079841136932373, 'time_between_consecutive_saves_sec': None, 'checkpointer_blocking_start_time': 1776861946.604214, 'checkpointer_blocking_duration_secs': 0.9869015216827393, 'get_old_steps_start_time': 1776861947.5911376, 'get_old_steps_duration_secs': 8.058547973632812e-05, 'checkpoint_manager_blocking_start_time': 1776861929.8456693, 'checkpoint_manager_blocking_duration_secs': 18.086202383041382}
I0422 12:45:47.932092 134531539752768 checkpoint_manager.py:1994] [process=1][thread=MainThread][step=5][wait_until_finished] Waiting for Save Finalize thread (save_finalize) to complete.
I0422 12:45:52.531240 134360393049856 array_metadata_store.py:203] [process=1][thread=array_type_handler] Wrote 52 array_metadata.ArrayMetadata to gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_post_train_fixes_20260422_123915/pt_sft_linen_xpk_feat_nnx_post_train_fixes_20260422_123915_01_sft_smoke/checkpoints/5/optimizer_state/array_metadatas/process_1
I0422 12:45:53.030768 134359831000832 array_metadata_store.py:203] [process=1][thread=array_type_handler] Wrote 22 array_metadata.ArrayMetadata to gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_post_train_fixes_20260422_123915/pt_sft_linen_xpk_feat_nnx_post_train_fixes_20260422_123915_01_sft_smoke/checkpoints/5/model_params/array_metadatas/process_1
I0422 12:45:53.031953 134360493696768 base_pytree_checkpoint_handler.py:128] [process=1] /jax/checkpoint/write/gbytes_per_sec: 4.467 KiB/s (total gbytes: 25.5 KiB) (time elapsed: 5 seconds) (per-host)
I0422 12:45:53.032100 134360493696768 base_pytree_checkpoint_handler.py:128] [process=1] /jax/checkpoint/write/gbytes_per_sec: 13.403 KiB/s (total gbytes: 76.6 KiB) (time elapsed: 5 seconds) (per-host)
I0422 12:45:53.032138 134360493696768 async_checkpointer.py:90] [process=1][thread=async_save] 4 Handler Commit operations completed. Time taken: 5.441172s.
I0422 12:46:03.920583 134360493696768 async_checkpointer.py:144] [process=1][thread=async_save] Background save thread done. Time taken: 16.329598s.
I0422 12:46:03.920892 134360367871744 async_checkpointer.py:273] [process=1][thread=save_finalize] Done with waiting for background save thread=async_save.
I0422 12:46:03.921009 134360367871744 async_checkpointer.py:283] [process=1][thread=save_finalize] No errors found in background save thread=async_save.
I0422 12:46:03.921053 134360367871744 checkpoint_manager.py:2103] [process=1][thread=save_finalize][step=5] CheckpointManager Save Finalize is syncing with other hosts...
I0422 12:46:03.922494 134360367871744 checkpoint_manager.py:2112] [process=1][thread=save_finalize][step=5] CheckpointManager Save Finalize is done on all hosts.
I0422 12:46:03.922679 134531539752768 checkpoint_manager.py:2006] [process=1][thread=MainThread][step=5][wait_until_finished] Done waiting for Save Finalize thread (save_finalize) running at step=5.
I0422 12:46:03.922832 134531539752768 checkpoint.py:459] Closing _NonBlockingMetadataStore(enable_write=True, _write_lock=<locked _thread.RLock object owner=134531539752768 count=1 at 0x7a3dc8077600>, _store_impl=<orbax.checkpoint._src.metadata.checkpoint._MetadataStoreImpl object at 0x7a3ca80504a0>, _single_thread_executor=<concurrent.futures.thread.ThreadPoolExecutor object at 0x7a3ca8060050>, _write_futures=[])
I0422 12:46:03.923264 134531539752768 checkpoint.py:459] Closing _NonBlockingMetadataStore(enable_write=True, _write_lock=<locked _thread.RLock object owner=134531539752768 count=1 at 0x7a3dc8077600>, _store_impl=<orbax.checkpoint._src.metadata.checkpoint._MetadataStoreImpl object at 0x7a3ca80504a0>, _single_thread_executor=<concurrent.futures.thread.ThreadPoolExecutor object at 0x7a3ca8060050>, _write_futures=[])
I0422 12:46:03.923294 134531539752768 checkpoint.py:459] Closing _NonBlockingMetadataStore(enable_write=True, _write_lock=<locked _thread.RLock object owner=134531539752768 count=1 at 0x7a3dc8077600>, _store_impl=<orbax.checkpoint._src.metadata.checkpoint._MetadataStoreImpl object at 0x7a3ca80504a0>, _single_thread_executor=<concurrent.futures.thread.ThreadPoolExecutor object at 0x7a3ca8060050>, _write_futures=[])
I0422 12:46:04.336791 134366441240320 grain_pool.py:542] Grain pool is exiting.
I0422 12:46:04.336890 134366441240320 grain_pool.py:547] Shutting down multiprocessing system.
I0422 12:46:06.425156 134366441240320 grain_pool.py:547] Shutting down multiprocessing system.
/usr/local/lib/python3.12/multiprocessing/resource_tracker.py:279: UserWarning: resource_tracker: There appear to be 15 leaked shared_memory objects to clean up at shutdown
  warnings.warn('resource_tracker: There appear to be %d '
XPK End: Wed Apr 22 12:46:15 UTC 2026
EXIT_CODE=0
NNX  ·  7f06c99ac  ·  feat_nnx_post_train_fixes_20260422_123915  ·  full log
XPK Start: Wed Apr 22 12:57:23 UTC 2026
2026-04-22 12:57:51.780689: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)
I0422 12:57:51.998446 133422147196736 max_utils.py:273] Attempting to initialize the jax distributed system...
INFO:2026-04-22 12:58:01,039:jax._src.distributed:149: Starting JAX distributed service on [::]:8482
I0422 12:58:01.039395 133422147196736 distributed.py:149] Starting JAX distributed service on [::]:8482
INFO:2026-04-22 12:58:01,041:jax._src.distributed:166: Connecting to JAX distributed service on mt-01-sft-smoke-031k7-slice-job-0-0.mt-01-sft-smoke-031k7:8482
I0422 12:58:01.041793 133422147196736 distributed.py:166] Connecting to JAX distributed service on mt-01-sft-smoke-031k7-slice-job-0-0.mt-01-sft-smoke-031k7:8482
I0422 12:58:02.788225 133422147196736 max_utils.py:284] Jax distributed system initialized!
I0422 12:58:09.040838 133422147196736 max_utils.py:800] System Information: Jax Version: 0.8.3
I0422 12:58:09.040946 133422147196736 max_utils.py:801] System Information: Jaxlib Version: 0.8.3
I0422 12:58:09.040986 133422147196736 max_utils.py:802] System Information: Jax Backend: PJRT C API
TFRT TPU v6 lite
Built on Dec 15 2025 14:03:46 (1765836226) cl/844590465
I0422 12:58:09.044381 133422147196736 maxtext_utils.py:1718] Num_devices: 32, shape (1, 4, 1, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1)
I0422 12:58:09.138288 133422147196736 maxtext_utils.py:1718] Num_devices: 32, shape (1, 4, 1, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1)
I0422 12:58:09.238423 133422147196736 maxtext_utils.py:1718] Num_devices: 32, shape (1, 4, 1, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1)
I0422 12:58:10.349864 133422147196736 config.py:112] TensorFlow version 2.20.0 available.
I0422 12:58:10.350386 133422147196736 config.py:125] JAX version 0.8.3 available.
/deps/src/maxtext/input_pipeline/input_pipeline_utils.py:467: UserWarning: WARNING: Inefficient dataloading. Your train or eval dataset contains 3 shards, smaller than number of host loading data. This is known to lead to inefficient dataloading. Seegithub.com/google/maxtext/blob/main/getting_started/Data_Input_Pipeline.md#multihost-dataloading-best-practice
  warnings.warn(
E0422 12:58:15.889910 133422147196736 packing.py:209] PackAndBatchOperation is deprecated. Please use lazy_dataset.FirstFitPackIterDataset instead.
I0422 12:58:15.890145 133422147196736 data_loader.py:408] Adding CopyNumPyArrayToSharedMemory MapTransform.
I0422 12:58:16.286330 133422147196736 pytree_checkpoint_handler.py:577] save_device_host_concurrent_bytes=None
I0422 12:58:16.286839 133422147196736 base_pytree_checkpoint_handler.py:411] Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=False, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x79580df1c200>, enable_pinned_host_transfer=False, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB)
I0422 12:58:16.286889 133422147196736 pytree_checkpoint_handler.py:577] save_device_host_concurrent_bytes=None
I0422 12:58:16.286928 133422147196736 base_pytree_checkpoint_handler.py:411] Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=False, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x79580df1c200>, enable_pinned_host_transfer=False, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB)
I0422 12:58:16.286975 133422147196736 checkpoint_manager.py:702] [process=6][thread=MainThread] CheckpointManager init: checkpointers=None, item_names=None, item_handlers={'model_params': <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x793ff9b8e7e0>, 'optimizer_state': <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x793ff9b6fbc0>, 'custom_metadata': <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x793a64223380>}, handler_registry=None
I0422 12:58:16.287200 133422147196736 composite_checkpoint_handler.py:237] Deferred registration for item: "model_params". Adding handler `<orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x793ff9b8e7e0>` for item "model_params" and save args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>` to `_handler_registry`.
I0422 12:58:16.287246 133422147196736 composite_checkpoint_handler.py:237] Deferred registration for item: "optimizer_state". Adding handler `<orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x793ff9b6fbc0>` for item "optimizer_state" and save args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>` to `_handler_registry`.
I0422 12:58:16.287275 133422147196736 composite_checkpoint_handler.py:237] Deferred registration for item: "custom_metadata". Adding handler `<orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x793a64223380>` for item "custom_metadata" and save args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>` to `_handler_registry`.
I0422 12:58:16.287301 133422147196736 composite_checkpoint_handler.py:237] Deferred registration for item: "metrics". Adding handler `<orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x793a64028e00>` for item "metrics" and save args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>` to `_handler_registry`.
I0422 12:58:16.287329 133422147196736 composite_checkpoint_handler.py:505] Initialized registry DefaultCheckpointHandlerRegistry({('model_params', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x793ff9b8e7e0>, ('model_params', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x793ff9b8e7e0>, ('optimizer_state', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x793ff9b6fbc0>, ('optimizer_state', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x793ff9b6fbc0>, ('custom_metadata', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x793a64223380>, ('custom_metadata', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x793a64223380>, ('metrics', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x793a64028e00>, ('metrics', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x793a64028e00>}).
I0422 12:58:16.287533 133422147196736 abstract_checkpointer.py:35] orbax-checkpoint version: 0.11.28
I0422 12:58:16.287588 133422147196736 async_checkpointer.py:177] [process=6][thread=MainThread] Using barrier_sync_fn: <function get_barrier_sync_fn.<locals>._fn at 0x793a6418d300> timeout: 600 secs and primary_host=0 for async checkpoint writes
I0422 12:58:18.633499 133422147196736 checkpoint_manager.py:1788] Found 0 checkpoint steps in gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_post_train_fixes_20260422_123915/pt_sft_nnx_xpk_feat_nnx_post_train_fixes_20260422_123915_01_sft_smoke/checkpoints
I0422 12:58:19.052539 133422147196736 checkpoint_manager.py:921] [process=6][thread=MainThread] CheckpointManager created,  primary_host=0, CheckpointManagerOptions=CheckpointManagerOptions(save_interval_steps=10000, max_to_keep=None, keep_time_interval=None, keep_period=None, should_keep_fn=None, best_fn=None, best_mode='max', keep_checkpoints_without_metrics=True, step_prefix=None, step_format_fixed_length=None, step_name_format=None, create=True, cleanup_tmp_directories=False, save_on_steps=frozenset(), single_host_load_and_broadcast=False, todelete_subdir=None, todelete_full_path=None, enable_hns=False, enable_background_delete=False, read_only=False, enable_async_checkpointing=True, async_options=None, multiprocessing_options=MultiprocessingOptions(primary_host=0, active_processes=None, barrier_sync_key_prefix=None), should_save_fn=None, file_options=FileOptions(path_permission_mode=None), save_root_metadata=True, temporary_path_class=None, save_decision_policy=None, preservation_policy=None, prevent_write_metrics=False, enable_should_save_is_saving_in_progress_check=True, enable_per_process_directory_creation=False), root_directory=gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_post_train_fixes_20260422_123915/pt_sft_nnx_xpk_feat_nnx_post_train_fixes_20260422_123915_01_sft_smoke/checkpoints: <orbax.checkpoint.checkpoint_manager.CheckpointManager object at 0x793a64223980>
I0422 12:58:19.052906 133422147196736 peft_trainer.py:584] Training with mesh: Mesh('diloco': 1, 'data': 4, 'stage': 1, 'fsdp': 8, 'fsdp_transpose': 1, 'sequence': 1, 'context': 1, 'context_autoregressive': 1, 'tensor': 1, 'tensor_transpose': 1, 'tensor_sequence': 1, 'expert': 1, 'autoregressive': 1, axis_types=(Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto))
I0422 12:58:19.480984 133422147196736 peft_trainer.py:594] Compiled train_step cache size: 0
I0422 12:58:19.485069 133422147196736 metric_logger.py:301] number parameters: 0.000 billion
I0422 12:58:19.487508 133268733159168 grain_pool.py:367] Grain pool will use 1 processes.
I0422 12:58:19.514011 133268733159168 grain_pool.py:440] Grain pool will start child processes.
Per train step:
 Total TFLOPs: 0.00 
 split as 54.29% learnable weight flops and 45.71% attention flops
I0422 12:58:19.519266 133268733159168 grain_pool.py:448] Grain pool started all child processes.
2026-04-22 12:58:23.535922: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-04-22 12:58:23.581019: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-04-22 12:58:24.750117: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-04-22 12:58:28.617521: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)
Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/deps/src/maxtext/trainers/post_train/sft/train_sft.py", line 281, in <module>
    app.run(main)
  File "/usr/local/lib/python3.12/site-packages/absl/app.py", line 316, in run
    _run_main(main, args)
  File "/usr/local/lib/python3.12/site-packages/absl/app.py", line 261, in _run_main
    sys.exit(main(argv))
             ^^^^^^^^^^
  File "/deps/src/maxtext/trainers/post_train/sft/train_sft.py", line 277, in main
    train(mt_config, goodput_recorder)
  File "/deps/src/maxtext/trainers/post_train/sft/train_sft.py", line 254, in train
    trainer = train_model(mt_config, trainer, mesh)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/deps/src/maxtext/trainers/post_train/sft/train_sft.py", line 240, in train_model
    trainer.train(trainer.data_hooks.train_data_iterator, trainer.data_hooks.eval_data_iterator)
  File "/usr/local/lib/python3.12/site-packages/tunix/sft/peft_trainer.py", line 692, in train
    train_loss, aux, grad_norm = train_step(train_example)
    ^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: not enough values to unpack (expected 3, got 2)
I0422 12:58:37.634279 133268733159168 grain_pool.py:542] Grain pool is exiting.
I0422 12:58:37.634390 133268733159168 grain_pool.py:547] Shutting down multiprocessing system.
/usr/local/lib/python3.12/multiprocessing/resource_tracker.py:279: UserWarning: resource_tracker: There appear to be 27 leaked semaphore objects to clean up at shutdown
  warnings.warn('resource_tracker: There appear to be %d '
/usr/local/lib/python3.12/multiprocessing/resource_tracker.py:279: UserWarning: resource_tracker: There appear to be 20 leaked shared_memory objects to clean up at shutdown
  warnings.warn('resource_tracker: There appear to be %d '
XPK End: Wed Apr 22 12:58:44 UTC 2026
EXIT_CODE=1