MaxView

‹ —Case: 01_sft_smoke07_distill_smoke ›

Metrics: Linen vs NNX  ·  feat/nnx-post-train-fixes

MetricLinen  d8cde296bNNX  d8cde296bDiff (NNX − Linen)
Parameters0.000 billion0.000 billion
Final loss5.84505.8300-0.015
TFLOP/s0.0000.0000
Tok/s212.7210.5-2.167
Avg s/step4.7214.333-0.388
Memory %0.030.030
JAX0.8.30.8.3

Diff = NNX value − Linen value. Green = NNX improved. Red = NNX regressed.

Linen  ·  d8cde296b  ·  feat_nnx_post_train_fixes_20260416_210550  ·  full log
2026-04-16 21:05:58.072298: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)
I0416 21:05:58.607383 131141424139392 max_utils.py:238] Skipping jax distributed system due to skip_jax_distributed_system=True flag.
I0416 21:06:27.700041 131141424139392 max_utils.py:800] System Information: Jax Version: 0.8.3
I0416 21:06:27.700183 131141424139392 max_utils.py:801] System Information: Jaxlib Version: 0.8.3
I0416 21:06:27.700224 131141424139392 max_utils.py:802] System Information: Jax Backend: PJRT C API
TFRT TPU v6 lite
Built on Dec 15 2025 14:03:46 (1765836226) cl/844590465
I0416 21:06:27.704583 131141424139392 maxtext_utils.py:1687] Num_devices: 8, shape (1, 1, 1, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1)
I0416 21:06:27.896583 131141424139392 maxtext_utils.py:1687] Num_devices: 8, shape (1, 1, 1, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1)
I0416 21:06:28.928112 131141424139392 max_utils.py:194] tensorboardX not available; using no-op SummaryWriter.
I0416 21:06:28.953595 131141424139392 config.py:112] TensorFlow version 2.20.0 available.
I0416 21:06:28.953993 131141424139392 config.py:125] JAX version 0.8.3 available.
E0416 21:06:31.870134 131141424139392 packing.py:209] PackAndBatchOperation is deprecated. Please use lazy_dataset.FirstFitPackIterDataset instead.
I0416 21:06:32.236766 131141424139392 pytree_checkpoint_handler.py:592] save_device_host_concurrent_bytes=None
I0416 21:06:32.237290 131141424139392 base_pytree_checkpoint_handler.py:441] Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=False, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x77451c39c920>, enable_pinned_host_transfer=False, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB)
I0416 21:06:32.237338 131141424139392 pytree_checkpoint_handler.py:592] save_device_host_concurrent_bytes=None
I0416 21:06:32.237371 131141424139392 base_pytree_checkpoint_handler.py:441] Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=False, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x77451c39c920>, enable_pinned_host_transfer=False, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB)
I0416 21:06:32.237428 131141424139392 checkpoint_manager.py:708] [process=0][thread=MainThread] CheckpointManager init: checkpointers=None, item_names=None, item_handlers={'model_params': <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x773e7f3d99d0>, 'optimizer_state': <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x773e81b025d0>, 'custom_metadata': <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x773e7d7ff8c0>}, handler_registry=None
I0416 21:06:32.237707 131141424139392 composite_checkpoint_handler.py:237] Deferred registration for item: "model_params". Adding handler `<orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x773e7f3d99d0>` for item "model_params" and save args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>` to `_handler_registry`.
I0416 21:06:32.237746 131141424139392 composite_checkpoint_handler.py:237] Deferred registration for item: "optimizer_state". Adding handler `<orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x773e81b025d0>` for item "optimizer_state" and save args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>` to `_handler_registry`.
I0416 21:06:32.237767 131141424139392 composite_checkpoint_handler.py:237] Deferred registration for item: "custom_metadata". Adding handler `<orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x773e7d7ff8c0>` for item "custom_metadata" and save args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>` to `_handler_registry`.
I0416 21:06:32.237785 131141424139392 composite_checkpoint_handler.py:237] Deferred registration for item: "metrics". Adding handler `<orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x773e7d7fc230>` for item "metrics" and save args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>` to `_handler_registry`.
I0416 21:06:32.237809 131141424139392 composite_checkpoint_handler.py:505] Initialized registry DefaultCheckpointHandlerRegistry({('model_params', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x773e7f3d99d0>, ('model_params', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x773e7f3d99d0>, ('optimizer_state', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x773e81b025d0>, ('optimizer_state', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x773e81b025d0>, ('custom_metadata', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x773e7d7ff8c0>, ('custom_metadata', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x773e7d7ff8c0>, ('metrics', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x773e7d7fc230>, ('metrics', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x773e7d7fc230>}).
I0416 21:06:32.238992 131141424139392 abstract_checkpointer.py:35] orbax-checkpoint version: 0.11.36
I0416 21:06:32.239040 131141424139392 async_checkpointer.py:192] [process=0][thread=MainThread] Using barrier_sync_fn: <function get_barrier_sync_fn.<locals>.<lambda> at 0x773e7d743d80> timeout: 1200 secs and primary_host=0 for async checkpoint writes
I0416 21:06:34.580789 131141424139392 checkpoint_manager.py:564] Created directory=gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260416_210550/pt_sft_linen_feat_nnx_post_train_fixes_20260416_210550_01_sft_smoke/checkpoints
I0416 21:06:36.940611 131141424139392 checkpoint_manager.py:1812] Found 0 checkpoint steps in gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260416_210550/pt_sft_linen_feat_nnx_post_train_fixes_20260416_210550_01_sft_smoke/checkpoints
I0416 21:06:36.940864 131141424139392 checkpoint_manager.py:929] [process=0][thread=MainThread] CheckpointManager created,  primary_host=0, CheckpointManagerOptions=CheckpointManagerOptions(save_interval_steps=10000, max_to_keep=None, keep_time_interval=None, keep_period=None, should_keep_fn=None, best_fn=None, best_mode='max', keep_checkpoints_without_metrics=True, step_prefix=None, step_format_fixed_length=None, step_name_format=None, create=True, cleanup_tmp_directories=False, save_on_steps=frozenset(), single_host_load_and_broadcast=False, todelete_subdir=None, todelete_full_path=None, enable_background_delete=False, read_only=False, enable_async_checkpointing=True, async_options=None, multiprocessing_options=MultiprocessingOptions(primary_host=0, active_processes=None, barrier_sync_key_prefix=None), should_save_fn=None, file_options=FileOptions(path_permission_mode=None), save_root_metadata=True, temporary_path_class=None, save_decision_policy=None, preservation_policy=None, prevent_write_metrics=False, enable_should_save_is_saving_in_progress_check=True, enable_per_process_directory_creation=False, lightweight_initialize=False), root_directory=gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260416_210550/pt_sft_linen_feat_nnx_post_train_fixes_20260416_210550_01_sft_smoke/checkpoints: <orbax.checkpoint.checkpoint_manager.CheckpointManager object at 0x773e7d7fc080>
I0416 21:06:38.989873 131141424139392 metrics_logger.py:64] WandbBackend skipped: 'wandb' library not installed.
I0416 21:06:38.990192 131141424139392 peft_trainer.py:590] Training with mesh: Mesh('diloco': 1, 'data': 1, 'stage': 1, 'fsdp': 8, 'fsdp_transpose': 1, 'sequence': 1, 'context': 1, 'context_autoregressive': 1, 'tensor': 1, 'tensor_transpose': 1, 'tensor_sequence': 1, 'expert': 1, 'autoregressive': 1, axis_types=(Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto))
I0416 21:06:39.431640 131141424139392 peft_trainer.py:600] Compiled train_step cache size: 0
[DECOUPLED NO-OP] gcs_storage: using stubs.
[DECOUPLED NO-OP] mldiagnostics: using stub.
[DECOUPLED NO-OP] mldiagnostics: using stub.
[DECOUPLED NO-OP] mldiagnostics: using stub.
[DECOUPLED NO-OP] workload_monitor: using stub.
[DECOUPLED NO-OP] vertex_tensorboard: using stub.

Training:   0%|          | 0/5 [00:00<?, ?step/s]I0416 21:06:39.433747 131141424139392 metric_logger.py:289] number parameters: 0.000 billion
Per train step:
 Total TFLOPs: 0.00 
 split as 54.29% learnable weight flops and 45.71% attention flops
