MaxView

← Back to run

Log Summary

XPK Start: Tue Apr 21 01:13:10 UTC 2026
2026-04-21 01:13:27.230264: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)
I0421 01:13:30.793302 133374975641408 max_utils.py:273] Attempting to initialize the jax distributed system...
INFO:2026-04-21 01:13:39,831:jax._src.distributed:149: Starting JAX distributed service on [::]:8482
I0421 01:13:39.831867 133374975641408 distributed.py:149] Starting JAX distributed service on [::]:8482
INFO:2026-04-21 01:13:39,834:jax._src.distributed:166: Connecting to JAX distributed service on mt-07-distill-smoke-w667i-slice-job-0-0.mt-07-distill-smoke-w667i:8482
I0421 01:13:39.834049 133374975641408 distributed.py:166] Connecting to JAX distributed service on mt-07-distill-smoke-w667i-slice-job-0-0.mt-07-distill-smoke-w667i:8482
I0421 01:13:41.116251 133374975641408 max_utils.py:284] Jax distributed system initialized!
I0421 01:13:47.550853 133374975641408 max_utils.py:244] Jax distributed system is already initialized.
I0421 01:13:48.018420 133374975641408 max_utils.py:244] Jax distributed system is already initialized.
I0421 01:13:48.019634 133374975641408 tokenizer.py:245] Tokenizer path: meta-llama/Llama-2-7b-chat-hf
I0421 01:13:48.019683 133374975641408 tokenizer.py:224] Loading HF tokenizer: meta-llama/Llama-2-7b-chat-hf
I0421 01:13:52.057850 133374975641408 _schedule.py:129] A polynomial schedule was set with a non-positive `transition_steps` value; this results in a constant schedule with value `init_value`.
I0421 01:13:52.060760 133374975641408 maxtext_utils.py:1398] Num_devices: 32, shape (1, 4, 1, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1)
I0421 01:13:52.060876 133374975641408 train_distill.py:582] Applying logical axis rules for model initialization and training...
I0421 01:13:52.060951 133374975641408 train_distill.py:586] Loading Student from ...
I0421 01:13:52.060982 133374975641408 train_distill.py:169] --- Student Configuration ---
I0421 01:13:52.061005 133374975641408 train_distill.py:170]   Model Name:      gpt3-52k
I0421 01:13:52.061027 133374975641408 train_distill.py:171]   Dimensions:      1 Layers, 16 Emb Dim, 8 Head Dim
I0421 01:13:52.061046 133374975641408 train_distill.py:174]   Attention Heads: 2 Query, 2 KV
I0421 01:13:52.061064 133374975641408 train_distill.py:175]   Vocab Size:      32000
I0421 01:13:52.061084 133374975641408 train_distill.py:176]   Checkpoint:      
I0421 01:13:52.061102 133374975641408 train_distill.py:460] Initializing model: gpt3-52k...
I0421 01:13:53.304909 133374975641408 train_distill.py:600] Loading Teacher from gs://lance-maxtext/pt_seed_ckpts/pt_seed_ckpts/pt_seed_ckpt_gpt352k_v32k_linen/checkpoints/4/items...
I0421 01:13:53.305024 133374975641408 train_distill.py:169] --- Teacher Configuration ---
I0421 01:13:53.305054 133374975641408 train_distill.py:170]   Model Name:      gpt3-52k
I0421 01:13:53.305080 133374975641408 train_distill.py:171]   Dimensions:      1 Layers, 16 Emb Dim, 8 Head Dim
I0421 01:13:53.305111 133374975641408 train_distill.py:174]   Attention Heads: 2 Query, 2 KV
I0421 01:13:53.305130 133374975641408 train_distill.py:175]   Vocab Size:      32000
I0421 01:13:53.305147 133374975641408 train_distill.py:176]   Checkpoint:      gs://lance-maxtext/pt_seed_ckpts/pt_seed_ckpts/pt_seed_ckpt_gpt352k_v32k_linen/checkpoints/4/items
I0421 01:13:53.305166 133374975641408 train_distill.py:460] Initializing model: gpt3-52k...
I0421 01:13:54.449292 133374975641408 pytree_checkpoint_handler.py:577] save_device_host_concurrent_bytes=None
I0421 01:13:54.449767 133374975641408 base_pytree_checkpoint_handler.py:411] Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=True, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x794d10698b00>, enable_pinned_host_transfer=False, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB)
I0421 01:13:54.449827 133374975641408 abstract_checkpointer.py:35] orbax-checkpoint version: 0.11.28
W0421 01:13:55.013754 133374975641408 checkpoint.py:202] Metadata file does not exist: gs://lance-maxtext/pt_seed_ckpts/pt_seed_ckpts/pt_seed_ckpt_gpt352k_v32k_linen/checkpoints/4/items/_CHECKPOINT_METADATA
I0421 01:13:55.553961    2136 google_auth_provider.cc:181] Running on GCE, using service account 562977990677-compute@developer.gserviceaccount.com
I0421 01:13:57.117045 133374975641408 checkpointer.py:304] Restoring checkpoint from gs://lance-maxtext/pt_seed_ckpts/pt_seed_ckpts/pt_seed_ckpt_gpt352k_v32k_linen/checkpoints/4/items.
W0421 01:13:59.729218 133374975641408 transform_utils.py:230] The transformations API will eventually be replaced by an upgraded design. The current API will not be removed until this point, but it will no longer be actively worked on.
I0421 01:13:59.729697 133374975641408 transform_utils.py:288] The following keys are not loaded from the original tree after applying specified transforms: params/params/decoder/to_nnx__rngs/dropout/count, params/params/decoder/to_nnx__rngs/dropout/key, params/params/decoder/to_nnx__rngs/params/count, params/params/decoder/to_nnx__rngs/params/key
I0421 01:13:59.783282 133374975641408 checkpointer.py:318] Finished restoring checkpoint in 3.47 seconds from gs://lance-maxtext/pt_seed_ckpts/pt_seed_ckpts/pt_seed_ckpt_gpt352k_v32k_linen/checkpoints/4/items.
I0421 01:14:00.470213 133374975641408 train_distill.py:626] Initializing Data Iterators via MaxText pipeline...
I0421 01:14:00.534726 133374975641408 config.py:112] TensorFlow version 2.20.0 available.
I0421 01:14:00.535324 133374975641408 config.py:125] JAX version 0.8.3 available.
E0421 01:14:02.563189 133374975641408 packing.py:209] PackAndBatchOperation is deprecated. Please use lazy_dataset.FirstFitPackIterDataset instead.
I0421 01:14:02.563428 133374975641408 data_loader.py:408] Adding CopyNumPyArrayToSharedMemory MapTransform.
I0421 01:14:02.567799 133374975641408 train_distill.py:405] Input Pipeline Checkpointing: DISABLED
I0421 01:14:02.567864 133374975641408 train_distill.py:409] Reason: Iterator 'MultiHostDataLoadIterator' is not recognized as Grain (dataset_type='DatasetType.HF', has_save=False)
I0421 01:14:02.567927 133374975641408 pytree_checkpoint_handler.py:577] save_device_host_concurrent_bytes=None
I0421 01:14:02.568005 133374975641408 base_pytree_checkpoint_handler.py:411] Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=False, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x794d10698b00>, enable_pinned_host_transfer=False, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB)
I0421 01:14:02.568045 133374975641408 pytree_checkpoint_handler.py:577] save_device_host_concurrent_bytes=None
I0421 01:14:02.568077 133374975641408 base_pytree_checkpoint_handler.py:411] Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=False, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x794d10698b00>, enable_pinned_host_transfer=False, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB)
I0421 01:14:02.568126 133374975641408 checkpoint_manager.py:702] [process=1][thread=MainThread] CheckpointManager init: checkpointers=None, item_names=None, item_handlers={'model_params': <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x793545671af0>, 'optimizer_state': <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x793547ae3b30>, 'custom_metadata': <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x793545087a40>}, handler_registry=None
I0421 01:14:02.568324 133374975641408 composite_checkpoint_handler.py:237] Deferred registration for item: "model_params". Adding handler `<orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x793545671af0>` for item "model_params" and save args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>` to `_handler_registry`.
I0421 01:14:02.568365 133374975641408 composite_checkpoint_handler.py:237] Deferred registration for item: "optimizer_state". Adding handler `<orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x793547ae3b30>` for item "optimizer_state" and save args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>` to `_handler_registry`.
I0421 01:14:02.568392 133374975641408 composite_checkpoint_handler.py:237] Deferred registration for item: "custom_metadata". Adding handler `<orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x793545087a40>` for item "custom_metadata" and save args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>` to `_handler_registry`.
I0421 01:14:02.568431 133374975641408 composite_checkpoint_handler.py:237] Deferred registration for item: "metrics". Adding handler `<orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7943146e9e80>` for item "metrics" and save args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>` to `_handler_registry`.
I0421 01:14:02.568459 133374975641408 composite_checkpoint_handler.py:505] Initialized registry DefaultCheckpointHandlerRegistry({('model_params', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x793545671af0>, ('model_params', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x793545671af0>, ('optimizer_state', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x793547ae3b30>, ('optimizer_state', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x793547ae3b30>, ('custom_metadata', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x793545087a40>, ('custom_metadata', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x793545087a40>, ('metrics', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7943146e9e80>, ('metrics', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7943146e9e80>}).
I0421 01:14:02.568879 133374975641408 async_checkpointer.py:177] [process=1][thread=MainThread] Using barrier_sync_fn: <function get_barrier_sync_fn.<locals>._fn at 0x79353fe42f20> timeout: 600 secs and primary_host=0 for async checkpoint writes
I0421 01:14:05.451980 133374975641408 checkpoint_manager.py:1788] Found 0 checkpoint steps in gs://lance-maxtext/pt_ckpt_xpk_test_pipeline_scan_nnx_20260421_005336/pt_distill_nnx_xpk_test_pipeline_scan_nnx_20260421_005336_07_distill_smoke/checkpoints
I0421 01:14:05.896166 133374975641408 checkpoint_manager.py:921] [process=1][thread=MainThread] CheckpointManager created,  primary_host=0, CheckpointManagerOptions=CheckpointManagerOptions(save_interval_steps=2000, max_to_keep=None, keep_time_interval=None, keep_period=None, should_keep_fn=None, best_fn=None, best_mode='max', keep_checkpoints_without_metrics=True, step_prefix=None, step_format_fixed_length=None, step_name_format=None, create=True, cleanup_tmp_directories=False, save_on_steps=frozenset(), single_host_load_and_broadcast=False, todelete_subdir=None, todelete_full_path=None, enable_hns=False, enable_background_delete=False, read_only=False, enable_async_checkpointing=True, async_options=None, multiprocessing_options=MultiprocessingOptions(primary_host=0, active_processes=None, barrier_sync_key_prefix=None), should_save_fn=None, file_options=FileOptions(path_permission_mode=None), save_root_metadata=True, temporary_path_class=None, save_decision_policy=None, preservation_policy=None, prevent_write_metrics=False, enable_should_save_is_saving_in_progress_check=True, enable_per_process_directory_creation=False), root_directory=gs://lance-maxtext/pt_ckpt_xpk_test_pipeline_scan_nnx_20260421_005336/pt_distill_nnx_xpk_test_pipeline_scan_nnx_20260421_005336_07_distill_smoke/checkpoints: <orbax.checkpoint.checkpoint_manager.CheckpointManager object at 0x793545087a10>
I0421 01:14:05.896336 133374975641408 pytree_checkpoint_handler.py:577] save_device_host_concurrent_bytes=None
I0421 01:14:05.896419 133374975641408 base_pytree_checkpoint_handler.py:411] Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=False, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x794d10698b00>, enable_pinned_host_transfer=False, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB)
I0421 01:14:05.896457 133374975641408 pytree_checkpoint_handler.py:577] save_device_host_concurrent_bytes=None
I0421 01:14:05.896491 133374975641408 base_pytree_checkpoint_handler.py:411] Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=False, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x794d10698b00>, enable_pinned_host_transfer=False, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB)
I0421 01:14:05.896528 133374975641408 checkpoint_manager.py:1983] [process=1][thread=MainThread][wait_until_finished] No Save Finalize thread to wait for. Returning.
I0421 01:14:05.896584 133374975641408 checkpoint.py:459] Closing _NonBlockingMetadataStore(enable_write=True, _write_lock=<locked _thread.RLock object owner=133374975641408 count=1 at 0x79353fe1fe00>, _store_impl=<orbax.checkpoint._src.metadata.checkpoint._MetadataStoreImpl object at 0x793545084ce0>, _single_thread_executor=<concurrent.futures.thread.ThreadPoolExecutor object at 0x7935450852b0>, _write_futures=[])
I0421 01:14:05.896911 133374975641408 checkpoint.py:459] Closing _NonBlockingMetadataStore(enable_write=True, _write_lock=<locked _thread.RLock object owner=133374975641408 count=1 at 0x79353fe1fe00>, _store_impl=<orbax.checkpoint._src.metadata.checkpoint._MetadataStoreImpl object at 0x793545084ce0>, _single_thread_executor=<concurrent.futures.thread.ThreadPoolExecutor object at 0x7935450852b0>, _write_futures=[])
I0421 01:14:05.896939 133374975641408 checkpoint.py:459] Closing _NonBlockingMetadataStore(enable_write=True, _write_lock=<locked _thread.RLock object owner=133374975641408 count=1 at 0x79353fe1fe00>, _store_impl=<orbax.checkpoint._src.metadata.checkpoint._MetadataStoreImpl object at 0x793545084ce0>, _single_thread_executor=<concurrent.futures.thread.ThreadPoolExecutor object at 0x7935450852b0>, _write_futures=[])
I0421 01:14:05.896969 133374975641408 checkpoint_manager.py:702] [process=1][thread=MainThread] CheckpointManager init: checkpointers=None, item_names=None, item_handlers={'model_params': <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7935450879e0>, 'optimizer_state': <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x79353fd03d10>, 'custom_metadata': <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x79353fd032c0>, 'iter': <maxtext.common.checkpointing.GrainCheckpointHandler object at 0x79353fd029f0>}, handler_registry=None
I0421 01:14:05.897070 133374975641408 composite_checkpoint_handler.py:237] Deferred registration for item: "model_params". Adding handler `<orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7935450879e0>` for item "model_params" and save args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>` to `_handler_registry`.
I0421 01:14:05.897106 133374975641408 composite_checkpoint_handler.py:237] Deferred registration for item: "optimizer_state". Adding handler `<orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x79353fd03d10>` for item "optimizer_state" and save args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>` to `_handler_registry`.
I0421 01:14:05.897130 133374975641408 composite_checkpoint_handler.py:237] Deferred registration for item: "custom_metadata". Adding handler `<orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x79353fd032c0>` for item "custom_metadata" and save args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>` to `_handler_registry`.
I0421 01:14:05.897157 133374975641408 composite_checkpoint_handler.py:237] Deferred registration for item: "iter". Adding handler `<maxtext.common.checkpointing.GrainCheckpointHandler object at 0x79353fd029f0>` for item "iter" and save args `<class 'maxtext.common.checkpointing.GrainCheckpointSave'>` and restore args `<class 'maxtext.common.checkpointing.GrainCheckpointRestore'>` to `_handler_registry`.
I0421 01:14:05.897180 133374975641408 composite_checkpoint_handler.py:237] Deferred registration for item: "metrics". Adding handler `<orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x79353fd025a0>` for item "metrics" and save args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>` to `_handler_registry`.
I0421 01:14:05.897204 133374975641408 composite_checkpoint_handler.py:505] Initialized registry DefaultCheckpointHandlerRegistry({('model_params', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7935450879e0>, ('model_params', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7935450879e0>, ('optimizer_state', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x79353fd03d10>, ('optimizer_state', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x79353fd03d10>, ('custom_metadata', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x79353fd032c0>, ('custom_metadata', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x79353fd032c0>, ('iter', <class 'maxtext.common.checkpointing.GrainCheckpointSave'>): <maxtext.common.checkpointing.GrainCheckpointHandler object at 0x79353fd029f0>, ('iter', <class 'maxtext.common.checkpointing.GrainCheckpointRestore'>): <maxtext.common.checkpointing.GrainCheckpointHandler object at 0x79353fd029f0>, ('metrics', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x79353fd025a0>, ('metrics', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x79353fd025a0>}).
I0421 01:14:05.897275 133374975641408 async_checkpointer.py:177] [process=1][thread=MainThread] Using barrier_sync_fn: <function get_barrier_sync_fn.<locals>._fn at 0x79353fe43060> timeout: 600 secs and primary_host=0 for async checkpoint writes
I0421 01:14:06.688003 133374975641408 checkpoint_manager.py:1788] Found 0 checkpoint steps in gs://lance-maxtext/pt_ckpt_xpk_test_pipeline_scan_nnx_20260421_005336/pt_distill_nnx_xpk_test_pipeline_scan_nnx_20260421_005336_07_distill_smoke/checkpoints
I0421 01:14:06.695802 133374975641408 checkpoint_manager.py:921] [process=1][thread=MainThread] CheckpointManager created,  primary_host=0, CheckpointManagerOptions=CheckpointManagerOptions(save_interval_steps=2000, max_to_keep=None, keep_time_interval=None, keep_period=None, should_keep_fn=None, best_fn=None, best_mode='max', keep_checkpoints_without_metrics=True, step_prefix=None, step_format_fixed_length=None, step_name_format=None, create=True, cleanup_tmp_directories=False, save_on_steps=frozenset(), single_host_load_and_broadcast=False, todelete_subdir=None, todelete_full_path=None, enable_hns=False, enable_background_delete=False, read_only=False, enable_async_checkpointing=True, async_options=None, multiprocessing_options=MultiprocessingOptions(primary_host=0, active_processes=None, barrier_sync_key_prefix=None), should_save_fn=None, file_options=FileOptions(path_permission_mode=None), save_root_metadata=True, temporary_path_class=None, save_decision_policy=None, preservation_policy=None, prevent_write_metrics=False, enable_should_save_is_saving_in_progress_check=True, enable_per_process_directory_creation=False), root_directory=gs://lance-maxtext/pt_ckpt_xpk_test_pipeline_scan_nnx_20260421_005336/pt_distill_nnx_xpk_test_pipeline_scan_nnx_20260421_005336_07_distill_smoke/checkpoints: <orbax.checkpoint.checkpoint_manager.CheckpointManager object at 0x793545a08d70>
I0421 01:14:06.695969 133374975641408 train_distill.py:673] Starting Distillation Training...
I0421 01:14:06.696057 133374975641408 peft_trainer.py:590] Training with mesh: Mesh('diloco': 1, 'data': 4, 'stage': 1, 'fsdp': 8, 'fsdp_transpose': 1, 'sequence': 1, 'context': 1, 'context_autoregressive': 1, 'tensor': 1, 'tensor_transpose': 1, 'tensor_sequence': 1, 'expert': 1, 'autoregressive': 1, axis_types=(Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto))
I0421 01:14:07.051615 133374975641408 peft_trainer.py:600] Compiled train_step cache size: 0

Training:   0%|          | 0/5 [00:00<?, ?step/s]I0421 01:14:07.053492 133231269627648 grain_pool.py:367] Grain pool will use 1 processes.
I0421 01:14:07.080302 133231269627648 grain_pool.py:440] Grain pool will start child processes.
I0421 01:14:07.085367 133231269627648 grain_pool.py:448] Grain pool started all child processes.
2026-04-21 01:14:13.113979: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)
I0421 01:14:16.004560 133374975641408 utils.py:86] Train loop finished in: 8.9524 seconds
Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/deps/src/maxtext/trainers/post_train/distillation/train_distill.py", line 747, in <module>
    app.run(main)
  File "/usr/local/lib/python3.12/site-packages/absl/app.py", line 316, in run
    _run_main(main, args)
  File "/usr/local/lib/python3.12/site-packages/absl/app.py", line 261, in _run_main
    sys.exit(main(argv))
             ^^^^^^^^^^
  File "/deps/src/maxtext/trainers/post_train/distillation/train_distill.py", line 743, in main
    train_distill(student_config, teacher_config, is_offline, global_config.offline_data_dir)
  File "/deps/src/maxtext/trainers/post_train/distillation/train_distill.py", line 675, in train_distill
    trainer.train(train_iter, eval_iter)
  File "/usr/local/lib/python3.12/site-packages/tunix/sft/peft_trainer.py", line 659, in train
    train_example = sharding_utils.shard_input(
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/tunix/sft/sharding_utils.py", line 58, in shard_input
    return jax.tree.map(
           ^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/tree.py", line 155, in map
    return tree_util.tree_map(f, tree, *rest, is_leaf=is_leaf)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/tree_util.py", line 362, in tree_map
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/tree_util.py", line 362, in <genexpr>
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
                             ^^^^^^
  File "/usr/local/lib/python3.12/site-packages/tunix/sft/sharding_utils.py", line 59, in <lambda>
    lambda x: jax.make_array_from_process_local_data(
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/array.py", line 986, in make_array_from_process_local_data
    out = [_array_from_process_local_data(data, s, shape)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/array.py", line 1048, in _array_from_process_local_data
    return make_array_from_callback(global_shape, sharding, cb)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/array.py", line 845, in make_array_from_callback
    per_device_values = api.device_put(per_device_values, devices)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/api.py", line 2729, in device_put
    out_flat = dispatch._batched_device_put_impl(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/dispatch.py", line 558, in _batched_device_put_impl
    y = _device_put_impl(x, device=device, src=src, copy=cp, aval=aval)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/dispatch.py", line 545, in _device_put_impl
    return _device_put_sharding_impl(x, aval, device, copy)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/dispatch.py", line 487, in _device_put_sharding_impl
    raise ValueError(
ValueError: device_put's first argument must be a fully addressable array, but got value with devices {TpuDevice(id=24, process_index=6, coords=(0,6,0), core_on_chip=0), TpuDevice(id=18, process_index=5, coords=(2,4,0), core_on_chip=0), TpuDevice(id=26, process_index=7, coords=(2,6,0), core_on_chip=0), TpuDevice(id=9, process_index=2, coords=(1,2,0), core_on_chip=0), TpuDevice(id=2, process_index=1, coords=(2,0,0), core_on_chip=0), TpuDevice(id=14, process_index=3, coords=(2,3,0), core_on_chip=0), TpuDevice(id=3, process_index=1, coords=(3,0,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=25, process_index=6, coords=(1,6,0), core_on_chip=0), TpuDevice(id=19, process_index=5, coords=(3,4,0), core_on_chip=0), TpuDevice(id=27, process_index=7, coords=(3,6,0), core_on_chip=0), TpuDevice(id=20, process_index=4, coords=(0,5,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=15, process_index=3, coords=(3,3,0), core_on_chip=0), TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=6, process_index=1, coords=(2,1,0), core_on_chip=0), TpuDevice(id=10, process_index=3, coords=(2,2,0), core_on_chip=0), TpuDevice(id=7, process_index=1, coords=(3,1,0), core_on_chip=0), TpuDevice(id=13, process_index=2, coords=(1,3,0), core_on_chip=0), TpuDevice(id=28, process_index=6, coords=(0,7,0), core_on_chip=0), TpuDevice(id=22, process_index=5, coords=(2,5,0), core_on_chip=0), TpuDevice(id=30, process_index=7, coords=(2,7,0), core_on_chip=0), TpuDevice(id=21, process_index=4, coords=(1,5,0), core_on_chip=0), TpuDevice(id=16, process_index=4, coords=(0,4,0), core_on_chip=0), TpuDevice(id=8, process_index=2, coords=(0,2,0), core_on_chip=0), TpuDevice(id=17, process_index=4, coords=(1,4,0), core_on_chip=0), TpuDevice(id=29, process_index=6, coords=(1,7,0), core_on_chip=0), TpuDevice(id=23, process_index=5, coords=(3,5,0), core_on_chip=0), TpuDevice(id=31, process_index=7, coords=(3,7,0), core_on_chip=0), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=11, process_index=3, coords=(3,2,0), core_on_chip=0), TpuDevice(id=12, process_index=2, coords=(0,3,0), core_on_chip=0)}
I0421 01:14:16.354610 133231269627648 grain_pool.py:542] Grain pool is exiting.
I0421 01:14:16.354711 133231269627648 grain_pool.py:547] Shutting down multiprocessing system.
I0421 01:14:17.799537 133231269627648 grain_pool.py:547] Shutting down multiprocessing system.

Training:   0%|          | 0/5 [00:13<?, ?step/s]
/usr/local/lib/python3.12/multiprocessing/resource_tracker.py:279: UserWarning: resource_tracker: There appear to be 15 leaked shared_memory objects to clean up at shutdown
  warnings.warn('resource_tracker: There appear to be %d '
XPK End: Tue Apr 21 01:14:25 UTC 2026
EXIT_CODE=1