XPK Start: Fri Apr 24 15:19:32 UTC 2026 2026-04-24 15:19:49.390541: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303) I0424 15:19:52.962687 137557288490816 max_utils.py:273] Attempting to initialize the jax distributed system... INFO:2026-04-24 15:20:02,002:jax._src.distributed:149: Starting JAX distributed service on [::]:8482 I0424 15:20:02.002367 137557288490816 distributed.py:149] Starting JAX distributed service on [::]:8482 INFO:2026-04-24 15:20:02,004:jax._src.distributed:166: Connecting to JAX distributed service on mt-07-distill-smoke-wch4j-slice-job-0-0.mt-07-distill-smoke-wch4j:8482 I0424 15:20:02.004570 137557288490816 distributed.py:166] Connecting to JAX distributed service on mt-07-distill-smoke-wch4j-slice-job-0-0.mt-07-distill-smoke-wch4j:8482 I0424 15:20:03.076986 137557288490816 max_utils.py:284] Jax distributed system initialized! I0424 15:20:09.308848 137557288490816 max_utils.py:244] Jax distributed system is already initialized. I0424 15:20:09.792050 137557288490816 max_utils.py:244] Jax distributed system is already initialized. I0424 15:20:09.793503 137557288490816 tokenizer.py:245] Tokenizer path: meta-llama/Llama-2-7b-chat-hf I0424 15:20:09.793575 137557288490816 tokenizer.py:224] Loading HF tokenizer: meta-llama/Llama-2-7b-chat-hf I0424 15:20:13.619421 137557288490816 _schedule.py:129] A polynomial schedule was set with a non-positive `transition_steps` value; this results in a constant schedule with value `init_value`. I0424 15:20:13.622502 137557288490816 maxtext_utils.py:1631] Num_devices: 32, shape (1, 4, 1, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1) I0424 15:20:13.622635 137557288490816 train_distill.py:582] Applying logical axis rules for model initialization and training... I0424 15:20:13.622709 137557288490816 train_distill.py:586] Loading Student from ... I0424 15:20:13.622740 137557288490816 train_distill.py:169] --- Student Configuration --- I0424 15:20:13.622762 137557288490816 train_distill.py:170] Model Name: gpt3-52k I0424 15:20:13.622786 137557288490816 train_distill.py:171] Dimensions: 1 Layers, 16 Emb Dim, 8 Head Dim I0424 15:20:13.622806 137557288490816 train_distill.py:174] Attention Heads: 2 Query, 2 KV I0424 15:20:13.622823 137557288490816 train_distill.py:175] Vocab Size: 32000 I0424 15:20:13.622840 137557288490816 train_distill.py:176] Checkpoint: I0424 15:20:13.622859 137557288490816 train_distill.py:460] Initializing model: gpt3-52k... I0424 15:20:14.899807 137557288490816 train_distill.py:600] Loading Teacher from gs://lance-maxtext/pt_seed_ckpts/pt_seed_ckpts/pt_seed_ckpt_gpt352k_v32k_linen/checkpoints/4/items... I0424 15:20:14.899921 137557288490816 train_distill.py:169] --- Teacher Configuration --- I0424 15:20:14.899950 137557288490816 train_distill.py:170] Model Name: gpt3-52k I0424 15:20:14.899973 137557288490816 train_distill.py:171] Dimensions: 1 Layers, 16 Emb Dim, 8 Head Dim I0424 15:20:14.899994 137557288490816 train_distill.py:174] Attention Heads: 2 Query, 2 KV I0424 15:20:14.900015 137557288490816 train_distill.py:175] Vocab Size: 32000 I0424 15:20:14.900032 137557288490816 train_distill.py:176] Checkpoint: gs://lance-maxtext/pt_seed_ckpts/pt_seed_ckpts/pt_seed_ckpt_gpt352k_v32k_linen/checkpoints/4/items I0424 15:20:14.900053 137557288490816 train_distill.py:460] Initializing model: gpt3-52k... I0424 15:20:15.869744 137557288490816 pytree_checkpoint_handler.py:577] save_device_host_concurrent_bytes=None I0424 15:20:15.870175 137557288490816 base_pytree_checkpoint_handler.py:411] Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=True, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x7d1ad8e6f260>, enable_pinned_host_transfer=False, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB) I0424 15:20:15.870236 137557288490816 abstract_checkpointer.py:35] orbax-checkpoint version: 0.11.28 W0424 15:20:16.398228 137557288490816 checkpoint.py:202] Metadata file does not exist: gs://lance-maxtext/pt_seed_ckpts/pt_seed_ckpts/pt_seed_ckpt_gpt352k_v32k_linen/checkpoints/4/items/_CHECKPOINT_METADATA I0424 15:20:16.951308 2139 google_auth_provider.cc:181] Running on GCE, using service account 562977990677-compute@developer.gserviceaccount.com I0424 15:20:18.129154 137557288490816 checkpointer.py:304] Restoring checkpoint from gs://lance-maxtext/pt_seed_ckpts/pt_seed_ckpts/pt_seed_ckpt_gpt352k_v32k_linen/checkpoints/4/items. W0424 15:20:20.280083 137557288490816 transform_utils.py:230] The transformations API will eventually be replaced by an upgraded design. The current API will not be removed until this point, but it will no longer be actively worked on. I0424 15:20:20.280547 137557288490816 transform_utils.py:288] The following keys are not loaded from the original tree after applying specified transforms: params/params/decoder/dropout/rngs/aqt/count, params/params/decoder/dropout/rngs/aqt/key, params/params/decoder/dropout/rngs/dropout/count, params/params/decoder/dropout/rngs/dropout/key, params/params/decoder/dropout/rngs/params/count, params/params/decoder/dropout/rngs/params/key, params/params/decoder/layers/dropout/rngs/aqt/count, params/params/decoder/layers/dropout/rngs/aqt/key, params/params/decoder/layers/dropout/rngs/dropout/count, params/params/decoder/layers/dropout/rngs/dropout/key, params/params/decoder/layers/dropout/rngs/params/count, params/params/decoder/layers/dropout/rngs/params/key, params/params/decoder/layers/mlp/dropout/rngs/aqt/count, params/params/decoder/layers/mlp/dropout/rngs/aqt/key, params/params/decoder/layers/mlp/dropout/rngs/dropout/count, params/params/decoder/layers/mlp/dropout/rngs/dropout/key, params/params/decoder/layers/mlp/dropout/rngs/params/count, params/params/decoder/layers/mlp/dropout/rngs/params/key, params/params/decoder/layers/rngs/aqt/count, params/params/decoder/layers/rngs/aqt/key, params/params/decoder/layers/rngs/dropout/count, params/params/decoder/layers/rngs/dropout/key, params/params/decoder/layers/rngs/params/count, params/params/decoder/layers/rngs/params/key, params/params/decoder/rngs/aqt/count, params/params/decoder/rngs/aqt/key, params/params/decoder/rngs/dropout/count, params/params/decoder/rngs/dropout/key, params/params/decoder/rngs/params/count, params/params/decoder/rngs/params/key I0424 15:20:21.100839 137557288490816 checkpointer.py:318] Finished restoring checkpoint in 3.37 seconds from gs://lance-maxtext/pt_seed_ckpts/pt_seed_ckpts/pt_seed_ckpt_gpt352k_v32k_linen/checkpoints/4/items. I0424 15:20:21.802968 137557288490816 train_distill.py:626] Initializing Data Iterators via MaxText pipeline... I0424 15:20:21.867631 137557288490816 config.py:112] TensorFlow version 2.20.0 available. I0424 15:20:21.868157 137557288490816 config.py:125] JAX version 0.8.3 available. E0424 15:20:23.837157 137557288490816 packing.py:209] PackAndBatchOperation is deprecated. Please use lazy_dataset.FirstFitPackIterDataset instead. I0424 15:20:23.837388 137557288490816 data_loader.py:408] Adding CopyNumPyArrayToSharedMemory MapTransform. I0424 15:20:23.840375 137557288490816 train_distill.py:405] Input Pipeline Checkpointing: DISABLED I0424 15:20:23.840440 137557288490816 train_distill.py:409] Reason: Iterator 'MultiHostDataLoadIterator' is not recognized as Grain (dataset_type='DatasetType.HF', has_save=False) I0424 15:20:23.840504 137557288490816 pytree_checkpoint_handler.py:577] save_device_host_concurrent_bytes=None I0424 15:20:23.840581 137557288490816 base_pytree_checkpoint_handler.py:411] Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=False, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x7d1ad8e6f260>, enable_pinned_host_transfer=False, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB) I0424 15:20:23.840623 137557288490816 pytree_checkpoint_handler.py:577] save_device_host_concurrent_bytes=None I0424 15:20:23.840655 137557288490816 base_pytree_checkpoint_handler.py:411] Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=False, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x7d1ad8e6f260>, enable_pinned_host_transfer=False, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB) I0424 15:20:23.840698 137557288490816 checkpoint_manager.py:702] [process=6][thread=MainThread] CheckpointManager init: checkpointers=None, item_names=None, item_handlers={'model_params': <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7d02ef003b30>, 'optimizer_state': <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7d02ef001670>, 'custom_metadata': <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7d02ef000e30>}, handler_registry=None I0424 15:20:23.840893 137557288490816 composite_checkpoint_handler.py:237] Deferred registration for item: "model_params". Adding handler `<orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7d02ef003b30>` for item "model_params" and save args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>` to `_handler_registry`. I0424 15:20:23.840934 137557288490816 composite_checkpoint_handler.py:237] Deferred registration for item: "optimizer_state". Adding handler `<orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7d02ef001670>` for item "optimizer_state" and save args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>` to `_handler_registry`. I0424 15:20:23.840961 137557288490816 composite_checkpoint_handler.py:237] Deferred registration for item: "custom_metadata". Adding handler `<orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7d02ef000e30>` for item "custom_metadata" and save args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>` to `_handler_registry`. I0424 15:20:23.840984 137557288490816 composite_checkpoint_handler.py:237] Deferred registration for item: "metrics". Adding handler `<orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7d02eec73710>` for item "metrics" and save args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>` to `_handler_registry`. I0424 15:20:23.841011 137557288490816 composite_checkpoint_handler.py:505] Initialized registry DefaultCheckpointHandlerRegistry({('model_params', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7d02ef003b30>, ('model_params', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7d02ef003b30>, ('optimizer_state', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7d02ef001670>, ('optimizer_state', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7d02ef001670>, ('custom_metadata', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7d02ef000e30>, ('custom_metadata', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7d02ef000e30>, ('metrics', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7d02eec73710>, ('metrics', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7d02eec73710>}). I0424 15:20:23.842839 137557288490816 async_checkpointer.py:177] [process=6][thread=MainThread] Using barrier_sync_fn: <function get_barrier_sync_fn.<locals>._fn at 0x7d02eec2dee0> timeout: 600 secs and primary_host=0 for async checkpoint writes I0424 15:20:26.319814 137557288490816 checkpoint_manager.py:1788] Found 0 checkpoint steps in gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_set_defaults_true_20260424_145753/pt_distill_nnx_xpk_feat_nnx_set_defaults_true_20260424_145753_07_distill_smoke/checkpoints I0424 15:20:26.322489 137557288490816 checkpoint_manager.py:921] [process=6][thread=MainThread] CheckpointManager created, primary_host=0, CheckpointManagerOptions=CheckpointManagerOptions(save_interval_steps=2000, max_to_keep=None, keep_time_interval=None, keep_period=None, should_keep_fn=None, best_fn=None, best_mode='max', keep_checkpoints_without_metrics=True, step_prefix=None, step_format_fixed_length=None, step_name_format=None, create=True, cleanup_tmp_directories=False, save_on_steps=frozenset(), single_host_load_and_broadcast=False, todelete_subdir=None, todelete_full_path=None, enable_hns=False, enable_background_delete=False, read_only=False, enable_async_checkpointing=True, async_options=None, multiprocessing_options=MultiprocessingOptions(primary_host=0, active_processes=None, barrier_sync_key_prefix=None), should_save_fn=None, file_options=FileOptions(path_permission_mode=None), save_root_metadata=True, temporary_path_class=None, save_decision_policy=None, preservation_policy=None, prevent_write_metrics=False, enable_should_save_is_saving_in_progress_check=True, enable_per_process_directory_creation=False), root_directory=gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_set_defaults_true_20260424_145753/pt_distill_nnx_xpk_feat_nnx_set_defaults_true_20260424_145753_07_distill_smoke/checkpoints: <orbax.checkpoint.checkpoint_manager.CheckpointManager object at 0x7d02ef0008c0> I0424 15:20:26.322597 137557288490816 pytree_checkpoint_handler.py:577] save_device_host_concurrent_bytes=None I0424 15:20:26.322661 137557288490816 base_pytree_checkpoint_handler.py:411] Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=False, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x7d1ad8e6f260>, enable_pinned_host_transfer=False, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB) I0424 15:20:26.322697 137557288490816 pytree_checkpoint_handler.py:577] save_device_host_concurrent_bytes=None I0424 15:20:26.322728 137557288490816 base_pytree_checkpoint_handler.py:411] Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=False, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x7d1ad8e6f260>, enable_pinned_host_transfer=False, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB) I0424 15:20:26.322763 137557288490816 checkpoint_manager.py:1983] [process=6][thread=MainThread][wait_until_finished] No Save Finalize thread to wait for. Returning. I0424 15:20:26.322814 137557288490816 checkpoint.py:459] Closing _NonBlockingMetadataStore(enable_write=True, _write_lock=<locked _thread.RLock object owner=137557288490816 count=1 at 0x7d02eec0b6c0>, _store_impl=<orbax.checkpoint._src.metadata.checkpoint._MetadataStoreImpl object at 0x7d02eec73560>, _single_thread_executor=<concurrent.futures.thread.ThreadPoolExecutor object at 0x7d02eec73590>, _write_futures=[]) I0424 15:20:26.323157 137557288490816 checkpoint.py:459] Closing _NonBlockingMetadataStore(enable_write=True, _write_lock=<locked _thread.RLock object owner=137557288490816 count=1 at 0x7d02eec0b6c0>, _store_impl=<orbax.checkpoint._src.metadata.checkpoint._MetadataStoreImpl object at 0x7d02eec73560>, _single_thread_executor=<concurrent.futures.thread.ThreadPoolExecutor object at 0x7d02eec73590>, _write_futures=[]) I0424 15:20:26.323183 137557288490816 checkpoint.py:459] Closing _NonBlockingMetadataStore(enable_write=True, _write_lock=<locked _thread.RLock object owner=137557288490816 count=1 at 0x7d02eec0b6c0>, _store_impl=<orbax.checkpoint._src.metadata.checkpoint._MetadataStoreImpl object at 0x7d02eec73560>, _single_thread_executor=<concurrent.futures.thread.ThreadPoolExecutor object at 0x7d02eec73590>, _write_futures=[]) I0424 15:20:26.323213 137557288490816 checkpoint_manager.py:702] [process=6][thread=MainThread] CheckpointManager init: checkpointers=None, item_names=None, item_handlers={'model_params': <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7d02ef0005f0>, 'optimizer_state': <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7d02eec72390>, 'custom_metadata': <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7d02eec723c0>, 'iter': <maxtext.common.checkpointing.GrainCheckpointHandler object at 0x7d02eec71c40>}, handler_registry=None I0424 15:20:26.323313 137557288490816 composite_checkpoint_handler.py:237] Deferred registration for item: "model_params". Adding handler `<orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7d02ef0005f0>` for item "model_params" and save args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>` to `_handler_registry`. I0424 15:20:26.323350 137557288490816 composite_checkpoint_handler.py:237] Deferred registration for item: "optimizer_state". Adding handler `<orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7d02eec72390>` for item "optimizer_state" and save args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>` to `_handler_registry`. I0424 15:20:26.323375 137557288490816 composite_checkpoint_handler.py:237] Deferred registration for item: "custom_metadata". Adding handler `<orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7d02eec723c0>` for item "custom_metadata" and save args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>` to `_handler_registry`. I0424 15:20:26.323402 137557288490816 composite_checkpoint_handler.py:237] Deferred registration for item: "iter". Adding handler `<maxtext.common.checkpointing.GrainCheckpointHandler object at 0x7d02eec71c40>` for item "iter" and save args `<class 'maxtext.common.checkpointing.GrainCheckpointSave'>` and restore args `<class 'maxtext.common.checkpointing.GrainCheckpointRestore'>` to `_handler_registry`. I0424 15:20:26.323425 137557288490816 composite_checkpoint_handler.py:237] Deferred registration for item: "metrics". Adding handler `<orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7d02eec732f0>` for item "metrics" and save args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>` to `_handler_registry`. I0424 15:20:26.323449 137557288490816 composite_checkpoint_handler.py:505] Initialized registry DefaultCheckpointHandlerRegistry({('model_params', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7d02ef0005f0>, ('model_params', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7d02ef0005f0>, ('optimizer_state', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7d02eec72390>, ('optimizer_state', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7d02eec72390>, ('custom_metadata', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7d02eec723c0>, ('custom_metadata', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7d02eec723c0>, ('iter', <class 'maxtext.common.checkpointing.GrainCheckpointSave'>): <maxtext.common.checkpointing.GrainCheckpointHandler object at 0x7d02eec71c40>, ('iter', <class 'maxtext.common.checkpointing.GrainCheckpointRestore'>): <maxtext.common.checkpointing.GrainCheckpointHandler object at 0x7d02eec71c40>, ('metrics', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7d02eec732f0>, ('metrics', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7d02eec732f0>}). I0424 15:20:26.323516 137557288490816 async_checkpointer.py:177] [process=6][thread=MainThread] Using barrier_sync_fn: <function get_barrier_sync_fn.<locals>._fn at 0x7d02eec2e020> timeout: 600 secs and primary_host=0 for async checkpoint writes I0424 15:20:26.709237 137557288490816 checkpoint_manager.py:1788] Found 0 checkpoint steps in gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_set_defaults_true_20260424_145753/pt_distill_nnx_xpk_feat_nnx_set_defaults_true_20260424_145753_07_distill_smoke/checkpoints I0424 15:20:26.717228 137557288490816 checkpoint_manager.py:921] [process=6][thread=MainThread] CheckpointManager created, primary_host=0, CheckpointManagerOptions=CheckpointManagerOptions(save_interval_steps=2000, max_to_keep=None, keep_time_interval=None, keep_period=None, should_keep_fn=None, best_fn=None, best_mode='max', keep_checkpoints_without_metrics=True, step_prefix=None, step_format_fixed_length=None, step_name_format=None, create=True, cleanup_tmp_directories=False, save_on_steps=frozenset(), single_host_load_and_broadcast=False, todelete_subdir=None, todelete_full_path=None, enable_hns=False, enable_background_delete=False, read_only=False, enable_async_checkpointing=True, async_options=None, multiprocessing_options=MultiprocessingOptions(primary_host=0, active_processes=None, barrier_sync_key_prefix=None), should_save_fn=None, file_options=FileOptions(path_permission_mode=None), save_root_metadata=True, temporary_path_class=None, save_decision_policy=None, preservation_policy=None, prevent_write_metrics=False, enable_should_save_is_saving_in_progress_check=True, enable_per_process_directory_creation=False), root_directory=gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_set_defaults_true_20260424_145753/pt_distill_nnx_xpk_feat_nnx_set_defaults_true_20260424_145753_07_distill_smoke/checkpoints: <orbax.checkpoint.checkpoint_manager.CheckpointManager object at 0x7d02eec73c80> I0424 15:20:26.717404 137557288490816 train_distill.py:673] Starting Distillation Training... I0424 15:20:26.717522 137557288490816 peft_trainer.py:590] Training with mesh: Mesh('diloco': 1, 'data': 4, 'stage': 1, 'fsdp': 8, 'fsdp_transpose': 1, 'sequence': 1, 'context': 1, 'context_autoregressive': 1, 'tensor': 1, 'tensor_transpose': 1, 'tensor_sequence': 1, 'expert': 1, 'autoregressive': 1, axis_types=(Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto)) I0424 15:20:26.845584 137557288490816 peft_trainer.py:600] Compiled train_step cache size: 0 Training: 0%| | 0/5 [00:00<?, ?step/s]I0424 15:20:26.847494 137413032662784 grain_pool.py:367] Grain pool will use 1 processes. I0424 15:20:26.874532 137413032662784 grain_pool.py:440] Grain pool will start child processes. I0424 15:20:26.880037 137413032662784 grain_pool.py:448] Grain pool started all child processes. 2026-04-24 15:20:32.941925: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303) I0424 15:20:36.286113 137557288490816 utils.py:86] Train loop finished in: 9.4399 seconds Traceback (most recent call last): File "<frozen runpy>", line 198, in _run_module_as_main File "<frozen runpy>", line 88, in _run_code File "/deps/src/maxtext/trainers/post_train/distillation/train_distill.py", line 747, in <module> app.run(main) File "/usr/local/lib/python3.12/site-packages/absl/app.py", line 316, in run _run_main(main, args) File "/usr/local/lib/python3.12/site-packages/absl/app.py", line 261, in _run_main sys.exit(main(argv)) ^^^^^^^^^^ File "/deps/src/maxtext/trainers/post_train/distillation/train_distill.py", line 743, in main train_distill(student_config, teacher_config, is_offline, global_config.offline_data_dir) File "/deps/src/maxtext/trainers/post_train/distillation/train_distill.py", line 675, in train_distill trainer.train(train_iter, eval_iter) File "/usr/local/lib/python3.12/site-packages/tunix/sft/peft_trainer.py", line 659, in train train_example = sharding_utils.shard_input( ^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/site-packages/tunix/sft/sharding_utils.py", line 58, in shard_input return jax.tree.map( ^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/site-packages/jax/_src/tree.py", line 155, in map return tree_util.tree_map(f, tree, *rest, is_leaf=is_leaf) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/site-packages/jax/_src/tree_util.py", line 362, in tree_map return treedef.unflatten(f(*xs) for xs in zip(*all_leaves)) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/site-packages/jax/_src/tree_util.py", line 362, in <genexpr> return treedef.unflatten(f(*xs) for xs in zip(*all_leaves)) ^^^^^^ File "/usr/local/lib/python3.12/site-packages/tunix/sft/sharding_utils.py", line 59, in <lambda> lambda x: jax.make_array_from_process_local_data( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/site-packages/jax/_src/array.py", line 986, in make_array_from_process_local_data out = [_array_from_process_local_data(data, s, shape) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/site-packages/jax/_src/array.py", line 1048, in _array_from_process_local_data return make_array_from_callback(global_shape, sharding, cb) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/site-packages/jax/_src/array.py", line 845, in make_array_from_callback per_device_values = api.device_put(per_device_values, devices) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/site-packages/jax/_src/api.py", line 2729, in device_put out_flat = dispatch._batched_device_put_impl( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/site-packages/jax/_src/dispatch.py", line 558, in _batched_device_put_impl y = _device_put_impl(x, device=device, src=src, copy=cp, aval=aval) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/site-packages/jax/_src/dispatch.py", line 545, in _device_put_impl return _device_put_sharding_impl(x, aval, device, copy) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/site-packages/jax/_src/dispatch.py", line 487, in _device_put_sharding_impl raise ValueError( ValueError: device_put's first argument must be a fully addressable array, but got value with devices {TpuDevice(id=28, process_index=6, coords=(0,7,0), core_on_chip=0), TpuDevice(id=29, process_index=6, coords=(1,7,0), core_on_chip=0), TpuDevice(id=12, process_index=2, coords=(0,3,0), core_on_chip=0), TpuDevice(id=6, process_index=1, coords=(2,1,0), core_on_chip=0), TpuDevice(id=18, process_index=5, coords=(2,4,0), core_on_chip=0), TpuDevice(id=9, process_index=2, coords=(1,2,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=15, process_index=3, coords=(3,3,0), core_on_chip=0), TpuDevice(id=17, process_index=4, coords=(1,4,0), core_on_chip=0), TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=8, process_index=2, coords=(0,2,0), core_on_chip=0), TpuDevice(id=10, process_index=3, coords=(2,2,0), core_on_chip=0), TpuDevice(id=2, process_index=1, coords=(2,0,0), core_on_chip=0), TpuDevice(id=22, process_index=5, coords=(2,5,0), core_on_chip=0), TpuDevice(id=24, process_index=6, coords=(0,6,0), core_on_chip=0), TpuDevice(id=21, process_index=4, coords=(1,5,0), core_on_chip=0), TpuDevice(id=30, process_index=7, coords=(2,7,0), core_on_chip=0), TpuDevice(id=13, process_index=2, coords=(1,3,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=20, process_index=4, coords=(0,5,0), core_on_chip=0), TpuDevice(id=14, process_index=3, coords=(2,3,0), core_on_chip=0), TpuDevice(id=7, process_index=1, coords=(3,1,0), core_on_chip=0), TpuDevice(id=23, process_index=5, coords=(3,5,0), core_on_chip=0), TpuDevice(id=31, process_index=7, coords=(3,7,0), core_on_chip=0), TpuDevice(id=11, process_index=3, coords=(3,2,0), core_on_chip=0), TpuDevice(id=27, process_index=7, coords=(3,6,0), core_on_chip=0), TpuDevice(id=19, process_index=5, coords=(3,4,0), core_on_chip=0), TpuDevice(id=25, process_index=6, coords=(1,6,0), core_on_chip=0), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=26, process_index=7, coords=(2,6,0), core_on_chip=0), TpuDevice(id=16, process_index=4, coords=(0,4,0), core_on_chip=0), TpuDevice(id=3, process_index=1, coords=(3,0,0), core_on_chip=0)} I0424 15:20:36.634377 137413032662784 grain_pool.py:542] Grain pool is exiting. I0424 15:20:36.634479 137413032662784 grain_pool.py:547] Shutting down multiprocessing system. I0424 15:20:38.095810 137413032662784 grain_pool.py:547] Shutting down multiprocessing system. Training: 0%| | 0/5 [00:13<?, ?step/s] /usr/local/lib/python3.12/multiprocessing/resource_tracker.py:279: UserWarning: resource_tracker: There appear to be 15 leaked shared_memory objects to clean up at shutdown warnings.warn('resource_tracker: There appear to be %d ' XPK End: Fri Apr 24 15:20:48 UTC 2026 EXIT_CODE=1