2026-04-16 21:06:42.610804: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-04-16 21:06:42.649431: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-04-16 21:06:43.658852: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-04-16 21:06:46.061876: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)
I0416 21:06:56.713685 131141424139392 checkpoint_manager.py:2009] [process=0][thread=MainThread][wait_until_finished] No Save Finalize thread to wait for. Returning.
I0416 21:06:56.713886 131141424139392 checkpoint_manager.py:1512] [process=0] Saving checkpoint at step 1
I0416 21:06:56.713952 131141424139392 event_tracking.py:70] [process=0] [async] Started save checkpoint @ gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260416_210550/pt_sft_linen_feat_nnx_post_train_fixes_20260416_210550_01_sft_smoke/checkpoints/1.
I0416 21:06:56.816035 131141424139392 signaling_client.py:373] Using ThreadSafeKeyValueSignalingClient
I0416 21:06:56.834959 131141424139392 jax_array_handlers.py:360] Scheduling D2H of 22 prioritized jax.Array.
I0416 21:06:56.835082 131141424139392 replica_slices.py:424] Transferring arrays to host memory with options: use_replica_parallel=True, min_slice_bytes_for_replica_parallel=None, max_replicas_for_replica_parallel=None, enable_pinned_host_transfer=False
I0416 21:06:56.906855 131031418013248 atomicity.py:140] Creating tmp directory gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260416_210550/pt_sft_linen_feat_nnx_post_train_fixes_20260416_210550_01_sft_smoke/checkpoints/1
I0416 21:06:57.625192 131031407527488 atomicity.py:140] Creating tmp directory gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260416_210550/pt_sft_linen_feat_nnx_post_train_fixes_20260416_210550_01_sft_smoke/checkpoints/1/model_params
I0416 21:06:57.631219 131031407527488 atomicity.py:140] Creating tmp directory gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260416_210550/pt_sft_linen_feat_nnx_post_train_fixes_20260416_210550_01_sft_smoke/checkpoints/1/optimizer_state
I0416 21:06:57.694640 131141424139392 base_pytree_checkpoint_handler.py:154] [process=0][thread=MainThread] Initiated "orbax.checkpoint._src.serialization.jax_array_handlers.ArrayHandler".serialize. Time taken: 0.860302s
I0416 21:06:57.695139 131141424139392 jax_array_handlers.py:360] Scheduling D2H of 52 prioritized jax.Array.
I0416 21:06:57.695184 131141424139392 replica_slices.py:424] Transferring arrays to host memory with options: use_replica_parallel=True, min_slice_bytes_for_replica_parallel=None, max_replicas_for_replica_parallel=None, enable_pinned_host_transfer=False
I0416 21:06:57.727409 131141424139392 base_pytree_checkpoint_handler.py:154] [process=0][thread=MainThread] Initiated "orbax.checkpoint._src.serialization.jax_array_handlers.ArrayHandler".serialize. Time taken: 0.032664s
I0416 21:06:57.727786 131141424139392 base_pytree_checkpoint_handler.py:130] [process=0] /jax/orbax/write/blocking_gbytes_per_sec: 225.179 KiB/s (total gbytes: 205.0 KiB) (time elapsed: 0.9102644920349121 s) (per-host)
I0416 21:06:57.727900 131141424139392 base_pytree_checkpoint_handler.py:768] [process=0][thread=MainThread] Initiated Pytree async_save. Time taken: 0.910393s (batch_requests_ready=0.002069s, total_serialization_initiated=0.907947s, others=0.000377s)
I0416 21:06:57.728331 131141424139392 base_pytree_checkpoint_handler.py:130] [process=0] /jax/orbax/write/blocking_gbytes_per_sec: 680.282 KiB/s (total gbytes: 614.8 KiB) (time elapsed: 0.9037835597991943 s) (per-host)
I0416 21:06:57.728412 131141424139392 base_pytree_checkpoint_handler.py:768] [process=0][thread=MainThread] Initiated Pytree async_save. Time taken: 0.903876s (batch_requests_ready=0.002553s, total_serialization_initiated=0.900864s, others=0.000459s)
I0416 21:06:57.728477 131141424139392 composite_checkpoint_handler.py:715] [process=0][thread=MainThread] Initiated CompositeCheckpointHandler.async_save. Time taken: 0.911857s (all_items=0.000020s, per_item={'model_params': '0.00001621', 'optimizer_state': '0.00000405'}, temp_paths=0.911836)
I0416 21:06:57.729208 131141424139392 event_tracking.py:125] [process=0] [async] Finished blocking save in 1.02 seconds. Continuing save @ gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260416_210550/pt_sft_linen_feat_nnx_post_train_fixes_20260416_210550_01_sft_smoke/checkpoints/1.
I0416 21:06:57.729418 131031313155648 async_checkpointer.py:76] [process=0][thread=async_save] Background save thread started. Deadline for this save operation is 2026-04-16 21:26:57.729383
I0416 21:06:57.729664 131141424139392 checkpoint_manager.py:1560] [process=0][thread=MainThread][step=1] Starting CheckpointManager Save Finalize thread=save_finalize
I0416 21:06:57.730023 131141424139392 standard_logger.py:34] {'step': 1, 'event_type': 'save', 'directory': 'gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260416_210550/pt_sft_linen_feat_nnx_post_train_fixes_20260416_210550_01_sft_smoke/checkpoints', 'reached_preemption': False, 'preemption_received_at': None, 'synchronous': False, 'wait_for_prev_start_time': 1776373616.713666, 'wait_for_prev_duration_secs': 9.775161743164062e-05, 'time_between_consecutive_saves_sec': None, 'checkpointer_blocking_start_time': 1776373616.7139091, 'checkpointer_blocking_duration_secs': 1.0156042575836182, 'get_old_steps_start_time': 1776373617.7295296, 'get_old_steps_duration_secs': 9.703636169433594e-05, 'checkpoint_manager_blocking_start_time': 1776373616.7135963, 'checkpoint_manager_blocking_duration_secs': 1.0164036750793457}
I0416 21:06:57.730171 131141424139392 profiler.py:85] Starting JAX profiler at step 1.
I0416 21:06:57.858526 131031376070208 checkpoint.py:188] Wrote Metadata={'item_handlers': None, 'metrics': {}, 'performance_metrics': {}, 'init_timestamp_nsecs': 1776373617531409892, 'commit_timestamp_nsecs': None, 'custom_metadata': {}}, json={"item_handlers": null, "metrics": {}, "performance_metrics": {}, "init_timestamp_nsecs": 1776373617531409892, "commit_timestamp_nsecs": null, "custom_metadata": {}} to gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260416_210550/pt_sft_linen_feat_nnx_post_train_fixes_20260416_210550_01_sft_smoke/checkpoints/1/_CHECKPOINT_METADATA
I0416 21:06:57.859720 131031428499008 async_checkpointer.py:280] [process=0][thread=save_finalize] Waiting for background save thread=async_save.
I0416 21:06:58.049573 131141424139392 peft_trainer.py:485] Train step 1 training loss: 6.124308  - training perplexity: 456.828308

Training:   0%|          | 0/5 [00:18<?, ?step/s, _train_loss=6.12, _train_perplexity=457, _train_steps_per_sec=0.058]
Training:  20%|██        | 1/5 [00:18<01:14, 18.62s/step, _train_loss=6.12, _train_perplexity=457, _train_steps_per_sec=0.058]I0416 21:06:58.050524 131141424139392 max_utils.py:750] 
Memstats: After params initialized:
I0416 21:06:58.050608 131141424139392 max_utils.py:756] 	Using (GB) 0.01 / 31.25 (0.032000%) on TPU_0(process=0,(0,0,0,0))
I0416 21:06:58.050659 131141424139392 max_utils.py:756] 	Using (GB) 0.01 / 31.25 (0.032000%) on TPU_1(process=0,(1,0,0,0))
I0416 21:06:58.050703 131141424139392 max_utils.py:756] 	Using (GB) 0.01 / 31.25 (0.032000%) on TPU_2(process=0,(0,1,0,0))
I0416 21:06:58.050742 131141424139392 max_utils.py:756] 	Using (GB) 0.01 / 31.25 (0.032000%) on TPU_3(process=0,(1,1,0,0))
I0416 21:06:58.050779 131141424139392 max_utils.py:756] 	Using (GB) 0.01 / 31.25 (0.032000%) on TPU_4(process=0,(0,2,0,0))
I0416 21:06:58.050815 131141424139392 max_utils.py:756] 	Using (GB) 0.01 / 31.25 (0.032000%) on TPU_5(process=0,(1,2,0,0))
I0416 21:06:58.050850 131141424139392 max_utils.py:756] 	Using (GB) 0.01 / 31.25 (0.032000%) on TPU_6(process=0,(0,3,0,0))
I0416 21:06:58.050885 131141424139392 max_utils.py:756] 	Using (GB) 0.01 / 31.25 (0.032000%) on TPU_7(process=0,(1,3,0,0))
I0416 21:06:58.184929 131141424139392 metric_logger.py:185] completed step: 1, seconds: 18.617, TFLOP/s/device: 0.000, Tokens/s/device: 55.005, total_weights: 6826, loss: 6.124
I0416 21:06:58.195772 131141424139392 peft_trainer.py:485] Train step 2 training loss: 6.129846  - training perplexity: 459.365448

Training:  20%|██        | 1/5 [00:18<01:14, 18.62s/step, _train_loss=6.13, _train_perplexity=458, _train_steps_per_sec=0.404]
Training:  40%|████      | 2/5 [00:18<00:23,  7.75s/step, _train_loss=6.13, _train_perplexity=458, _train_steps_per_sec=0.404]I0416 21:06:58.197339 131141424139392 metric_logger.py:185] completed step: 2, seconds: 0.146, TFLOP/s/device: 0.002, Tokens/s/device: 7035.691, total_weights: 4636, loss: 6.130
I0416 21:06:58.211188 131141424139392 peft_trainer.py:485] Train step 3 training loss: 5.572245  - training perplexity: 263.023956

Training:  40%|████      | 2/5 [00:18<00:23,  7.75s/step, _train_loss=5.94, _train_perplexity=381, _train_steps_per_sec=2.5]  I0416 21:06:58.212568 131141424139392 metric_logger.py:185] completed step: 3, seconds: 0.015, TFLOP/s/device: 0.014, Tokens/s/device: 66797.783, total_weights: 5886, loss: 5.572
I0416 21:06:58.222007 131141424139392 peft_trainer.py:485] Train step 4 training loss: 5.824473  - training perplexity: 338.482697

