MaxView

‹ 01_sft_smokeCase: 07_distill_smoke— ›

Metrics: Linen vs NNX  ·  feat/nnx-set-defaults-true

MetricLinen  73213e044NNX  73213e044Diff (NNX − Linen)

Diff = NNX value − Linen value. Green = NNX improved. Red = NNX regressed.

Linen  ·  73213e044  ·  feat_nnx_set_defaults_true_20260425_150421  ·  full log
XPK Start: Sat Apr 25 15:12:34 UTC 2026
2026-04-25 15:12:51.208139: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)
I0425 15:12:54.776250 133523214628672 max_utils.py:273] Attempting to initialize the jax distributed system...
INFO:2026-04-25 15:13:03,815:jax._src.distributed:149: Starting JAX distributed service on [::]:8482
I0425 15:13:03.815100 133523214628672 distributed.py:149] Starting JAX distributed service on [::]:8482
INFO:2026-04-25 15:13:03,817:jax._src.distributed:166: Connecting to JAX distributed service on mt-07-distill-smoke-lvnta-slice-job-0-0.mt-07-distill-smoke-lvnta:8482
I0425 15:13:03.817366 133523214628672 distributed.py:166] Connecting to JAX distributed service on mt-07-distill-smoke-lvnta-slice-job-0-0.mt-07-distill-smoke-lvnta:8482
I0425 15:13:05.294958 133523214628672 max_utils.py:284] Jax distributed system initialized!
I0425 15:13:11.585475 133523214628672 max_utils.py:244] Jax distributed system is already initialized.
I0425 15:13:12.069175 133523214628672 max_utils.py:244] Jax distributed system is already initialized.
I0425 15:13:12.070368 133523214628672 tokenizer.py:245] Tokenizer path: meta-llama/Llama-2-7b-chat-hf
I0425 15:13:12.070417 133523214628672 tokenizer.py:224] Loading HF tokenizer: meta-llama/Llama-2-7b-chat-hf
I0425 15:13:15.977295 133523214628672 _schedule.py:129] A polynomial schedule was set with a non-positive `transition_steps` value; this results in a constant schedule with value `init_value`.
I0425 15:13:15.980357 133523214628672 maxtext_utils.py:1631] Num_devices: 32, shape (1, 4, 1, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1)
I0425 15:13:15.980502 133523214628672 train_distill.py:582] Applying logical axis rules for model initialization and training...
I0425 15:13:15.980578 133523214628672 train_distill.py:586] Loading Student from ...
I0425 15:13:15.980612 133523214628672 train_distill.py:169] --- Student Configuration ---
I0425 15:13:15.980635 133523214628672 train_distill.py:170]   Model Name:      gpt3-52k
I0425 15:13:15.980657 133523214628672 train_distill.py:171]   Dimensions:      1 Layers, 16 Emb Dim, 8 Head Dim
I0425 15:13:15.980677 133523214628672 train_distill.py:174]   Attention Heads: 2 Query, 2 KV
I0425 15:13:15.980695 133523214628672 train_distill.py:175]   Vocab Size:      32000
I0425 15:13:15.980713 133523214628672 train_distill.py:176]   Checkpoint:      
I0425 15:13:15.980730 133523214628672 train_distill.py:460] Initializing model: gpt3-52k...
I0425 15:13:17.374107 133523214628672 train_distill.py:600] Loading Teacher from gs://lance-maxtext/pt_seed_ckpts/pt_seed_ckpts/pt_seed_ckpt_gpt352k_v32k_linen/checkpoints/4/items...
I0425 15:13:17.374222 133523214628672 train_distill.py:169] --- Teacher Configuration ---
I0425 15:13:17.374251 133523214628672 train_distill.py:170]   Model Name:      gpt3-52k
I0425 15:13:17.374275 133523214628672 train_distill.py:171]   Dimensions:      1 Layers, 16 Emb Dim, 8 Head Dim
I0425 15:13:17.374301 133523214628672 train_distill.py:174]   Attention Heads: 2 Query, 2 KV
I0425 15:13:17.374319 133523214628672 train_distill.py:175]   Vocab Size:      32000
I0425 15:13:17.374336 133523214628672 train_distill.py:176]   Checkpoint:      gs://lance-maxtext/pt_seed_ckpts/pt_seed_ckpts/pt_seed_ckpt_gpt352k_v32k_linen/checkpoints/4/items
I0425 15:13:17.374355 133523214628672 train_distill.py:460] Initializing model: gpt3-52k...
I0425 15:13:18.412016 133523214628672 pytree_checkpoint_handler.py:577] save_device_host_concurrent_bytes=None
I0425 15:13:18.412462 133523214628672 base_pytree_checkpoint_handler.py:411] Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=True, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x796f973c3b00>, enable_pinned_host_transfer=False, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB)
I0425 15:13:18.412523 133523214628672 abstract_checkpointer.py:35] orbax-checkpoint version: 0.11.28
W0425 15:13:18.937724 133523214628672 checkpoint.py:202] Metadata file does not exist: gs://lance-maxtext/pt_seed_ckpts/pt_seed_ckpts/pt_seed_ckpt_gpt352k_v32k_linen/checkpoints/4/items/_CHECKPOINT_METADATA
I0425 15:13:19.483842    2147 google_auth_provider.cc:181] Running on GCE, using service account 562977990677-compute@developer.gserviceaccount.com
I0425 15:13:20.649509 133523214628672 checkpointer.py:304] Restoring checkpoint from gs://lance-maxtext/pt_seed_ckpts/pt_seed_ckpts/pt_seed_ckpt_gpt352k_v32k_linen/checkpoints/4/items.
W0425 15:13:23.018229 133523214628672 transform_utils.py:230] The transformations API will eventually be replaced by an upgraded design. The current API will not be removed until this point, but it will no longer be actively worked on.
I0425 15:13:23.018727 133523214628672 transform_utils.py:288] The following keys are not loaded from the original tree after applying specified transforms: params/params/decoder/dropout/rngs/aqt/count, params/params/decoder/dropout/rngs/aqt/key, params/params/decoder/dropout/rngs/dropout/count, params/params/decoder/dropout/rngs/dropout/key, params/params/decoder/dropout/rngs/params/count, params/params/decoder/dropout/rngs/params/key, params/params/decoder/layers/dropout/rngs/aqt/count, params/params/decoder/layers/dropout/rngs/aqt/key, params/params/decoder/layers/dropout/rngs/dropout/count, params/params/decoder/layers/dropout/rngs/dropout/key, params/params/decoder/layers/dropout/rngs/params/count, params/params/decoder/layers/dropout/rngs/params/key, params/params/decoder/layers/mlp/dropout/rngs/aqt/count, params/params/decoder/layers/mlp/dropout/rngs/aqt/key, params/params/decoder/layers/mlp/dropout/rngs/dropout/count, params/params/decoder/layers/mlp/dropout/rngs/dropout/key, params/params/decoder/layers/mlp/dropout/rngs/params/count, params/params/decoder/layers/mlp/dropout/rngs/params/key, params/params/decoder/layers/rngs/aqt/count, params/params/decoder/layers/rngs/aqt/key, params/params/decoder/layers/rngs/dropout/count, params/params/decoder/layers/rngs/dropout/key, params/params/decoder/layers/rngs/params/count, params/params/decoder/layers/rngs/params/key, params/params/decoder/rngs/aqt/count, params/params/decoder/rngs/aqt/key, params/params/decoder/rngs/dropout/count, params/params/decoder/rngs/dropout/key, params/params/decoder/rngs/params/count, params/params/decoder/rngs/params/key
I0425 15:13:23.147746 133523214628672 checkpointer.py:318] Finished restoring checkpoint in 2.89 seconds from gs://lance-maxtext/pt_seed_ckpts/pt_seed_ckpts/pt_seed_ckpt_gpt352k_v32k_linen/checkpoints/4/items.
I0425 15:13:23.840438 133523214628672 train_distill.py:626] Initializing Data Iterators via MaxText pipeline...
I0425 15:13:23.905853 133523214628672 config.py:112] TensorFlow version 2.20.0 available.
I0425 15:13:23.906347 133523214628672 config.py:125] JAX version 0.8.3 available.
E0425 15:13:26.398077 133523214628672 packing.py:209] PackAndBatchOperation is deprecated. Please use lazy_dataset.FirstFitPackIterDataset instead.
I0425 15:13:26.398304 133523214628672 data_loader.py:408] Adding CopyNumPyArrayToSharedMemory MapTransform.
I0425 15:13:26.401355 133523214628672 train_distill.py:405] Input Pipeline Checkpointing: DISABLED
I0425 15:13:26.401417 133523214628672 train_distill.py:409] Reason: Iterator 'MultiHostDataLoadIterator' is not recognized as Grain (dataset_type='DatasetType.HF', has_save=False)
I0425 15:13:26.401503 133523214628672 pytree_checkpoint_handler.py:577] save_device_host_concurrent_bytes=None
I0425 15:13:26.401581 133523214628672 base_pytree_checkpoint_handler.py:411] Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=False, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x796f973c3b00>, enable_pinned_host_transfer=False, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB)
I0425 15:13:26.401623 133523214628672 pytree_checkpoint_handler.py:577] save_device_host_concurrent_bytes=None
I0425 15:13:26.401660 133523214628672 base_pytree_checkpoint_handler.py:411] Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=False, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x796f973c3b00>, enable_pinned_host_transfer=False, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB)
I0425 15:13:26.401703 133523214628672 checkpoint_manager.py:702] [process=3][thread=MainThread] CheckpointManager init: checkpointers=None, item_names=None, item_handlers={'model_params': <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7957d37e0bf0>, 'optimizer_state': <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7957d376d9a0>, 'custom_metadata': <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x796f91d7f020>}, handler_registry=None
I0425 15:13:26.401893 133523214628672 composite_checkpoint_handler.py:237] Deferred registration for item: "model_params". Adding handler `<orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7957d37e0bf0>` for item "model_params" and save args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>` to `_handler_registry`.
I0425 15:13:26.401935 133523214628672 composite_checkpoint_handler.py:237] Deferred registration for item: "optimizer_state". Adding handler `<orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7957d376d9a0>` for item "optimizer_state" and save args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>` to `_handler_registry`.
I0425 15:13:26.401961 133523214628672 composite_checkpoint_handler.py:237] Deferred registration for item: "custom_metadata". Adding handler `<orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x796f91d7f020>` for item "custom_metadata" and save args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>` to `_handler_registry`.
I0425 15:13:26.401985 133523214628672 composite_checkpoint_handler.py:237] Deferred registration for item: "metrics". Adding handler `<orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x795690516780>` for item "metrics" and save args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>` to `_handler_registry`.
I0425 15:13:26.402012 133523214628672 composite_checkpoint_handler.py:505] Initialized registry DefaultCheckpointHandlerRegistry({('model_params', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7957d37e0bf0>, ('model_params', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7957d37e0bf0>, ('optimizer_state', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7957d376d9a0>, ('optimizer_state', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7957d376d9a0>, ('custom_metadata', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x796f91d7f020>, ('custom_metadata', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x796f91d7f020>, ('metrics', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x795690516780>, ('metrics', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x795690516780>}).
I0425 15:13:26.402437 133523214628672 async_checkpointer.py:177] [process=3][thread=MainThread] Using barrier_sync_fn: <function get_barrier_sync_fn.<locals>._fn at 0x7957d1c29580> timeout: 600 secs and primary_host=0 for async checkpoint writes
I0425 15:13:28.728317 133523214628672 checkpoint_manager.py:1788] Found 0 checkpoint steps in gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_set_defaults_true_20260425_150421/pt_distill_linen_xpk_feat_nnx_set_defaults_true_20260425_150421_07_distill_smoke/checkpoints
I0425 15:13:28.742144 133523214628672 checkpoint_manager.py:921] [process=3][thread=MainThread] CheckpointManager created,  primary_host=0, CheckpointManagerOptions=CheckpointManagerOptions(save_interval_steps=2000, max_to_keep=None, keep_time_interval=None, keep_period=None, should_keep_fn=None, best_fn=None, best_mode='max', keep_checkpoints_without_metrics=True, step_prefix=None, step_format_fixed_length=None, step_name_format=None, create=True, cleanup_tmp_directories=False, save_on_steps=frozenset(), single_host_load_and_broadcast=False, todelete_subdir=None, todelete_full_path=None, enable_hns=False, enable_background_delete=False, read_only=False, enable_async_checkpointing=True, async_options=None, multiprocessing_options=MultiprocessingOptions(primary_host=0, active_processes=None, barrier_sync_key_prefix=None), should_save_fn=None, file_options=FileOptions(path_permission_mode=None), save_root_metadata=True, temporary_path_class=None, save_decision_policy=None, preservation_policy=None, prevent_write_metrics=False, enable_should_save_is_saving_in_progress_check=True, enable_per_process_directory_creation=False), root_directory=gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_set_defaults_true_20260425_150421/pt_distill_linen_xpk_feat_nnx_set_defaults_true_20260425_150421_07_distill_smoke/checkpoints: <orbax.checkpoint.checkpoint_manager.CheckpointManager object at 0x796f91d7ef60>
I0425 15:13:28.742269 133523214628672 pytree_checkpoint_handler.py:577] save_device_host_concurrent_bytes=None
I0425 15:13:28.742337 133523214628672 base_pytree_checkpoint_handler.py:411] Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=False, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x796f973c3b00>, enable_pinned_host_transfer=False, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB)
I0425 15:13:28.742374 133523214628672 pytree_checkpoint_handler.py:577] save_device_host_concurrent_bytes=None
I0425 15:13:28.742407 133523214628672 base_pytree_checkpoint_handler.py:411] Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=False, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x796f973c3b00>, enable_pinned_host_transfer=False, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB)
I0425 15:13:28.742462 133523214628672 checkpoint_manager.py:1983] [process=3][thread=MainThread][wait_until_finished] No Save Finalize thread to wait for. Returning.
I0425 15:13:28.742515 133523214628672 checkpoint.py:459] Closing _NonBlockingMetadataStore(enable_write=True, _write_lock=<locked _thread.RLock object owner=133523214628672 count=1 at 0x7957d37b0480>, _store_impl=<orbax.checkpoint._src.metadata.checkpoint._MetadataStoreImpl object at 0x796f91d7e870>, _single_thread_executor=<concurrent.futures.thread.ThreadPoolExecutor object at 0x7957d3177cb0>, _write_futures=[])
I0425 15:13:28.742850 133523214628672 checkpoint.py:459] Closing _NonBlockingMetadataStore(enable_write=True, _write_lock=<locked _thread.RLock object owner=133523214628672 count=1 at 0x7957d37b0480>, _store_impl=<orbax.checkpoint._src.metadata.checkpoint._MetadataStoreImpl object at 0x796f91d7e870>, _single_thread_executor=<concurrent.futures.thread.ThreadPoolExecutor object at 0x7957d3177cb0>, _write_futures=[])
I0425 15:13:28.742878 133523214628672 checkpoint.py:459] Closing _NonBlockingMetadataStore(enable_write=True, _write_lock=<locked _thread.RLock object owner=133523214628672 count=1 at 0x7957d37b0480>, _store_impl=<orbax.checkpoint._src.metadata.checkpoint._MetadataStoreImpl object at 0x796f91d7e870>, _single_thread_executor=<concurrent.futures.thread.ThreadPoolExecutor object at 0x7957d3177cb0>, _write_futures=[])
I0425 15:13:28.742915 133523214628672 checkpoint_manager.py:702] [process=3][thread=MainThread] CheckpointManager init: checkpointers=None, item_names=None, item_handlers={'model_params': <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x796f91d7e990>, 'optimizer_state': <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x796f91d7e0f0>, 'custom_metadata': <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x796f91d7da00>, 'iter': <maxtext.common.checkpointing.GrainCheckpointHandler object at 0x796f91d7e420>}, handler_registry=None
I0425 15:13:28.743012 133523214628672 composite_checkpoint_handler.py:237] Deferred registration for item: "model_params". Adding handler `<orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x796f91d7e990>` for item "model_params" and save args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>` to `_handler_registry`.
I0425 15:13:28.743045 133523214628672 composite_checkpoint_handler.py:237] Deferred registration for item: "optimizer_state". Adding handler `<orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x796f91d7e0f0>` for item "optimizer_state" and save args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>` to `_handler_registry`.
I0425 15:13:28.743068 133523214628672 composite_checkpoint_handler.py:237] Deferred registration for item: "custom_metadata". Adding handler `<orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x796f91d7da00>` for item "custom_metadata" and save args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>` to `_handler_registry`.
I0425 15:13:28.743094 133523214628672 composite_checkpoint_handler.py:237] Deferred registration for item: "iter". Adding handler `<maxtext.common.checkpointing.GrainCheckpointHandler object at 0x796f91d7e420>` for item "iter" and save args `<class 'maxtext.common.checkpointing.GrainCheckpointSave'>` and restore args `<class 'maxtext.common.checkpointing.GrainCheckpointRestore'>` to `_handler_registry`.
I0425 15:13:28.743117 133523214628672 composite_checkpoint_handler.py:237] Deferred registration for item: "metrics". Adding handler `<orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x796f91d7dc70>` for item "metrics" and save args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>` to `_handler_registry`.
I0425 15:13:28.743141 133523214628672 composite_checkpoint_handler.py:505] Initialized registry DefaultCheckpointHandlerRegistry({('model_params', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x796f91d7e990>, ('model_params', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x796f91d7e990>, ('optimizer_state', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x796f91d7e0f0>, ('optimizer_state', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x796f91d7e0f0>, ('custom_metadata', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x796f91d7da00>, ('custom_metadata', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x796f91d7da00>, ('iter', <class 'maxtext.common.checkpointing.GrainCheckpointSave'>): <maxtext.common.checkpointing.GrainCheckpointHandler object at 0x796f91d7e420>, ('iter', <class 'maxtext.common.checkpointing.GrainCheckpointRestore'>): <maxtext.common.checkpointing.GrainCheckpointHandler object at 0x796f91d7e420>, ('metrics', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x796f91d7dc70>, ('metrics', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x796f91d7dc70>}).
I0425 15:13:28.743210 133523214628672 async_checkpointer.py:177] [process=3][thread=MainThread] Using barrier_sync_fn: <function get_barrier_sync_fn.<locals>._fn at 0x7957d1c296c0> timeout: 600 secs and primary_host=0 for async checkpoint writes
I0425 15:13:29.116262 133523214628672 checkpoint_manager.py:1788] Found 0 checkpoint steps in gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_set_defaults_true_20260425_150421/pt_distill_linen_xpk_feat_nnx_set_defaults_true_20260425_150421_07_distill_smoke/checkpoints
I0425 15:13:29.128212 133523214628672 checkpoint_manager.py:921] [process=3][thread=MainThread] CheckpointManager created,  primary_host=0, CheckpointManagerOptions=CheckpointManagerOptions(save_interval_steps=2000, max_to_keep=None, keep_time_interval=None, keep_period=None, should_keep_fn=None, best_fn=None, best_mode='max', keep_checkpoints_without_metrics=True, step_prefix=None, step_format_fixed_length=None, step_name_format=None, create=True, cleanup_tmp_directories=False, save_on_steps=frozenset(), single_host_load_and_broadcast=False, todelete_subdir=None, todelete_full_path=None, enable_hns=False, enable_background_delete=False, read_only=False, enable_async_checkpointing=True, async_options=None, multiprocessing_options=MultiprocessingOptions(primary_host=0, active_processes=None, barrier_sync_key_prefix=None), should_save_fn=None, file_options=FileOptions(path_permission_mode=None), save_root_metadata=True, temporary_path_class=None, save_decision_policy=None, preservation_policy=None, prevent_write_metrics=False, enable_should_save_is_saving_in_progress_check=True, enable_per_process_directory_creation=False), root_directory=gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_set_defaults_true_20260425_150421/pt_distill_linen_xpk_feat_nnx_set_defaults_true_20260425_150421_07_distill_smoke/checkpoints: <orbax.checkpoint.checkpoint_manager.CheckpointManager object at 0x7957d2a38410>
I0425 15:13:29.128384 133523214628672 train_distill.py:673] Starting Distillation Training...
I0425 15:13:29.128490 133523214628672 peft_trainer.py:590] Training with mesh: Mesh('diloco': 1, 'data': 4, 'stage': 1, 'fsdp': 8, 'fsdp_transpose': 1, 'sequence': 1, 'context': 1, 'context_autoregressive': 1, 'tensor': 1, 'tensor_transpose': 1, 'tensor_sequence': 1, 'expert': 1, 'autoregressive': 1, axis_types=(Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto))
I0425 15:13:29.490185 133523214628672 peft_trainer.py:600] Compiled train_step cache size: 0

Training:   0%|          | 0/5 [00:00<?, ?step/s]I0425 15:13:29.492058 133378833626880 grain_pool.py:367] Grain pool will use 1 processes.
I0425 15:13:29.518354 133378833626880 grain_pool.py:440] Grain pool will start child processes.
I0425 15:13:29.523716 133378833626880 grain_pool.py:448] Grain pool started all child processes.
2026-04-25 15:13:35.563854: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)
I0425 15:13:38.829664 133523214628672 utils.py:86] Train loop finished in: 9.3389 seconds
Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/deps/src/maxtext/trainers/post_train/distillation/train_distill.py", line 747, in <module>
    app.run(main)
  File "/usr/local/lib/python3.12/site-packages/absl/app.py", line 316, in run
    _run_main(main, args)
  File "/usr/local/lib/python3.12/site-packages/absl/app.py", line 261, in _run_main
    sys.exit(main(argv))
             ^^^^^^^^^^
  File "/deps/src/maxtext/trainers/post_train/distillation/train_distill.py", line 743, in main
    train_distill(student_config, teacher_config, is_offline, global_config.offline_data_dir)
  File "/deps/src/maxtext/trainers/post_train/distillation/train_distill.py", line 675, in train_distill
    trainer.train(train_iter, eval_iter)
  File "/usr/local/lib/python3.12/site-packages/tunix/sft/peft_trainer.py", line 659, in train
    train_example = sharding_utils.shard_input(
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/tunix/sft/sharding_utils.py", line 58, in shard_input
    return jax.tree.map(
           ^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/tree.py", line 155, in map
    return tree_util.tree_map(f, tree, *rest, is_leaf=is_leaf)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/tree_util.py", line 362, in tree_map
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/tree_util.py", line 362, in <genexpr>
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
                             ^^^^^^
  File "/usr/local/lib/python3.12/site-packages/tunix/sft/sharding_utils.py", line 59, in <lambda>
    lambda x: jax.make_array_from_process_local_data(
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/array.py", line 986, in make_array_from_process_local_data
    out = [_array_from_process_local_data(data, s, shape)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/array.py", line 1048, in _array_from_process_local_data
    return make_array_from_callback(global_shape, sharding, cb)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/array.py", line 845, in make_array_from_callback
    per_device_values = api.device_put(per_device_values, devices)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/api.py", line 2729, in device_put
    out_flat = dispatch._batched_device_put_impl(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/dispatch.py", line 558, in _batched_device_put_impl
    y = _device_put_impl(x, device=device, src=src, copy=cp, aval=aval)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/dispatch.py", line 545, in _device_put_impl
    return _device_put_sharding_impl(x, aval, device, copy)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/dispatch.py", line 487, in _device_put_sharding_impl
    raise ValueError(
ValueError: device_put's first argument must be a fully addressable array, but got value with devices {TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=30, process_index=7, coords=(2,7,0), core_on_chip=0), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=15, process_index=3, coords=(3,3,0), core_on_chip=0), TpuDevice(id=23, process_index=5, coords=(3,5,0), core_on_chip=0), TpuDevice(id=12, process_index=2, coords=(0,3,0), core_on_chip=0), TpuDevice(id=2, process_index=1, coords=(2,0,0), core_on_chip=0), TpuDevice(id=3, process_index=1, coords=(3,0,0), core_on_chip=0), TpuDevice(id=18, process_index=5, coords=(2,4,0), core_on_chip=0), TpuDevice(id=20, process_index=4, coords=(0,5,0), core_on_chip=0), TpuDevice(id=28, process_index=6, coords=(0,7,0), core_on_chip=0), TpuDevice(id=27, process_index=7, coords=(3,6,0), core_on_chip=0), TpuDevice(id=9, process_index=2, coords=(1,2,0), core_on_chip=0), TpuDevice(id=24, process_index=6, coords=(0,6,0), core_on_chip=0), TpuDevice(id=6, process_index=1, coords=(2,1,0), core_on_chip=0), TpuDevice(id=7, process_index=1, coords=(3,1,0), core_on_chip=0), TpuDevice(id=8, process_index=2, coords=(0,2,0), core_on_chip=0), TpuDevice(id=13, process_index=2, coords=(1,3,0), core_on_chip=0), TpuDevice(id=16, process_index=4, coords=(0,4,0), core_on_chip=0), TpuDevice(id=21, process_index=4, coords=(1,5,0), core_on_chip=0), TpuDevice(id=31, process_index=7, coords=(3,7,0), core_on_chip=0), TpuDevice(id=19, process_index=5, coords=(3,4,0), core_on_chip=0), TpuDevice(id=25, process_index=6, coords=(1,6,0), core_on_chip=0), TpuDevice(id=10, process_index=3, coords=(2,2,0), core_on_chip=0), TpuDevice(id=29, process_index=6, coords=(1,7,0), core_on_chip=0), TpuDevice(id=22, process_index=5, coords=(2,5,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=11, process_index=3, coords=(3,2,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=17, process_index=4, coords=(1,4,0), core_on_chip=0), TpuDevice(id=14, process_index=3, coords=(2,3,0), core_on_chip=0), TpuDevice(id=26, process_index=7, coords=(2,6,0), core_on_chip=0)}
I0425 15:13:39.169388 133378833626880 grain_pool.py:542] Grain pool is exiting.
I0425 15:13:39.169514 133378833626880 grain_pool.py:547] Shutting down multiprocessing system.
I0425 15:13:40.632550 133378833626880 grain_pool.py:547] Shutting down multiprocessing system.

Training:   0%|          | 0/5 [00:13<?, ?step/s]
/usr/local/lib/python3.12/multiprocessing/resource_tracker.py:279: UserWarning: resource_tracker: There appear to be 15 leaked shared_memory objects to clean up at shutdown
  warnings.warn('resource_tracker: There appear to be %d '
XPK End: Sat Apr 25 15:13:49 UTC 2026
EXIT_CODE=1
NNX  ·  73213e044  ·  feat_nnx_set_defaults_true_20260425_150421  ·  full log
XPK Start: Sat Apr 25 15:22:57 UTC 2026
2026-04-25 15:23:14.362310: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)
I0425 15:23:17.949251 133362793142080 max_utils.py:273] Attempting to initialize the jax distributed system...
INFO:2026-04-25 15:23:26,989:jax._src.distributed:149: Starting JAX distributed service on [::]:8482
I0425 15:23:26.989190 133362793142080 distributed.py:149] Starting JAX distributed service on [::]:8482
INFO:2026-04-25 15:23:26,991:jax._src.distributed:166: Connecting to JAX distributed service on mt-07-distill-smoke-svdg7-slice-job-0-0.mt-07-distill-smoke-svdg7:8482
I0425 15:23:26.991553 133362793142080 distributed.py:166] Connecting to JAX distributed service on mt-07-distill-smoke-svdg7-slice-job-0-0.mt-07-distill-smoke-svdg7:8482
I0425 15:23:27.710994 133362793142080 max_utils.py:284] Jax distributed system initialized!
I0425 15:23:33.986195 133362793142080 max_utils.py:244] Jax distributed system is already initialized.
I0425 15:23:34.466631 133362793142080 max_utils.py:244] Jax distributed system is already initialized.
I0425 15:23:34.467814 133362793142080 tokenizer.py:245] Tokenizer path: meta-llama/Llama-2-7b-chat-hf
I0425 15:23:34.467864 133362793142080 tokenizer.py:224] Loading HF tokenizer: meta-llama/Llama-2-7b-chat-hf
I0425 15:23:38.466324 133362793142080 _schedule.py:129] A polynomial schedule was set with a non-positive `transition_steps` value; this results in a constant schedule with value `init_value`.
I0425 15:23:38.469383 133362793142080 maxtext_utils.py:1631] Num_devices: 32, shape (1, 4, 1, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1)
I0425 15:23:38.469506 133362793142080 train_distill.py:582] Applying logical axis rules for model initialization and training...
I0425 15:23:38.469578 133362793142080 train_distill.py:586] Loading Student from ...
I0425 15:23:38.469607 133362793142080 train_distill.py:169] --- Student Configuration ---
I0425 15:23:38.469628 133362793142080 train_distill.py:170]   Model Name:      gpt3-52k
I0425 15:23:38.469650 133362793142080 train_distill.py:171]   Dimensions:      1 Layers, 16 Emb Dim, 8 Head Dim
I0425 15:23:38.469669 133362793142080 train_distill.py:174]   Attention Heads: 2 Query, 2 KV
I0425 15:23:38.469687 133362793142080 train_distill.py:175]   Vocab Size:      32000
I0425 15:23:38.469704 133362793142080 train_distill.py:176]   Checkpoint:      
I0425 15:23:38.469723 133362793142080 train_distill.py:460] Initializing model: gpt3-52k...
I0425 15:23:39.755678 133362793142080 train_distill.py:600] Loading Teacher from gs://lance-maxtext/pt_seed_ckpts/pt_seed_ckpts/pt_seed_ckpt_gpt352k_v32k_linen/checkpoints/4/items...
I0425 15:23:39.755789 133362793142080 train_distill.py:169] --- Teacher Configuration ---
I0425 15:23:39.755819 133362793142080 train_distill.py:170]   Model Name:      gpt3-52k
I0425 15:23:39.755845 133362793142080 train_distill.py:171]   Dimensions:      1 Layers, 16 Emb Dim, 8 Head Dim
I0425 15:23:39.755867 133362793142080 train_distill.py:174]   Attention Heads: 2 Query, 2 KV
I0425 15:23:39.755888 133362793142080 train_distill.py:175]   Vocab Size:      32000
I0425 15:23:39.755911 133362793142080 train_distill.py:176]   Checkpoint:      gs://lance-maxtext/pt_seed_ckpts/pt_seed_ckpts/pt_seed_ckpt_gpt352k_v32k_linen/checkpoints/4/items
I0425 15:23:39.755930 133362793142080 train_distill.py:460] Initializing model: gpt3-52k...
I0425 15:23:40.717142 133362793142080 pytree_checkpoint_handler.py:577] save_device_host_concurrent_bytes=None
I0425 15:23:40.717573 133362793142080 base_pytree_checkpoint_handler.py:411] Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=True, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x794a3a676060>, enable_pinned_host_transfer=False, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB)
I0425 15:23:40.717633 133362793142080 abstract_checkpointer.py:35] orbax-checkpoint version: 0.11.28
W0425 15:23:41.250195 133362793142080 checkpoint.py:202] Metadata file does not exist: gs://lance-maxtext/pt_seed_ckpts/pt_seed_ckpts/pt_seed_ckpt_gpt352k_v32k_linen/checkpoints/4/items/_CHECKPOINT_METADATA
I0425 15:23:41.791632    2138 google_auth_provider.cc:181] Running on GCE, using service account 562977990677-compute@developer.gserviceaccount.com
I0425 15:23:42.944052 133362793142080 checkpointer.py:304] Restoring checkpoint from gs://lance-maxtext/pt_seed_ckpts/pt_seed_ckpts/pt_seed_ckpt_gpt352k_v32k_linen/checkpoints/4/items.
W0425 15:23:45.099823 133362793142080 transform_utils.py:230] The transformations API will eventually be replaced by an upgraded design. The current API will not be removed until this point, but it will no longer be actively worked on.
I0425 15:23:45.100281 133362793142080 transform_utils.py:288] The following keys are not loaded from the original tree after applying specified transforms: params/params/decoder/dropout/rngs/aqt/count, params/params/decoder/dropout/rngs/aqt/key, params/params/decoder/dropout/rngs/dropout/count, params/params/decoder/dropout/rngs/dropout/key, params/params/decoder/dropout/rngs/params/count, params/params/decoder/dropout/rngs/params/key, params/params/decoder/layers/dropout/rngs/aqt/count, params/params/decoder/layers/dropout/rngs/aqt/key, params/params/decoder/layers/dropout/rngs/dropout/count, params/params/decoder/layers/dropout/rngs/dropout/key, params/params/decoder/layers/dropout/rngs/params/count, params/params/decoder/layers/dropout/rngs/params/key, params/params/decoder/layers/mlp/dropout/rngs/aqt/count, params/params/decoder/layers/mlp/dropout/rngs/aqt/key, params/params/decoder/layers/mlp/dropout/rngs/dropout/count, params/params/decoder/layers/mlp/dropout/rngs/dropout/key, params/params/decoder/layers/mlp/dropout/rngs/params/count, params/params/decoder/layers/mlp/dropout/rngs/params/key, params/params/decoder/layers/rngs/aqt/count, params/params/decoder/layers/rngs/aqt/key, params/params/decoder/layers/rngs/dropout/count, params/params/decoder/layers/rngs/dropout/key, params/params/decoder/layers/rngs/params/count, params/params/decoder/layers/rngs/params/key, params/params/decoder/rngs/aqt/count, params/params/decoder/rngs/aqt/key, params/params/decoder/rngs/dropout/count, params/params/decoder/rngs/dropout/key, params/params/decoder/rngs/params/count, params/params/decoder/rngs/params/key
I0425 15:23:46.042373 133362793142080 checkpointer.py:318] Finished restoring checkpoint in 3.48 seconds from gs://lance-maxtext/pt_seed_ckpts/pt_seed_ckpts/pt_seed_ckpt_gpt352k_v32k_linen/checkpoints/4/items.
I0425 15:23:46.742864 133362793142080 train_distill.py:626] Initializing Data Iterators via MaxText pipeline...
I0425 15:23:46.807441 133362793142080 config.py:112] TensorFlow version 2.20.0 available.
I0425 15:23:46.807937 133362793142080 config.py:125] JAX version 0.8.3 available.
E0425 15:23:48.836707 133362793142080 packing.py:209] PackAndBatchOperation is deprecated. Please use lazy_dataset.FirstFitPackIterDataset instead.
I0425 15:23:48.836965 133362793142080 data_loader.py:408] Adding CopyNumPyArrayToSharedMemory MapTransform.
I0425 15:23:48.840331 133362793142080 train_distill.py:405] Input Pipeline Checkpointing: DISABLED
I0425 15:23:48.840397 133362793142080 train_distill.py:409] Reason: Iterator 'MultiHostDataLoadIterator' is not recognized as Grain (dataset_type='DatasetType.HF', has_save=False)
I0425 15:23:48.840482 133362793142080 pytree_checkpoint_handler.py:577] save_device_host_concurrent_bytes=None
I0425 15:23:48.840578 133362793142080 base_pytree_checkpoint_handler.py:411] Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=False, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x794a3a676060>, enable_pinned_host_transfer=False, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB)
I0425 15:23:48.840641 133362793142080 pytree_checkpoint_handler.py:577] save_device_host_concurrent_bytes=None
I0425 15:23:48.840687 133362793142080 base_pytree_checkpoint_handler.py:411] Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=False, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x794a3a676060>, enable_pinned_host_transfer=False, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB)
I0425 15:23:48.840748 133362793142080 checkpoint_manager.py:702] [process=6][thread=MainThread] CheckpointManager init: checkpointers=None, item_names=None, item_handlers={'model_params': <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x79400ad2b770>, 'optimizer_state': <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x79400ad2b6b0>, 'custom_metadata': <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x79400ad28b00>}, handler_registry=None
I0425 15:23:48.840986 133362793142080 composite_checkpoint_handler.py:237] Deferred registration for item: "model_params". Adding handler `<orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x79400ad2b770>` for item "model_params" and save args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>` to `_handler_registry`.
I0425 15:23:48.841036 133362793142080 composite_checkpoint_handler.py:237] Deferred registration for item: "optimizer_state". Adding handler `<orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x79400ad2b6b0>` for item "optimizer_state" and save args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>` to `_handler_registry`.
I0425 15:23:48.841064 133362793142080 composite_checkpoint_handler.py:237] Deferred registration for item: "custom_metadata". Adding handler `<orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x79400ad28b00>` for item "custom_metadata" and save args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>` to `_handler_registry`.
I0425 15:23:48.841089 133362793142080 composite_checkpoint_handler.py:237] Deferred registration for item: "metrics". Adding handler `<orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x793369d94950>` for item "metrics" and save args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>` to `_handler_registry`.
I0425 15:23:48.841137 133362793142080 composite_checkpoint_handler.py:505] Initialized registry DefaultCheckpointHandlerRegistry({('model_params', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x79400ad2b770>, ('model_params', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x79400ad2b770>, ('optimizer_state', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x79400ad2b6b0>, ('optimizer_state', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x79400ad2b6b0>, ('custom_metadata', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x79400ad28b00>, ('custom_metadata', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x79400ad28b00>, ('metrics', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x793369d94950>, ('metrics', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x793369d94950>}).
I0425 15:23:48.843389 133362793142080 async_checkpointer.py:177] [process=6][thread=MainThread] Using barrier_sync_fn: <function get_barrier_sync_fn.<locals>._fn at 0x79400aca1ee0> timeout: 600 secs and primary_host=0 for async checkpoint writes
I0425 15:23:51.298186 133362793142080 checkpoint_manager.py:1788] Found 0 checkpoint steps in gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_set_defaults_true_20260425_150421/pt_distill_nnx_xpk_feat_nnx_set_defaults_true_20260425_150421_07_distill_smoke/checkpoints
I0425 15:23:52.154459 133362793142080 checkpoint_manager.py:921] [process=6][thread=MainThread] CheckpointManager created,  primary_host=0, CheckpointManagerOptions=CheckpointManagerOptions(save_interval_steps=2000, max_to_keep=None, keep_time_interval=None, keep_period=None, should_keep_fn=None, best_fn=None, best_mode='max', keep_checkpoints_without_metrics=True, step_prefix=None, step_format_fixed_length=None, step_name_format=None, create=True, cleanup_tmp_directories=False, save_on_steps=frozenset(), single_host_load_and_broadcast=False, todelete_subdir=None, todelete_full_path=None, enable_hns=False, enable_background_delete=False, read_only=False, enable_async_checkpointing=True, async_options=None, multiprocessing_options=MultiprocessingOptions(primary_host=0, active_processes=None, barrier_sync_key_prefix=None), should_save_fn=None, file_options=FileOptions(path_permission_mode=None), save_root_metadata=True, temporary_path_class=None, save_decision_policy=None, preservation_policy=None, prevent_write_metrics=False, enable_should_save_is_saving_in_progress_check=True, enable_per_process_directory_creation=False), root_directory=gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_set_defaults_true_20260425_150421/pt_distill_nnx_xpk_feat_nnx_set_defaults_true_20260425_150421_07_distill_smoke/checkpoints: <orbax.checkpoint.checkpoint_manager.CheckpointManager object at 0x79400ad290d0>
I0425 15:23:52.154638 133362793142080 pytree_checkpoint_handler.py:577] save_device_host_concurrent_bytes=None
I0425 15:23:52.154709 133362793142080 base_pytree_checkpoint_handler.py:411] Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=False, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x794a3a676060>, enable_pinned_host_transfer=False, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB)
I0425 15:23:52.154747 133362793142080 pytree_checkpoint_handler.py:577] save_device_host_concurrent_bytes=None
I0425 15:23:52.154779 133362793142080 base_pytree_checkpoint_handler.py:411] Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=False, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x794a3a676060>, enable_pinned_host_transfer=False, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB)
I0425 15:23:52.154816 133362793142080 checkpoint_manager.py:1983] [process=6][thread=MainThread][wait_until_finished] No Save Finalize thread to wait for. Returning.
I0425 15:23:52.154869 133362793142080 checkpoint.py:459] Closing _NonBlockingMetadataStore(enable_write=True, _write_lock=<locked _thread.RLock object owner=133362793142080 count=1 at 0x793369b06400>, _store_impl=<orbax.checkpoint._src.metadata.checkpoint._MetadataStoreImpl object at 0x79336852f050>, _single_thread_executor=<concurrent.futures.thread.ThreadPoolExecutor object at 0x79336852f080>, _write_futures=[])
I0425 15:23:52.155215 133362793142080 checkpoint.py:459] Closing _NonBlockingMetadataStore(enable_write=True, _write_lock=<locked _thread.RLock object owner=133362793142080 count=1 at 0x793369b06400>, _store_impl=<orbax.checkpoint._src.metadata.checkpoint._MetadataStoreImpl object at 0x79336852f050>, _single_thread_executor=<concurrent.futures.thread.ThreadPoolExecutor object at 0x79336852f080>, _write_futures=[])
I0425 15:23:52.155241 133362793142080 checkpoint.py:459] Closing _NonBlockingMetadataStore(enable_write=True, _write_lock=<locked _thread.RLock object owner=133362793142080 count=1 at 0x793369b06400>, _store_impl=<orbax.checkpoint._src.metadata.checkpoint._MetadataStoreImpl object at 0x79336852f050>, _single_thread_executor=<concurrent.futures.thread.ThreadPoolExecutor object at 0x79336852f080>, _write_futures=[])
I0425 15:23:52.155271 133362793142080 checkpoint_manager.py:702] [process=6][thread=MainThread] CheckpointManager init: checkpointers=None, item_names=None, item_handlers={'model_params': <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x79400ad291c0>, 'optimizer_state': <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x79336852ddc0>, 'custom_metadata': <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x79336852df40>, 'iter': <maxtext.common.checkpointing.GrainCheckpointHandler object at 0x79336852dfa0>}, handler_registry=None
I0425 15:23:52.155371 133362793142080 composite_checkpoint_handler.py:237] Deferred registration for item: "model_params". Adding handler `<orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x79400ad291c0>` for item "model_params" and save args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>` to `_handler_registry`.
I0425 15:23:52.155405 133362793142080 composite_checkpoint_handler.py:237] Deferred registration for item: "optimizer_state". Adding handler `<orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x79336852ddc0>` for item "optimizer_state" and save args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>` to `_handler_registry`.
I0425 15:23:52.155429 133362793142080 composite_checkpoint_handler.py:237] Deferred registration for item: "custom_metadata". Adding handler `<orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x79336852df40>` for item "custom_metadata" and save args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>` to `_handler_registry`.
I0425 15:23:52.155456 133362793142080 composite_checkpoint_handler.py:237] Deferred registration for item: "iter". Adding handler `<maxtext.common.checkpointing.GrainCheckpointHandler object at 0x79336852dfa0>` for item "iter" and save args `<class 'maxtext.common.checkpointing.GrainCheckpointSave'>` and restore args `<class 'maxtext.common.checkpointing.GrainCheckpointRestore'>` to `_handler_registry`.
I0425 15:23:52.155479 133362793142080 composite_checkpoint_handler.py:237] Deferred registration for item: "metrics". Adding handler `<orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x79336852d8e0>` for item "metrics" and save args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>` to `_handler_registry`.
I0425 15:23:52.155508 133362793142080 composite_checkpoint_handler.py:505] Initialized registry DefaultCheckpointHandlerRegistry({('model_params', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x79400ad291c0>, ('model_params', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x79400ad291c0>, ('optimizer_state', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x79336852ddc0>, ('optimizer_state', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x79336852ddc0>, ('custom_metadata', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x79336852df40>, ('custom_metadata', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x79336852df40>, ('iter', <class 'maxtext.common.checkpointing.GrainCheckpointSave'>): <maxtext.common.checkpointing.GrainCheckpointHandler object at 0x79336852dfa0>, ('iter', <class 'maxtext.common.checkpointing.GrainCheckpointRestore'>): <maxtext.common.checkpointing.GrainCheckpointHandler object at 0x79336852dfa0>, ('metrics', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x79336852d8e0>, ('metrics', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x79336852d8e0>}).
I0425 15:23:52.155578 133362793142080 async_checkpointer.py:177] [process=6][thread=MainThread] Using barrier_sync_fn: <function get_barrier_sync_fn.<locals>._fn at 0x79400aca2020> timeout: 600 secs and primary_host=0 for async checkpoint writes
I0425 15:23:52.546021 133362793142080 checkpoint_manager.py:1788] Found 0 checkpoint steps in gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_set_defaults_true_20260425_150421/pt_distill_nnx_xpk_feat_nnx_set_defaults_true_20260425_150421_07_distill_smoke/checkpoints
I0425 15:23:52.548172 133362793142080 checkpoint_manager.py:921] [process=6][thread=MainThread] CheckpointManager created,  primary_host=0, CheckpointManagerOptions=CheckpointManagerOptions(save_interval_steps=2000, max_to_keep=None, keep_time_interval=None, keep_period=None, should_keep_fn=None, best_fn=None, best_mode='max', keep_checkpoints_without_metrics=True, step_prefix=None, step_format_fixed_length=None, step_name_format=None, create=True, cleanup_tmp_directories=False, save_on_steps=frozenset(), single_host_load_and_broadcast=False, todelete_subdir=None, todelete_full_path=None, enable_hns=False, enable_background_delete=False, read_only=False, enable_async_checkpointing=True, async_options=None, multiprocessing_options=MultiprocessingOptions(primary_host=0, active_processes=None, barrier_sync_key_prefix=None), should_save_fn=None, file_options=FileOptions(path_permission_mode=None), save_root_metadata=True, temporary_path_class=None, save_decision_policy=None, preservation_policy=None, prevent_write_metrics=False, enable_should_save_is_saving_in_progress_check=True, enable_per_process_directory_creation=False), root_directory=gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_set_defaults_true_20260425_150421/pt_distill_nnx_xpk_feat_nnx_set_defaults_true_20260425_150421_07_distill_smoke/checkpoints: <orbax.checkpoint.checkpoint_manager.CheckpointManager object at 0x79309c1f7860>
I0425 15:23:52.548333 133362793142080 train_distill.py:673] Starting Distillation Training...
I0425 15:23:52.548422 133362793142080 peft_trainer.py:590] Training with mesh: Mesh('diloco': 1, 'data': 4, 'stage': 1, 'fsdp': 8, 'fsdp_transpose': 1, 'sequence': 1, 'context': 1, 'context_autoregressive': 1, 'tensor': 1, 'tensor_transpose': 1, 'tensor_sequence': 1, 'expert': 1, 'autoregressive': 1, axis_types=(Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto))
I0425 15:23:52.677599 133362793142080 peft_trainer.py:600] Compiled train_step cache size: 0

Training:   0%|          | 0/5 [00:00<?, ?step/s]I0425 15:23:52.679485 133218468620032 grain_pool.py:367] Grain pool will use 1 processes.
I0425 15:23:52.706627 133218468620032 grain_pool.py:440] Grain pool will start child processes.
I0425 15:23:52.711973 133218468620032 grain_pool.py:448] Grain pool started all child processes.
2026-04-25 15:23:58.850988: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)
I0425 15:24:02.152671 133362793142080 utils.py:86] Train loop finished in: 9.4745 seconds
Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/deps/src/maxtext/trainers/post_train/distillation/train_distill.py", line 747, in <module>
    app.run(main)
  File "/usr/local/lib/python3.12/site-packages/absl/app.py", line 316, in run
    _run_main(main, args)
  File "/usr/local/lib/python3.12/site-packages/absl/app.py", line 261, in _run_main
    sys.exit(main(argv))
             ^^^^^^^^^^
  File "/deps/src/maxtext/trainers/post_train/distillation/train_distill.py", line 743, in main
    train_distill(student_config, teacher_config, is_offline, global_config.offline_data_dir)
  File "/deps/src/maxtext/trainers/post_train/distillation/train_distill.py", line 675, in train_distill
    trainer.train(train_iter, eval_iter)
  File "/usr/local/lib/python3.12/site-packages/tunix/sft/peft_trainer.py", line 659, in train
    train_example = sharding_utils.shard_input(
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/tunix/sft/sharding_utils.py", line 58, in shard_input
    return jax.tree.map(
           ^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/tree.py", line 155, in map
    return tree_util.tree_map(f, tree, *rest, is_leaf=is_leaf)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/tree_util.py", line 362, in tree_map
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/tree_util.py", line 362, in <genexpr>
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
                             ^^^^^^
  File "/usr/local/lib/python3.12/site-packages/tunix/sft/sharding_utils.py", line 59, in <lambda>
    lambda x: jax.make_array_from_process_local_data(
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/array.py", line 986, in make_array_from_process_local_data
    out = [_array_from_process_local_data(data, s, shape)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/array.py", line 1048, in _array_from_process_local_data
    return make_array_from_callback(global_shape, sharding, cb)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/array.py", line 845, in make_array_from_callback
    per_device_values = api.device_put(per_device_values, devices)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/api.py", line 2729, in device_put
    out_flat = dispatch._batched_device_put_impl(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/dispatch.py", line 558, in _batched_device_put_impl
    y = _device_put_impl(x, device=device, src=src, copy=cp, aval=aval)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/dispatch.py", line 545, in _device_put_impl
    return _device_put_sharding_impl(x, aval, device, copy)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/dispatch.py", line 487, in _device_put_sharding_impl
    raise ValueError(
ValueError: device_put's first argument must be a fully addressable array, but got value with devices {TpuDevice(id=18, process_index=5, coords=(2,4,0), core_on_chip=0), TpuDevice(id=21, process_index=4, coords=(1,5,0), core_on_chip=0), TpuDevice(id=7, process_index=1, coords=(3,1,0), core_on_chip=0), TpuDevice(id=15, process_index=3, coords=(3,3,0), core_on_chip=0), TpuDevice(id=8, process_index=2, coords=(0,2,0), core_on_chip=0), TpuDevice(id=25, process_index=6, coords=(1,6,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=14, process_index=3, coords=(2,3,0), core_on_chip=0), TpuDevice(id=30, process_index=7, coords=(2,7,0), core_on_chip=0), TpuDevice(id=12, process_index=2, coords=(0,3,0), core_on_chip=0), TpuDevice(id=9, process_index=2, coords=(1,2,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=28, process_index=6, coords=(0,7,0), core_on_chip=0), TpuDevice(id=17, process_index=4, coords=(1,4,0), core_on_chip=0), TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=22, process_index=5, coords=(2,5,0), core_on_chip=0), TpuDevice(id=31, process_index=7, coords=(3,7,0), core_on_chip=0), TpuDevice(id=13, process_index=2, coords=(1,3,0), core_on_chip=0), TpuDevice(id=10, process_index=3, coords=(2,2,0), core_on_chip=0), TpuDevice(id=19, process_index=5, coords=(3,4,0), core_on_chip=0), TpuDevice(id=6, process_index=1, coords=(2,1,0), core_on_chip=0), TpuDevice(id=23, process_index=5, coords=(3,5,0), core_on_chip=0), TpuDevice(id=29, process_index=6, coords=(1,7,0), core_on_chip=0), TpuDevice(id=2, process_index=1, coords=(2,0,0), core_on_chip=0), TpuDevice(id=24, process_index=6, coords=(0,6,0), core_on_chip=0), TpuDevice(id=16, process_index=4, coords=(0,4,0), core_on_chip=0), TpuDevice(id=27, process_index=7, coords=(3,6,0), core_on_chip=0), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=20, process_index=4, coords=(0,5,0), core_on_chip=0), TpuDevice(id=26, process_index=7, coords=(2,6,0), core_on_chip=0), TpuDevice(id=11, process_index=3, coords=(3,2,0), core_on_chip=0), TpuDevice(id=3, process_index=1, coords=(3,0,0), core_on_chip=0)}
I0425 15:24:02.498421 133218468620032 grain_pool.py:542] Grain pool is exiting.
I0425 15:24:02.498523 133218468620032 grain_pool.py:547] Shutting down multiprocessing system.
I0425 15:24:03.962410 133218468620032 grain_pool.py:547] Shutting down multiprocessing system.

Training:   0%|          | 0/5 [00:13<?, ?step/s]
/usr/local/lib/python3.12/multiprocessing/resource_tracker.py:279: UserWarning: resource_tracker: There appear to be 15 leaked shared_memory objects to clean up at shutdown
  warnings.warn('resource_tracker: There appear to be %d '
XPK End: Sat Apr 25 15:24:13 UTC 2026
EXIT_CODE=1