feat/nnx-set-defaults-trueXPK Start: Thu Apr 23 16:03:49 UTC 2026 2026-04-23 16:04:06.050393: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303) I0423 16:04:09.812159 133516577158976 max_utils.py:273] Attempting to initialize the jax distributed system... INFO:2026-04-23 16:04:18,850:jax._src.distributed:149: Starting JAX distributed service on [::]:8482 I0423 16:04:18.850823 133516577158976 distributed.py:149] Starting JAX distributed service on [::]:8482 INFO:2026-04-23 16:04:18,853:jax._src.distributed:166: Connecting to JAX distributed service on mt-07-distill-smoke-bqca4-slice-job-0-0.mt-07-distill-smoke-bqca4:8482 I0423 16:04:18.853057 133516577158976 distributed.py:166] Connecting to JAX distributed service on mt-07-distill-smoke-bqca4-slice-job-0-0.mt-07-distill-smoke-bqca4:8482 I0423 16:04:20.517042 133516577158976 max_utils.py:284] Jax distributed system initialized! I0423 16:04:26.860422 133516577158976 max_utils.py:244] Jax distributed system is already initialized. I0423 16:04:27.346207 133516577158976 max_utils.py:244] Jax distributed system is already initialized. I0423 16:04:27.347392 133516577158976 tokenizer.py:245] Tokenizer path: meta-llama/Llama-2-7b-chat-hf I0423 16:04:27.347441 133516577158976 tokenizer.py:224] Loading HF tokenizer: meta-llama/Llama-2-7b-chat-hf I0423 16:04:31.279744 133516577158976 _schedule.py:129] A polynomial schedule was set with a non-positive `transition_steps` value; this results in a constant schedule with value `init_value`. I0423 16:04:31.282782 133516577158976 maxtext_utils.py:1631] Num_devices: 32, shape (1, 4, 1, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1) I0423 16:04:31.282908 133516577158976 train_distill.py:582] Applying logical axis rules for model initialization and training... I0423 16:04:31.282994 133516577158976 train_distill.py:586] Loading Student from ... I0423 16:04:31.283024 133516577158976 train_distill.py:169] --- Student Configuration --- I0423 16:04:31.283047 133516577158976 train_distill.py:170] Model Name: gpt3-52k I0423 16:04:31.283069 133516577158976 train_distill.py:171] Dimensions: 1 Layers, 16 Emb Dim, 8 Head Dim I0423 16:04:31.283088 133516577158976 train_distill.py:174] Attention Heads: 2 Query, 2 KV I0423 16:04:31.283106 133516577158976 train_distill.py:175] Vocab Size: 32000 I0423 16:04:31.283123 133516577158976 train_distill.py:176] Checkpoint: I0423 16:04:31.283142 133516577158976 train_distill.py:460] Initializing model: gpt3-52k... I0423 16:04:32.690222 133516577158976 train_distill.py:600] Loading Teacher from gs://lance-maxtext/pt_seed_ckpts/pt_seed_ckpts/pt_seed_ckpt_gpt352k_v32k_linen/checkpoints/4/items... I0423 16:04:32.690339 133516577158976 train_distill.py:169] --- Teacher Configuration --- I0423 16:04:32.690369 133516577158976 train_distill.py:170] Model Name: gpt3-52k I0423 16:04:32.690392 133516577158976 train_distill.py:171] Dimensions: 1 Layers, 16 Emb Dim, 8 Head Dim I0423 16:04:32.690414 133516577158976 train_distill.py:174] Attention Heads: 2 Query, 2 KV I0423 16:04:32.690435 133516577158976 train_distill.py:175] Vocab Size: 32000 I0423 16:04:32.690456 133516577158976 train_distill.py:176] Checkpoint: gs://lance-maxtext/pt_seed_ckpts/pt_seed_ckpts/pt_seed_ckpt_gpt352k_v32k_linen/checkpoints/4/items I0423 16:04:32.690475 133516577158976 train_distill.py:460] Initializing model: gpt3-52k... I0423 16:04:33.722031 133516577158976 pytree_checkpoint_handler.py:577] save_device_host_concurrent_bytes=None I0423 16:04:33.722459 133516577158976 base_pytree_checkpoint_handler.py:411] Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=True, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x796e08a3c230>, enable_pinned_host_transfer=False, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB) I0423 16:04:33.722524 133516577158976 abstract_checkpointer.py:35] orbax-checkpoint version: 0.11.28 W0423 16:04:34.237634 133516577158976 checkpoint.py:202] Metadata file does not exist: gs://lance-maxtext/pt_seed_ckpts/pt_seed_ckpts/pt_seed_ckpt_gpt352k_v32k_linen/checkpoints/4/items/_CHECKPOINT_METADATA I0423 16:04:34.776797 2133 google_auth_provider.cc:181] Running on GCE, using service account 562977990677-compute@developer.gserviceaccount.com I0423 16:04:35.611383 133516577158976 checkpointer.py:304] Restoring checkpoint from gs://lance-maxtext/pt_seed_ckpts/pt_seed_ckpts/pt_seed_ckpt_gpt352k_v32k_linen/checkpoints/4/items. W0423 16:04:38.184391 133516577158976 transform_utils.py:230] The transformations API will eventually be replaced by an upgraded design. The current API will not be removed until this point, but it will no longer be actively worked on. I0423 16:04:38.184824 133516577158976 transform_utils.py:288] The following keys are not loaded from the original tree after applying specified transforms: params/params/decoder/dropout/rngs/aqt/count, params/params/decoder/dropout/rngs/aqt/key, params/params/decoder/dropout/rngs/dropout/count, params/params/decoder/dropout/rngs/dropout/key, params/params/decoder/dropout/rngs/params/count, params/params/decoder/dropout/rngs/params/key, params/params/decoder/layers/dropout/rngs/aqt/count, params/params/decoder/layers/dropout/rngs/aqt/key, params/params/decoder/layers/dropout/rngs/dropout/count, params/params/decoder/layers/dropout/rngs/dropout/key, params/params/decoder/layers/dropout/rngs/params/count, params/params/decoder/layers/dropout/rngs/params/key, params/params/decoder/layers/mlp/dropout/rngs/aqt/count, params/params/decoder/layers/mlp/dropout/rngs/aqt/key, params/params/decoder/layers/mlp/dropout/rngs/dropout/count, params/params/decoder/layers/mlp/dropout/rngs/dropout/key, params/params/decoder/layers/mlp/dropout/rngs/params/count, params/params/decoder/layers/mlp/dropout/rngs/params/key, params/params/decoder/layers/rngs/aqt/count, params/params/decoder/layers/rngs/aqt/key, params/params/decoder/layers/rngs/dropout/count, params/params/decoder/layers/rngs/dropout/key, params/params/decoder/layers/rngs/params/count, params/params/decoder/layers/rngs/params/key, params/params/decoder/rngs/aqt/count, params/params/decoder/rngs/aqt/key, params/params/decoder/rngs/dropout/count, params/params/decoder/rngs/dropout/key, params/params/decoder/rngs/params/count, params/params/decoder/rngs/params/key I0423 16:04:38.518395 133516577158976 checkpointer.py:318] Finished restoring checkpoint in 3.29 seconds from gs://lance-maxtext/pt_seed_ckpts/pt_seed_ckpts/pt_seed_ckpt_gpt352k_v32k_linen/checkpoints/4/items. I0423 16:04:39.208953 133516577158976 train_distill.py:626] Initializing Data Iterators via MaxText pipeline... I0423 16:04:39.274759 133516577158976 config.py:112] TensorFlow version 2.20.0 available. I0423 16:04:39.275288 133516577158976 config.py:125] JAX version 0.8.3 available. E0423 16:04:41.704974 133516577158976 packing.py:209] PackAndBatchOperation is deprecated. Please use lazy_dataset.FirstFitPackIterDataset instead. I0423 16:04:41.705195 133516577158976 data_loader.py:408] Adding CopyNumPyArrayToSharedMemory MapTransform. I0423 16:04:41.708260 133516577158976 train_distill.py:405] Input Pipeline Checkpointing: DISABLED I0423 16:04:41.708324 133516577158976 train_distill.py:409] Reason: Iterator 'MultiHostDataLoadIterator' is not recognized as Grain (dataset_type='DatasetType.HF', has_save=False) I0423 16:04:41.708387 133516577158976 pytree_checkpoint_handler.py:577] save_device_host_concurrent_bytes=None I0423 16:04:41.708470 133516577158976 base_pytree_checkpoint_handler.py:411] Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=False, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x796e08a3c230>, enable_pinned_host_transfer=False, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB) I0423 16:04:41.708514 133516577158976 pytree_checkpoint_handler.py:577] save_device_host_concurrent_bytes=None I0423 16:04:41.708546 133516577158976 base_pytree_checkpoint_handler.py:411] Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=False, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x796e08a3c230>, enable_pinned_host_transfer=False, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB) I0423 16:04:41.708589 133516577158976 checkpoint_manager.py:702] [process=3][thread=MainThread] CheckpointManager init: checkpointers=None, item_names=None, item_handlers={'model_params': <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x79557c124dd0>, 'optimizer_state': <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x79557c124e00>, 'custom_metadata': <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x79557c124d70>}, handler_registry=None I0423 16:04:41.708784 133516577158976 composite_checkpoint_handler.py:237] Deferred registration for item: "model_params". Adding handler `<orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x79557c124dd0>` for item "model_params" and save args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>` to `_handler_registry`. I0423 16:04:41.708825 133516577158976 composite_checkpoint_handler.py:237] Deferred registration for item: "optimizer_state". Adding handler `<orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x79557c124e00>` for item "optimizer_state" and save args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>` to `_handler_registry`. I0423 16:04:41.708852 133516577158976 composite_checkpoint_handler.py:237] Deferred registration for item: "custom_metadata". Adding handler `<orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x79557c124d70>` for item "custom_metadata" and save args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>` to `_handler_registry`. I0423 16:04:41.708876 133516577158976 composite_checkpoint_handler.py:237] Deferred registration for item: "metrics". Adding handler `<orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x79557c16d310>` for item "metrics" and save args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>` to `_handler_registry`. I0423 16:04:41.708904 133516577158976 composite_checkpoint_handler.py:505] Initialized registry DefaultCheckpointHandlerRegistry({('model_params', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x79557c124dd0>, ('model_params', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x79557c124dd0>, ('optimizer_state', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x79557c124e00>, ('optimizer_state', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x79557c124e00>, ('custom_metadata', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x79557c124d70>, ('custom_metadata', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x79557c124d70>, ('metrics', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x79557c16d310>, ('metrics', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x79557c16d310>}). I0423 16:04:41.709340 133516577158976 async_checkpointer.py:177] [process=3][thread=MainThread] Using barrier_sync_fn: <function get_barrier_sync_fn.<locals>._fn at 0x7956463f9580> timeout: 600 secs and primary_host=0 for async checkpoint writes I0423 16:04:44.137038 133516577158976 checkpoint_manager.py:1788] Found 0 checkpoint steps in gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_set_defaults_true_20260423_155251/pt_distill_linen_xpk_feat_nnx_set_defaults_true_20260423_155251_07_distill_smoke/checkpoints I0423 16:04:44.155342 133516577158976 checkpoint_manager.py:921] [process=3][thread=MainThread] CheckpointManager created, primary_host=0, CheckpointManagerOptions=CheckpointManagerOptions(save_interval_steps=2000, max_to_keep=None, keep_time_interval=None, keep_period=None, should_keep_fn=None, best_fn=None, best_mode='max', keep_checkpoints_without_metrics=True, step_prefix=None, step_format_fixed_length=None, step_name_format=None, create=True, cleanup_tmp_directories=False, save_on_steps=frozenset(), single_host_load_and_broadcast=False, todelete_subdir=None, todelete_full_path=None, enable_hns=False, enable_background_delete=False, read_only=False, enable_async_checkpointing=True, async_options=None, multiprocessing_options=MultiprocessingOptions(primary_host=0, active_processes=None, barrier_sync_key_prefix=None), should_save_fn=None, file_options=FileOptions(path_permission_mode=None), save_root_metadata=True, temporary_path_class=None, save_decision_policy=None, preservation_policy=None, prevent_write_metrics=False, enable_should_save_is_saving_in_progress_check=True, enable_per_process_directory_creation=False), root_directory=gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_set_defaults_true_20260423_155251/pt_distill_linen_xpk_feat_nnx_set_defaults_true_20260423_155251_07_distill_smoke/checkpoints: <orbax.checkpoint.checkpoint_manager.CheckpointManager object at 0x79557c124da0> I0423 16:04:44.155474 133516577158976 pytree_checkpoint_handler.py:577] save_device_host_concurrent_bytes=None I0423 16:04:44.155543 133516577158976 base_pytree_checkpoint_handler.py:411] Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=False, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x796e08a3c230>, enable_pinned_host_transfer=False, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB) I0423 16:04:44.155578 133516577158976 pytree_checkpoint_handler.py:577] save_device_host_concurrent_bytes=None I0423 16:04:44.155609 133516577158976 base_pytree_checkpoint_handler.py:411] Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=False, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x796e08a3c230>, enable_pinned_host_transfer=False, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB) I0423 16:04:44.155646 133516577158976 checkpoint_manager.py:1983] [process=3][thread=MainThread][wait_until_finished] No Save Finalize thread to wait for. Returning. I0423 16:04:44.155698 133516577158976 checkpoint.py:459] Closing _NonBlockingMetadataStore(enable_write=True, _write_lock=<locked _thread.RLock object owner=133516577158976 count=1 at 0x7969a2599040>, _store_impl=<orbax.checkpoint._src.metadata.checkpoint._MetadataStoreImpl object at 0x79557c1251c0>, _single_thread_executor=<concurrent.futures.thread.ThreadPoolExecutor object at 0x79557c125280>, _write_futures=[]) I0423 16:04:44.156049 133516577158976 checkpoint.py:459] Closing _NonBlockingMetadataStore(enable_write=True, _write_lock=<locked _thread.RLock object owner=133516577158976 count=1 at 0x7969a2599040>, _store_impl=<orbax.checkpoint._src.metadata.checkpoint._MetadataStoreImpl object at 0x79557c1251c0>, _single_thread_executor=<concurrent.futures.thread.ThreadPoolExecutor object at 0x79557c125280>, _write_futures=[]) I0423 16:04:44.156076 133516577158976 checkpoint.py:459] Closing _NonBlockingMetadataStore(enable_write=True, _write_lock=<locked _thread.RLock object owner=133516577158976 count=1 at 0x7969a2599040>, _store_impl=<orbax.checkpoint._src.metadata.checkpoint._MetadataStoreImpl object at 0x79557c1251c0>, _single_thread_executor=<concurrent.futures.thread.ThreadPoolExecutor object at 0x79557c125280>, _write_futures=[]) I0423 16:04:44.156105 133516577158976 checkpoint_manager.py:702] [process=3][thread=MainThread] CheckpointManager init: checkpointers=None, item_names=None, item_handlers={'model_params': <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x79557c124aa0>, 'optimizer_state': <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x79557c125820>, 'custom_metadata': <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x79557c125c10>, 'iter': <maxtext.common.checkpointing.GrainCheckpointHandler object at 0x79557c126a50>}, handler_registry=None I0423 16:04:44.156204 133516577158976 composite_checkpoint_handler.py:237] Deferred registration for item: "model_params". Adding handler `<orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x79557c124aa0>` for item "model_params" and save args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>` to `_handler_registry`. I0423 16:04:44.156235 133516577158976 composite_checkpoint_handler.py:237] Deferred registration for item: "optimizer_state". Adding handler `<orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x79557c125820>` for item "optimizer_state" and save args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>` to `_handler_registry`. I0423 16:04:44.156256 133516577158976 composite_checkpoint_handler.py:237] Deferred registration for item: "custom_metadata". Adding handler `<orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x79557c125c10>` for item "custom_metadata" and save args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>` to `_handler_registry`. I0423 16:04:44.156283 133516577158976 composite_checkpoint_handler.py:237] Deferred registration for item: "iter". Adding handler `<maxtext.common.checkpointing.GrainCheckpointHandler object at 0x79557c126a50>` for item "iter" and save args `<class 'maxtext.common.checkpointing.GrainCheckpointSave'>` and restore args `<class 'maxtext.common.checkpointing.GrainCheckpointRestore'>` to `_handler_registry`. I0423 16:04:44.156306 133516577158976 composite_checkpoint_handler.py:237] Deferred registration for item: "metrics". Adding handler `<orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x79557c127440>` for item "metrics" and save args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>` to `_handler_registry`. I0423 16:04:44.156331 133516577158976 composite_checkpoint_handler.py:505] Initialized registry DefaultCheckpointHandlerRegistry({('model_params', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x79557c124aa0>, ('model_params', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x79557c124aa0>, ('optimizer_state', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x79557c125820>, ('optimizer_state', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x79557c125820>, ('custom_metadata', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x79557c125c10>, ('custom_metadata', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x79557c125c10>, ('iter', <class 'maxtext.common.checkpointing.GrainCheckpointSave'>): <maxtext.common.checkpointing.GrainCheckpointHandler object at 0x79557c126a50>, ('iter', <class 'maxtext.common.checkpointing.GrainCheckpointRestore'>): <maxtext.common.checkpointing.GrainCheckpointHandler object at 0x79557c126a50>, ('metrics', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x79557c127440>, ('metrics', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x79557c127440>}). I0423 16:04:44.156398 133516577158976 async_checkpointer.py:177] [process=3][thread=MainThread] Using barrier_sync_fn: <function get_barrier_sync_fn.<locals>._fn at 0x7956463f96c0> timeout: 600 secs and primary_host=0 for async checkpoint writes I0423 16:04:44.836634 133516577158976 checkpoint_manager.py:1788] Found 0 checkpoint steps in gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_set_defaults_true_20260423_155251/pt_distill_linen_xpk_feat_nnx_set_defaults_true_20260423_155251_07_distill_smoke/checkpoints I0423 16:04:44.839170 133516577158976 checkpoint_manager.py:921] [process=3][thread=MainThread] CheckpointManager created, primary_host=0, CheckpointManagerOptions=CheckpointManagerOptions(save_interval_steps=2000, max_to_keep=None, keep_time_interval=None, keep_period=None, should_keep_fn=None, best_fn=None, best_mode='max', keep_checkpoints_without_metrics=True, step_prefix=None, step_format_fixed_length=None, step_name_format=None, create=True, cleanup_tmp_directories=False, save_on_steps=frozenset(), single_host_load_and_broadcast=False, todelete_subdir=None, todelete_full_path=None, enable_hns=False, enable_background_delete=False, read_only=False, enable_async_checkpointing=True, async_options=None, multiprocessing_options=MultiprocessingOptions(primary_host=0, active_processes=None, barrier_sync_key_prefix=None), should_save_fn=None, file_options=FileOptions(path_permission_mode=None), save_root_metadata=True, temporary_path_class=None, save_decision_policy=None, preservation_policy=None, prevent_write_metrics=False, enable_should_save_is_saving_in_progress_check=True, enable_per_process_directory_creation=False), root_directory=gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_set_defaults_true_20260423_155251/pt_distill_linen_xpk_feat_nnx_set_defaults_true_20260423_155251_07_distill_smoke/checkpoints: <orbax.checkpoint.checkpoint_manager.CheckpointManager object at 0x79550435e840> I0423 16:04:44.839332 133516577158976 train_distill.py:673] Starting Distillation Training... I0423 16:04:44.839421 133516577158976 peft_trainer.py:590] Training with mesh: Mesh('diloco': 1, 'data': 4, 'stage': 1, 'fsdp': 8, 'fsdp_transpose': 1, 'sequence': 1, 'context': 1, 'context_autoregressive': 1, 'tensor': 1, 'tensor_transpose': 1, 'tensor_sequence': 1, 'expert': 1, 'autoregressive': 1, axis_types=(Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto)) I0423 16:04:45.194334 133516577158976 peft_trainer.py:600] Compiled train_step cache size: 0 Training: 0%| | 0/5 [00:00<?, ?step/s]I0423 16:04:45.196158 133373117925120 grain_pool.py:367] Grain pool will use 1 processes. I0423 16:04:45.222948 133373117925120 grain_pool.py:440] Grain pool will start child processes. I0423 16:04:45.228162 133373117925120 grain_pool.py:448] Grain pool started all child processes. 2026-04-23 16:04:51.294641: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303) I0423 16:04:54.786012 133516577158976 utils.py:86] Train loop finished in: 9.5911 seconds Traceback (most recent call last): File "<frozen runpy>", line 198, in _run_module_as_main File "<frozen runpy>", line 88, in _run_code File "/deps/src/maxtext/trainers/post_train/distillation/train_distill.py", line 747, in <module> app.run(main) File "/usr/local/lib/python3.12/site-packages/absl/app.py", line 316, in run _run_main(main, args) File "/usr/local/lib/python3.12/site-packages/absl/app.py", line 261, in _run_main sys.exit(main(argv)) ^^^^^^^^^^ File "/deps/src/maxtext/trainers/post_train/distillation/train_distill.py", line 743, in main train_distill(student_config, teacher_config, is_offline, global_config.offline_data_dir) File "/deps/src/maxtext/trainers/post_train/distillation/train_distill.py", line 675, in train_distill trainer.train(train_iter, eval_iter) File "/usr/local/lib/python3.12/site-packages/tunix/sft/peft_trainer.py", line 659, in train train_example = sharding_utils.shard_input( ^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/site-packages/tunix/sft/sharding_utils.py", line 58, in shard_input return jax.tree.map( ^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/site-packages/jax/_src/tree.py", line 155, in map return tree_util.tree_map(f, tree, *rest, is_leaf=is_leaf) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/site-packages/jax/_src/tree_util.py", line 362, in tree_map return treedef.unflatten(f(*xs) for xs in zip(*all_leaves)) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/site-packages/jax/_src/tree_util.py", line 362, in <genexpr> return treedef.unflatten(f(*xs) for xs in zip(*all_leaves)) ^^^^^^ File "/usr/local/lib/python3.12/site-packages/tunix/sft/sharding_utils.py", line 59, in <lambda> lambda x: jax.make_array_from_process_local_data( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/site-packages/jax/_src/array.py", line 986, in make_array_from_process_local_data out = [_array_from_process_local_data(data, s, shape) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/site-packages/jax/_src/array.py", line 1048, in _array_from_process_local_data return make_array_from_callback(global_shape, sharding, cb) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/site-packages/jax/_src/array.py", line 845, in make_array_from_callback per_device_values = api.device_put(per_device_values, devices) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/site-packages/jax/_src/api.py", line 2729, in device_put out_flat = dispatch._batched_device_put_impl( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/site-packages/jax/_src/dispatch.py", line 558, in _batched_device_put_impl y = _device_put_impl(x, device=device, src=src, copy=cp, aval=aval) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/site-packages/jax/_src/dispatch.py", line 545, in _device_put_impl return _device_put_sharding_impl(x, aval, device, copy) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/site-packages/jax/_src/dispatch.py", line 487, in _device_put_sharding_impl raise ValueError( ValueError: device_put's first argument must be a fully addressable array, but got value with devices {TpuDevice(id=28, process_index=6, coords=(0,7,0), core_on_chip=0), TpuDevice(id=9, process_index=2, coords=(1,2,0), core_on_chip=0), TpuDevice(id=30, process_index=7, coords=(2,7,0), core_on_chip=0), TpuDevice(id=19, process_index=5, coords=(3,4,0), core_on_chip=0), TpuDevice(id=11, process_index=3, coords=(3,2,0), core_on_chip=0), TpuDevice(id=12, process_index=2, coords=(0,3,0), core_on_chip=0), TpuDevice(id=2, process_index=1, coords=(2,0,0), core_on_chip=0), TpuDevice(id=18, process_index=5, coords=(2,4,0), core_on_chip=0), TpuDevice(id=16, process_index=4, coords=(0,4,0), core_on_chip=0), TpuDevice(id=25, process_index=6, coords=(1,6,0), core_on_chip=0), TpuDevice(id=15, process_index=3, coords=(3,3,0), core_on_chip=0), TpuDevice(id=13, process_index=2, coords=(1,3,0), core_on_chip=0), TpuDevice(id=23, process_index=5, coords=(3,5,0), core_on_chip=0), TpuDevice(id=31, process_index=7, coords=(3,7,0), core_on_chip=0), TpuDevice(id=29, process_index=6, coords=(1,7,0), core_on_chip=0), TpuDevice(id=8, process_index=2, coords=(0,2,0), core_on_chip=0), TpuDevice(id=20, process_index=4, coords=(0,5,0), core_on_chip=0), TpuDevice(id=3, process_index=1, coords=(3,0,0), core_on_chip=0), TpuDevice(id=7, process_index=1, coords=(3,1,0), core_on_chip=0), TpuDevice(id=22, process_index=5, coords=(2,5,0), core_on_chip=0), TpuDevice(id=26, process_index=7, coords=(2,6,0), core_on_chip=0), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=17, process_index=4, coords=(1,4,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=21, process_index=4, coords=(1,5,0), core_on_chip=0), TpuDevice(id=10, process_index=3, coords=(2,2,0), core_on_chip=0), TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=27, process_index=7, coords=(3,6,0), core_on_chip=0), TpuDevice(id=6, process_index=1, coords=(2,1,0), core_on_chip=0), TpuDevice(id=14, process_index=3, coords=(2,3,0), core_on_chip=0), TpuDevice(id=24, process_index=6, coords=(0,6,0), core_on_chip=0)} I0423 16:04:55.132720 133373117925120 grain_pool.py:542] Grain pool is exiting. I0423 16:04:55.132824 133373117925120 grain_pool.py:547] Shutting down multiprocessing system. I0423 16:04:56.578321 133373117925120 grain_pool.py:547] Shutting down multiprocessing system. Training: 0%| | 0/5 [00:13<?, ?step/s] /usr/local/lib/python3.12/multiprocessing/resource_tracker.py:279: UserWarning: resource_tracker: There appear to be 15 leaked shared_memory objects to clean up at shutdown warnings.warn('resource_tracker: There appear to be %d ' XPK End: Thu Apr 23 16:05:04 UTC 2026 EXIT_CODE=1
XPK Start: Thu Apr 23 16:14:20 UTC 2026 2026-04-23 16:14:37.055280: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303) I0423 16:14:40.620702 135432814733120 max_utils.py:273] Attempting to initialize the jax distributed system... INFO:2026-04-23 16:14:49,660:jax._src.distributed:149: Starting JAX distributed service on [::]:8482 I0423 16:14:49.660730 135432814733120 distributed.py:149] Starting JAX distributed service on [::]:8482 INFO:2026-04-23 16:14:49,662:jax._src.distributed:166: Connecting to JAX distributed service on mt-07-distill-smoke-k2i3x-slice-job-0-0.mt-07-distill-smoke-k2i3x:8482 I0423 16:14:49.662996 135432814733120 distributed.py:166] Connecting to JAX distributed service on mt-07-distill-smoke-k2i3x-slice-job-0-0.mt-07-distill-smoke-k2i3x:8482 I0423 16:14:50.849530 135432814733120 max_utils.py:284] Jax distributed system initialized! I0423 16:14:57.132709 135432814733120 max_utils.py:244] Jax distributed system is already initialized. I0423 16:14:57.622898 135432814733120 max_utils.py:244] Jax distributed system is already initialized. I0423 16:14:57.624088 135432814733120 tokenizer.py:245] Tokenizer path: meta-llama/Llama-2-7b-chat-hf I0423 16:14:57.624151 135432814733120 tokenizer.py:224] Loading HF tokenizer: meta-llama/Llama-2-7b-chat-hf I0423 16:15:01.626055 135432814733120 _schedule.py:129] A polynomial schedule was set with a non-positive `transition_steps` value; this results in a constant schedule with value `init_value`. I0423 16:15:01.629134 135432814733120 maxtext_utils.py:1631] Num_devices: 32, shape (1, 4, 1, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1) I0423 16:15:01.629257 135432814733120 train_distill.py:582] Applying logical axis rules for model initialization and training... I0423 16:15:01.629334 135432814733120 train_distill.py:586] Loading Student from ... I0423 16:15:01.629364 135432814733120 train_distill.py:169] --- Student Configuration --- I0423 16:15:01.629387 135432814733120 train_distill.py:170] Model Name: gpt3-52k I0423 16:15:01.629409 135432814733120 train_distill.py:171] Dimensions: 1 Layers, 16 Emb Dim, 8 Head Dim I0423 16:15:01.629428 135432814733120 train_distill.py:174] Attention Heads: 2 Query, 2 KV I0423 16:15:01.629446 135432814733120 train_distill.py:175] Vocab Size: 32000 I0423 16:15:01.629464 135432814733120 train_distill.py:176] Checkpoint: I0423 16:15:01.629483 135432814733120 train_distill.py:460] Initializing model: gpt3-52k... I0423 16:15:02.904985 135432814733120 train_distill.py:600] Loading Teacher from gs://lance-maxtext/pt_seed_ckpts/pt_seed_ckpts/pt_seed_ckpt_gpt352k_v32k_linen/checkpoints/4/items... I0423 16:15:02.905114 135432814733120 train_distill.py:169] --- Teacher Configuration --- I0423 16:15:02.905144 135432814733120 train_distill.py:170] Model Name: gpt3-52k I0423 16:15:02.905173 135432814733120 train_distill.py:171] Dimensions: 1 Layers, 16 Emb Dim, 8 Head Dim I0423 16:15:02.905200 135432814733120 train_distill.py:174] Attention Heads: 2 Query, 2 KV I0423 16:15:02.905219 135432814733120 train_distill.py:175] Vocab Size: 32000 I0423 16:15:02.905242 135432814733120 train_distill.py:176] Checkpoint: gs://lance-maxtext/pt_seed_ckpts/pt_seed_ckpts/pt_seed_ckpt_gpt352k_v32k_linen/checkpoints/4/items I0423 16:15:02.905260 135432814733120 train_distill.py:460] Initializing model: gpt3-52k... I0423 16:15:03.871424 135432814733120 pytree_checkpoint_handler.py:577] save_device_host_concurrent_bytes=None I0423 16:15:03.871859 135432814733120 base_pytree_checkpoint_handler.py:411] Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=True, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x7b2c3131c5f0>, enable_pinned_host_transfer=False, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB) I0423 16:15:03.871918 135432814733120 abstract_checkpointer.py:35] orbax-checkpoint version: 0.11.28 W0423 16:15:04.415654 135432814733120 checkpoint.py:202] Metadata file does not exist: gs://lance-maxtext/pt_seed_ckpts/pt_seed_ckpts/pt_seed_ckpt_gpt352k_v32k_linen/checkpoints/4/items/_CHECKPOINT_METADATA I0423 16:15:04.940065 2138 google_auth_provider.cc:181] Running on GCE, using service account 562977990677-compute@developer.gserviceaccount.com I0423 16:15:06.095719 135432814733120 checkpointer.py:304] Restoring checkpoint from gs://lance-maxtext/pt_seed_ckpts/pt_seed_ckpts/pt_seed_ckpt_gpt352k_v32k_linen/checkpoints/4/items. W0423 16:15:08.278707 135432814733120 transform_utils.py:230] The transformations API will eventually be replaced by an upgraded design. The current API will not be removed until this point, but it will no longer be actively worked on. I0423 16:15:08.279221 135432814733120 transform_utils.py:288] The following keys are not loaded from the original tree after applying specified transforms: params/params/decoder/dropout/rngs/aqt/count, params/params/decoder/dropout/rngs/aqt/key, params/params/decoder/dropout/rngs/dropout/count, params/params/decoder/dropout/rngs/dropout/key, params/params/decoder/dropout/rngs/params/count, params/params/decoder/dropout/rngs/params/key, params/params/decoder/layers/dropout/rngs/aqt/count, params/params/decoder/layers/dropout/rngs/aqt/key, params/params/decoder/layers/dropout/rngs/dropout/count, params/params/decoder/layers/dropout/rngs/dropout/key, params/params/decoder/layers/dropout/rngs/params/count, params/params/decoder/layers/dropout/rngs/params/key, params/params/decoder/layers/mlp/dropout/rngs/aqt/count, params/params/decoder/layers/mlp/dropout/rngs/aqt/key, params/params/decoder/layers/mlp/dropout/rngs/dropout/count, params/params/decoder/layers/mlp/dropout/rngs/dropout/key, params/params/decoder/layers/mlp/dropout/rngs/params/count, params/params/decoder/layers/mlp/dropout/rngs/params/key, params/params/decoder/layers/rngs/aqt/count, params/params/decoder/layers/rngs/aqt/key, params/params/decoder/layers/rngs/dropout/count, params/params/decoder/layers/rngs/dropout/key, params/params/decoder/layers/rngs/params/count, params/params/decoder/layers/rngs/params/key, params/params/decoder/rngs/aqt/count, params/params/decoder/rngs/aqt/key, params/params/decoder/rngs/dropout/count, params/params/decoder/rngs/dropout/key, params/params/decoder/rngs/params/count, params/params/decoder/rngs/params/key I0423 16:15:08.593834 135432814733120 checkpointer.py:318] Finished restoring checkpoint in 2.87 seconds from gs://lance-maxtext/pt_seed_ckpts/pt_seed_ckpts/pt_seed_ckpt_gpt352k_v32k_linen/checkpoints/4/items. I0423 16:15:09.305508 135432814733120 train_distill.py:626] Initializing Data Iterators via MaxText pipeline... I0423 16:15:09.369294 135432814733120 config.py:112] TensorFlow version 2.20.0 available. I0423 16:15:09.369881 135432814733120 config.py:125] JAX version 0.8.3 available. E0423 16:15:11.409394 135432814733120 packing.py:209] PackAndBatchOperation is deprecated. Please use lazy_dataset.FirstFitPackIterDataset instead. I0423 16:15:11.409618 135432814733120 data_loader.py:408] Adding CopyNumPyArrayToSharedMemory MapTransform. I0423 16:15:11.412721 135432814733120 train_distill.py:405] Input Pipeline Checkpointing: DISABLED I0423 16:15:11.412789 135432814733120 train_distill.py:409] Reason: Iterator 'MultiHostDataLoadIterator' is not recognized as Grain (dataset_type='DatasetType.HF', has_save=False) I0423 16:15:11.412864 135432814733120 pytree_checkpoint_handler.py:577] save_device_host_concurrent_bytes=None I0423 16:15:11.412972 135432814733120 base_pytree_checkpoint_handler.py:411] Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=False, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x7b2c3131c5f0>, enable_pinned_host_transfer=False, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB) I0423 16:15:11.413037 135432814733120 pytree_checkpoint_handler.py:577] save_device_host_concurrent_bytes=None I0423 16:15:11.413080 135432814733120 base_pytree_checkpoint_handler.py:411] Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=False, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x7b2c3131c5f0>, enable_pinned_host_transfer=False, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB) I0423 16:15:11.413165 135432814733120 checkpoint_manager.py:702] [process=6][thread=MainThread] CheckpointManager init: checkpointers=None, item_names=None, item_handlers={'model_params': <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7b1472ebb350>, 'optimizer_state': <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7b1472ebb290>, 'custom_metadata': <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7b1472ebb200>}, handler_registry=None I0423 16:15:11.413429 135432814733120 composite_checkpoint_handler.py:237] Deferred registration for item: "model_params". Adding handler `<orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7b1472ebb350>` for item "model_params" and save args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>` to `_handler_registry`. I0423 16:15:11.413485 135432814733120 composite_checkpoint_handler.py:237] Deferred registration for item: "optimizer_state". Adding handler `<orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7b1472ebb290>` for item "optimizer_state" and save args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>` to `_handler_registry`. I0423 16:15:11.413519 135432814733120 composite_checkpoint_handler.py:237] Deferred registration for item: "custom_metadata". Adding handler `<orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7b1472ebb200>` for item "custom_metadata" and save args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>` to `_handler_registry`. I0423 16:15:11.413561 135432814733120 composite_checkpoint_handler.py:237] Deferred registration for item: "metrics". Adding handler `<orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7b1472b32cf0>` for item "metrics" and save args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>` to `_handler_registry`. I0423 16:15:11.413608 135432814733120 composite_checkpoint_handler.py:505] Initialized registry DefaultCheckpointHandlerRegistry({('model_params', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7b1472ebb350>, ('model_params', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7b1472ebb350>, ('optimizer_state', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7b1472ebb290>, ('optimizer_state', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7b1472ebb290>, ('custom_metadata', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7b1472ebb200>, ('custom_metadata', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7b1472ebb200>, ('metrics', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7b1472b32cf0>, ('metrics', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7b1472b32cf0>}). I0423 16:15:11.414242 135432814733120 async_checkpointer.py:177] [process=6][thread=MainThread] Using barrier_sync_fn: <function get_barrier_sync_fn.<locals>._fn at 0x7b1472bd6200> timeout: 600 secs and primary_host=0 for async checkpoint writes I0423 16:15:13.862703 135432814733120 checkpoint_manager.py:1788] Found 0 checkpoint steps in gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_set_defaults_true_20260423_155251/pt_distill_nnx_xpk_feat_nnx_set_defaults_true_20260423_155251_07_distill_smoke/checkpoints I0423 16:15:13.951024 135432814733120 checkpoint_manager.py:921] [process=6][thread=MainThread] CheckpointManager created, primary_host=0, CheckpointManagerOptions=CheckpointManagerOptions(save_interval_steps=2000, max_to_keep=None, keep_time_interval=None, keep_period=None, should_keep_fn=None, best_fn=None, best_mode='max', keep_checkpoints_without_metrics=True, step_prefix=None, step_format_fixed_length=None, step_name_format=None, create=True, cleanup_tmp_directories=False, save_on_steps=frozenset(), single_host_load_and_broadcast=False, todelete_subdir=None, todelete_full_path=None, enable_hns=False, enable_background_delete=False, read_only=False, enable_async_checkpointing=True, async_options=None, multiprocessing_options=MultiprocessingOptions(primary_host=0, active_processes=None, barrier_sync_key_prefix=None), should_save_fn=None, file_options=FileOptions(path_permission_mode=None), save_root_metadata=True, temporary_path_class=None, save_decision_policy=None, preservation_policy=None, prevent_write_metrics=False, enable_should_save_is_saving_in_progress_check=True, enable_per_process_directory_creation=False), root_directory=gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_set_defaults_true_20260423_155251/pt_distill_nnx_xpk_feat_nnx_set_defaults_true_20260423_155251_07_distill_smoke/checkpoints: <orbax.checkpoint.checkpoint_manager.CheckpointManager object at 0x7b1472ebb1d0> I0423 16:15:13.951224 135432814733120 pytree_checkpoint_handler.py:577] save_device_host_concurrent_bytes=None I0423 16:15:13.951294 135432814733120 base_pytree_checkpoint_handler.py:411] Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=False, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x7b2c3131c5f0>, enable_pinned_host_transfer=False, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB) I0423 16:15:13.951331 135432814733120 pytree_checkpoint_handler.py:577] save_device_host_concurrent_bytes=None I0423 16:15:13.951362 135432814733120 base_pytree_checkpoint_handler.py:411] Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=False, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x7b2c3131c5f0>, enable_pinned_host_transfer=False, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB) I0423 16:15:13.951407 135432814733120 checkpoint_manager.py:1983] [process=6][thread=MainThread][wait_until_finished] No Save Finalize thread to wait for. Returning. I0423 16:15:13.951474 135432814733120 checkpoint.py:459] Closing _NonBlockingMetadataStore(enable_write=True, _write_lock=<locked _thread.RLock object owner=135432814733120 count=1 at 0x7b1472ba3540>, _store_impl=<orbax.checkpoint._src.metadata.checkpoint._MetadataStoreImpl object at 0x7b1472b32b40>, _single_thread_executor=<concurrent.futures.thread.ThreadPoolExecutor object at 0x7b1472b32b70>, _write_futures=[]) I0423 16:15:13.951890 135432814733120 checkpoint.py:459] Closing _NonBlockingMetadataStore(enable_write=True, _write_lock=<locked _thread.RLock object owner=135432814733120 count=1 at 0x7b1472ba3540>, _store_impl=<orbax.checkpoint._src.metadata.checkpoint._MetadataStoreImpl object at 0x7b1472b32b40>, _single_thread_executor=<concurrent.futures.thread.ThreadPoolExecutor object at 0x7b1472b32b70>, _write_futures=[]) I0423 16:15:13.951925 135432814733120 checkpoint.py:459] Closing _NonBlockingMetadataStore(enable_write=True, _write_lock=<locked _thread.RLock object owner=135432814733120 count=1 at 0x7b1472ba3540>, _store_impl=<orbax.checkpoint._src.metadata.checkpoint._MetadataStoreImpl object at 0x7b1472b32b40>, _single_thread_executor=<concurrent.futures.thread.ThreadPoolExecutor object at 0x7b1472b32b70>, _write_futures=[]) I0423 16:15:13.951977 135432814733120 checkpoint_manager.py:702] [process=6][thread=MainThread] CheckpointManager init: checkpointers=None, item_names=None, item_handlers={'model_params': <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7b1472ebb1a0>, 'optimizer_state': <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7b1472b33b30>, 'custom_metadata': <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7b1472b31c70>, 'iter': <maxtext.common.checkpointing.GrainCheckpointHandler object at 0x7b1472b31af0>}, handler_registry=None I0423 16:15:13.952144 135432814733120 composite_checkpoint_handler.py:237] Deferred registration for item: "model_params". Adding handler `<orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7b1472ebb1a0>` for item "model_params" and save args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>` to `_handler_registry`. I0423 16:15:13.952208 135432814733120 composite_checkpoint_handler.py:237] Deferred registration for item: "optimizer_state". Adding handler `<orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7b1472b33b30>` for item "optimizer_state" and save args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>` to `_handler_registry`. I0423 16:15:13.952252 135432814733120 composite_checkpoint_handler.py:237] Deferred registration for item: "custom_metadata". Adding handler `<orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7b1472b31c70>` for item "custom_metadata" and save args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>` to `_handler_registry`. I0423 16:15:13.952299 135432814733120 composite_checkpoint_handler.py:237] Deferred registration for item: "iter". Adding handler `<maxtext.common.checkpointing.GrainCheckpointHandler object at 0x7b1472b31af0>` for item "iter" and save args `<class 'maxtext.common.checkpointing.GrainCheckpointSave'>` and restore args `<class 'maxtext.common.checkpointing.GrainCheckpointRestore'>` to `_handler_registry`. I0423 16:15:13.952338 135432814733120 composite_checkpoint_handler.py:237] Deferred registration for item: "metrics". Adding handler `<orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7b1472b31850>` for item "metrics" and save args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>` to `_handler_registry`. I0423 16:15:13.952379 135432814733120 composite_checkpoint_handler.py:505] Initialized registry DefaultCheckpointHandlerRegistry({('model_params', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7b1472ebb1a0>, ('model_params', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7b1472ebb1a0>, ('optimizer_state', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7b1472b33b30>, ('optimizer_state', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7b1472b33b30>, ('custom_metadata', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7b1472b31c70>, ('custom_metadata', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7b1472b31c70>, ('iter', <class 'maxtext.common.checkpointing.GrainCheckpointSave'>): <maxtext.common.checkpointing.GrainCheckpointHandler object at 0x7b1472b31af0>, ('iter', <class 'maxtext.common.checkpointing.GrainCheckpointRestore'>): <maxtext.common.checkpointing.GrainCheckpointHandler object at 0x7b1472b31af0>, ('metrics', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7b1472b31850>, ('metrics', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7b1472b31850>}). I0423 16:15:13.952489 135432814733120 async_checkpointer.py:177] [process=6][thread=MainThread] Using barrier_sync_fn: <function get_barrier_sync_fn.<locals>._fn at 0x7b1472bd6340> timeout: 600 secs and primary_host=0 for async checkpoint writes I0423 16:15:14.655926 135432814733120 checkpoint_manager.py:1788] Found 0 checkpoint steps in gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_set_defaults_true_20260423_155251/pt_distill_nnx_xpk_feat_nnx_set_defaults_true_20260423_155251_07_distill_smoke/checkpoints I0423 16:15:15.093207 135432814733120 checkpoint_manager.py:921] [process=6][thread=MainThread] CheckpointManager created, primary_host=0, CheckpointManagerOptions=CheckpointManagerOptions(save_interval_steps=2000, max_to_keep=None, keep_time_interval=None, keep_period=None, should_keep_fn=None, best_fn=None, best_mode='max', keep_checkpoints_without_metrics=True, step_prefix=None, step_format_fixed_length=None, step_name_format=None, create=True, cleanup_tmp_directories=False, save_on_steps=frozenset(), single_host_load_and_broadcast=False, todelete_subdir=None, todelete_full_path=None, enable_hns=False, enable_background_delete=False, read_only=False, enable_async_checkpointing=True, async_options=None, multiprocessing_options=MultiprocessingOptions(primary_host=0, active_processes=None, barrier_sync_key_prefix=None), should_save_fn=None, file_options=FileOptions(path_permission_mode=None), save_root_metadata=True, temporary_path_class=None, save_decision_policy=None, preservation_policy=None, prevent_write_metrics=False, enable_should_save_is_saving_in_progress_check=True, enable_per_process_directory_creation=False), root_directory=gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_set_defaults_true_20260423_155251/pt_distill_nnx_xpk_feat_nnx_set_defaults_true_20260423_155251_07_distill_smoke/checkpoints: <orbax.checkpoint.checkpoint_manager.CheckpointManager object at 0x7b1472b332c0> I0423 16:15:15.093440 135432814733120 train_distill.py:673] Starting Distillation Training... I0423 16:15:15.093549 135432814733120 peft_trainer.py:590] Training with mesh: Mesh('diloco': 1, 'data': 4, 'stage': 1, 'fsdp': 8, 'fsdp_transpose': 1, 'sequence': 1, 'context': 1, 'context_autoregressive': 1, 'tensor': 1, 'tensor_transpose': 1, 'tensor_sequence': 1, 'expert': 1, 'autoregressive': 1, axis_types=(Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto)) I0423 16:15:15.223920 135432814733120 peft_trainer.py:600] Compiled train_step cache size: 0 Training: 0%| | 0/5 [00:00<?, ?step/s]I0423 16:15:15.225855 135289079047936 grain_pool.py:367] Grain pool will use 1 processes. I0423 16:15:15.252404 135289079047936 grain_pool.py:440] Grain pool will start child processes. I0423 16:15:15.257601 135289079047936 grain_pool.py:448] Grain pool started all child processes. 2026-04-23 16:15:21.324429: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303) I0423 16:15:24.700391 135432814733120 utils.py:86] Train loop finished in: 9.4757 seconds Traceback (most recent call last): File "<frozen runpy>", line 198, in _run_module_as_main File "<frozen runpy>", line 88, in _run_code File "/deps/src/maxtext/trainers/post_train/distillation/train_distill.py", line 747, in <module> app.run(main) File "/usr/local/lib/python3.12/site-packages/absl/app.py", line 316, in run _run_main(main, args) File "/usr/local/lib/python3.12/site-packages/absl/app.py", line 261, in _run_main sys.exit(main(argv)) ^^^^^^^^^^ File "/deps/src/maxtext/trainers/post_train/distillation/train_distill.py", line 743, in main train_distill(student_config, teacher_config, is_offline, global_config.offline_data_dir) File "/deps/src/maxtext/trainers/post_train/distillation/train_distill.py", line 675, in train_distill trainer.train(train_iter, eval_iter) File "/usr/local/lib/python3.12/site-packages/tunix/sft/peft_trainer.py", line 659, in train train_example = sharding_utils.shard_input( ^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/site-packages/tunix/sft/sharding_utils.py", line 58, in shard_input return jax.tree.map( ^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/site-packages/jax/_src/tree.py", line 155, in map return tree_util.tree_map(f, tree, *rest, is_leaf=is_leaf) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/site-packages/jax/_src/tree_util.py", line 362, in tree_map return treedef.unflatten(f(*xs) for xs in zip(*all_leaves)) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/site-packages/jax/_src/tree_util.py", line 362, in <genexpr> return treedef.unflatten(f(*xs) for xs in zip(*all_leaves)) ^^^^^^ File "/usr/local/lib/python3.12/site-packages/tunix/sft/sharding_utils.py", line 59, in <lambda> lambda x: jax.make_array_from_process_local_data( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/site-packages/jax/_src/array.py", line 986, in make_array_from_process_local_data out = [_array_from_process_local_data(data, s, shape) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/site-packages/jax/_src/array.py", line 1048, in _array_from_process_local_data return make_array_from_callback(global_shape, sharding, cb) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/site-packages/jax/_src/array.py", line 845, in make_array_from_callback per_device_values = api.device_put(per_device_values, devices) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/site-packages/jax/_src/api.py", line 2729, in device_put out_flat = dispatch._batched_device_put_impl( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/site-packages/jax/_src/dispatch.py", line 558, in _batched_device_put_impl y = _device_put_impl(x, device=device, src=src, copy=cp, aval=aval) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/site-packages/jax/_src/dispatch.py", line 545, in _device_put_impl return _device_put_sharding_impl(x, aval, device, copy) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/site-packages/jax/_src/dispatch.py", line 487, in _device_put_sharding_impl raise ValueError( ValueError: device_put's first argument must be a fully addressable array, but got value with devices {TpuDevice(id=13, process_index=2, coords=(1,3,0), core_on_chip=0), TpuDevice(id=27, process_index=7, coords=(3,6,0), core_on_chip=0), TpuDevice(id=21, process_index=4, coords=(1,5,0), core_on_chip=0), TpuDevice(id=2, process_index=1, coords=(2,0,0), core_on_chip=0), TpuDevice(id=23, process_index=5, coords=(3,5,0), core_on_chip=0), TpuDevice(id=8, process_index=2, coords=(0,2,0), core_on_chip=0), TpuDevice(id=17, process_index=4, coords=(1,4,0), core_on_chip=0), TpuDevice(id=10, process_index=3, coords=(2,2,0), core_on_chip=0), TpuDevice(id=14, process_index=3, coords=(2,3,0), core_on_chip=0), TpuDevice(id=22, process_index=5, coords=(2,5,0), core_on_chip=0), TpuDevice(id=6, process_index=1, coords=(2,1,0), core_on_chip=0), TpuDevice(id=26, process_index=7, coords=(2,6,0), core_on_chip=0), TpuDevice(id=12, process_index=2, coords=(0,3,0), core_on_chip=0), TpuDevice(id=25, process_index=6, coords=(1,6,0), core_on_chip=0), TpuDevice(id=7, process_index=1, coords=(3,1,0), core_on_chip=0), TpuDevice(id=11, process_index=3, coords=(3,2,0), core_on_chip=0), TpuDevice(id=31, process_index=7, coords=(3,7,0), core_on_chip=0), TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=3, process_index=1, coords=(3,0,0), core_on_chip=0), TpuDevice(id=24, process_index=6, coords=(0,6,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=30, process_index=7, coords=(2,7,0), core_on_chip=0), TpuDevice(id=20, process_index=4, coords=(0,5,0), core_on_chip=0), TpuDevice(id=28, process_index=6, coords=(0,7,0), core_on_chip=0), TpuDevice(id=9, process_index=2, coords=(1,2,0), core_on_chip=0), TpuDevice(id=18, process_index=5, coords=(2,4,0), core_on_chip=0), TpuDevice(id=15, process_index=3, coords=(3,3,0), core_on_chip=0), TpuDevice(id=19, process_index=5, coords=(3,4,0), core_on_chip=0), TpuDevice(id=29, process_index=6, coords=(1,7,0), core_on_chip=0), TpuDevice(id=16, process_index=4, coords=(0,4,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0)} I0423 16:15:25.042683 135289079047936 grain_pool.py:542] Grain pool is exiting. I0423 16:15:25.042787 135289079047936 grain_pool.py:547] Shutting down multiprocessing system. I0423 16:15:26.507450 135289079047936 grain_pool.py:547] Shutting down multiprocessing system. Training: 0%| | 0/5 [00:13<?, ?step/s] /usr/local/lib/python3.12/multiprocessing/resource_tracker.py:279: UserWarning: resource_tracker: There appear to be 15 leaked shared_memory objects to clean up at shutdown warnings.warn('resource_tracker: There appear to be %d ' XPK End: Thu Apr 23 16:15:35 UTC 2026 EXIT_CODE=1