XPK Start: Mon Apr 20 19:26:38 UTC 2026
2026-04-20 19:26:55.156008: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)
I0420 19:26:58.749278 133330057627456 max_utils.py:273] Attempting to initialize the jax distributed system...
INFO:2026-04-20 19:27:07,789:jax._src.distributed:149: Starting JAX distributed service on [::]:8482
I0420 19:27:07.789191 133330057627456 distributed.py:149] Starting JAX distributed service on [::]:8482
INFO:2026-04-20 19:27:07,791:jax._src.distributed:166: Connecting to JAX distributed service on mt-07-distill-smoke-d3wxz-slice-job-0-0.mt-07-distill-smoke-d3wxz:8482
I0420 19:27:07.791523 133330057627456 distributed.py:166] Connecting to JAX distributed service on mt-07-distill-smoke-d3wxz-slice-job-0-0.mt-07-distill-smoke-d3wxz:8482
I0420 19:27:08.614038 133330057627456 max_utils.py:284] Jax distributed system initialized!
I0420 19:27:15.037862 133330057627456 max_utils.py:244] Jax distributed system is already initialized.
I0420 19:27:15.512870 133330057627456 max_utils.py:244] Jax distributed system is already initialized.
I0420 19:27:15.514056 133330057627456 tokenizer.py:245] Tokenizer path: meta-llama/Llama-2-7b-chat-hf
I0420 19:27:15.514120 133330057627456 tokenizer.py:224] Loading HF tokenizer: meta-llama/Llama-2-7b-chat-hf
I0420 19:27:19.404189 133330057627456 _schedule.py:129] A polynomial schedule was set with a non-positive `transition_steps` value; this results in a constant schedule with value `init_value`.
I0420 19:27:19.407227 133330057627456 maxtext_utils.py:1631] Num_devices: 32, shape (1, 4, 1, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1)
I0420 19:27:19.407341 133330057627456 train_distill.py:582] Applying logical axis rules for model initialization and training...
I0420 19:27:19.407416 133330057627456 train_distill.py:586] Loading Student from ...
I0420 19:27:19.407447 133330057627456 train_distill.py:169] --- Student Configuration ---
I0420 19:27:19.407471 133330057627456 train_distill.py:170] Model Name: gpt3-52k
I0420 19:27:19.407493 133330057627456 train_distill.py:171] Dimensions: 1 Layers, 16 Emb Dim, 8 Head Dim
I0420 19:27:19.407513 133330057627456 train_distill.py:174] Attention Heads: 2 Query, 2 KV
I0420 19:27:19.407531 133330057627456 train_distill.py:175] Vocab Size: 32000
I0420 19:27:19.407549 133330057627456 train_distill.py:176] Checkpoint:
I0420 19:27:19.407567 133330057627456 train_distill.py:460] Initializing model: gpt3-52k...
I0420 19:27:20.803969 133330057627456 train_distill.py:600] Loading Teacher from gs://lance-maxtext/pt_seed_ckpts/pt_seed_ckpts/pt_seed_ckpt_gpt352k_v32k_linen/checkpoints/4/items...
I0420 19:27:20.804096 133330057627456 train_distill.py:169] --- Teacher Configuration ---
I0420 19:27:20.804127 133330057627456 train_distill.py:170] Model Name: gpt3-52k
I0420 19:27:20.804151 133330057627456 train_distill.py:171] Dimensions: 1 Layers, 16 Emb Dim, 8 Head Dim
I0420 19:27:20.804173 133330057627456 train_distill.py:174] Attention Heads: 2 Query, 2 KV
I0420 19:27:20.804191 133330057627456 train_distill.py:175] Vocab Size: 32000
I0420 19:27:20.804212 133330057627456 train_distill.py:176] Checkpoint: gs://lance-maxtext/pt_seed_ckpts/pt_seed_ckpts/pt_seed_ckpt_gpt352k_v32k_linen/checkpoints/4/items
I0420 19:27:20.804233 133330057627456 train_distill.py:460] Initializing model: gpt3-52k...
I0420 19:27:21.842482 133330057627456 pytree_checkpoint_handler.py:577] save_device_host_concurrent_bytes=None
I0420 19:27:21.842904 133330057627456 base_pytree_checkpoint_handler.py:411] Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=True, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x79429e4f8b00>, enable_pinned_host_transfer=False, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB)
I0420 19:27:21.842967 133330057627456 abstract_checkpointer.py:35] orbax-checkpoint version: 0.11.28
W0420 19:27:22.403795 133330057627456 checkpoint.py:202] Metadata file does not exist: gs://lance-maxtext/pt_seed_ckpts/pt_seed_ckpts/pt_seed_ckpt_gpt352k_v32k_linen/checkpoints/4/items/_CHECKPOINT_METADATA
I0420 19:27:22.936728 2153 google_auth_provider.cc:181] Running on GCE, using service account 562977990677-compute@developer.gserviceaccount.com
I0420 19:27:24.089306 133330057627456 checkpointer.py:304] Restoring checkpoint from gs://lance-maxtext/pt_seed_ckpts/pt_seed_ckpts/pt_seed_ckpt_gpt352k_v32k_linen/checkpoints/4/items.
W0420 19:27:26.629416 133330057627456 transform_utils.py:230] The transformations API will eventually be replaced by an upgraded design. The current API will not be removed until this point, but it will no longer be actively worked on.
I0420 19:27:26.629842 133330057627456 transform_utils.py:288] The following keys are not loaded from the original tree after applying specified transforms: params/params/decoder/dropout/rngs/aqt/count, params/params/decoder/dropout/rngs/aqt/key, params/params/decoder/dropout/rngs/dropout/count, params/params/decoder/dropout/rngs/dropout/key, params/params/decoder/dropout/rngs/params/count, params/params/decoder/dropout/rngs/params/key, params/params/decoder/layers/dropout/rngs/aqt/count, params/params/decoder/layers/dropout/rngs/aqt/key, params/params/decoder/layers/dropout/rngs/dropout/count, params/params/decoder/layers/dropout/rngs/dropout/key, params/params/decoder/layers/dropout/rngs/params/count, params/params/decoder/layers/dropout/rngs/params/key, params/params/decoder/layers/mlp/dropout/rngs/aqt/count, params/params/decoder/layers/mlp/dropout/rngs/aqt/key, params/params/decoder/layers/mlp/dropout/rngs/dropout/count, params/params/decoder/layers/mlp/dropout/rngs/dropout/key, params/params/decoder/layers/mlp/dropout/rngs/params/count, params/params/decoder/layers/mlp/dropout/rngs/params/key, params/params/decoder/layers/rngs/aqt/count, params/params/decoder/layers/rngs/aqt/key, params/params/decoder/layers/rngs/dropout/count, params/params/decoder/layers/rngs/dropout/key, params/params/decoder/layers/rngs/params/count, params/params/decoder/layers/rngs/params/key, params/params/decoder/rngs/aqt/count, params/params/decoder/rngs/aqt/key, params/params/decoder/rngs/dropout/count, params/params/decoder/rngs/dropout/key, params/params/decoder/rngs/params/count, params/params/decoder/rngs/params/key
I0420 19:27:26.687209 133330057627456 checkpointer.py:318] Finished restoring checkpoint in 2.97 seconds from gs://lance-maxtext/pt_seed_ckpts/pt_seed_ckpts/pt_seed_ckpt_gpt352k_v32k_linen/checkpoints/4/items.
I0420 19:27:27.046765 133330057627456 metrics_logger.py:64] WandbBackend skipped: 'wandb' library not installed.
I0420 19:27:27.310062 133330057627456 train_distill.py:626] Initializing Data Iterators via MaxText pipeline...
I0420 19:27:27.372755 133330057627456 config.py:112] TensorFlow version 2.20.0 available.
I0420 19:27:27.373269 133330057627456 config.py:125] JAX version 0.8.3 available.
E0420 19:27:29.918342 133330057627456 packing.py:209] PackAndBatchOperation is deprecated. Please use lazy_dataset.FirstFitPackIterDataset instead.
I0420 19:27:29.918570 133330057627456 data_loader.py:408] Adding CopyNumPyArrayToSharedMemory MapTransform.
I0420 19:27:29.921609 133330057627456 train_distill.py:405] Input Pipeline Checkpointing: DISABLED
I0420 19:27:29.921674 133330057627456 train_distill.py:409] Reason: Iterator 'MultiHostDataLoadIterator' is not recognized as Grain (dataset_type='DatasetType.HF', has_save=False)
I0420 19:27:29.921739 133330057627456 pytree_checkpoint_handler.py:577] save_device_host_concurrent_bytes=None
I0420 19:27:29.921813 133330057627456 base_pytree_checkpoint_handler.py:411] Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=False, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x79429e4f8b00>, enable_pinned_host_transfer=False, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB)
I0420 19:27:29.921861 133330057627456 pytree_checkpoint_handler.py:577] save_device_host_concurrent_bytes=None
I0420 19:27:29.921895 133330057627456 base_pytree_checkpoint_handler.py:411] Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=False, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x79429e4f8b00>, enable_pinned_host_transfer=False, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB)
I0420 19:27:29.921938 133330057627456 checkpoint_manager.py:702] [process=0][thread=MainThread] CheckpointManager init: checkpointers=None, item_names=None, item_handlers={'model_params': <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7938911a29f0>, 'optimizer_state': <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7938911a2a80>, 'custom_metadata': <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7938911a2b10>}, handler_registry=None
I0420 19:27:29.922142 133330057627456 composite_checkpoint_handler.py:237] Deferred registration for item: "model_params". Adding handler `<orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7938911a29f0>` for item "model_params" and save args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>` to `_handler_registry`.
I0420 19:27:29.922187 133330057627456 composite_checkpoint_handler.py:237] Deferred registration for item: "optimizer_state". Adding handler `<orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7938911a2a80>` for item "optimizer_state" and save args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>` to `_handler_registry`.
I0420 19:27:29.922214 133330057627456 composite_checkpoint_handler.py:237] Deferred registration for item: "custom_metadata". Adding handler `<orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7938911a2b10>` for item "custom_metadata" and save args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>` to `_handler_registry`.
I0420 19:27:29.922239 133330057627456 composite_checkpoint_handler.py:237] Deferred registration for item: "metrics". Adding handler `<orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7938911a2f00>` for item "metrics" and save args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>` to `_handler_registry`.
I0420 19:27:29.922267 133330057627456 composite_checkpoint_handler.py:505] Initialized registry DefaultCheckpointHandlerRegistry({('model_params', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7938911a29f0>, ('model_params', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7938911a29f0>, ('optimizer_state', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7938911a2a80>, ('optimizer_state', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7938911a2a80>, ('custom_metadata', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7938911a2b10>, ('custom_metadata', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7938911a2b10>, ('metrics', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7938911a2f00>, ('metrics', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7938911a2f00>}).
I0420 19:27:29.922685 133330057627456 async_checkpointer.py:177] [process=0][thread=MainThread] Using barrier_sync_fn: <function get_barrier_sync_fn.<locals>._fn at 0x792ae89a3ba0> timeout: 600 secs and primary_host=0 for async checkpoint writes
I0420 19:27:31.983934 133330057627456 checkpoint_manager.py:558] Created directory=gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_set_defaults_true_20260420_190413/pt_distill_linen_xpk_feat_nnx_set_defaults_true_20260420_190413_07_distill_smoke/checkpoints
I0420 19:27:32.410804 133330057627456 checkpoint_manager.py:1788] Found 0 checkpoint steps in gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_set_defaults_true_20260420_190413/pt_distill_linen_xpk_feat_nnx_set_defaults_true_20260420_190413_07_distill_smoke/checkpoints
I0420 19:27:32.426350 133330057627456 checkpoint_manager.py:921] [process=0][thread=MainThread] CheckpointManager created, primary_host=0, CheckpointManagerOptions=CheckpointManagerOptions(save_interval_steps=2000, max_to_keep=None, keep_time_interval=None, keep_period=None, should_keep_fn=None, best_fn=None, best_mode='max', keep_checkpoints_without_metrics=True, step_prefix=None, step_format_fixed_length=None, step_name_format=None, create=True, cleanup_tmp_directories=False, save_on_steps=frozenset(), single_host_load_and_broadcast=False, todelete_subdir=None, todelete_full_path=None, enable_hns=False, enable_background_delete=False, read_only=False, enable_async_checkpointing=True, async_options=None, multiprocessing_options=MultiprocessingOptions(primary_host=0, active_processes=None, barrier_sync_key_prefix=None), should_save_fn=None, file_options=FileOptions(path_permission_mode=None), save_root_metadata=True, temporary_path_class=None, save_decision_policy=None, preservation_policy=None, prevent_write_metrics=False, enable_should_save_is_saving_in_progress_check=True, enable_per_process_directory_creation=False), root_directory=gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_set_defaults_true_20260420_190413/pt_distill_linen_xpk_feat_nnx_set_defaults_true_20260420_190413_07_distill_smoke/checkpoints: <orbax.checkpoint.checkpoint_manager.CheckpointManager object at 0x7938911a2b40>
I0420 19:27:32.426472 133330057627456 pytree_checkpoint_handler.py:577] save_device_host_concurrent_bytes=None
I0420 19:27:32.426537 133330057627456 base_pytree_checkpoint_handler.py:411] Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=False, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x79429e4f8b00>, enable_pinned_host_transfer=False, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB)
I0420 19:27:32.426585 133330057627456 pytree_checkpoint_handler.py:577] save_device_host_concurrent_bytes=None
I0420 19:27:32.426645 133330057627456 base_pytree_checkpoint_handler.py:411] Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=False, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x79429e4f8b00>, enable_pinned_host_transfer=False, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB)
I0420 19:27:32.426698 133330057627456 checkpoint_manager.py:1983] [process=0][thread=MainThread][wait_until_finished] No Save Finalize thread to wait for. Returning.
I0420 19:27:32.426772 133330057627456 checkpoint.py:459] Closing _NonBlockingMetadataStore(enable_write=True, _write_lock=<locked _thread.RLock object owner=133330057627456 count=1 at 0x792aea67abc0>, _store_impl=<orbax.checkpoint._src.metadata.checkpoint._MetadataStoreImpl object at 0x7938911a2d20>, _single_thread_executor=<concurrent.futures.thread.ThreadPoolExecutor object at 0x7938911a2d50>, _write_futures=[])
I0420 19:27:32.427166 133330057627456 checkpoint.py:459] Closing _NonBlockingMetadataStore(enable_write=True, _write_lock=<locked _thread.RLock object owner=133330057627456 count=1 at 0x792aea67abc0>, _store_impl=<orbax.checkpoint._src.metadata.checkpoint._MetadataStoreImpl object at 0x7938911a2d20>, _single_thread_executor=<concurrent.futures.thread.ThreadPoolExecutor object at 0x7938911a2d50>, _write_futures=[])
I0420 19:27:32.427197 133330057627456 checkpoint.py:459] Closing _NonBlockingMetadataStore(enable_write=True, _write_lock=<locked _thread.RLock object owner=133330057627456 count=1 at 0x792aea67abc0>, _store_impl=<orbax.checkpoint._src.metadata.checkpoint._MetadataStoreImpl object at 0x7938911a2d20>, _single_thread_executor=<concurrent.futures.thread.ThreadPoolExecutor object at 0x7938911a2d50>, _write_futures=[])
I0420 19:27:32.427242 133330057627456 checkpoint_manager.py:702] [process=0][thread=MainThread] CheckpointManager init: checkpointers=None, item_names=None, item_handlers={'model_params': <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7938911a2b70>, 'optimizer_state': <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7938911a3bc0>, 'custom_metadata': <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7938911a7e60>, 'iter': <maxtext.common.checkpointing.GrainCheckpointHandler object at 0x7938911a78c0>}, handler_registry=None
I0420 19:27:32.427375 133330057627456 composite_checkpoint_handler.py:237] Deferred registration for item: "model_params". Adding handler `<orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7938911a2b70>` for item "model_params" and save args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>` to `_handler_registry`.
I0420 19:27:32.427415 133330057627456 composite_checkpoint_handler.py:237] Deferred registration for item: "optimizer_state". Adding handler `<orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7938911a3bc0>` for item "optimizer_state" and save args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>` to `_handler_registry`.
I0420 19:27:32.427441 133330057627456 composite_checkpoint_handler.py:237] Deferred registration for item: "custom_metadata". Adding handler `<orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7938911a7e60>` for item "custom_metadata" and save args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>` to `_handler_registry`.
I0420 19:27:32.427468 133330057627456 composite_checkpoint_handler.py:237] Deferred registration for item: "iter". Adding handler `<maxtext.common.checkpointing.GrainCheckpointHandler object at 0x7938911a78c0>` for item "iter" and save args `<class 'maxtext.common.checkpointing.GrainCheckpointSave'>` and restore args `<class 'maxtext.common.checkpointing.GrainCheckpointRestore'>` to `_handler_registry`.
I0420 19:27:32.427491 133330057627456 composite_checkpoint_handler.py:237] Deferred registration for item: "metrics". Adding handler `<orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7938911a79e0>` for item "metrics" and save args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>` to `_handler_registry`.
I0420 19:27:32.427517 133330057627456 composite_checkpoint_handler.py:505] Initialized registry DefaultCheckpointHandlerRegistry({('model_params', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7938911a2b70>, ('model_params', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7938911a2b70>, ('optimizer_state', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7938911a3bc0>, ('optimizer_state', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7938911a3bc0>, ('custom_metadata', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7938911a7e60>, ('custom_metadata', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7938911a7e60>, ('iter', <class 'maxtext.common.checkpointing.GrainCheckpointSave'>): <maxtext.common.checkpointing.GrainCheckpointHandler object at 0x7938911a78c0>, ('iter', <class 'maxtext.common.checkpointing.GrainCheckpointRestore'>): <maxtext.common.checkpointing.GrainCheckpointHandler object at 0x7938911a78c0>, ('metrics', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7938911a79e0>, ('metrics', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7938911a79e0>}).
I0420 19:27:32.427593 133330057627456 async_checkpointer.py:177] [process=0][thread=MainThread] Using barrier_sync_fn: <function get_barrier_sync_fn.<locals>._fn at 0x792ae89a3ce0> timeout: 600 secs and primary_host=0 for async checkpoint writes
I0420 19:27:32.810872 133330057627456 checkpoint_manager.py:1788] Found 0 checkpoint steps in gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_set_defaults_true_20260420_190413/pt_distill_linen_xpk_feat_nnx_set_defaults_true_20260420_190413_07_distill_smoke/checkpoints
I0420 19:27:32.813680 133330057627456 checkpoint_manager.py:921] [process=0][thread=MainThread] CheckpointManager created, primary_host=0, CheckpointManagerOptions=CheckpointManagerOptions(save_interval_steps=2000, max_to_keep=None, keep_time_interval=None, keep_period=None, should_keep_fn=None, best_fn=None, best_mode='max', keep_checkpoints_without_metrics=True, step_prefix=None, step_format_fixed_length=None, step_name_format=None, create=True, cleanup_tmp_directories=False, save_on_steps=frozenset(), single_host_load_and_broadcast=False, todelete_subdir=None, todelete_full_path=None, enable_hns=False, enable_background_delete=False, read_only=False, enable_async_checkpointing=True, async_options=None, multiprocessing_options=MultiprocessingOptions(primary_host=0, active_processes=None, barrier_sync_key_prefix=None), should_save_fn=None, file_options=FileOptions(path_permission_mode=None), save_root_metadata=True, temporary_path_class=None, save_decision_policy=None, preservation_policy=None, prevent_write_metrics=False, enable_should_save_is_saving_in_progress_check=True, enable_per_process_directory_creation=False), root_directory=gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_set_defaults_true_20260420_190413/pt_distill_linen_xpk_feat_nnx_set_defaults_true_20260420_190413_07_distill_smoke/checkpoints: <orbax.checkpoint.checkpoint_manager.CheckpointManager object at 0x7938911a3860>
I0420 19:27:32.813844 133330057627456 train_distill.py:673] Starting Distillation Training...
I0420 19:27:32.813933 133330057627456 peft_trainer.py:590] Training with mesh: Mesh('diloco': 1, 'data': 4, 'stage': 1, 'fsdp': 8, 'fsdp_transpose': 1, 'sequence': 1, 'context': 1, 'context_autoregressive': 1, 'tensor': 1, 'tensor_transpose': 1, 'tensor_sequence': 1, 'expert': 1, 'autoregressive': 1, axis_types=(Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto))
I0420 19:27:33.172985 133330057627456 peft_trainer.py:600] Compiled train_step cache size: 0
Training: 0%| | 0/5 [00:00<?, ?step/s]I0420 19:27:33.174842 133184964511488 grain_pool.py:367] Grain pool will use 1 processes.
I0420 19:27:33.201719 133184964511488 grain_pool.py:440] Grain pool will start child processes.
I0420 19:27:33.206917 133184964511488 grain_pool.py:448] Grain pool started all child processes.
2026-04-20 19:27:39.209721: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)
I0420 19:27:42.485465 133330057627456 utils.py:86] Train loop finished in: 9.3119 seconds
Traceback (most recent call last):
File "<frozen runpy>", line 198, in _run_module_as_main
File "<frozen runpy>", line 88, in _run_code
File "/deps/src/maxtext/trainers/post_train/distillation/train_distill.py", line 747, in <module>
app.run(main)
File "/usr/local/lib/python3.12/site-packages/absl/app.py", line 316, in run
_run_main(main, args)
File "/usr/local/lib/python3.12/site-packages/absl/app.py", line 261, in _run_main
sys.exit(main(argv))
^^^^^^^^^^
File "/deps/src/maxtext/trainers/post_train/distillation/train_distill.py", line 743, in main
train_distill(student_config, teacher_config, is_offline, global_config.offline_data_dir)
File "/deps/src/maxtext/trainers/post_train/distillation/train_distill.py", line 675, in train_distill
trainer.train(train_iter, eval_iter)
File "/usr/local/lib/python3.12/site-packages/tunix/sft/peft_trainer.py", line 659, in train
train_example = sharding_utils.shard_input(
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/site-packages/tunix/sft/sharding_utils.py", line 58, in shard_input
return jax.tree.map(
^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/site-packages/jax/_src/tree.py", line 155, in map
return tree_util.tree_map(f, tree, *rest, is_leaf=is_leaf)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/site-packages/jax/_src/tree_util.py", line 362, in tree_map
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/site-packages/jax/_src/tree_util.py", line 362, in <genexpr>
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
^^^^^^
File "/usr/local/lib/python3.12/site-packages/tunix/sft/sharding_utils.py", line 59, in <lambda>
lambda x: jax.make_array_from_process_local_data(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/site-packages/jax/_src/array.py", line 986, in make_array_from_process_local_data
out = [_array_from_process_local_data(data, s, shape)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/site-packages/jax/_src/array.py", line 1048, in _array_from_process_local_data
return make_array_from_callback(global_shape, sharding, cb)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/site-packages/jax/_src/array.py", line 845, in make_array_from_callback
per_device_values = api.device_put(per_device_values, devices)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/site-packages/jax/_src/api.py", line 2729, in device_put
out_flat = dispatch._batched_device_put_impl(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/site-packages/jax/_src/dispatch.py", line 558, in _batched_device_put_impl
y = _device_put_impl(x, device=device, src=src, copy=cp, aval=aval)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/site-packages/jax/_src/dispatch.py", line 545, in _device_put_impl
return _device_put_sharding_impl(x, aval, device, copy)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/site-packages/jax/_src/dispatch.py", line 487, in _device_put_sharding_impl
raise ValueError(
ValueError: device_put's first argument must be a fully addressable array, but got value with devices {TpuDevice(id=26, process_index=7, coords=(2,6,0), core_on_chip=0), TpuDevice(id=24, process_index=6, coords=(0,6,0), core_on_chip=0), TpuDevice(id=21, process_index=4, coords=(1,5,0), core_on_chip=0), TpuDevice(id=25, process_index=6, coords=(1,6,0), core_on_chip=0), TpuDevice(id=14, process_index=3, coords=(2,3,0), core_on_chip=0), TpuDevice(id=7, process_index=1, coords=(3,1,0), core_on_chip=0), TpuDevice(id=11, process_index=3, coords=(3,2,0), core_on_chip=0), TpuDevice(id=2, process_index=1, coords=(2,0,0), core_on_chip=0), TpuDevice(id=31, process_index=7, coords=(3,7,0), core_on_chip=0), TpuDevice(id=28, process_index=6, coords=(0,7,0), core_on_chip=0), TpuDevice(id=30, process_index=7, coords=(2,7,0), core_on_chip=0), TpuDevice(id=18, process_index=5, coords=(2,4,0), core_on_chip=0), TpuDevice(id=15, process_index=3, coords=(3,3,0), core_on_chip=0), TpuDevice(id=23, process_index=5, coords=(3,5,0), core_on_chip=0), TpuDevice(id=9, process_index=2, coords=(1,2,0), core_on_chip=0), TpuDevice(id=12, process_index=2, coords=(0,3,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=6, process_index=1, coords=(2,1,0), core_on_chip=0), TpuDevice(id=29, process_index=6, coords=(1,7,0), core_on_chip=0), TpuDevice(id=22, process_index=5, coords=(2,5,0), core_on_chip=0), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=16, process_index=4, coords=(0,4,0), core_on_chip=0), TpuDevice(id=17, process_index=4, coords=(1,4,0), core_on_chip=0), TpuDevice(id=19, process_index=5, coords=(3,4,0), core_on_chip=0), TpuDevice(id=10, process_index=3, coords=(2,2,0), core_on_chip=0), TpuDevice(id=3, process_index=1, coords=(3,0,0), core_on_chip=0), TpuDevice(id=13, process_index=2, coords=(1,3,0), core_on_chip=0), TpuDevice(id=27, process_index=7, coords=(3,6,0), core_on_chip=0), TpuDevice(id=8, process_index=2, coords=(0,2,0), core_on_chip=0), TpuDevice(id=20, process_index=4, coords=(0,5,0), core_on_chip=0), TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(1,1,0), core_on_chip=0)}
I0420 19:27:42.837149 133184964511488 grain_pool.py:542] Grain pool is exiting.
I0420 19:27:42.837267 133184964511488 grain_pool.py:547] Shutting down multiprocessing system.
I0420 19:27:44.291117 133184964511488 grain_pool.py:547] Shutting down multiprocessing system.
Training: 0%| | 0/5 [00:13<?, ?step/s]
/usr/local/lib/python3.12/multiprocessing/resource_tracker.py:279: UserWarning: resource_tracker: There appear to be 15 leaked shared_memory objects to clean up at shutdown
warnings.warn('resource_tracker: There appear to be %d '
Exception ignored in: <function GCSRecordWriter.__del__ at 0x794297ab91c0>
Traceback (most recent call last):
File "/usr/local/lib/python3.12/site-packages/tensorboardX/record_writer.py", line 134, in __del__
File "/usr/local/lib/python3.12/site-packages/tensorboardX/record_writer.py", line 158, in close
File "/usr/local/lib/python3.12/site-packages/tensorboardX/record_writer.py", line 149, in flush
File "/usr/local/lib/python3.12/copy.py", line 87, in copy
ImportError: sys.meta_path is None, Python is likely shutting down
XPK End: Mon Apr 20 19:27:54 UTC 2026
EXIT_CODE=1