MaxView

← Back to run

Log Summary

XPK Start: Thu Apr 23 16:09:16 UTC 2026
2026-04-23 16:09:44.837755: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)
I0423 16:09:45.061774 134150966765376 max_utils.py:273] Attempting to initialize the jax distributed system...
INFO:2026-04-23 16:09:54,104:jax._src.distributed:149: Starting JAX distributed service on [::]:8482
I0423 16:09:54.104011 134150966765376 distributed.py:149] Starting JAX distributed service on [::]:8482
INFO:2026-04-23 16:09:54,106:jax._src.distributed:166: Connecting to JAX distributed service on mt-02-sft-nnx-ckpt-kknlj-slice-job-0-0.mt-02-sft-nnx-ckpt-kknlj:8482
I0423 16:09:54.106371 134150966765376 distributed.py:166] Connecting to JAX distributed service on mt-02-sft-nnx-ckpt-kknlj-slice-job-0-0.mt-02-sft-nnx-ckpt-kknlj:8482
I0423 16:09:55.871679 134150966765376 max_utils.py:284] Jax distributed system initialized!
I0423 16:10:01.989294 134150966765376 max_utils.py:800] System Information: Jax Version: 0.8.3
I0423 16:10:01.989401 134150966765376 max_utils.py:801] System Information: Jaxlib Version: 0.8.3
I0423 16:10:01.989442 134150966765376 max_utils.py:802] System Information: Jax Backend: PJRT C API
TFRT TPU v6 lite
Built on Dec 15 2025 14:03:46 (1765836226) cl/844590465
I0423 16:10:01.992839 134150966765376 maxtext_utils.py:1631] Num_devices: 32, shape (1, 4, 1, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1)
I0423 16:10:02.087722 134150966765376 maxtext_utils.py:1631] Num_devices: 32, shape (1, 4, 1, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1)
I0423 16:10:02.188205 134150966765376 maxtext_utils.py:1631] Num_devices: 32, shape (1, 4, 1, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1)
I0423 16:10:03.236536 134150966765376 pytree_checkpoint_handler.py:577] save_device_host_concurrent_bytes=None
I0423 16:10:03.236992 134150966765376 base_pytree_checkpoint_handler.py:411] Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=True, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x7a01bf0f2270>, enable_pinned_host_transfer=False, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB)
I0423 16:10:03.237052 134150966765376 abstract_checkpointer.py:35] orbax-checkpoint version: 0.11.28
W0423 16:10:03.812139 134150966765376 checkpoint.py:202] Metadata file does not exist: gs://lance-maxtext/nnx_ckpt_feat_nnx_trainstate_and_training_loop_20260411_044231/nnx_ckpt_feat_nnx_trainstate_and_training_loop_20260411_044231/nnx_feat_nnx_trainstate_and_training_loop_20260411_044231_08_checkpoint_async_true/checkpoints/9/items/_CHECKPOINT_METADATA
I0423 16:10:04.371950    1932 google_auth_provider.cc:181] Running on GCE, using service account 562977990677-compute@developer.gserviceaccount.com
I0423 16:10:05.603973 134150966765376 checkpointer.py:304] Restoring checkpoint from gs://lance-maxtext/nnx_ckpt_feat_nnx_trainstate_and_training_loop_20260411_044231/nnx_ckpt_feat_nnx_trainstate_and_training_loop_20260411_044231/nnx_feat_nnx_trainstate_and_training_loop_20260411_044231_08_checkpoint_async_true/checkpoints/9/items.
W0423 16:10:06.537312 134150966765376 transform_utils.py:230] The transformations API will eventually be replaced by an upgraded design. The current API will not be removed until this point, but it will no longer be actively worked on.
I0423 16:10:06.537716 134150966765376 transform_utils.py:288] The following keys are not loaded from the original tree after applying specified transforms: decoder/decoder_norm/bias/value, decoder/decoder_norm/scale/value, decoder/dropout/rngs/aqt/count/value, decoder/dropout/rngs/aqt/key/value, decoder/dropout/rngs/dropout/count/value, decoder/dropout/rngs/dropout/key/value, decoder/dropout/rngs/params/count/value, decoder/dropout/rngs/params/key/value, decoder/layers/dropout/rngs/aqt/count/value, decoder/layers/dropout/rngs/aqt/key/value, decoder/layers/dropout/rngs/dropout/count/value, decoder/layers/dropout/rngs/dropout/key/value, decoder/layers/dropout/rngs/params/count/value, decoder/layers/dropout/rngs/params/key/value, decoder/layers/mlp/dropout/rngs/aqt/count/value, decoder/layers/mlp/dropout/rngs/aqt/key/value, decoder/layers/mlp/dropout/rngs/dropout/count/value, decoder/layers/mlp/dropout/rngs/dropout/key/value, decoder/layers/mlp/dropout/rngs/params/count/value, decoder/layers/mlp/dropout/rngs/params/key/value, decoder/layers/mlp/mlp_layer_norm/bias/value, decoder/layers/mlp/mlp_layer_norm/scale/value, decoder/layers/mlp/wi/bias/value, decoder/layers/mlp/wi/kernel/value, decoder/layers/mlp/wo/bias/value, decoder/layers/mlp/wo/kernel/value, decoder/layers/pre_self_attention_norm/bias/value, decoder/layers/pre_self_attention_norm/scale/value, decoder/layers/rngs/aqt/count/value, decoder/layers/rngs/aqt/key/value, decoder/layers/rngs/dropout/count/value, decoder/layers/rngs/dropout/key/value, decoder/layers/rngs/params/count/value, decoder/layers/rngs/params/key/value, decoder/layers/self_attention/out/bias/value, decoder/layers/self_attention/out/kernel/value, decoder/layers/self_attention/qkv_proj/bias/value, decoder/layers/self_attention/qkv_proj/kernel/value, decoder/position_embedder/embedding/value, decoder/rngs/aqt/count/value, decoder/rngs/aqt/key/value, decoder/rngs/dropout/count/value, decoder/rngs/dropout/key/value, decoder/rngs/params/count/value, decoder/rngs/params/key/value, token_embedder/embedding/value
I0423 16:10:07.035330 134150966765376 checkpointer.py:318] Finished restoring checkpoint in 1.80 seconds from gs://lance-maxtext/nnx_ckpt_feat_nnx_trainstate_and_training_loop_20260411_044231/nnx_ckpt_feat_nnx_trainstate_and_training_loop_20260411_044231/nnx_feat_nnx_trainstate_and_training_loop_20260411_044231_08_checkpoint_async_true/checkpoints/9/items.
I0423 16:10:07.103669 134150966765376 config.py:112] TensorFlow version 2.20.0 available.
I0423 16:10:07.104200 134150966765376 config.py:125] JAX version 0.8.3 available.
/deps/src/maxtext/input_pipeline/input_pipeline_utils.py:467: UserWarning: WARNING: Inefficient dataloading. Your train or eval dataset contains 3 shards, smaller than number of host loading data. This is known to lead to inefficient dataloading. Seegithub.com/google/maxtext/blob/main/getting_started/Data_Input_Pipeline.md#multihost-dataloading-best-practice
  warnings.warn(
E0423 16:10:12.610219 134150966765376 packing.py:209] PackAndBatchOperation is deprecated. Please use lazy_dataset.FirstFitPackIterDataset instead.
I0423 16:10:12.610439 134150966765376 data_loader.py:408] Adding CopyNumPyArrayToSharedMemory MapTransform.
I0423 16:10:13.001979 134150966765376 pytree_checkpoint_handler.py:577] save_device_host_concurrent_bytes=None
I0423 16:10:13.002135 134150966765376 base_pytree_checkpoint_handler.py:411] Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=False, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x7a01bf0f2270>, enable_pinned_host_transfer=False, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB)
I0423 16:10:13.002182 134150966765376 pytree_checkpoint_handler.py:577] save_device_host_concurrent_bytes=None
I0423 16:10:13.002218 134150966765376 base_pytree_checkpoint_handler.py:411] Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=False, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x7a01bf0f2270>, enable_pinned_host_transfer=False, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB)
I0423 16:10:13.002261 134150966765376 checkpoint_manager.py:702] [process=6][thread=MainThread] CheckpointManager init: checkpointers=None, item_names=None, item_handlers={'model_params': <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x79e95b366c60>, 'optimizer_state': <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x79e95b8f69c0>, 'custom_metadata': <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x79e95a52d580>}, handler_registry=None
I0423 16:10:13.002474 134150966765376 composite_checkpoint_handler.py:237] Deferred registration for item: "model_params". Adding handler `<orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x79e95b366c60>` for item "model_params" and save args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>` to `_handler_registry`.
I0423 16:10:13.002518 134150966765376 composite_checkpoint_handler.py:237] Deferred registration for item: "optimizer_state". Adding handler `<orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x79e95b8f69c0>` for item "optimizer_state" and save args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>` to `_handler_registry`.
I0423 16:10:13.002551 134150966765376 composite_checkpoint_handler.py:237] Deferred registration for item: "custom_metadata". Adding handler `<orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x79e95a52d580>` for item "custom_metadata" and save args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>` to `_handler_registry`.
I0423 16:10:13.002578 134150966765376 composite_checkpoint_handler.py:237] Deferred registration for item: "metrics". Adding handler `<orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x79e820569dc0>` for item "metrics" and save args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>` to `_handler_registry`.
I0423 16:10:13.002607 134150966765376 composite_checkpoint_handler.py:505] Initialized registry DefaultCheckpointHandlerRegistry({('model_params', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x79e95b366c60>, ('model_params', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x79e95b366c60>, ('optimizer_state', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x79e95b8f69c0>, ('optimizer_state', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x79e95b8f69c0>, ('custom_metadata', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x79e95a52d580>, ('custom_metadata', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x79e95a52d580>, ('metrics', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x79e820569dc0>, ('metrics', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x79e820569dc0>}).
I0423 16:10:13.004119 134150966765376 async_checkpointer.py:177] [process=6][thread=MainThread] Using barrier_sync_fn: <function get_barrier_sync_fn.<locals>._fn at 0x79e95a54aca0> timeout: 600 secs and primary_host=0 for async checkpoint writes
I0423 16:10:15.418348 134150966765376 checkpoint_manager.py:1788] Found 0 checkpoint steps in gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_set_defaults_true_20260423_155251/pt_sft_nnx_xpk_feat_nnx_set_defaults_true_20260423_155251_02_sft_nnx_ckpt/checkpoints
I0423 16:10:15.442225 134150966765376 checkpoint_manager.py:921] [process=6][thread=MainThread] CheckpointManager created,  primary_host=0, CheckpointManagerOptions=CheckpointManagerOptions(save_interval_steps=10000, max_to_keep=None, keep_time_interval=None, keep_period=None, should_keep_fn=None, best_fn=None, best_mode='max', keep_checkpoints_without_metrics=True, step_prefix=None, step_format_fixed_length=None, step_name_format=None, create=True, cleanup_tmp_directories=False, save_on_steps=frozenset(), single_host_load_and_broadcast=False, todelete_subdir=None, todelete_full_path=None, enable_hns=False, enable_background_delete=False, read_only=False, enable_async_checkpointing=True, async_options=None, multiprocessing_options=MultiprocessingOptions(primary_host=0, active_processes=None, barrier_sync_key_prefix=None), should_save_fn=None, file_options=FileOptions(path_permission_mode=None), save_root_metadata=True, temporary_path_class=None, save_decision_policy=None, preservation_policy=None, prevent_write_metrics=False, enable_should_save_is_saving_in_progress_check=True, enable_per_process_directory_creation=False), root_directory=gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_set_defaults_true_20260423_155251/pt_sft_nnx_xpk_feat_nnx_set_defaults_true_20260423_155251_02_sft_nnx_ckpt/checkpoints: <orbax.checkpoint.checkpoint_manager.CheckpointManager object at 0x79e95a52f0b0>
I0423 16:10:15.442515 134150966765376 peft_trainer.py:590] Training with mesh: Mesh('diloco': 1, 'data': 4, 'stage': 1, 'fsdp': 8, 'fsdp_transpose': 1, 'sequence': 1, 'context': 1, 'context_autoregressive': 1, 'tensor': 1, 'tensor_transpose': 1, 'tensor_sequence': 1, 'expert': 1, 'autoregressive': 1, axis_types=(Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto))
I0423 16:10:15.859174 134150966765376 peft_trainer.py:600] Compiled train_step cache size: 0

Training:   0%|          | 0/5 [00:00<?, ?step/s]I0423 16:10:15.863509 134150966765376 metric_logger.py:289] number parameters: 0.000 billion
I0423 16:10:15.865889 133998340728576 grain_pool.py:367] Grain pool will use 1 processes.
I0423 16:10:15.892387 133998340728576 grain_pool.py:440] Grain pool will start child processes.
Per train step:
 Total TFLOPs: 0.00 
 split as 54.29% learnable weight flops and 45.71% attention flops
