MaxView

← Back to run

Log Summary

XPK Start: Tue Apr 21 01:02:24 UTC 2026
2026-04-21 01:02:41.597610: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)
I0421 01:02:45.171041 134370678888256 max_utils.py:273] Attempting to initialize the jax distributed system...
INFO:2026-04-21 01:02:54,209:jax._src.distributed:149: Starting JAX distributed service on [::]:8482
I0421 01:02:54.209831 134370678888256 distributed.py:149] Starting JAX distributed service on [::]:8482
INFO:2026-04-21 01:02:54,212:jax._src.distributed:166: Connecting to JAX distributed service on mt-07-distill-smoke-znjmb-slice-job-0-0.mt-07-distill-smoke-znjmb:8482
I0421 01:02:54.212256 134370678888256 distributed.py:166] Connecting to JAX distributed service on mt-07-distill-smoke-znjmb-slice-job-0-0.mt-07-distill-smoke-znjmb:8482
I0421 01:02:55.480396 134370678888256 max_utils.py:284] Jax distributed system initialized!
I0421 01:03:02.104563 134370678888256 max_utils.py:244] Jax distributed system is already initialized.
I0421 01:03:02.579613 134370678888256 max_utils.py:244] Jax distributed system is already initialized.
I0421 01:03:02.580936 134370678888256 tokenizer.py:245] Tokenizer path: meta-llama/Llama-2-7b-chat-hf
I0421 01:03:02.580990 134370678888256 tokenizer.py:224] Loading HF tokenizer: meta-llama/Llama-2-7b-chat-hf
I0421 01:03:06.489233 134370678888256 _schedule.py:129] A polynomial schedule was set with a non-positive `transition_steps` value; this results in a constant schedule with value `init_value`.
I0421 01:03:06.492291 134370678888256 maxtext_utils.py:1398] Num_devices: 32, shape (1, 4, 1, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1)
I0421 01:03:06.492415 134370678888256 train_distill.py:582] Applying logical axis rules for model initialization and training...
I0421 01:03:06.492487 134370678888256 train_distill.py:586] Loading Student from ...
I0421 01:03:06.492516 134370678888256 train_distill.py:169] --- Student Configuration ---
I0421 01:03:06.492538 134370678888256 train_distill.py:170]   Model Name:      gpt3-52k
I0421 01:03:06.492560 134370678888256 train_distill.py:171]   Dimensions:      1 Layers, 16 Emb Dim, 8 Head Dim
I0421 01:03:06.492580 134370678888256 train_distill.py:174]   Attention Heads: 2 Query, 2 KV
I0421 01:03:06.492598 134370678888256 train_distill.py:175]   Vocab Size:      32000
I0421 01:03:06.492614 134370678888256 train_distill.py:176]   Checkpoint:      
I0421 01:03:06.492633 134370678888256 train_distill.py:460] Initializing model: gpt3-52k...
I0421 01:03:07.910212 134370678888256 train_distill.py:600] Loading Teacher from gs://lance-maxtext/pt_seed_ckpts/pt_seed_ckpts/pt_seed_ckpt_gpt352k_v32k_linen/checkpoints/4/items...
I0421 01:03:07.910337 134370678888256 train_distill.py:169] --- Teacher Configuration ---
I0421 01:03:07.910367 134370678888256 train_distill.py:170]   Model Name:      gpt3-52k
I0421 01:03:07.910392 134370678888256 train_distill.py:171]   Dimensions:      1 Layers, 16 Emb Dim, 8 Head Dim
I0421 01:03:07.910413 134370678888256 train_distill.py:174]   Attention Heads: 2 Query, 2 KV
I0421 01:03:07.910431 134370678888256 train_distill.py:175]   Vocab Size:      32000
I0421 01:03:07.910451 134370678888256 train_distill.py:176]   Checkpoint:      gs://lance-maxtext/pt_seed_ckpts/pt_seed_ckpts/pt_seed_ckpt_gpt352k_v32k_linen/checkpoints/4/items
I0421 01:03:07.910471 134370678888256 train_distill.py:460] Initializing model: gpt3-52k...
I0421 01:03:08.977286 134370678888256 pytree_checkpoint_handler.py:577] save_device_host_concurrent_bytes=None
I0421 01:03:08.977783 134370678888256 base_pytree_checkpoint_handler.py:411] Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=True, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x7a34e50f41d0>, enable_pinned_host_transfer=False, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB)
I0421 01:03:08.977858 134370678888256 abstract_checkpointer.py:35] orbax-checkpoint version: 0.11.28
W0421 01:03:10.787415 134370678888256 checkpoint.py:202] Metadata file does not exist: gs://lance-maxtext/pt_seed_ckpts/pt_seed_ckpts/pt_seed_ckpt_gpt352k_v32k_linen/checkpoints/4/items/_CHECKPOINT_METADATA
I0421 01:03:11.736437    2123 google_auth_provider.cc:181] Running on GCE, using service account 562977990677-compute@developer.gserviceaccount.com
I0421 01:03:12.857438 134370678888256 checkpointer.py:304] Restoring checkpoint from gs://lance-maxtext/pt_seed_ckpts/pt_seed_ckpts/pt_seed_ckpt_gpt352k_v32k_linen/checkpoints/4/items.
W0421 01:03:15.138873 134370678888256 transform_utils.py:230] The transformations API will eventually be replaced by an upgraded design. The current API will not be removed until this point, but it will no longer be actively worked on.
I0421 01:03:15.139260 134370678888256 transform_utils.py:288] The following keys are not loaded from the original tree after applying specified transforms: params/params/decoder/to_nnx__rngs/dropout/count, params/params/decoder/to_nnx__rngs/dropout/key, params/params/decoder/to_nnx__rngs/params/count, params/params/decoder/to_nnx__rngs/params/key
I0421 01:03:16.566707 134370678888256 checkpointer.py:318] Finished restoring checkpoint in 4.08 seconds from gs://lance-maxtext/pt_seed_ckpts/pt_seed_ckpts/pt_seed_ckpt_gpt352k_v32k_linen/checkpoints/4/items.
I0421 01:03:17.256286 134370678888256 train_distill.py:626] Initializing Data Iterators via MaxText pipeline...
I0421 01:03:17.319814 134370678888256 config.py:112] TensorFlow version 2.20.0 available.
I0421 01:03:17.320316 134370678888256 config.py:125] JAX version 0.8.3 available.
E0421 01:03:19.458063 134370678888256 packing.py:209] PackAndBatchOperation is deprecated. Please use lazy_dataset.FirstFitPackIterDataset instead.
I0421 01:03:19.458323 134370678888256 data_loader.py:408] Adding CopyNumPyArrayToSharedMemory MapTransform.
I0421 01:03:19.462674 134370678888256 train_distill.py:405] Input Pipeline Checkpointing: DISABLED
I0421 01:03:19.462745 134370678888256 train_distill.py:409] Reason: Iterator 'MultiHostDataLoadIterator' is not recognized as Grain (dataset_type='DatasetType.HF', has_save=False)
I0421 01:03:19.462822 134370678888256 pytree_checkpoint_handler.py:577] save_device_host_concurrent_bytes=None
I0421 01:03:19.462915 134370678888256 base_pytree_checkpoint_handler.py:411] Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=False, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x7a34e50f41d0>, enable_pinned_host_transfer=False, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB)
I0421 01:03:19.462967 134370678888256 pytree_checkpoint_handler.py:577] save_device_host_concurrent_bytes=None
I0421 01:03:19.463019 134370678888256 base_pytree_checkpoint_handler.py:411] Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=False, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x7a34e50f41d0>, enable_pinned_host_transfer=False, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB)
I0421 01:03:19.463082 134370678888256 checkpoint_manager.py:702] [process=7][thread=MainThread] CheckpointManager init: checkpointers=None, item_names=None, item_handlers={'model_params': <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7a1d338bf530>, 'optimizer_state': <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7a1d338bf4a0>, 'custom_metadata': <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7a1d338bf410>}, handler_registry=None
I0421 01:03:19.463300 134370678888256 composite_checkpoint_handler.py:237] Deferred registration for item: "model_params". Adding handler `<orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7a1d338bf530>` for item "model_params" and save args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>` to `_handler_registry`.
I0421 01:03:19.463356 134370678888256 composite_checkpoint_handler.py:237] Deferred registration for item: "optimizer_state". Adding handler `<orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7a1d338bf4a0>` for item "optimizer_state" and save args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>` to `_handler_registry`.
I0421 01:03:19.463402 134370678888256 composite_checkpoint_handler.py:237] Deferred registration for item: "custom_metadata". Adding handler `<orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7a1d338bf410>` for item "custom_metadata" and save args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>` to `_handler_registry`.
I0421 01:03:19.463444 134370678888256 composite_checkpoint_handler.py:237] Deferred registration for item: "metrics". Adding handler `<orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7a1ef8a340e0>` for item "metrics" and save args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>` to `_handler_registry`.
I0421 01:03:19.463488 134370678888256 composite_checkpoint_handler.py:505] Initialized registry DefaultCheckpointHandlerRegistry({('model_params', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7a1d338bf530>, ('model_params', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7a1d338bf530>, ('optimizer_state', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7a1d338bf4a0>, ('optimizer_state', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7a1d338bf4a0>, ('custom_metadata', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7a1d338bf410>, ('custom_metadata', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7a1d338bf410>, ('metrics', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7a1ef8a340e0>, ('metrics', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7a1ef8a340e0>}).
I0421 01:03:19.463920 134370678888256 async_checkpointer.py:177] [process=7][thread=MainThread] Using barrier_sync_fn: <function get_barrier_sync_fn.<locals>._fn at 0x7a1d336bfce0> timeout: 600 secs and primary_host=0 for async checkpoint writes
I0421 01:03:21.814073 134370678888256 checkpoint_manager.py:1788] Found 0 checkpoint steps in gs://lance-maxtext/pt_ckpt_xpk_test_pipeline_scan_nnx_20260421_005336/pt_distill_linen_xpk_test_pipeline_scan_nnx_20260421_005336_07_distill_smoke/checkpoints
I0421 01:03:22.219938 134370678888256 checkpoint_manager.py:921] [process=7][thread=MainThread] CheckpointManager created,  primary_host=0, CheckpointManagerOptions=CheckpointManagerOptions(save_interval_steps=2000, max_to_keep=None, keep_time_interval=None, keep_period=None, should_keep_fn=None, best_fn=None, best_mode='max', keep_checkpoints_without_metrics=True, step_prefix=None, step_format_fixed_length=None, step_name_format=None, create=True, cleanup_tmp_directories=False, save_on_steps=frozenset(), single_host_load_and_broadcast=False, todelete_subdir=None, todelete_full_path=None, enable_hns=False, enable_background_delete=False, read_only=False, enable_async_checkpointing=True, async_options=None, multiprocessing_options=MultiprocessingOptions(primary_host=0, active_processes=None, barrier_sync_key_prefix=None), should_save_fn=None, file_options=FileOptions(path_permission_mode=None), save_root_metadata=True, temporary_path_class=None, save_decision_policy=None, preservation_policy=None, prevent_write_metrics=False, enable_should_save_is_saving_in_progress_check=True, enable_per_process_directory_creation=False), root_directory=gs://lance-maxtext/pt_ckpt_xpk_test_pipeline_scan_nnx_20260421_005336/pt_distill_linen_xpk_test_pipeline_scan_nnx_20260421_005336_07_distill_smoke/checkpoints: <orbax.checkpoint.checkpoint_manager.CheckpointManager object at 0x7a1ef9ae7a70>
I0421 01:03:22.220119 134370678888256 pytree_checkpoint_handler.py:577] save_device_host_concurrent_bytes=None
I0421 01:03:22.220214 134370678888256 base_pytree_checkpoint_handler.py:411] Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=False, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x7a34e50f41d0>, enable_pinned_host_transfer=False, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB)
I0421 01:03:22.220264 134370678888256 pytree_checkpoint_handler.py:577] save_device_host_concurrent_bytes=None
I0421 01:03:22.220319 134370678888256 base_pytree_checkpoint_handler.py:411] Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=False, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x7a34e50f41d0>, enable_pinned_host_transfer=False, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB)
I0421 01:03:22.220374 134370678888256 checkpoint_manager.py:1983] [process=7][thread=MainThread][wait_until_finished] No Save Finalize thread to wait for. Returning.
I0421 01:03:22.220446 134370678888256 checkpoint.py:459] Closing _NonBlockingMetadataStore(enable_write=True, _write_lock=<locked _thread.RLock object owner=134370678888256 count=1 at 0x7a1ef9b16540>, _store_impl=<orbax.checkpoint._src.metadata.checkpoint._MetadataStoreImpl object at 0x7a1d338bf200>, _single_thread_executor=<concurrent.futures.thread.ThreadPoolExecutor object at 0x7a1d338bf1d0>, _write_futures=[])
I0421 01:03:22.220808 134370678888256 checkpoint.py:459] Closing _NonBlockingMetadataStore(enable_write=True, _write_lock=<locked _thread.RLock object owner=134370678888256 count=1 at 0x7a1ef9b16540>, _store_impl=<orbax.checkpoint._src.metadata.checkpoint._MetadataStoreImpl object at 0x7a1d338bf200>, _single_thread_executor=<concurrent.futures.thread.ThreadPoolExecutor object at 0x7a1d338bf1d0>, _write_futures=[])
I0421 01:03:22.220852 134370678888256 checkpoint.py:459] Closing _NonBlockingMetadataStore(enable_write=True, _write_lock=<locked _thread.RLock object owner=134370678888256 count=1 at 0x7a1ef9b16540>, _store_impl=<orbax.checkpoint._src.metadata.checkpoint._MetadataStoreImpl object at 0x7a1d338bf200>, _single_thread_executor=<concurrent.futures.thread.ThreadPoolExecutor object at 0x7a1d338bf1d0>, _write_futures=[])
I0421 01:03:22.220903 134370678888256 checkpoint_manager.py:702] [process=7][thread=MainThread] CheckpointManager init: checkpointers=None, item_names=None, item_handlers={'model_params': <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7a1d338bf3e0>, 'optimizer_state': <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7a1d338be4b0>, 'custom_metadata': <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7a1d3342a570>, 'iter': <maxtext.common.checkpointing.GrainCheckpointHandler object at 0x7a1d3342a330>}, handler_registry=None
I0421 01:03:22.221034 134370678888256 composite_checkpoint_handler.py:237] Deferred registration for item: "model_params". Adding handler `<orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7a1d338bf3e0>` for item "model_params" and save args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>` to `_handler_registry`.
I0421 01:03:22.221080 134370678888256 composite_checkpoint_handler.py:237] Deferred registration for item: "optimizer_state". Adding handler `<orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7a1d338be4b0>` for item "optimizer_state" and save args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>` to `_handler_registry`.
I0421 01:03:22.221119 134370678888256 composite_checkpoint_handler.py:237] Deferred registration for item: "custom_metadata". Adding handler `<orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7a1d3342a570>` for item "custom_metadata" and save args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>` to `_handler_registry`.
I0421 01:03:22.221164 134370678888256 composite_checkpoint_handler.py:237] Deferred registration for item: "iter". Adding handler `<maxtext.common.checkpointing.GrainCheckpointHandler object at 0x7a1d3342a330>` for item "iter" and save args `<class 'maxtext.common.checkpointing.GrainCheckpointSave'>` and restore args `<class 'maxtext.common.checkpointing.GrainCheckpointRestore'>` to `_handler_registry`.
I0421 01:03:22.221200 134370678888256 composite_checkpoint_handler.py:237] Deferred registration for item: "metrics". Adding handler `<orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7a1d3342a900>` for item "metrics" and save args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>` to `_handler_registry`.
I0421 01:03:22.221246 134370678888256 composite_checkpoint_handler.py:505] Initialized registry DefaultCheckpointHandlerRegistry({('model_params', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7a1d338bf3e0>, ('model_params', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7a1d338bf3e0>, ('optimizer_state', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7a1d338be4b0>, ('optimizer_state', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7a1d338be4b0>, ('custom_metadata', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7a1d3342a570>, ('custom_metadata', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7a1d3342a570>, ('iter', <class 'maxtext.common.checkpointing.GrainCheckpointSave'>): <maxtext.common.checkpointing.GrainCheckpointHandler object at 0x7a1d3342a330>, ('iter', <class 'maxtext.common.checkpointing.GrainCheckpointRestore'>): <maxtext.common.checkpointing.GrainCheckpointHandler object at 0x7a1d3342a330>, ('metrics', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7a1d3342a900>, ('metrics', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7a1d3342a900>}).
I0421 01:03:22.221345 134370678888256 async_checkpointer.py:177] [process=7][thread=MainThread] Using barrier_sync_fn: <function get_barrier_sync_fn.<locals>._fn at 0x7a1d336bfe20> timeout: 600 secs and primary_host=0 for async checkpoint writes
I0421 01:03:22.617565 134370678888256 checkpoint_manager.py:1788] Found 0 checkpoint steps in gs://lance-maxtext/pt_ckpt_xpk_test_pipeline_scan_nnx_20260421_005336/pt_distill_linen_xpk_test_pipeline_scan_nnx_20260421_005336_07_distill_smoke/checkpoints
I0421 01:03:23.042988 134370678888256 checkpoint_manager.py:921] [process=7][thread=MainThread] CheckpointManager created,  primary_host=0, CheckpointManagerOptions=CheckpointManagerOptions(save_interval_steps=2000, max_to_keep=None, keep_time_interval=None, keep_period=None, should_keep_fn=None, best_fn=None, best_mode='max', keep_checkpoints_without_metrics=True, step_prefix=None, step_format_fixed_length=None, step_name_format=None, create=True, cleanup_tmp_directories=False, save_on_steps=frozenset(), single_host_load_and_broadcast=False, todelete_subdir=None, todelete_full_path=None, enable_hns=False, enable_background_delete=False, read_only=False, enable_async_checkpointing=True, async_options=None, multiprocessing_options=MultiprocessingOptions(primary_host=0, active_processes=None, barrier_sync_key_prefix=None), should_save_fn=None, file_options=FileOptions(path_permission_mode=None), save_root_metadata=True, temporary_path_class=None, save_decision_policy=None, preservation_policy=None, prevent_write_metrics=False, enable_should_save_is_saving_in_progress_check=True, enable_per_process_directory_creation=False), root_directory=gs://lance-maxtext/pt_ckpt_xpk_test_pipeline_scan_nnx_20260421_005336/pt_distill_linen_xpk_test_pipeline_scan_nnx_20260421_005336_07_distill_smoke/checkpoints: <orbax.checkpoint.checkpoint_manager.CheckpointManager object at 0x7a1ef9ae6e70>
I0421 01:03:23.043222 134370678888256 train_distill.py:673] Starting Distillation Training...
I0421 01:03:23.043321 134370678888256 peft_trainer.py:590] Training with mesh: Mesh('diloco': 1, 'data': 4, 'stage': 1, 'fsdp': 8, 'fsdp_transpose': 1, 'sequence': 1, 'context': 1, 'context_autoregressive': 1, 'tensor': 1, 'tensor_transpose': 1, 'tensor_sequence': 1, 'expert': 1, 'autoregressive': 1, axis_types=(Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto))
I0421 01:03:23.390647 134370678888256 peft_trainer.py:600] Compiled train_step cache size: 0

Training:   0%|          | 0/5 [00:00<?, ?step/s]I0421 01:03:23.392200 134227878188800 grain_pool.py:367] Grain pool will use 1 processes.
I0421 01:03:23.418585 134227878188800 grain_pool.py:440] Grain pool will start child processes.
I0421 01:03:23.423595 134227878188800 grain_pool.py:448] Grain pool started all child processes.
2026-04-21 01:03:29.485725: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)
I0421 01:03:32.640183 134370678888256 utils.py:86] Train loop finished in: 9.2490 seconds
Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/deps/src/maxtext/trainers/post_train/distillation/train_distill.py", line 747, in <module>
    app.run(main)
  File "/usr/local/lib/python3.12/site-packages/absl/app.py", line 316, in run
    _run_main(main, args)
  File "/usr/local/lib/python3.12/site-packages/absl/app.py", line 261, in _run_main
    sys.exit(main(argv))
             ^^^^^^^^^^
  File "/deps/src/maxtext/trainers/post_train/distillation/train_distill.py", line 743, in main
    train_distill(student_config, teacher_config, is_offline, global_config.offline_data_dir)
  File "/deps/src/maxtext/trainers/post_train/distillation/train_distill.py", line 675, in train_distill
    trainer.train(train_iter, eval_iter)
  File "/usr/local/lib/python3.12/site-packages/tunix/sft/peft_trainer.py", line 659, in train
    train_example = sharding_utils.shard_input(
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/tunix/sft/sharding_utils.py", line 58, in shard_input
    return jax.tree.map(
           ^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/tree.py", line 155, in map
    return tree_util.tree_map(f, tree, *rest, is_leaf=is_leaf)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/tree_util.py", line 362, in tree_map
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/tree_util.py", line 362, in <genexpr>
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
                             ^^^^^^
  File "/usr/local/lib/python3.12/site-packages/tunix/sft/sharding_utils.py", line 59, in <lambda>
    lambda x: jax.make_array_from_process_local_data(
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/array.py", line 986, in make_array_from_process_local_data
    out = [_array_from_process_local_data(data, s, shape)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/array.py", line 1048, in _array_from_process_local_data
    return make_array_from_callback(global_shape, sharding, cb)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/array.py", line 845, in make_array_from_callback
    per_device_values = api.device_put(per_device_values, devices)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/api.py", line 2729, in device_put
    out_flat = dispatch._batched_device_put_impl(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/dispatch.py", line 558, in _batched_device_put_impl
    y = _device_put_impl(x, device=device, src=src, copy=cp, aval=aval)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/dispatch.py", line 545, in _device_put_impl
    return _device_put_sharding_impl(x, aval, device, copy)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/dispatch.py", line 487, in _device_put_sharding_impl
    raise ValueError(
ValueError: device_put's first argument must be a fully addressable array, but got value with devices {TpuDevice(id=11, process_index=3, coords=(3,2,0), core_on_chip=0), TpuDevice(id=2, process_index=1, coords=(2,0,0), core_on_chip=0), TpuDevice(id=24, process_index=6, coords=(0,6,0), core_on_chip=0), TpuDevice(id=18, process_index=5, coords=(2,4,0), core_on_chip=0), TpuDevice(id=26, process_index=7, coords=(2,6,0), core_on_chip=0), TpuDevice(id=17, process_index=4, coords=(1,4,0), core_on_chip=0), TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=14, process_index=3, coords=(2,3,0), core_on_chip=0), TpuDevice(id=3, process_index=1, coords=(3,0,0), core_on_chip=0), TpuDevice(id=13, process_index=2, coords=(1,3,0), core_on_chip=0), TpuDevice(id=6, process_index=1, coords=(2,1,0), core_on_chip=0), TpuDevice(id=12, process_index=2, coords=(0,3,0), core_on_chip=0), TpuDevice(id=25, process_index=6, coords=(1,6,0), core_on_chip=0), TpuDevice(id=19, process_index=5, coords=(3,4,0), core_on_chip=0), TpuDevice(id=27, process_index=7, coords=(3,6,0), core_on_chip=0), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=20, process_index=4, coords=(0,5,0), core_on_chip=0), TpuDevice(id=15, process_index=3, coords=(3,3,0), core_on_chip=0), TpuDevice(id=7, process_index=1, coords=(3,1,0), core_on_chip=0), TpuDevice(id=16, process_index=4, coords=(0,4,0), core_on_chip=0), TpuDevice(id=28, process_index=6, coords=(0,7,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=22, process_index=5, coords=(2,5,0), core_on_chip=0), TpuDevice(id=30, process_index=7, coords=(2,7,0), core_on_chip=0), TpuDevice(id=10, process_index=3, coords=(2,2,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=9, process_index=2, coords=(1,2,0), core_on_chip=0), TpuDevice(id=21, process_index=4, coords=(1,5,0), core_on_chip=0), TpuDevice(id=29, process_index=6, coords=(1,7,0), core_on_chip=0), TpuDevice(id=23, process_index=5, coords=(3,5,0), core_on_chip=0), TpuDevice(id=31, process_index=7, coords=(3,7,0), core_on_chip=0), TpuDevice(id=8, process_index=2, coords=(0,2,0), core_on_chip=0)}
I0421 01:03:32.982362 134227878188800 grain_pool.py:542] Grain pool is exiting.
I0421 01:03:32.982461 134227878188800 grain_pool.py:547] Shutting down multiprocessing system.
I0421 01:03:34.422555 134227878188800 grain_pool.py:547] Shutting down multiprocessing system.

Training:   0%|          | 0/5 [00:13<?, ?step/s]
/usr/local/lib/python3.12/multiprocessing/resource_tracker.py:279: UserWarning: resource_tracker: There appear to be 15 leaked shared_memory objects to clean up at shutdown
  warnings.warn('resource_tracker: There appear to be %d '
XPK End: Tue Apr 21 01:03:44 UTC 2026
EXIT_CODE=1