Training:  60%|██████    | 3/5 [00:18<00:15,  7.75s/step, _train_loss=5.91, _train_perplexity=370, _train_steps_per_sec=18.1]I0416 21:06:58.223372 131141424139392 metric_logger.py:185] completed step: 4, seconds: 0.011, TFLOP/s/device: 0.020, Tokens/s/device: 94734.379, total_weights: 4990, loss: 5.824
I0416 21:06:58.223663 131141424139392 profiler.py:113] Stopping JAX profiler at step 5.
I0416 21:06:58.316122 4156963 google_auth_provider.cc:149] Using credentials at ~/.config/gcloud/application_default_credentials.json
I0416 21:06:58.316198 4156963 google_auth_provider.cc:156] Using OAuth2 AuthProvider
I0416 21:06:59.120499 131031365584448 array_metadata_store.py:203] [process=0][thread=array_type_handler] Wrote 22 array_metadata.ArrayMetadata to gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260416_210550/pt_sft_linen_feat_nnx_post_train_fixes_20260416_210550_01_sft_smoke/checkpoints/1/model_params/array_metadatas/process_0
I0416 21:06:59.126204 131031344612928 array_metadata_store.py:203] [process=0][thread=array_type_handler] Wrote 52 array_metadata.ArrayMetadata to gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260416_210550/pt_sft_linen_feat_nnx_post_train_fixes_20260416_210550_01_sft_smoke/checkpoints/1/optimizer_state/array_metadatas/process_0
I0416 21:07:00.805595 131031334127168 base_pytree_checkpoint_handler.py:1282] [process=0][thread=write_metadata_after_commits] Commit + Array metadata written. Time taken: 2.493249s (commit=2.042352s, array_metadata_write=0.450896s)
I0416 21:07:00.806669 131031313155648 base_pytree_checkpoint_handler.py:130] [process=0] /jax/orbax/write/gbytes_per_sec: 51.383 KiB/s (total gbytes: 205.0 KiB) (time elapsed: 3.9891133308410645 s) (per-host)
I0416 21:07:00.875975 131031323641408 base_pytree_checkpoint_handler.py:1282] [process=0][thread=write_metadata_after_commits] Commit + Array metadata written. Time taken: 2.555255s (commit=2.130029s, array_metadata_write=0.425225s)
I0416 21:07:00.877107 131031313155648 base_pytree_checkpoint_handler.py:130] [process=0] /jax/orbax/write/gbytes_per_sec: 151.715 KiB/s (total gbytes: 614.8 KiB) (time elapsed: 4.052513837814331 s) (per-host)
I0416 21:07:00.877363 131031313155648 async_checkpointer.py:90] [process=0][thread=async_save] 4 Handler Commit operations completed. Time taken: 3.147656s.
I0416 21:07:01.076204 131031313155648 checkpoint.py:228] Read Metadata={'item_handlers': None, 'metrics': {}, 'performance_metrics': {}, 'init_timestamp_nsecs': 1776373617531409892, 'commit_timestamp_nsecs': None, 'custom_metadata': {}} from gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260416_210550/pt_sft_linen_feat_nnx_post_train_fixes_20260416_210550_01_sft_smoke/checkpoints/1/_CHECKPOINT_METADATA
I0416 21:07:01.259173 131031313155648 array_metadata_store.py:367] [process=0][thread=async_save] Skipped cross-host ArrayMetadata validation because only one process is found: process_index=0.
I0416 21:07:01.482818 131031376070208 checkpoint.py:247] Updated Metadata={'item_handlers': {'model_params': 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler', 'optimizer_state': 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler'}, 'metrics': {}, 'performance_metrics': {}, 'init_timestamp_nsecs': 1776373617531409892, 'commit_timestamp_nsecs': None, 'custom_metadata': {}} to gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260416_210550/pt_sft_linen_feat_nnx_post_train_fixes_20260416_210550_01_sft_smoke/checkpoints/1/_CHECKPOINT_METADATA
I0416 21:07:01.723141 131031313155648 base_pytree_checkpoint_handler.py:1406] [process=0][thread=async_save] Pytree save finalize (merge_ocdbt + ArrayMetadata validation) completed. Time taken: 0.606190s. use_zarr3=False, enable_post_merge_validation=True, directory=gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260416_210550/pt_sft_linen_feat_nnx_post_train_fixes_20260416_210550_01_sft_smoke/checkpoints/1/model_params
I0416 21:07:01.723968 131031313155648 atomicity.py:666] Finalizing gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260416_210550/pt_sft_linen_feat_nnx_post_train_fixes_20260416_210550_01_sft_smoke/checkpoints/1/model_params
I0416 21:07:02.147941 131031313155648 array_metadata_store.py:367] [process=0][thread=async_save] Skipped cross-host ArrayMetadata validation because only one process is found: process_index=0.
I0416 21:07:02.599188 131031313155648 base_pytree_checkpoint_handler.py:1406] [process=0][thread=async_save] Pytree save finalize (merge_ocdbt + ArrayMetadata validation) completed. Time taken: 0.596655s. use_zarr3=False, enable_post_merge_validation=True, directory=gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260416_210550/pt_sft_linen_feat_nnx_post_train_fixes_20260416_210550_01_sft_smoke/checkpoints/1/optimizer_state
I0416 21:07:02.600133 131031313155648 atomicity.py:666] Finalizing gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260416_210550/pt_sft_linen_feat_nnx_post_train_fixes_20260416_210550_01_sft_smoke/checkpoints/1/optimizer_state
I0416 21:07:02.893655 131031313155648 atomicity.py:666] Finalizing gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260416_210550/pt_sft_linen_feat_nnx_post_train_fixes_20260416_210550_01_sft_smoke/checkpoints/1
I0416 21:07:03.035638 131141424139392 utils.py:86] Train loop finished in: 23.6012 seconds
I0416 21:07:03.036417 131141424139392 peft_trainer.py:485] Train step 5 training loss: 5.844853  - training perplexity: 345.451904

