MaxView

← Back to run

Log Summary

XPK Start: Mon Apr 20 19:45:05 UTC 2026
2026-04-20 19:45:22.142225: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)
I0420 19:45:25.823624 136380168128320 max_utils.py:273] Attempting to initialize the jax distributed system...
INFO:2026-04-20 19:45:34,866:jax._src.distributed:149: Starting JAX distributed service on [::]:8482
I0420 19:45:34.866165 136380168128320 distributed.py:149] Starting JAX distributed service on [::]:8482
INFO:2026-04-20 19:45:34,868:jax._src.distributed:166: Connecting to JAX distributed service on mt-07-distill-smoke-cgg8k-slice-job-0-0.mt-07-distill-smoke-cgg8k:8482
I0420 19:45:34.868508 136380168128320 distributed.py:166] Connecting to JAX distributed service on mt-07-distill-smoke-cgg8k-slice-job-0-0.mt-07-distill-smoke-cgg8k:8482
I0420 19:45:35.987337 136380168128320 max_utils.py:284] Jax distributed system initialized!
I0420 19:45:42.285028 136380168128320 max_utils.py:244] Jax distributed system is already initialized.
I0420 19:45:42.761775 136380168128320 max_utils.py:244] Jax distributed system is already initialized.
I0420 19:45:42.763440 136380168128320 tokenizer.py:245] Tokenizer path: meta-llama/Llama-2-7b-chat-hf
I0420 19:45:42.763506 136380168128320 tokenizer.py:224] Loading HF tokenizer: meta-llama/Llama-2-7b-chat-hf
I0420 19:45:46.805620 136380168128320 _schedule.py:129] A polynomial schedule was set with a non-positive `transition_steps` value; this results in a constant schedule with value `init_value`.
I0420 19:45:46.808645 136380168128320 maxtext_utils.py:1631] Num_devices: 32, shape (1, 4, 1, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1)
I0420 19:45:46.808781 136380168128320 train_distill.py:582] Applying logical axis rules for model initialization and training...
I0420 19:45:46.808853 136380168128320 train_distill.py:586] Loading Student from ...
I0420 19:45:46.808881 136380168128320 train_distill.py:169] --- Student Configuration ---
I0420 19:45:46.808904 136380168128320 train_distill.py:170]   Model Name:      gpt3-52k
I0420 19:45:46.808927 136380168128320 train_distill.py:171]   Dimensions:      1 Layers, 16 Emb Dim, 8 Head Dim
I0420 19:45:46.808945 136380168128320 train_distill.py:174]   Attention Heads: 2 Query, 2 KV
I0420 19:45:46.808963 136380168128320 train_distill.py:175]   Vocab Size:      32000
I0420 19:45:46.808981 136380168128320 train_distill.py:176]   Checkpoint:      
I0420 19:45:46.808998 136380168128320 train_distill.py:460] Initializing model: gpt3-52k...
I0420 19:45:48.080358 136380168128320 train_distill.py:600] Loading Teacher from gs://lance-maxtext/pt_seed_ckpts/pt_seed_ckpts/pt_seed_ckpt_gpt352k_v32k_linen/checkpoints/4/items...
I0420 19:45:48.080489 136380168128320 train_distill.py:169] --- Teacher Configuration ---
I0420 19:45:48.080519 136380168128320 train_distill.py:170]   Model Name:      gpt3-52k
I0420 19:45:48.080545 136380168128320 train_distill.py:171]   Dimensions:      1 Layers, 16 Emb Dim, 8 Head Dim
I0420 19:45:48.080567 136380168128320 train_distill.py:174]   Attention Heads: 2 Query, 2 KV
I0420 19:45:48.080586 136380168128320 train_distill.py:175]   Vocab Size:      32000
I0420 19:45:48.080605 136380168128320 train_distill.py:176]   Checkpoint:      gs://lance-maxtext/pt_seed_ckpts/pt_seed_ckpts/pt_seed_ckpt_gpt352k_v32k_linen/checkpoints/4/items
I0420 19:45:48.080626 136380168128320 train_distill.py:460] Initializing model: gpt3-52k...
I0420 19:45:49.040537 136380168128320 pytree_checkpoint_handler.py:577] save_device_host_concurrent_bytes=None
I0420 19:45:49.040992 136380168128320 base_pytree_checkpoint_handler.py:411] Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=True, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x7c08c3f7c530>, enable_pinned_host_transfer=False, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB)
I0420 19:45:49.041054 136380168128320 abstract_checkpointer.py:35] orbax-checkpoint version: 0.11.28
W0420 19:45:49.567520 136380168128320 checkpoint.py:202] Metadata file does not exist: gs://lance-maxtext/pt_seed_ckpts/pt_seed_ckpts/pt_seed_ckpt_gpt352k_v32k_linen/checkpoints/4/items/_CHECKPOINT_METADATA
I0420 19:45:50.120752    2131 google_auth_provider.cc:181] Running on GCE, using service account 562977990677-compute@developer.gserviceaccount.com
I0420 19:45:51.256231 136380168128320 checkpointer.py:304] Restoring checkpoint from gs://lance-maxtext/pt_seed_ckpts/pt_seed_ckpts/pt_seed_ckpt_gpt352k_v32k_linen/checkpoints/4/items.
W0420 19:45:53.305360 136380168128320 transform_utils.py:230] The transformations API will eventually be replaced by an upgraded design. The current API will not be removed until this point, but it will no longer be actively worked on.
I0420 19:45:53.305807 136380168128320 transform_utils.py:288] The following keys are not loaded from the original tree after applying specified transforms: params/params/decoder/dropout/rngs/aqt/count, params/params/decoder/dropout/rngs/aqt/key, params/params/decoder/dropout/rngs/dropout/count, params/params/decoder/dropout/rngs/dropout/key, params/params/decoder/dropout/rngs/params/count, params/params/decoder/dropout/rngs/params/key, params/params/decoder/layers/dropout/rngs/aqt/count, params/params/decoder/layers/dropout/rngs/aqt/key, params/params/decoder/layers/dropout/rngs/dropout/count, params/params/decoder/layers/dropout/rngs/dropout/key, params/params/decoder/layers/dropout/rngs/params/count, params/params/decoder/layers/dropout/rngs/params/key, params/params/decoder/layers/mlp/dropout/rngs/aqt/count, params/params/decoder/layers/mlp/dropout/rngs/aqt/key, params/params/decoder/layers/mlp/dropout/rngs/dropout/count, params/params/decoder/layers/mlp/dropout/rngs/dropout/key, params/params/decoder/layers/mlp/dropout/rngs/params/count, params/params/decoder/layers/mlp/dropout/rngs/params/key, params/params/decoder/layers/rngs/aqt/count, params/params/decoder/layers/rngs/aqt/key, params/params/decoder/layers/rngs/dropout/count, params/params/decoder/layers/rngs/dropout/key, params/params/decoder/layers/rngs/params/count, params/params/decoder/layers/rngs/params/key, params/params/decoder/rngs/aqt/count, params/params/decoder/rngs/aqt/key, params/params/decoder/rngs/dropout/count, params/params/decoder/rngs/dropout/key, params/params/decoder/rngs/params/count, params/params/decoder/rngs/params/key
I0420 19:45:53.719532 136380168128320 checkpointer.py:318] Finished restoring checkpoint in 2.84 seconds from gs://lance-maxtext/pt_seed_ckpts/pt_seed_ckpts/pt_seed_ckpt_gpt352k_v32k_linen/checkpoints/4/items.
I0420 19:45:54.424439 136380168128320 train_distill.py:626] Initializing Data Iterators via MaxText pipeline...
I0420 19:45:54.489692 136380168128320 config.py:112] TensorFlow version 2.20.0 available.
I0420 19:45:54.490178 136380168128320 config.py:125] JAX version 0.8.3 available.
E0420 19:45:56.524330 136380168128320 packing.py:209] PackAndBatchOperation is deprecated. Please use lazy_dataset.FirstFitPackIterDataset instead.
I0420 19:45:56.524553 136380168128320 data_loader.py:408] Adding CopyNumPyArrayToSharedMemory MapTransform.
I0420 19:45:56.527637 136380168128320 train_distill.py:405] Input Pipeline Checkpointing: DISABLED
I0420 19:45:56.527716 136380168128320 train_distill.py:409] Reason: Iterator 'MultiHostDataLoadIterator' is not recognized as Grain (dataset_type='DatasetType.HF', has_save=False)
I0420 19:45:56.527783 136380168128320 pytree_checkpoint_handler.py:577] save_device_host_concurrent_bytes=None
I0420 19:45:56.527884 136380168128320 base_pytree_checkpoint_handler.py:411] Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=False, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x7c08c3f7c530>, enable_pinned_host_transfer=False, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB)
I0420 19:45:56.527953 136380168128320 pytree_checkpoint_handler.py:577] save_device_host_concurrent_bytes=None
I0420 19:45:56.528011 136380168128320 base_pytree_checkpoint_handler.py:411] Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=False, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x7c08c3f7c530>, enable_pinned_host_transfer=False, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB)
I0420 19:45:56.528084 136380168128320 checkpoint_manager.py:702] [process=7][thread=MainThread] CheckpointManager init: checkpointers=None, item_names=None, item_handlers={'model_params': <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7bf22c294a70>, 'optimizer_state': <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7bf106005010>, 'custom_metadata': <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7bf106004fe0>}, handler_registry=None
I0420 19:45:56.528336 136380168128320 composite_checkpoint_handler.py:237] Deferred registration for item: "model_params". Adding handler `<orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7bf22c294a70>` for item "model_params" and save args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>` to `_handler_registry`.
I0420 19:45:56.528389 136380168128320 composite_checkpoint_handler.py:237] Deferred registration for item: "optimizer_state". Adding handler `<orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7bf106005010>` for item "optimizer_state" and save args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>` to `_handler_registry`.
I0420 19:45:56.528420 136380168128320 composite_checkpoint_handler.py:237] Deferred registration for item: "custom_metadata". Adding handler `<orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7bf106004fe0>` for item "custom_metadata" and save args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>` to `_handler_registry`.
I0420 19:45:56.528460 136380168128320 composite_checkpoint_handler.py:237] Deferred registration for item: "metrics". Adding handler `<orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7bf105c2b770>` for item "metrics" and save args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>` to `_handler_registry`.
I0420 19:45:56.528502 136380168128320 composite_checkpoint_handler.py:505] Initialized registry DefaultCheckpointHandlerRegistry({('model_params', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7bf22c294a70>, ('model_params', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7bf22c294a70>, ('optimizer_state', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7bf106005010>, ('optimizer_state', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7bf106005010>, ('custom_metadata', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7bf106004fe0>, ('custom_metadata', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7bf106004fe0>, ('metrics', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7bf105c2b770>, ('metrics', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7bf105c2b770>}).
I0420 19:45:56.530400 136380168128320 async_checkpointer.py:177] [process=7][thread=MainThread] Using barrier_sync_fn: <function get_barrier_sync_fn.<locals>._fn at 0x7bf105cd9ee0> timeout: 600 secs and primary_host=0 for async checkpoint writes
I0420 19:45:59.117197 136380168128320 checkpoint_manager.py:1788] Found 0 checkpoint steps in gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_set_defaults_true_20260420_190413/pt_distill_nnx_xpk_feat_nnx_set_defaults_true_20260420_190413_07_distill_smoke/checkpoints
I0420 19:45:59.147717 136380168128320 checkpoint_manager.py:921] [process=7][thread=MainThread] CheckpointManager created,  primary_host=0, CheckpointManagerOptions=CheckpointManagerOptions(save_interval_steps=2000, max_to_keep=None, keep_time_interval=None, keep_period=None, should_keep_fn=None, best_fn=None, best_mode='max', keep_checkpoints_without_metrics=True, step_prefix=None, step_format_fixed_length=None, step_name_format=None, create=True, cleanup_tmp_directories=False, save_on_steps=frozenset(), single_host_load_and_broadcast=False, todelete_subdir=None, todelete_full_path=None, enable_hns=False, enable_background_delete=False, read_only=False, enable_async_checkpointing=True, async_options=None, multiprocessing_options=MultiprocessingOptions(primary_host=0, active_processes=None, barrier_sync_key_prefix=None), should_save_fn=None, file_options=FileOptions(path_permission_mode=None), save_root_metadata=True, temporary_path_class=None, save_decision_policy=None, preservation_policy=None, prevent_write_metrics=False, enable_should_save_is_saving_in_progress_check=True, enable_per_process_directory_creation=False), root_directory=gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_set_defaults_true_20260420_190413/pt_distill_nnx_xpk_feat_nnx_set_defaults_true_20260420_190413_07_distill_smoke/checkpoints: <orbax.checkpoint.checkpoint_manager.CheckpointManager object at 0x7bf106004e90>
I0420 19:45:59.147843 136380168128320 pytree_checkpoint_handler.py:577] save_device_host_concurrent_bytes=None
I0420 19:45:59.147909 136380168128320 base_pytree_checkpoint_handler.py:411] Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=False, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x7c08c3f7c530>, enable_pinned_host_transfer=False, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB)
I0420 19:45:59.147954 136380168128320 pytree_checkpoint_handler.py:577] save_device_host_concurrent_bytes=None
I0420 19:45:59.147998 136380168128320 base_pytree_checkpoint_handler.py:411] Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=False, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x7c08c3f7c530>, enable_pinned_host_transfer=False, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB)
I0420 19:45:59.148060 136380168128320 checkpoint_manager.py:1983] [process=7][thread=MainThread][wait_until_finished] No Save Finalize thread to wait for. Returning.
I0420 19:45:59.148133 136380168128320 checkpoint.py:459] Closing _NonBlockingMetadataStore(enable_write=True, _write_lock=<locked _thread.RLock object owner=136380168128320 count=1 at 0x7bf105ca9580>, _store_impl=<orbax.checkpoint._src.metadata.checkpoint._MetadataStoreImpl object at 0x7bf105c2b5c0>, _single_thread_executor=<concurrent.futures.thread.ThreadPoolExecutor object at 0x7bf105c2b5f0>, _write_futures=[])
I0420 19:45:59.148504 136380168128320 checkpoint.py:459] Closing _NonBlockingMetadataStore(enable_write=True, _write_lock=<locked _thread.RLock object owner=136380168128320 count=1 at 0x7bf105ca9580>, _store_impl=<orbax.checkpoint._src.metadata.checkpoint._MetadataStoreImpl object at 0x7bf105c2b5c0>, _single_thread_executor=<concurrent.futures.thread.ThreadPoolExecutor object at 0x7bf105c2b5f0>, _write_futures=[])
I0420 19:45:59.148539 136380168128320 checkpoint.py:459] Closing _NonBlockingMetadataStore(enable_write=True, _write_lock=<locked _thread.RLock object owner=136380168128320 count=1 at 0x7bf105ca9580>, _store_impl=<orbax.checkpoint._src.metadata.checkpoint._MetadataStoreImpl object at 0x7bf105c2b5c0>, _single_thread_executor=<concurrent.futures.thread.ThreadPoolExecutor object at 0x7bf105c2b5f0>, _write_futures=[])
I0420 19:45:59.148586 136380168128320 checkpoint_manager.py:702] [process=7][thread=MainThread] CheckpointManager init: checkpointers=None, item_names=None, item_handlers={'model_params': <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7bf106004920>, 'optimizer_state': <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7bf105c2a5a0>, 'custom_metadata': <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7bf105c2a510>, 'iter': <maxtext.common.checkpointing.GrainCheckpointHandler object at 0x7bf105c29d90>}, handler_registry=None
I0420 19:45:59.148734 136380168128320 composite_checkpoint_handler.py:237] Deferred registration for item: "model_params". Adding handler `<orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7bf106004920>` for item "model_params" and save args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>` to `_handler_registry`.
I0420 19:45:59.148785 136380168128320 composite_checkpoint_handler.py:237] Deferred registration for item: "optimizer_state". Adding handler `<orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7bf105c2a5a0>` for item "optimizer_state" and save args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>` to `_handler_registry`.
I0420 19:45:59.148820 136380168128320 composite_checkpoint_handler.py:237] Deferred registration for item: "custom_metadata". Adding handler `<orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7bf105c2a510>` for item "custom_metadata" and save args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>` to `_handler_registry`.
I0420 19:45:59.148850 136380168128320 composite_checkpoint_handler.py:237] Deferred registration for item: "iter". Adding handler `<maxtext.common.checkpointing.GrainCheckpointHandler object at 0x7bf105c29d90>` for item "iter" and save args `<class 'maxtext.common.checkpointing.GrainCheckpointSave'>` and restore args `<class 'maxtext.common.checkpointing.GrainCheckpointRestore'>` to `_handler_registry`.
I0420 19:45:59.148874 136380168128320 composite_checkpoint_handler.py:237] Deferred registration for item: "metrics". Adding handler `<orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7bf105c29a00>` for item "metrics" and save args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>` to `_handler_registry`.
I0420 19:45:59.148898 136380168128320 composite_checkpoint_handler.py:505] Initialized registry DefaultCheckpointHandlerRegistry({('model_params', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7bf106004920>, ('model_params', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7bf106004920>, ('optimizer_state', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7bf105c2a5a0>, ('optimizer_state', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7bf105c2a5a0>, ('custom_metadata', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7bf105c2a510>, ('custom_metadata', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7bf105c2a510>, ('iter', <class 'maxtext.common.checkpointing.GrainCheckpointSave'>): <maxtext.common.checkpointing.GrainCheckpointHandler object at 0x7bf105c29d90>, ('iter', <class 'maxtext.common.checkpointing.GrainCheckpointRestore'>): <maxtext.common.checkpointing.GrainCheckpointHandler object at 0x7bf105c29d90>, ('metrics', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7bf105c29a00>, ('metrics', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7bf105c29a00>}).
I0420 19:45:59.148974 136380168128320 async_checkpointer.py:177] [process=7][thread=MainThread] Using barrier_sync_fn: <function get_barrier_sync_fn.<locals>._fn at 0x7bf105cda020> timeout: 600 secs and primary_host=0 for async checkpoint writes
I0420 19:45:59.541131 136380168128320 checkpoint_manager.py:1788] Found 0 checkpoint steps in gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_set_defaults_true_20260420_190413/pt_distill_nnx_xpk_feat_nnx_set_defaults_true_20260420_190413_07_distill_smoke/checkpoints
I0420 19:45:59.581771 136380168128320 checkpoint_manager.py:921] [process=7][thread=MainThread] CheckpointManager created,  primary_host=0, CheckpointManagerOptions=CheckpointManagerOptions(save_interval_steps=2000, max_to_keep=None, keep_time_interval=None, keep_period=None, should_keep_fn=None, best_fn=None, best_mode='max', keep_checkpoints_without_metrics=True, step_prefix=None, step_format_fixed_length=None, step_name_format=None, create=True, cleanup_tmp_directories=False, save_on_steps=frozenset(), single_host_load_and_broadcast=False, todelete_subdir=None, todelete_full_path=None, enable_hns=False, enable_background_delete=False, read_only=False, enable_async_checkpointing=True, async_options=None, multiprocessing_options=MultiprocessingOptions(primary_host=0, active_processes=None, barrier_sync_key_prefix=None), should_save_fn=None, file_options=FileOptions(path_permission_mode=None), save_root_metadata=True, temporary_path_class=None, save_decision_policy=None, preservation_policy=None, prevent_write_metrics=False, enable_should_save_is_saving_in_progress_check=True, enable_per_process_directory_creation=False), root_directory=gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_set_defaults_true_20260420_190413/pt_distill_nnx_xpk_feat_nnx_set_defaults_true_20260420_190413_07_distill_smoke/checkpoints: <orbax.checkpoint.checkpoint_manager.CheckpointManager object at 0x7bef3059fbc0>
I0420 19:45:59.581968 136380168128320 train_distill.py:673] Starting Distillation Training...
I0420 19:45:59.582093 136380168128320 peft_trainer.py:590] Training with mesh: Mesh('diloco': 1, 'data': 4, 'stage': 1, 'fsdp': 8, 'fsdp_transpose': 1, 'sequence': 1, 'context': 1, 'context_autoregressive': 1, 'tensor': 1, 'tensor_transpose': 1, 'tensor_sequence': 1, 'expert': 1, 'autoregressive': 1, axis_types=(Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto))
I0420 19:45:59.709984 136380168128320 peft_trainer.py:600] Compiled train_step cache size: 0

Training:   0%|          | 0/5 [00:00<?, ?step/s]I0420 19:45:59.711937 136236790425344 grain_pool.py:367] Grain pool will use 1 processes.
I0420 19:45:59.738581 136236790425344 grain_pool.py:440] Grain pool will start child processes.
I0420 19:45:59.743619 136236790425344 grain_pool.py:448] Grain pool started all child processes.
2026-04-20 19:46:05.815930: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)
I0420 19:46:08.691818 136380168128320 utils.py:86] Train loop finished in: 8.9812 seconds
Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/deps/src/maxtext/trainers/post_train/distillation/train_distill.py", line 747, in <module>
    app.run(main)
  File "/usr/local/lib/python3.12/site-packages/absl/app.py", line 316, in run
    _run_main(main, args)
  File "/usr/local/lib/python3.12/site-packages/absl/app.py", line 261, in _run_main
    sys.exit(main(argv))
             ^^^^^^^^^^
  File "/deps/src/maxtext/trainers/post_train/distillation/train_distill.py", line 743, in main
    train_distill(student_config, teacher_config, is_offline, global_config.offline_data_dir)
  File "/deps/src/maxtext/trainers/post_train/distillation/train_distill.py", line 675, in train_distill
    trainer.train(train_iter, eval_iter)
  File "/usr/local/lib/python3.12/site-packages/tunix/sft/peft_trainer.py", line 659, in train
    train_example = sharding_utils.shard_input(
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/tunix/sft/sharding_utils.py", line 58, in shard_input
    return jax.tree.map(
           ^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/tree.py", line 155, in map
    return tree_util.tree_map(f, tree, *rest, is_leaf=is_leaf)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/tree_util.py", line 362, in tree_map
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/tree_util.py", line 362, in <genexpr>
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
                             ^^^^^^
  File "/usr/local/lib/python3.12/site-packages/tunix/sft/sharding_utils.py", line 59, in <lambda>
    lambda x: jax.make_array_from_process_local_data(
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/array.py", line 986, in make_array_from_process_local_data
    out = [_array_from_process_local_data(data, s, shape)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/array.py", line 1048, in _array_from_process_local_data
    return make_array_from_callback(global_shape, sharding, cb)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/array.py", line 845, in make_array_from_callback
    per_device_values = api.device_put(per_device_values, devices)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/api.py", line 2729, in device_put
    out_flat = dispatch._batched_device_put_impl(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/dispatch.py", line 558, in _batched_device_put_impl
    y = _device_put_impl(x, device=device, src=src, copy=cp, aval=aval)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/dispatch.py", line 545, in _device_put_impl
    return _device_put_sharding_impl(x, aval, device, copy)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/dispatch.py", line 487, in _device_put_sharding_impl
    raise ValueError(
ValueError: device_put's first argument must be a fully addressable array, but got value with devices {TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=11, process_index=3, coords=(3,2,0), core_on_chip=0), TpuDevice(id=25, process_index=6, coords=(1,6,0), core_on_chip=0), TpuDevice(id=30, process_index=7, coords=(2,7,0), core_on_chip=0), TpuDevice(id=15, process_index=3, coords=(3,3,0), core_on_chip=0), TpuDevice(id=2, process_index=1, coords=(2,0,0), core_on_chip=0), TpuDevice(id=8, process_index=2, coords=(0,2,0), core_on_chip=0), TpuDevice(id=16, process_index=4, coords=(0,4,0), core_on_chip=0), TpuDevice(id=12, process_index=2, coords=(0,3,0), core_on_chip=0), TpuDevice(id=3, process_index=1, coords=(3,0,0), core_on_chip=0), TpuDevice(id=28, process_index=6, coords=(0,7,0), core_on_chip=0), TpuDevice(id=18, process_index=5, coords=(2,4,0), core_on_chip=0), TpuDevice(id=31, process_index=7, coords=(3,7,0), core_on_chip=0), TpuDevice(id=17, process_index=4, coords=(1,4,0), core_on_chip=0), TpuDevice(id=22, process_index=5, coords=(2,5,0), core_on_chip=0), TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=29, process_index=6, coords=(1,7,0), core_on_chip=0), TpuDevice(id=10, process_index=3, coords=(2,2,0), core_on_chip=0), TpuDevice(id=23, process_index=5, coords=(3,5,0), core_on_chip=0), TpuDevice(id=6, process_index=1, coords=(2,1,0), core_on_chip=0), TpuDevice(id=13, process_index=2, coords=(1,3,0), core_on_chip=0), TpuDevice(id=19, process_index=5, coords=(3,4,0), core_on_chip=0), TpuDevice(id=7, process_index=1, coords=(3,1,0), core_on_chip=0), TpuDevice(id=20, process_index=4, coords=(0,5,0), core_on_chip=0), TpuDevice(id=26, process_index=7, coords=(2,6,0), core_on_chip=0), TpuDevice(id=21, process_index=4, coords=(1,5,0), core_on_chip=0), TpuDevice(id=14, process_index=3, coords=(2,3,0), core_on_chip=0), TpuDevice(id=24, process_index=6, coords=(0,6,0), core_on_chip=0), TpuDevice(id=27, process_index=7, coords=(3,6,0), core_on_chip=0), TpuDevice(id=9, process_index=2, coords=(1,2,0), core_on_chip=0)}
I0420 19:46:09.039382 136236790425344 grain_pool.py:542] Grain pool is exiting.
I0420 19:46:09.039483 136236790425344 grain_pool.py:547] Shutting down multiprocessing system.
I0420 19:46:10.491646 136236790425344 grain_pool.py:547] Shutting down multiprocessing system.

Training:   0%|          | 0/5 [00:13<?, ?step/s]
/usr/local/lib/python3.12/multiprocessing/resource_tracker.py:279: UserWarning: resource_tracker: There appear to be 15 leaked shared_memory objects to clean up at shutdown
  warnings.warn('resource_tracker: There appear to be %d '
XPK End: Mon Apr 20 19:46:20 UTC 2026
EXIT_CODE=1