I0423 16:10:15.897648 133998340728576 grain_pool.py:448] Grain pool started all child processes.
2026-04-23 16:10:19.836998: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-04-23 16:10:19.881623: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-04-23 16:10:20.872807: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-04-23 16:10:25.105339: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)
I0423 16:10:31.026971 134150966765376 utils.py:86] Train loop finished in: 15.1624 seconds
Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/deps/src/maxtext/trainers/post_train/sft/train_sft.py", line 216, in <module>
    app.run(main)
  File "/usr/local/lib/python3.12/site-packages/absl/app.py", line 316, in run
    _run_main(main, args)
  File "/usr/local/lib/python3.12/site-packages/absl/app.py", line 261, in _run_main
    sys.exit(main(argv))
             ^^^^^^^^^^
  File "/deps/src/maxtext/trainers/post_train/sft/train_sft.py", line 212, in main
    train(mt_config, goodput_recorder)
  File "/deps/src/maxtext/trainers/post_train/sft/train_sft.py", line 189, in train
    trainer = train_model(mt_config, trainer, mesh)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/deps/src/maxtext/trainers/post_train/sft/train_sft.py", line 175, in train_model
    trainer.train(trainer.data_hooks.train_data_iterator, trainer.data_hooks.eval_data_iterator)
  File "/usr/local/lib/python3.12/site-packages/tunix/sft/peft_trainer.py", line 659, in train
    train_example = sharding_utils.shard_input(
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/tunix/sft/sharding_utils.py", line 58, in shard_input
    return jax.tree.map(
           ^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/tree.py", line 155, in map
    return tree_util.tree_map(f, tree, *rest, is_leaf=is_leaf)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/tree_util.py", line 362, in tree_map
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/tree_util.py", line 362, in <genexpr>
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
                             ^^^^^^
  File "/usr/local/lib/python3.12/site-packages/tunix/sft/sharding_utils.py", line 59, in <lambda>
    lambda x: jax.make_array_from_process_local_data(
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/array.py", line 986, in make_array_from_process_local_data
    out = [_array_from_process_local_data(data, s, shape)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/array.py", line 1048, in _array_from_process_local_data
    return make_array_from_callback(global_shape, sharding, cb)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/array.py", line 845, in make_array_from_callback
    per_device_values = api.device_put(per_device_values, devices)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/api.py", line 2729, in device_put
    out_flat = dispatch._batched_device_put_impl(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/dispatch.py", line 558, in _batched_device_put_impl
    y = _device_put_impl(x, device=device, src=src, copy=cp, aval=aval)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/dispatch.py", line 545, in _device_put_impl
    return _device_put_sharding_impl(x, aval, device, copy)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/dispatch.py", line 487, in _device_put_sharding_impl
    raise ValueError(
ValueError: device_put's first argument must be a fully addressable array, but got value with devices {TpuDevice(id=29, process_index=6, coords=(1,7,0), core_on_chip=0), TpuDevice(id=21, process_index=4, coords=(1,5,0), core_on_chip=0), TpuDevice(id=15, process_index=3, coords=(3,3,0), core_on_chip=0), TpuDevice(id=23, process_index=5, coords=(3,5,0), core_on_chip=0), TpuDevice(id=26, process_index=7, coords=(2,6,0), core_on_chip=0), TpuDevice(id=9, process_index=2, coords=(1,2,0), core_on_chip=0), TpuDevice(id=12, process_index=2, coords=(0,3,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=10, process_index=3, coords=(2,2,0), core_on_chip=0), TpuDevice(id=16, process_index=4, coords=(0,4,0), core_on_chip=0), TpuDevice(id=24, process_index=6, coords=(0,6,0), core_on_chip=0), TpuDevice(id=18, process_index=5, coords=(2,4,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=17, process_index=4, coords=(1,4,0), core_on_chip=0), TpuDevice(id=25, process_index=6, coords=(1,6,0), core_on_chip=0), TpuDevice(id=11, process_index=3, coords=(3,2,0), core_on_chip=0), TpuDevice(id=7, process_index=1, coords=(3,1,0), core_on_chip=0), TpuDevice(id=19, process_index=5, coords=(3,4,0), core_on_chip=0), TpuDevice(id=13, process_index=2, coords=(1,3,0), core_on_chip=0), TpuDevice(id=3, process_index=1, coords=(3,0,0), core_on_chip=0), TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=27, process_index=7, coords=(3,6,0), core_on_chip=0), TpuDevice(id=20, process_index=4, coords=(0,5,0), core_on_chip=0), TpuDevice(id=28, process_index=6, coords=(0,7,0), core_on_chip=0), TpuDevice(id=14, process_index=3, coords=(2,3,0), core_on_chip=0), TpuDevice(id=22, process_index=5, coords=(2,5,0), core_on_chip=0), TpuDevice(id=30, process_index=7, coords=(2,7,0), core_on_chip=0), TpuDevice(id=31, process_index=7, coords=(3,7,0), core_on_chip=0), TpuDevice(id=2, process_index=1, coords=(2,0,0), core_on_chip=0), TpuDevice(id=8, process_index=2, coords=(0,2,0), core_on_chip=0), TpuDevice(id=6, process_index=1, coords=(2,1,0), core_on_chip=0)}
I0423 16:10:31.373426 133998340728576 grain_pool.py:542] Grain pool is exiting.
I0423 16:10:31.373527 133998340728576 grain_pool.py:547] Shutting down multiprocessing system.
I0423 16:10:37.250508 133998340728576 grain_pool.py:547] Shutting down multiprocessing system.

Training:   0%|          | 0/5 [00:25<?, ?step/s]
/usr/local/lib/python3.12/multiprocessing/resource_tracker.py:279: UserWarning: resource_tracker: There appear to be 20 leaked shared_memory objects to clean up at shutdown
  warnings.warn('resource_tracker: There appear to be %d '
XPK End: Thu Apr 23 16:10:47 UTC 2026
EXIT_CODE=1