Training:  80%|████████  | 4/5 [00:23<00:07,  7.75s/step, _train_loss=5.9, _train_perplexity=365, _train_steps_per_sec=32.8] 
Training: 100%|██████████| 5/5 [00:23<00:00,  3.36s/step, _train_loss=5.9, _train_perplexity=365, _train_steps_per_sec=32.8]I0416 21:07:03.038115 131141424139392 metric_logger.py:185] completed step: 5, seconds: 4.814, TFLOP/s/device: 0.000, Tokens/s/device: 212.701, total_weights: 4264, loss: 5.845
I0416 21:07:03.040626 131141424139392 checkpoint_manager.py:2020] [process=0][thread=MainThread][step=1][wait_until_finished] Waiting for Save Finalize thread (save_finalize) to complete.
I0416 21:07:03.635948 131031313155648 atomicity.py:847] [process=0][thread=async_save] Finished saving checkpoint (finalized tmp dir) to `gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260416_210550/pt_sft_linen_feat_nnx_post_train_fixes_20260416_210550_01_sft_smoke/checkpoints/1`.
I0416 21:07:03.636744 131031313155648 event_tracking.py:138] [process=0] [async] Finished save (blocking + background) in 6.92 seconds @ gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260416_210550/pt_sft_linen_feat_nnx_post_train_fixes_20260416_210550_01_sft_smoke/checkpoints/1
I0416 21:07:03.636819 131031313155648 async_checkpointer.py:160] [process=0][thread=async_save] Background save thread done. Time taken: 5.907113s.
I0416 21:07:03.636979 131031428499008 async_checkpointer.py:288] [process=0][thread=save_finalize] Done with waiting for background save thread=async_save.
I0416 21:07:03.637105 131031428499008 async_checkpointer.py:298] [process=0][thread=save_finalize] No errors found in background save thread=async_save.
I0416 21:07:03.637160 131031428499008 checkpoint_manager.py:2137] [process=0][thread=save_finalize][step=1] CheckpointManager Save Finalize is syncing with other hosts...
I0416 21:07:03.637205 131031428499008 checkpoint_manager.py:2146] [process=0][thread=save_finalize][step=1] CheckpointManager Save Finalize is done on all hosts.
I0416 21:07:03.637351 131141424139392 checkpoint_manager.py:2032] [process=0][thread=MainThread][step=1][wait_until_finished] Done waiting for Save Finalize thread (save_finalize) running at step=1.
I0416 21:07:03.637683 131141424139392 checkpoint_manager.py:1512] [process=0] Saving checkpoint at step 5
I0416 21:07:03.637762 131141424139392 event_tracking.py:70] [process=0] [async] Started save checkpoint @ gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260416_210550/pt_sft_linen_feat_nnx_post_train_fixes_20260416_210550_01_sft_smoke/checkpoints/5.
I0416 21:07:03.740402 131141424139392 jax_array_handlers.py:360] Scheduling D2H of 22 prioritized jax.Array.
I0416 21:07:03.740522 131141424139392 replica_slices.py:424] Transferring arrays to host memory with options: use_replica_parallel=True, min_slice_bytes_for_replica_parallel=None, max_replicas_for_replica_parallel=None, enable_pinned_host_transfer=False
I0416 21:07:03.752720 131141424139392 base_pytree_checkpoint_handler.py:154] [process=0][thread=MainThread] Initiated "orbax.checkpoint._src.serialization.jax_array_handlers.ArrayHandler".serialize. Time taken: 0.013043s
I0416 21:07:03.753026 131141424139392 jax_array_handlers.py:360] Scheduling D2H of 52 prioritized jax.Array.
I0416 21:07:03.753084 131141424139392 replica_slices.py:424] Transferring arrays to host memory with options: use_replica_parallel=True, min_slice_bytes_for_replica_parallel=None, max_replicas_for_replica_parallel=None, enable_pinned_host_transfer=False
I0416 21:07:03.784501 131141424139392 base_pytree_checkpoint_handler.py:154] [process=0][thread=MainThread] Initiated "orbax.checkpoint._src.serialization.jax_array_handlers.ArrayHandler".serialize. Time taken: 0.031715s
I0416 21:07:03.784797 131141424139392 base_pytree_checkpoint_handler.py:130] [process=0] /jax/orbax/write/blocking_gbytes_per_sec: 3.317 MiB/s (total gbytes: 205.0 KiB) (time elapsed: 0.06035447120666504 s) (per-host)
I0416 21:07:03.784916 131141424139392 base_pytree_checkpoint_handler.py:768] [process=0][thread=MainThread] Initiated Pytree async_save. Time taken: 0.060488s (batch_requests_ready=0.001481s, total_serialization_initiated=0.058696s, others=0.000311s)
I0416 21:07:03.785264 131141424139392 base_pytree_checkpoint_handler.py:130] [process=0] /jax/orbax/write/blocking_gbytes_per_sec: 10.898 MiB/s (total gbytes: 614.8 KiB) (time elapsed: 0.05509686470031738 s) (per-host)
I0416 21:07:03.785342 131141424139392 base_pytree_checkpoint_handler.py:768] [process=0][thread=MainThread] Initiated Pytree async_save. Time taken: 0.055185s (batch_requests_ready=0.002453s, total_serialization_initiated=0.052355s, others=0.000377s)
I0416 21:07:03.785403 131141424139392 composite_checkpoint_handler.py:715] [process=0][thread=MainThread] Initiated CompositeCheckpointHandler.async_save. Time taken: 0.061501s (all_items=0.000013s, per_item={'model_params': '0.00001025', 'optimizer_state': '0.00000262'}, temp_paths=0.061488)
I0416 21:07:03.786124 131141424139392 event_tracking.py:125] [process=0] [async] Finished blocking save in 0.15 seconds. Continuing save @ gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260416_210550/pt_sft_linen_feat_nnx_post_train_fixes_20260416_210550_01_sft_smoke/checkpoints/5.
I0416 21:07:03.786350 131031227172416 async_checkpointer.py:76] [process=0][thread=async_save] Background save thread started. Deadline for this save operation is 2026-04-16 21:27:03.786315
I0416 21:07:03.786589 131141424139392 checkpoint_manager.py:1560] [process=0][thread=MainThread][step=5] Starting CheckpointManager Save Finalize thread=save_finalize
I0416 21:07:03.786915 131031153772096 async_checkpointer.py:280] [process=0][thread=save_finalize] Waiting for background save thread=async_save.
I0416 21:07:03.787040 131141424139392 standard_logger.py:34] {'step': 5, 'event_type': 'save', 'directory': 'gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260416_210550/pt_sft_linen_feat_nnx_post_train_fixes_20260416_210550_01_sft_smoke/checkpoints', 'reached_preemption': False, 'preemption_received_at': None, 'synchronous': False, 'wait_for_prev_start_time': 1776373623.0406003, 'wait_for_prev_duration_secs': 0.59686279296875, 'time_between_consecutive_saves_sec': None, 'checkpointer_blocking_start_time': 1776373623.6377094, 'checkpointer_blocking_duration_secs': 0.14873147010803223, 'get_old_steps_start_time': 1776373623.7864592, 'get_old_steps_duration_secs': 9.417533874511719e-05, 'checkpoint_manager_blocking_start_time': 1776373623.0405579, 'checkpoint_manager_blocking_duration_secs': 0.7464585304260254}
I0416 21:07:03.787205 131141424139392 checkpoint_manager.py:2020] [process=0][thread=MainThread][step=5][wait_until_finished] Waiting for Save Finalize thread (save_finalize) to complete.
I0416 21:07:03.816087 131031428499008 atomicity.py:140] Creating tmp directory gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260416_210550/pt_sft_linen_feat_nnx_post_train_fixes_20260416_210550_01_sft_smoke/checkpoints/5
I0416 21:07:04.514771 131031323641408 atomicity.py:140] Creating tmp directory gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260416_210550/pt_sft_linen_feat_nnx_post_train_fixes_20260416_210550_01_sft_smoke/checkpoints/5/model_params
I0416 21:07:04.518631 131031323641408 atomicity.py:140] Creating tmp directory gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260416_210550/pt_sft_linen_feat_nnx_post_train_fixes_20260416_210550_01_sft_smoke/checkpoints/5/optimizer_state
I0416 21:07:05.790271 131031365584448 array_metadata_store.py:203] [process=0][thread=array_type_handler] Wrote 22 array_metadata.ArrayMetadata to gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260416_210550/pt_sft_linen_feat_nnx_post_train_fixes_20260416_210550_01_sft_smoke/checkpoints/5/model_params/array_metadatas/process_0
I0416 21:07:05.792644 131031334127168 array_metadata_store.py:203] [process=0][thread=array_type_handler] Wrote 52 array_metadata.ArrayMetadata to gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260416_210550/pt_sft_linen_feat_nnx_post_train_fixes_20260416_210550_01_sft_smoke/checkpoints/5/optimizer_state/array_metadatas/process_0
I0416 21:07:07.334915 131031281698368 base_pytree_checkpoint_handler.py:1282] [process=0][thread=write_metadata_after_commits] Commit + Array metadata written. Time taken: 2.326957s (commit=1.907889s, array_metadata_write=0.419068s)
I0416 21:07:07.336055 131031227172416 base_pytree_checkpoint_handler.py:130] [process=0] /jax/orbax/write/gbytes_per_sec: 56.755 KiB/s (total gbytes: 205.0 KiB) (time elapsed: 3.611562490463257 s) (per-host)
I0416 21:07:07.369688 131031237658176 base_pytree_checkpoint_handler.py:1282] [process=0][thread=write_metadata_after_commits] Commit + Array metadata written. Time taken: 2.362172s (commit=1.932085s, array_metadata_write=0.430088s)
I0416 21:07:07.676827 131031227172416 base_pytree_checkpoint_handler.py:130] [process=0] /jax/orbax/write/gbytes_per_sec: 155.786 KiB/s (total gbytes: 614.8 KiB) (time elapsed: 3.9466214179992676 s) (per-host)
I0416 21:07:07.677260 131031227172416 async_checkpointer.py:90] [process=0][thread=async_save] 4 Handler Commit operations completed. Time taken: 3.890627s.
I0416 21:07:08.063488 131031227172416 array_metadata_store.py:367] [process=0][thread=async_save] Skipped cross-host ArrayMetadata validation because only one process is found: process_index=0.
I0416 21:07:08.507838 131031227172416 base_pytree_checkpoint_handler.py:1406] [process=0][thread=async_save] Pytree save finalize (merge_ocdbt + ArrayMetadata validation) completed. Time taken: 0.591783s. use_zarr3=False, enable_post_merge_validation=True, directory=gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260416_210550/pt_sft_linen_feat_nnx_post_train_fixes_20260416_210550_01_sft_smoke/checkpoints/5/model_params
I0416 21:07:08.508700 131031227172416 atomicity.py:666] Finalizing gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260416_210550/pt_sft_linen_feat_nnx_post_train_fixes_20260416_210550_01_sft_smoke/checkpoints/5/model_params
I0416 21:07:08.929417 131031227172416 array_metadata_store.py:367] [process=0][thread=async_save] Skipped cross-host ArrayMetadata validation because only one process is found: process_index=0.
I0416 21:07:09.405541 131031227172416 base_pytree_checkpoint_handler.py:1406] [process=0][thread=async_save] Pytree save finalize (merge_ocdbt + ArrayMetadata validation) completed. Time taken: 0.621231s. use_zarr3=False, enable_post_merge_validation=True, directory=gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260416_210550/pt_sft_linen_feat_nnx_post_train_fixes_20260416_210550_01_sft_smoke/checkpoints/5/optimizer_state
I0416 21:07:09.406440 131031227172416 atomicity.py:666] Finalizing gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260416_210550/pt_sft_linen_feat_nnx_post_train_fixes_20260416_210550_01_sft_smoke/checkpoints/5/optimizer_state
I0416 21:07:09.673276 131031227172416 atomicity.py:666] Finalizing gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260416_210550/pt_sft_linen_feat_nnx_post_train_fixes_20260416_210550_01_sft_smoke/checkpoints/5
I0416 21:07:10.382983 131031227172416 atomicity.py:847] [process=0][thread=async_save] Finished saving checkpoint (finalized tmp dir) to `gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260416_210550/pt_sft_linen_feat_nnx_post_train_fixes_20260416_210550_01_sft_smoke/checkpoints/5`.
I0416 21:07:10.383759 131031227172416 event_tracking.py:138] [process=0] [async] Finished save (blocking + background) in 6.75 seconds @ gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260416_210550/pt_sft_linen_feat_nnx_post_train_fixes_20260416_210550_01_sft_smoke/checkpoints/5
I0416 21:07:10.383828 131031227172416 async_checkpointer.py:160] [process=0][thread=async_save] Background save thread done. Time taken: 6.597198s.
I0416 21:07:10.383990 131031153772096 async_checkpointer.py:288] [process=0][thread=save_finalize] Done with waiting for background save thread=async_save.
I0416 21:07:10.384119 131031153772096 async_checkpointer.py:298] [process=0][thread=save_finalize] No errors found in background save thread=async_save.
I0416 21:07:10.384172 131031153772096 checkpoint_manager.py:2137] [process=0][thread=save_finalize][step=5] CheckpointManager Save Finalize is syncing with other hosts...
I0416 21:07:10.384212 131031153772096 checkpoint_manager.py:2146] [process=0][thread=save_finalize][step=5] CheckpointManager Save Finalize is done on all hosts.
I0416 21:07:10.384423 131141424139392 checkpoint_manager.py:2032] [process=0][thread=MainThread][step=5][wait_until_finished] Done waiting for Save Finalize thread (save_finalize) running at step=5.
I0416 21:07:10.384559 131141424139392 checkpoint.py:459] Closing _NonBlockingMetadataStore(enable_write=True, _write_lock=<locked _thread.RLock object owner=131141424139392 count=1 at 0x773e82bb7440>, _store_impl=<orbax.checkpoint._src.metadata.checkpoint._MetadataStoreImpl object at 0x773e7d779010>, _single_thread_executor=<concurrent.futures.thread.ThreadPoolExecutor object at 0x773e7d779a60>, _write_futures=[])
I0416 21:07:10.384980 131141424139392 checkpoint.py:459] Closing _NonBlockingMetadataStore(enable_write=True, _write_lock=<locked _thread.RLock object owner=131141424139392 count=1 at 0x773e82bb7440>, _store_impl=<orbax.checkpoint._src.metadata.checkpoint._MetadataStoreImpl object at 0x773e7d779010>, _single_thread_executor=<concurrent.futures.thread.ThreadPoolExecutor object at 0x773e7d779a60>, _write_futures=[])
I0416 21:07:10.385017 131141424139392 checkpoint.py:459] Closing _NonBlockingMetadataStore(enable_write=True, _write_lock=<locked _thread.RLock object owner=131141424139392 count=1 at 0x773e82bb7440>, _store_impl=<orbax.checkpoint._src.metadata.checkpoint._MetadataStoreImpl object at 0x773e7d779010>, _single_thread_executor=<concurrent.futures.thread.ThreadPoolExecutor object at 0x773e7d779a60>, _write_futures=[])

Training: 100%|██████████| 5/5 [00:32<00:00,  6.41s/step, _train_loss=5.9, _train_perplexity=365, _train_steps_per_sec=32.8]
[DECOUPLED NO-OP] gcs_storage: using stubs.
[DECOUPLED NO-OP] mldiagnostics: using stub.
[DECOUPLED NO-OP] mldiagnostics: using stub.
[DECOUPLED NO-OP] mldiagnostics: using stub.
[DECOUPLED NO-OP] workload_monitor: using stub.
[DECOUPLED NO-OP] vertex_tensorboard: using stub.
~/.local/share/uv/python/cpython-3.12.12-linux-x86_64-gnu/lib/python3.12/multiprocessing/resource_tracker.py:279: UserWarning: resource_tracker: There appear to be 15 leaked shared_memory objects to clean up at shutdown
  warnings.warn('resource_tracker: There appear to be %d '
NNX  ·  d8cde296b  ·  feat_nnx_post_train_fixes_20260416_210550  ·  full log
2026-04-16 21:11:20.865701: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)
I0416 21:11:21.338263 125630730390656 max_utils.py:238] Skipping jax distributed system due to skip_jax_distributed_system=True flag.
I0416 21:12:18.422026 125630730390656 max_utils.py:800] System Information: Jax Version: 0.8.3
I0416 21:12:18.422156 125630730390656 max_utils.py:801] System Information: Jaxlib Version: 0.8.3
I0416 21:12:18.422191 125630730390656 max_utils.py:802] System Information: Jax Backend: PJRT C API
TFRT TPU v6 lite
Built on Dec 15 2025 14:03:46 (1765836226) cl/844590465
I0416 21:12:18.424954 125630730390656 maxtext_utils.py:1687] Num_devices: 8, shape (1, 1, 1, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1)
I0416 21:12:18.506169 125630730390656 maxtext_utils.py:1687] Num_devices: 8, shape (1, 1, 1, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1)
I0416 21:12:18.589121 125630730390656 maxtext_utils.py:1687] Num_devices: 8, shape (1, 1, 1, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1)
I0416 21:12:19.550992 125630730390656 max_utils.py:194] tensorboardX not available; using no-op SummaryWriter.
I0416 21:12:19.572965 125630730390656 config.py:112] TensorFlow version 2.20.0 available.
I0416 21:12:19.573403 125630730390656 config.py:125] JAX version 0.8.3 available.
E0416 21:12:21.213800 125630730390656 packing.py:209] PackAndBatchOperation is deprecated. Please use lazy_dataset.FirstFitPackIterDataset instead.
I0416 21:12:21.557279 125630730390656 pytree_checkpoint_handler.py:592] save_device_host_concurrent_bytes=None
I0416 21:12:21.557686 125630730390656 base_pytree_checkpoint_handler.py:441] Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=False, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x72420cdf0170>, enable_pinned_host_transfer=False, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB)
I0416 21:12:21.557728 125630730390656 pytree_checkpoint_handler.py:592] save_device_host_concurrent_bytes=None
I0416 21:12:21.557760 125630730390656 base_pytree_checkpoint_handler.py:441] Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=False, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x72420cdf0170>, enable_pinned_host_transfer=False, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB)
I0416 21:12:21.557800 125630730390656 checkpoint_manager.py:708] [process=0][thread=MainThread] CheckpointManager init: checkpointers=None, item_names=None, item_handlers={'model_params': <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x723b70d01820>, 'optimizer_state': <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x723b75992300>, 'custom_metadata': <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x723b72502540>}, handler_registry=None
I0416 21:12:21.558089 125630730390656 composite_checkpoint_handler.py:237] Deferred registration for item: "model_params". Adding handler `<orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x723b70d01820>` for item "model_params" and save args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>` to `_handler_registry`.
I0416 21:12:21.558130 125630730390656 composite_checkpoint_handler.py:237] Deferred registration for item: "optimizer_state". Adding handler `<orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x723b75992300>` for item "optimizer_state" and save args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>` to `_handler_registry`.
I0416 21:12:21.558153 125630730390656 composite_checkpoint_handler.py:237] Deferred registration for item: "custom_metadata". Adding handler `<orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x723b72502540>` for item "custom_metadata" and save args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>` to `_handler_registry`.
I0416 21:12:21.558172 125630730390656 composite_checkpoint_handler.py:237] Deferred registration for item: "metrics". Adding handler `<orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x723b70d78b90>` for item "metrics" and save args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>` to `_handler_registry`.
I0416 21:12:21.558194 125630730390656 composite_checkpoint_handler.py:505] Initialized registry DefaultCheckpointHandlerRegistry({('model_params', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x723b70d01820>, ('model_params', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x723b70d01820>, ('optimizer_state', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x723b75992300>, ('optimizer_state', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x723b75992300>, ('custom_metadata', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x723b72502540>, ('custom_metadata', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x723b72502540>, ('metrics', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x723b70d78b90>, ('metrics', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x723b70d78b90>}).
I0416 21:12:21.558350 125630730390656 abstract_checkpointer.py:35] orbax-checkpoint version: 0.11.36
I0416 21:12:21.558386 125630730390656 async_checkpointer.py:192] [process=0][thread=MainThread] Using barrier_sync_fn: <function get_barrier_sync_fn.<locals>.<lambda> at 0x723b70d9a980> timeout: 1200 secs and primary_host=0 for async checkpoint writes
I0416 21:12:23.530811 125630730390656 checkpoint_manager.py:564] Created directory=gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260416_210550/pt_sft_nnx_feat_nnx_post_train_fixes_20260416_210550_01_sft_smoke/checkpoints
I0416 21:12:25.884749 125630730390656 checkpoint_manager.py:1812] Found 0 checkpoint steps in gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260416_210550/pt_sft_nnx_feat_nnx_post_train_fixes_20260416_210550_01_sft_smoke/checkpoints
I0416 21:12:25.885014 125630730390656 checkpoint_manager.py:929] [process=0][thread=MainThread] CheckpointManager created,  primary_host=0, CheckpointManagerOptions=CheckpointManagerOptions(save_interval_steps=10000, max_to_keep=None, keep_time_interval=None, keep_period=None, should_keep_fn=None, best_fn=None, best_mode='max', keep_checkpoints_without_metrics=True, step_prefix=None, step_format_fixed_length=None, step_name_format=None, create=True, cleanup_tmp_directories=False, save_on_steps=frozenset(), single_host_load_and_broadcast=False, todelete_subdir=None, todelete_full_path=None, enable_background_delete=False, read_only=False, enable_async_checkpointing=True, async_options=None, multiprocessing_options=MultiprocessingOptions(primary_host=0, active_processes=None, barrier_sync_key_prefix=None), should_save_fn=None, file_options=FileOptions(path_permission_mode=None), save_root_metadata=True, temporary_path_class=None, save_decision_policy=None, preservation_policy=None, prevent_write_metrics=False, enable_should_save_is_saving_in_progress_check=True, enable_per_process_directory_creation=False, lightweight_initialize=False), root_directory=gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260416_210550/pt_sft_nnx_feat_nnx_post_train_fixes_20260416_210550_01_sft_smoke/checkpoints: <orbax.checkpoint.checkpoint_manager.CheckpointManager object at 0x723b70d34500>
I0416 21:12:27.919018 125630730390656 metrics_logger.py:64] WandbBackend skipped: 'wandb' library not installed.
I0416 21:12:27.919302 125630730390656 peft_trainer.py:590] Training with mesh: Mesh('diloco': 1, 'data': 1, 'stage': 1, 'fsdp': 8, 'fsdp_transpose': 1, 'sequence': 1, 'context': 1, 'context_autoregressive': 1, 'tensor': 1, 'tensor_transpose': 1, 'tensor_sequence': 1, 'expert': 1, 'autoregressive': 1, axis_types=(Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto))
I0416 21:12:28.343440 125630730390656 peft_trainer.py:600] Compiled train_step cache size: 0
[DECOUPLED NO-OP] gcs_storage: using stubs.
[DECOUPLED NO-OP] mldiagnostics: using stub.
[DECOUPLED NO-OP] mldiagnostics: using stub.
[DECOUPLED NO-OP] mldiagnostics: using stub.
[DECOUPLED NO-OP] workload_monitor: using stub.
[DECOUPLED NO-OP] vertex_tensorboard: using stub.

Training:   0%|          | 0/5 [00:00<?, ?step/s]I0416 21:12:28.346850 125630730390656 metric_logger.py:289] number parameters: 0.000 billion
Per train step:
 Total TFLOPs: 0.00 
 split as 54.29% learnable weight flops and 45.71% attention flops
2026-04-16 21:12:31.523597: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-04-16 21:12:31.559899: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-04-16 21:12:32.586056: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-04-16 21:12:35.038875: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)
I0416 21:12:43.622141 125630730390656 checkpoint_manager.py:2009] [process=0][thread=MainThread][wait_until_finished] No Save Finalize thread to wait for. Returning.
I0416 21:12:43.622338 125630730390656 checkpoint_manager.py:1512] [process=0] Saving checkpoint at step 1
I0416 21:12:43.622401 125630730390656 event_tracking.py:70] [process=0] [async] Started save checkpoint @ gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260416_210550/pt_sft_nnx_feat_nnx_post_train_fixes_20260416_210550_01_sft_smoke/checkpoints/1.
I0416 21:12:43.721835 125630730390656 signaling_client.py:373] Using ThreadSafeKeyValueSignalingClient
I0416 21:12:43.745375 125630730390656 jax_array_handlers.py:360] Scheduling D2H of 46 prioritized jax.Array.
I0416 21:12:43.745437 125630730390656 replica_slices.py:424] Transferring arrays to host memory with options: use_replica_parallel=True, min_slice_bytes_for_replica_parallel=None, max_replicas_for_replica_parallel=None, enable_pinned_host_transfer=False
I0416 21:12:43.820411 125520742188608 atomicity.py:140] Creating tmp directory gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260416_210550/pt_sft_nnx_feat_nnx_post_train_fixes_20260416_210550_01_sft_smoke/checkpoints/1
I0416 21:12:44.522585 125520731702848 atomicity.py:140] Creating tmp directory gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260416_210550/pt_sft_nnx_feat_nnx_post_train_fixes_20260416_210550_01_sft_smoke/checkpoints/1/model_params
I0416 21:12:44.530431 125520731702848 atomicity.py:140] Creating tmp directory gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260416_210550/pt_sft_nnx_feat_nnx_post_train_fixes_20260416_210550_01_sft_smoke/checkpoints/1/optimizer_state
I0416 21:12:44.601477 125630730390656 base_pytree_checkpoint_handler.py:154] [process=0][thread=MainThread] Initiated "orbax.checkpoint._src.serialization.jax_array_handlers.ArrayHandler".serialize. Time taken: 0.857357s
I0416 21:12:44.601936 125630730390656 jax_array_handlers.py:360] Scheduling D2H of 52 prioritized jax.Array.
I0416 21:12:44.601984 125630730390656 replica_slices.py:424] Transferring arrays to host memory with options: use_replica_parallel=True, min_slice_bytes_for_replica_parallel=None, max_replicas_for_replica_parallel=None, enable_pinned_host_transfer=False
I0416 21:12:44.632354 125630730390656 base_pytree_checkpoint_handler.py:154] [process=0][thread=MainThread] Initiated "orbax.checkpoint._src.serialization.jax_array_handlers.ArrayHandler".serialize. Time taken: 0.030776s
I0416 21:12:44.632738 125630730390656 base_pytree_checkpoint_handler.py:130] [process=0] /jax/orbax/write/blocking_gbytes_per_sec: 225.446 KiB/s (total gbytes: 205.1 KiB) (time elapsed: 0.9098124504089355 s) (per-host)
I0416 21:12:44.632850 125630730390656 base_pytree_checkpoint_handler.py:768] [process=0][thread=MainThread] Initiated Pytree async_save. Time taken: 0.909937s (batch_requests_ready=0.002929s, total_serialization_initiated=0.906625s, others=0.000383s)
I0416 21:12:44.633179 125630730390656 base_pytree_checkpoint_handler.py:130] [process=0] /jax/orbax/write/blocking_gbytes_per_sec: 685.026 KiB/s (total gbytes: 614.8 KiB) (time elapsed: 0.8975248336791992 s) (per-host)
I0416 21:12:44.633291 125630730390656 base_pytree_checkpoint_handler.py:768] [process=0][thread=MainThread] Initiated Pytree async_save. Time taken: 0.897647s (batch_requests_ready=0.002311s, total_serialization_initiated=0.894947s, others=0.000388s)
I0416 21:12:44.633352 125630730390656 composite_checkpoint_handler.py:715] [process=0][thread=MainThread] Initiated CompositeCheckpointHandler.async_save. Time taken: 0.910980s (all_items=0.000020s, per_item={'model_params': '0.00001621', 'optimizer_state': '0.00000381'}, temp_paths=0.910960)
I0416 21:12:44.633954 125630730390656 event_tracking.py:125] [process=0] [async] Finished blocking save in 1.01 seconds. Continuing save @ gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260416_210550/pt_sft_nnx_feat_nnx_post_train_fixes_20260416_210550_01_sft_smoke/checkpoints/1.
I0416 21:12:44.634176 125520622650944 async_checkpointer.py:76] [process=0][thread=async_save] Background save thread started. Deadline for this save operation is 2026-04-16 21:32:44.634141
I0416 21:12:44.634407 125630730390656 checkpoint_manager.py:1560] [process=0][thread=MainThread][step=1] Starting CheckpointManager Save Finalize thread=save_finalize
I0416 21:12:44.634724 125630730390656 standard_logger.py:34] {'step': 1, 'event_type': 'save', 'directory': 'gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260416_210550/pt_sft_nnx_feat_nnx_post_train_fixes_20260416_210550_01_sft_smoke/checkpoints', 'reached_preemption': False, 'preemption_received_at': None, 'synchronous': False, 'wait_for_prev_start_time': 1776373963.622122, 'wait_for_prev_duration_secs': 9.250640869140625e-05, 'time_between_consecutive_saves_sec': None, 'checkpointer_blocking_start_time': 1776373963.6223605, 'checkpointer_blocking_duration_secs': 1.0119056701660156, 'get_old_steps_start_time': 1776373964.6342828, 'get_old_steps_duration_secs': 8.678436279296875e-05, 'checkpoint_manager_blocking_start_time': 1776373963.6220398, 'checkpoint_manager_blocking_duration_secs': 1.0126631259918213}
I0416 21:12:44.634864 125630730390656 profiler.py:85] Starting JAX profiler at step 1.
I0416 21:12:44.889782 125520696051264 checkpoint.py:188] Wrote Metadata={'item_handlers': None, 'metrics': {}, 'performance_metrics': {}, 'init_timestamp_nsecs': 1776373964428003108, 'commit_timestamp_nsecs': None, 'custom_metadata': {}}, json={"item_handlers": null, "metrics": {}, "performance_metrics": {}, "init_timestamp_nsecs": 1776373964428003108, "commit_timestamp_nsecs": null, "custom_metadata": {}} to gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260416_210550/pt_sft_nnx_feat_nnx_post_train_fixes_20260416_210550_01_sft_smoke/checkpoints/1/_CHECKPOINT_METADATA
I0416 21:12:44.890212 125520752674368 async_checkpointer.py:280] [process=0][thread=save_finalize] Waiting for background save thread=async_save.
I0416 21:12:44.973328 125630730390656 peft_trainer.py:485] Train step 1 training loss: 6.004404  - training perplexity: 405.209259

Training:   0%|          | 0/5 [00:16<?, ?step/s, _train_loss=6, _train_perplexity=405, _train_steps_per_sec=0.065]
Training:  20%|██        | 1/5 [00:16<01:06, 16.63s/step, _train_loss=6, _train_perplexity=405, _train_steps_per_sec=0.065]I0416 21:12:44.974439 125630730390656 max_utils.py:750] 
Memstats: After params initialized:
I0416 21:12:44.974524 125630730390656 max_utils.py:756] 	Using (GB) 0.01 / 31.25 (0.032000%) on TPU_0(process=0,(0,0,0,0))
I0416 21:12:44.974582 125630730390656 max_utils.py:756] 	Using (GB) 0.01 / 31.25 (0.032000%) on TPU_1(process=0,(1,0,0,0))
I0416 21:12:44.974632 125630730390656 max_utils.py:756] 	Using (GB) 0.01 / 31.25 (0.032000%) on TPU_2(process=0,(0,1,0,0))
I0416 21:12:44.974676 125630730390656 max_utils.py:756] 	Using (GB) 0.01 / 31.25 (0.032000%) on TPU_3(process=0,(1,1,0,0))
I0416 21:12:44.974719 125630730390656 max_utils.py:756] 	Using (GB) 0.01 / 31.25 (0.032000%) on TPU_4(process=0,(0,2,0,0))
I0416 21:12:44.974761 125630730390656 max_utils.py:756] 	Using (GB) 0.01 / 31.25 (0.032000%) on TPU_5(process=0,(1,2,0,0))
I0416 21:12:44.974807 125630730390656 max_utils.py:756] 	Using (GB) 0.01 / 31.25 (0.032000%) on TPU_6(process=0,(0,3,0,0))
I0416 21:12:44.974851 125630730390656 max_utils.py:756] 	Using (GB) 0.01 / 31.25 (0.032000%) on TPU_7(process=0,(1,3,0,0))
I0416 21:12:45.100788 125630730390656 metric_logger.py:185] completed step: 1, seconds: 16.628, TFLOP/s/device: 0.000, Tokens/s/device: 61.585, total_weights: 6826, loss: 6.004
I0416 21:12:45.113594 125630730390656 peft_trainer.py:485] Train step 2 training loss: 6.137612  - training perplexity: 462.946655

Training:  20%|██        | 1/5 [00:16<01:06, 16.63s/step, _train_loss=6.07, _train_perplexity=433, _train_steps_per_sec=0.403]
Training:  40%|████      | 2/5 [00:16<00:20,  6.93s/step, _train_loss=6.07, _train_perplexity=433, _train_steps_per_sec=0.403]I0416 21:12:45.115060 125630730390656 metric_logger.py:185] completed step: 2, seconds: 0.140, TFLOP/s/device: 0.002, Tokens/s/device: 7336.279, total_weights: 4636, loss: 6.138
I0416 21:12:45.135096 125630730390656 peft_trainer.py:485] Train step 3 training loss: 5.638189  - training perplexity: 280.953430

Training:  40%|████      | 2/5 [00:16<00:20,  6.93s/step, _train_loss=5.93, _train_perplexity=375, _train_steps_per_sec=2.59] I0416 21:12:45.136373 125630730390656 metric_logger.py:185] completed step: 3, seconds: 0.021, TFLOP/s/device: 0.010, Tokens/s/device: 48230.951, total_weights: 5886, loss: 5.638
I0416 21:12:45.147703 125630730390656 peft_trainer.py:485] Train step 4 training loss: 5.768358  - training perplexity: 320.011780

Training:  60%|██████    | 3/5 [00:16<00:13,  6.93s/step, _train_loss=5.89, _train_perplexity=360, _train_steps_per_sec=13.6]I0416 21:12:45.148957 125630730390656 metric_logger.py:185] completed step: 4, seconds: 0.013, TFLOP/s/device: 0.017, Tokens/s/device: 80794.500, total_weights: 4990, loss: 5.768
I0416 21:12:45.149272 125630730390656 profiler.py:113] Stopping JAX profiler at step 5.
I0416 21:12:45.220222 4168901 google_auth_provider.cc:149] Using credentials at ~/.config/gcloud/application_default_credentials.json
I0416 21:12:45.220320 4168901 google_auth_provider.cc:156] Using OAuth2 AuthProvider
I0416 21:12:45.998951 125520654108224 array_metadata_store.py:203] [process=0][thread=array_type_handler] Wrote 52 array_metadata.ArrayMetadata to gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260416_210550/pt_sft_nnx_feat_nnx_post_train_fixes_20260416_210550_01_sft_smoke/checkpoints/1/optimizer_state/array_metadatas/process_0
I0416 21:12:45.999851 125520675079744 array_metadata_store.py:203] [process=0][thread=array_type_handler] Wrote 46 array_metadata.ArrayMetadata to gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260416_210550/pt_sft_nnx_feat_nnx_post_train_fixes_20260416_210550_01_sft_smoke/checkpoints/1/model_params/array_metadatas/process_0
I0416 21:12:47.702682 125520633136704 base_pytree_checkpoint_handler.py:1282] [process=0][thread=write_metadata_after_commits] Commit + Array metadata written. Time taken: 2.480485s (commit=2.027834s, array_metadata_write=0.452651s)
I0416 21:12:47.715854 125520643622464 base_pytree_checkpoint_handler.py:1282] [process=0][thread=write_metadata_after_commits] Commit + Array metadata written. Time taken: 2.487516s (commit=2.049904s, array_metadata_write=0.437612s)
I0416 21:12:47.716807 125520622650944 base_pytree_checkpoint_handler.py:130] [process=0] /jax/orbax/write/gbytes_per_sec: 51.357 KiB/s (total gbytes: 205.1 KiB) (time elapsed: 3.993842601776123 s) (per-host)
I0416 21:12:47.717123 125520622650944 base_pytree_checkpoint_handler.py:130] [process=0] /jax/orbax/write/gbytes_per_sec: 154.422 KiB/s (total gbytes: 614.8 KiB) (time elapsed: 3.9814741611480713 s) (per-host)
I0416 21:12:47.717203 125520622650944 async_checkpointer.py:90] [process=0][thread=async_save] 4 Handler Commit operations completed. Time taken: 3.082754s.
I0416 21:12:47.899681 125520622650944 checkpoint.py:228] Read Metadata={'item_handlers': None, 'metrics': {}, 'performance_metrics': {}, 'init_timestamp_nsecs': 1776373964428003108, 'commit_timestamp_nsecs': None, 'custom_metadata': {}} from gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260416_210550/pt_sft_nnx_feat_nnx_post_train_fixes_20260416_210550_01_sft_smoke/checkpoints/1/_CHECKPOINT_METADATA
I0416 21:12:48.063940 125520622650944 array_metadata_store.py:367] [process=0][thread=async_save] Skipped cross-host ArrayMetadata validation because only one process is found: process_index=0.
I0416 21:12:48.277314 125520696051264 checkpoint.py:247] Updated Metadata={'item_handlers': {'model_params': 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler', 'optimizer_state': 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler'}, 'metrics': {}, 'performance_metrics': {}, 'init_timestamp_nsecs': 1776373964428003108, 'commit_timestamp_nsecs': None, 'custom_metadata': {}} to gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260416_210550/pt_sft_nnx_feat_nnx_post_train_fixes_20260416_210550_01_sft_smoke/checkpoints/1/_CHECKPOINT_METADATA
I0416 21:12:48.492764 125520622650944 base_pytree_checkpoint_handler.py:1406] [process=0][thread=async_save] Pytree save finalize (merge_ocdbt + ArrayMetadata validation) completed. Time taken: 0.556041s. use_zarr3=False, enable_post_merge_validation=True, directory=gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260416_210550/pt_sft_nnx_feat_nnx_post_train_fixes_20260416_210550_01_sft_smoke/checkpoints/1/model_params
I0416 21:12:48.493509 125520622650944 atomicity.py:666] Finalizing gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260416_210550/pt_sft_nnx_feat_nnx_post_train_fixes_20260416_210550_01_sft_smoke/checkpoints/1/model_params
I0416 21:12:48.904649 125520622650944 array_metadata_store.py:367] [process=0][thread=async_save] Skipped cross-host ArrayMetadata validation because only one process is found: process_index=0.
I0416 21:12:49.361347 125520622650944 base_pytree_checkpoint_handler.py:1406] [process=0][thread=async_save] Pytree save finalize (merge_ocdbt + ArrayMetadata validation) completed. Time taken: 0.584717s. use_zarr3=False, enable_post_merge_validation=True, directory=gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260416_210550/pt_sft_nnx_feat_nnx_post_train_fixes_20260416_210550_01_sft_smoke/checkpoints/1/optimizer_state
I0416 21:12:49.362140 125520622650944 atomicity.py:666] Finalizing gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260416_210550/pt_sft_nnx_feat_nnx_post_train_fixes_20260416_210550_01_sft_smoke/checkpoints/1/optimizer_state
I0416 21:12:49.630800 125520622650944 atomicity.py:666] Finalizing gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260416_210550/pt_sft_nnx_feat_nnx_post_train_fixes_20260416_210550_01_sft_smoke/checkpoints/1
I0416 21:12:50.011156 125630730390656 utils.py:86] Train loop finished in: 21.6636 seconds
I0416 21:12:50.011788 125630730390656 peft_trainer.py:485] Train step 5 training loss: 5.830140  - training perplexity: 340.406403

Training:  80%|████████  | 4/5 [00:21<00:06,  6.93s/step, _train_loss=5.88, _train_perplexity=356, _train_steps_per_sec=26.6]
Training: 100%|██████████| 5/5 [00:21<00:00,  3.14s/step, _train_loss=5.88, _train_perplexity=356, _train_steps_per_sec=26.6]I0416 21:12:50.013059 125630730390656 metric_logger.py:185] completed step: 5, seconds: 4.864, TFLOP/s/device: 0.000, Tokens/s/device: 210.534, total_weights: 4264, loss: 5.830
I0416 21:12:50.016418 125630730390656 checkpoint_manager.py:2020] [process=0][thread=MainThread][step=1][wait_until_finished] Waiting for Save Finalize thread (save_finalize) to complete.
I0416 21:12:50.396630 125520622650944 atomicity.py:847] [process=0][thread=async_save] Finished saving checkpoint (finalized tmp dir) to `gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260416_210550/pt_sft_nnx_feat_nnx_post_train_fixes_20260416_210550_01_sft_smoke/checkpoints/1`.
I0416 21:12:50.397328 125520622650944 event_tracking.py:138] [process=0] [async] Finished save (blocking + background) in 6.77 seconds @ gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260416_210550/pt_sft_nnx_feat_nnx_post_train_fixes_20260416_210550_01_sft_smoke/checkpoints/1
I0416 21:12:50.397404 125520622650944 async_checkpointer.py:160] [process=0][thread=async_save] Background save thread done. Time taken: 5.762956s.
I0416 21:12:50.397569 125520752674368 async_checkpointer.py:288] [process=0][thread=save_finalize] Done with waiting for background save thread=async_save.
I0416 21:12:50.397688 125520752674368 async_checkpointer.py:298] [process=0][thread=save_finalize] No errors found in background save thread=async_save.
I0416 21:12:50.397747 125520752674368 checkpoint_manager.py:2137] [process=0][thread=save_finalize][step=1] CheckpointManager Save Finalize is syncing with other hosts...
I0416 21:12:50.397792 125520752674368 checkpoint_manager.py:2146] [process=0][thread=save_finalize][step=1] CheckpointManager Save Finalize is done on all hosts.
I0416 21:12:50.397879 125630730390656 checkpoint_manager.py:2032] [process=0][thread=MainThread][step=1][wait_until_finished] Done waiting for Save Finalize thread (save_finalize) running at step=1.
I0416 21:12:50.398112 125630730390656 checkpoint_manager.py:1512] [process=0] Saving checkpoint at step 5
I0416 21:12:50.398177 125630730390656 event_tracking.py:70] [process=0] [async] Started save checkpoint @ gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260416_210550/pt_sft_nnx_feat_nnx_post_train_fixes_20260416_210550_01_sft_smoke/checkpoints/5.
I0416 21:12:50.495798 125630730390656 jax_array_handlers.py:360] Scheduling D2H of 46 prioritized jax.Array.
I0416 21:12:50.495931 125630730390656 replica_slices.py:424] Transferring arrays to host memory with options: use_replica_parallel=True, min_slice_bytes_for_replica_parallel=None, max_replicas_for_replica_parallel=None, enable_pinned_host_transfer=False
I0416 21:12:50.509776 125630730390656 base_pytree_checkpoint_handler.py:154] [process=0][thread=MainThread] Initiated "orbax.checkpoint._src.serialization.jax_array_handlers.ArrayHandler".serialize. Time taken: 0.015279s
I0416 21:12:50.510094 125630730390656 jax_array_handlers.py:360] Scheduling D2H of 52 prioritized jax.Array.
I0416 21:12:50.510136 125630730390656 replica_slices.py:424] Transferring arrays to host memory with options: use_replica_parallel=True, min_slice_bytes_for_replica_parallel=None, max_replicas_for_replica_parallel=None, enable_pinned_host_transfer=False
I0416 21:12:50.540337 125630730390656 base_pytree_checkpoint_handler.py:154] [process=0][thread=MainThread] Initiated "orbax.checkpoint._src.serialization.jax_array_handlers.ArrayHandler".serialize. Time taken: 0.030495s
I0416 21:12:50.540679 125630730390656 base_pytree_checkpoint_handler.py:130] [process=0] /jax/orbax/write/blocking_gbytes_per_sec: 3.024 MiB/s (total gbytes: 205.1 KiB) (time elapsed: 0.06624436378479004 s) (per-host)
I0416 21:12:50.540791 125630730390656 base_pytree_checkpoint_handler.py:768] [process=0][thread=MainThread] Initiated Pytree async_save. Time taken: 0.066370s (batch_requests_ready=0.002296s, total_serialization_initiated=0.063724s, others=0.000350s)
I0416 21:12:50.541223 125630730390656 base_pytree_checkpoint_handler.py:130] [process=0] /jax/orbax/write/blocking_gbytes_per_sec: 10.814 MiB/s (total gbytes: 614.8 KiB) (time elapsed: 0.05552482604980469 s) (per-host)
I0416 21:12:50.541297 125630730390656 base_pytree_checkpoint_handler.py:768] [process=0][thread=MainThread] Initiated Pytree async_save. Time taken: 0.055608s (batch_requests_ready=0.002334s, total_serialization_initiated=0.052819s, others=0.000455s)
I0416 21:12:50.541356 125630730390656 composite_checkpoint_handler.py:715] [process=0][thread=MainThread] Initiated CompositeCheckpointHandler.async_save. Time taken: 0.067413s (all_items=0.000013s, per_item={'model_params': '0.00001025', 'optimizer_state': '0.00000262'}, temp_paths=0.067400)
I0416 21:12:50.541937 125630730390656 event_tracking.py:125] [process=0] [async] Finished blocking save in 0.14 seconds. Continuing save @ gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260416_210550/pt_sft_nnx_feat_nnx_post_train_fixes_20260416_210550_01_sft_smoke/checkpoints/5.
I0416 21:12:50.542116 125520486336064 async_checkpointer.py:76] [process=0][thread=async_save] Background save thread started. Deadline for this save operation is 2026-04-16 21:32:50.542083
I0416 21:12:50.542349 125630730390656 checkpoint_manager.py:1560] [process=0][thread=MainThread][step=5] Starting CheckpointManager Save Finalize thread=save_finalize
I0416 21:12:50.542663 125520475850304 async_checkpointer.py:280] [process=0][thread=save_finalize] Waiting for background save thread=async_save.
I0416 21:12:50.542773 125630730390656 standard_logger.py:34] {'step': 5, 'event_type': 'save', 'directory': 'gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260416_210550/pt_sft_nnx_feat_nnx_post_train_fixes_20260416_210550_01_sft_smoke/checkpoints', 'reached_preemption': False, 'preemption_received_at': None, 'synchronous': False, 'wait_for_prev_start_time': 1776373970.0163968, 'wait_for_prev_duration_secs': 0.3815345764160156, 'time_between_consecutive_saves_sec': None, 'checkpointer_blocking_start_time': 1776373970.3981366, 'checkpointer_blocking_duration_secs': 0.14407014846801758, 'get_old_steps_start_time': 1776373970.5422237, 'get_old_steps_duration_secs': 8.7738037109375e-05, 'checkpoint_manager_blocking_start_time': 1776373970.0163558, 'checkpoint_manager_blocking_duration_secs': 0.5263934135437012}
I0416 21:12:50.542941 125630730390656 checkpoint_manager.py:2020] [process=0][thread=MainThread][step=5][wait_until_finished] Waiting for Save Finalize thread (save_finalize) to complete.
I0416 21:12:50.557421 125520622650944 atomicity.py:140] Creating tmp directory gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260416_210550/pt_sft_nnx_feat_nnx_post_train_fixes_20260416_210550_01_sft_smoke/checkpoints/5
I0416 21:12:51.208363 125520643622464 atomicity.py:140] Creating tmp directory gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260416_210550/pt_sft_nnx_feat_nnx_post_train_fixes_20260416_210550_01_sft_smoke/checkpoints/5/optimizer_state
I0416 21:12:51.214314 125520643622464 atomicity.py:140] Creating tmp directory gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260416_210550/pt_sft_nnx_feat_nnx_post_train_fixes_20260416_210550_01_sft_smoke/checkpoints/5/model_params
I0416 21:12:52.446114 125520675079744 array_metadata_store.py:203] [process=0][thread=array_type_handler] Wrote 46 array_metadata.ArrayMetadata to gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260416_210550/pt_sft_nnx_feat_nnx_post_train_fixes_20260416_210550_01_sft_smoke/checkpoints/5/model_params/array_metadatas/process_0
I0416 21:12:52.460892 125520633136704 array_metadata_store.py:203] [process=0][thread=array_type_handler] Wrote 52 array_metadata.ArrayMetadata to gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260416_210550/pt_sft_nnx_feat_nnx_post_train_fixes_20260416_210550_01_sft_smoke/checkpoints/5/optimizer_state/array_metadatas/process_0
I0416 21:12:54.014780 125520591193664 base_pytree_checkpoint_handler.py:1282] [process=0][thread=write_metadata_after_commits] Commit + Array metadata written. Time taken: 2.347946s (commit=1.925215s, array_metadata_write=0.422731s)
I0416 21:12:54.015827 125520559736384 base_pytree_checkpoint_handler.py:1282] [process=0][thread=write_metadata_after_commits] Commit + Array metadata written. Time taken: 2.349145s (commit=1.928037s, array_metadata_write=0.421108s)
I0416 21:12:54.016997 125520486336064 base_pytree_checkpoint_handler.py:130] [process=0] /jax/orbax/write/gbytes_per_sec: 57.900 KiB/s (total gbytes: 205.1 KiB) (time elapsed: 3.542527198791504 s) (per-host)
I0416 21:12:54.017487 125520486336064 base_pytree_checkpoint_handler.py:130] [process=0] /jax/orbax/write/gbytes_per_sec: 174.084 KiB/s (total gbytes: 614.8 KiB) (time elapsed: 3.5317907333374023 s) (per-host)
I0416 21:12:54.017590 125520486336064 async_checkpointer.py:90] [process=0][thread=async_save] 4 Handler Commit operations completed. Time taken: 3.475188s.
I0416 21:12:54.371819 125520486336064 array_metadata_store.py:367] [process=0][thread=async_save] Skipped cross-host ArrayMetadata validation because only one process is found: process_index=0.
I0416 21:12:54.770210 125520486336064 base_pytree_checkpoint_handler.py:1406] [process=0][thread=async_save] Pytree save finalize (merge_ocdbt + ArrayMetadata validation) completed. Time taken: 0.534747s. use_zarr3=False, enable_post_merge_validation=True, directory=gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260416_210550/pt_sft_nnx_feat_nnx_post_train_fixes_20260416_210550_01_sft_smoke/checkpoints/5/model_params
I0416 21:12:54.770857 125520486336064 atomicity.py:666] Finalizing gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260416_210550/pt_sft_nnx_feat_nnx_post_train_fixes_20260416_210550_01_sft_smoke/checkpoints/5/model_params
I0416 21:12:55.187703 125520486336064 array_metadata_store.py:367] [process=0][thread=async_save] Skipped cross-host ArrayMetadata validation because only one process is found: process_index=0.
I0416 21:12:55.610383 125520486336064 base_pytree_checkpoint_handler.py:1406] [process=0][thread=async_save] Pytree save finalize (merge_ocdbt + ArrayMetadata validation) completed. Time taken: 0.558571s. use_zarr3=False, enable_post_merge_validation=True, directory=gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260416_210550/pt_sft_nnx_feat_nnx_post_train_fixes_20260416_210550_01_sft_smoke/checkpoints/5/optimizer_state
I0416 21:12:55.611114 125520486336064 atomicity.py:666] Finalizing gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260416_210550/pt_sft_nnx_feat_nnx_post_train_fixes_20260416_210550_01_sft_smoke/checkpoints/5/optimizer_state
I0416 21:12:55.873681 125520486336064 atomicity.py:666] Finalizing gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260416_210550/pt_sft_nnx_feat_nnx_post_train_fixes_20260416_210550_01_sft_smoke/checkpoints/5
I0416 21:12:56.615331 125520486336064 atomicity.py:847] [process=0][thread=async_save] Finished saving checkpoint (finalized tmp dir) to `gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260416_210550/pt_sft_nnx_feat_nnx_post_train_fixes_20260416_210550_01_sft_smoke/checkpoints/5`.
I0416 21:12:56.615975 125520486336064 event_tracking.py:138] [process=0] [async] Finished save (blocking + background) in 6.22 seconds @ gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260416_210550/pt_sft_nnx_feat_nnx_post_train_fixes_20260416_210550_01_sft_smoke/checkpoints/5
I0416 21:12:56.616064 125520486336064 async_checkpointer.py:160] [process=0][thread=async_save] Background save thread done. Time taken: 6.073657s.
I0416 21:12:56.616255 125520475850304 async_checkpointer.py:288] [process=0][thread=save_finalize] Done with waiting for background save thread=async_save.
I0416 21:12:56.616367 125520475850304 async_checkpointer.py:298] [process=0][thread=save_finalize] No errors found in background save thread=async_save.
I0416 21:12:56.616421 125520475850304 checkpoint_manager.py:2137] [process=0][thread=save_finalize][step=5] CheckpointManager Save Finalize is syncing with other hosts...
I0416 21:12:56.616463 125520475850304 checkpoint_manager.py:2146] [process=0][thread=save_finalize][step=5] CheckpointManager Save Finalize is done on all hosts.
I0416 21:12:56.616599 125630730390656 checkpoint_manager.py:2032] [process=0][thread=MainThread][step=5][wait_until_finished] Done waiting for Save Finalize thread (save_finalize) running at step=5.
I0416 21:12:56.616852 125630730390656 checkpoint.py:459] Closing _NonBlockingMetadataStore(enable_write=True, _write_lock=<locked _thread.RLock object owner=125630730390656 count=1 at 0x723b75984900>, _store_impl=<orbax.checkpoint._src.metadata.checkpoint._MetadataStoreImpl object at 0x723b6f163950>, _single_thread_executor=<concurrent.futures.thread.ThreadPoolExecutor object at 0x723b6f163980>, _write_futures=[])
I0416 21:12:56.617353 125630730390656 checkpoint.py:459] Closing _NonBlockingMetadataStore(enable_write=True, _write_lock=<locked _thread.RLock object owner=125630730390656 count=1 at 0x723b75984900>, _store_impl=<orbax.checkpoint._src.metadata.checkpoint._MetadataStoreImpl object at 0x723b6f163950>, _single_thread_executor=<concurrent.futures.thread.ThreadPoolExecutor object at 0x723b6f163980>, _write_futures=[])
I0416 21:12:56.617381 125630730390656 checkpoint.py:459] Closing _NonBlockingMetadataStore(enable_write=True, _write_lock=<locked _thread.RLock object owner=125630730390656 count=1 at 0x723b75984900>, _store_impl=<orbax.checkpoint._src.metadata.checkpoint._MetadataStoreImpl object at 0x723b6f163950>, _single_thread_executor=<concurrent.futures.thread.ThreadPoolExecutor object at 0x723b6f163980>, _write_futures=[])

Training: 100%|██████████| 5/5 [00:29<00:00,  5.89s/step, _train_loss=5.88, _train_perplexity=356, _train_steps_per_sec=26.6]
[DECOUPLED NO-OP] gcs_storage: using stubs.
[DECOUPLED NO-OP] mldiagnostics: using stub.
[DECOUPLED NO-OP] mldiagnostics: using stub.
[DECOUPLED NO-OP] mldiagnostics: using stub.
[DECOUPLED NO-OP] workload_monitor: using stub.
[DECOUPLED NO-OP] vertex_tensorboard: using stub.
~/.local/share/uv/python/cpython-3.12.12-linux-x86_64-gnu/lib/python3.12/multiprocessing/resource_tracker.py:279: UserWarning: resource_tracker: There appear to be 15 leaked shared_memory objects to clean up at shutdown
  warnings.warn('resource_tracker: There appear to be %d '