MaxView

← Back to run

Log Summary

XPK Start: Fri Apr 24 07:09:54 UTC 2026
`rope_parameters`'s factor field must be a float >= 1, got 40
`rope_parameters`'s beta_fast field must be a float, got 32
`rope_parameters`'s beta_slow field must be a float, got 1
DeepseekV32Config got `key=rope_scaling` in kwargs but hasn't set it as attribute. For RoPE standardization you need to set `self.rope_parameters` in model's config. 
2026-04-24 07:10:27.170313: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)
I0424 07:10:27.370136 132122617874240 max_utils.py:273] Attempting to initialize the jax distributed system...
I0424 07:10:36.407311 132122617874240 distributed.py:149] Starting JAX distributed service on [::]:8482
I0424 07:10:36.409664 132122617874240 distributed.py:172] Connecting to JAX distributed service on mt-02-sft-linen-ckpt-a4ux3-slice-job-0-0.mt-02-sft-linen-ckpt-a4ux3:8482
I0424 07:10:36.880113 132122617874240 max_utils.py:284] Jax distributed system initialized!
I0424 07:10:43.140167 132122617874240 max_utils.py:800] System Information: Jax Version: 0.9.2
I0424 07:10:43.140271 132122617874240 max_utils.py:801] System Information: Jaxlib Version: 0.9.2
I0424 07:10:43.140310 132122617874240 max_utils.py:802] System Information: Jax Backend: PJRT C API
TFRT TPU v6 lite
Built on Apr 6 2026 20:48:10 (1775533690) cl/895581894
I0424 07:10:43.143757 132122617874240 maxtext_utils.py:1604] Num_devices: 32, shape (1, 4, 1, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1)
I0424 07:10:43.804006 132122617874240 maxtext_utils.py:1604] Num_devices: 32, shape (1, 4, 1, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1)
I0424 07:10:44.919630 132122617874240 pytree_checkpoint_handler.py:577] save_device_host_concurrent_bytes=None
I0424 07:10:44.920157 132122617874240 base_pytree_checkpoint_handler.py:411] Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=True, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x78297b67c1d0>, enable_pinned_host_transfer=False, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB)
I0424 07:10:44.920221 132122617874240 abstract_checkpointer.py:35] orbax-checkpoint version: 0.11.28
W0424 07:10:45.467833 132122617874240 checkpoint.py:202] Metadata file does not exist: gs://lance-maxtext/pt_seed_ckpts/pt_seed_ckpts/pt_seed_ckpt_gpt352k_linen/checkpoints/9/items/_CHECKPOINT_METADATA
I0424 07:10:46.105762    1939 google_auth_provider.cc:181] Running on GCE, using service account 562977990677-compute@developer.gserviceaccount.com
I0424 07:10:47.252435 132122617874240 checkpointer.py:304] Restoring checkpoint from gs://lance-maxtext/pt_seed_ckpts/pt_seed_ckpts/pt_seed_ckpt_gpt352k_linen/checkpoints/9/items.
W0424 07:10:49.080025 132122617874240 transform_utils.py:230] The transformations API will eventually be replaced by an upgraded design. The current API will not be removed until this point, but it will no longer be actively worked on.
I0424 07:10:49.080400 132122617874240 transform_utils.py:288] The following keys are not loaded from the original tree after applying specified transforms: params/params/decoder/to_nnx__rngs/aqt/count, params/params/decoder/to_nnx__rngs/aqt/key, params/params/decoder/to_nnx__rngs/dropout/count, params/params/decoder/to_nnx__rngs/dropout/key, params/params/decoder/to_nnx__rngs/params/count, params/params/decoder/to_nnx__rngs/params/key
I0424 07:10:49.597837 132122617874240 checkpointer.py:318] Finished restoring checkpoint in 2.71 seconds from gs://lance-maxtext/pt_seed_ckpts/pt_seed_ckpts/pt_seed_ckpt_gpt352k_linen/checkpoints/9/items.
I0424 07:10:49.669801 132122617874240 config.py:112] TensorFlow version 2.20.0 available.
I0424 07:10:49.670294 132122617874240 config.py:125] JAX version 0.9.2 available.
I0424 07:10:50.092720 132122617874240 _client.py:1025] HTTP Request: HEAD https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k/resolve/main/README.md "HTTP/1.1 307 Temporary Redirect"
I0424 07:10:50.101165 132122617874240 _client.py:1025] HTTP Request: HEAD https://huggingface.co/api/resolve-cache/datasets/HuggingFaceH4/ultrachat_200k/8049631c405ae6576f93f445c6b8166f76f5505a/README.md "HTTP/1.1 200 OK"
I0424 07:10:50.110310 132122617874240 _client.py:1025] HTTP Request: GET https://huggingface.co/api/resolve-cache/datasets/HuggingFaceH4/ultrachat_200k/8049631c405ae6576f93f445c6b8166f76f5505a/README.md "HTTP/1.1 200 OK"
I0424 07:10:50.221190 132122617874240 _client.py:1025] HTTP Request: HEAD https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k/resolve/8049631c405ae6576f93f445c6b8166f76f5505a/ultrachat_200k.py "HTTP/1.1 404 Not Found"
I0424 07:10:50.524397 132122617874240 _client.py:1025] HTTP Request: HEAD https://s3.amazonaws.com/datasets.huggingface.co/datasets/datasets/HuggingFaceH4/ultrachat_200k/HuggingFaceH4/ultrachat_200k.py "HTTP/1.1 404 Not Found"
I0424 07:10:50.635056 132122617874240 _client.py:1025] HTTP Request: GET https://huggingface.co/api/datasets/HuggingFaceH4/ultrachat_200k/revision/8049631c405ae6576f93f445c6b8166f76f5505a "HTTP/1.1 200 OK"
I0424 07:10:50.745300 132122617874240 _client.py:1025] HTTP Request: HEAD https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k/resolve/8049631c405ae6576f93f445c6b8166f76f5505a/.huggingface.yaml "HTTP/1.1 404 Not Found"
I0424 07:10:50.908565 132122617874240 _client.py:1025] HTTP Request: GET https://datasets-server.huggingface.co/info?dataset=HuggingFaceH4/ultrachat_200k "HTTP/1.1 200 OK"
I0424 07:10:51.016268 132122617874240 _client.py:1025] HTTP Request: GET https://huggingface.co/api/datasets/HuggingFaceH4/ultrachat_200k/tree/8049631c405ae6576f93f445c6b8166f76f5505a/data?recursive=true&expand=false "HTTP/1.1 200 OK"
I0424 07:10:51.122364 132122617874240 _client.py:1025] HTTP Request: GET https://huggingface.co/api/datasets/HuggingFaceH4/ultrachat_200k/tree/8049631c405ae6576f93f445c6b8166f76f5505a?recursive=false&expand=false "HTTP/1.1 200 OK"
I0424 07:10:51.245379 132122617874240 _client.py:1025] HTTP Request: HEAD https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k/resolve/8049631c405ae6576f93f445c6b8166f76f5505a/dataset_infos.json "HTTP/1.1 404 Not Found"
I0424 07:10:51.425088 132122617874240 _client.py:1025] HTTP Request: HEAD https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/resolve/main/config.json "HTTP/1.1 200 OK"
I0424 07:10:51.536775 132122617874240 _client.py:1025] HTTP Request: GET https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/resolve/main/config.json "HTTP/1.1 200 OK"
I0424 07:10:51.646545 132122617874240 _client.py:1025] HTTP Request: HEAD https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/resolve/main/tokenizer_config.json "HTTP/1.1 200 OK"
I0424 07:10:51.752075 132122617874240 _client.py:1025] HTTP Request: GET https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/resolve/main/tokenizer_config.json "HTTP/1.1 200 OK"
I0424 07:10:51.862768 132122617874240 _client.py:1025] HTTP Request: GET https://huggingface.co/api/models/meta-llama/Llama-2-7b-chat-hf/tree/main/additional_chat_templates?recursive=false&expand=false "HTTP/1.1 404 Not Found"
I0424 07:10:51.967383 132122617874240 _client.py:1025] HTTP Request: GET https://huggingface.co/api/models/meta-llama/Llama-2-7b-chat-hf/tree/main?recursive=true&expand=false "HTTP/1.1 200 OK"
I0424 07:10:52.081294 132122617874240 _client.py:1025] HTTP Request: HEAD https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/resolve/main/tokenizer.model "HTTP/1.1 302 Found"
I0424 07:10:52.193705 132122617874240 _client.py:1025] HTTP Request: GET https://huggingface.co/api/models/meta-llama/Llama-2-7b-chat-hf/xet-read-token/f5db02db724555f92da89c216ac04704f23d4590 "HTTP/1.1 200 OK"
I0424 07:10:52.849299 132122617874240 _client.py:1025] HTTP Request: HEAD https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/resolve/main/tokenizer.json "HTTP/1.1 200 OK"
I0424 07:10:52.958084 132122617874240 _client.py:1025] HTTP Request: GET https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/resolve/main/tokenizer.json "HTTP/1.1 200 OK"
I0424 07:10:53.291311 132122617874240 _client.py:1025] HTTP Request: HEAD https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/resolve/main/added_tokens.json "HTTP/1.1 404 Not Found"
I0424 07:10:53.399936 132122617874240 _client.py:1025] HTTP Request: HEAD https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/resolve/main/special_tokens_map.json "HTTP/1.1 200 OK"
I0424 07:10:53.505913 132122617874240 _client.py:1025] HTTP Request: GET https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/resolve/main/special_tokens_map.json "HTTP/1.1 200 OK"
I0424 07:10:53.614464 132122617874240 _client.py:1025] HTTP Request: HEAD https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/resolve/main/chat_template.jinja "HTTP/1.1 404 Not Found"
/deps/src/maxtext/input_pipeline/input_pipeline_utils.py:467: UserWarning: WARNING: Inefficient dataloading. Your train or eval dataset contains 3 shards, smaller than number of host loading data. This is known to lead to inefficient dataloading. Seegithub.com/google/maxtext/blob/main/getting_started/Data_Input_Pipeline.md#multihost-dataloading-best-practice
  warnings.warn(
E0424 07:10:53.717719 132122617874240 packing.py:209] PackAndBatchOperation is deprecated. Please use lazy_dataset.FirstFitPackIterDataset instead.
I0424 07:10:53.717928 132122617874240 data_loader.py:408] Adding CopyNumPyArrayToSharedMemory MapTransform.
I0424 07:10:54.133581 132122617874240 pytree_checkpoint_handler.py:577] save_device_host_concurrent_bytes=None
I0424 07:10:54.133731 132122617874240 base_pytree_checkpoint_handler.py:411] Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=False, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x78297b67c1d0>, enable_pinned_host_transfer=False, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB)
I0424 07:10:54.133778 132122617874240 pytree_checkpoint_handler.py:577] save_device_host_concurrent_bytes=None
I0424 07:10:54.133811 132122617874240 base_pytree_checkpoint_handler.py:411] Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=False, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x78297b67c1d0>, enable_pinned_host_transfer=False, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB)
I0424 07:10:54.133853 132122617874240 checkpoint_manager.py:702] [process=5][thread=MainThread] CheckpointManager init: checkpointers=None, item_names=None, item_handlers={'model_params': <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x781ea4147680>, 'optimizer_state': <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x781ea4144a40>, 'custom_metadata': <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x781ea414e6f0>}, handler_registry=None
I0424 07:10:54.134061 132122617874240 composite_checkpoint_handler.py:237] Deferred registration for item: "model_params". Adding handler `<orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x781ea4147680>` for item "model_params" and save args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>` to `_handler_registry`.
I0424 07:10:54.134104 132122617874240 composite_checkpoint_handler.py:237] Deferred registration for item: "optimizer_state". Adding handler `<orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x781ea4144a40>` for item "optimizer_state" and save args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>` to `_handler_registry`.
I0424 07:10:54.134135 132122617874240 composite_checkpoint_handler.py:237] Deferred registration for item: "custom_metadata". Adding handler `<orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x781ea414e6f0>` for item "custom_metadata" and save args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>` to `_handler_registry`.
I0424 07:10:54.134161 132122617874240 composite_checkpoint_handler.py:237] Deferred registration for item: "metrics". Adding handler `<orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x781ea4151ee0>` for item "metrics" and save args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>` to `_handler_registry`.
I0424 07:10:54.134188 132122617874240 composite_checkpoint_handler.py:505] Initialized registry DefaultCheckpointHandlerRegistry({('model_params', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x781ea4147680>, ('model_params', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x781ea4147680>, ('optimizer_state', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x781ea4144a40>, ('optimizer_state', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x781ea4144a40>, ('custom_metadata', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x781ea414e6f0>, ('custom_metadata', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x781ea414e6f0>, ('metrics', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x781ea4151ee0>, ('metrics', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x781ea4151ee0>}).
I0424 07:10:54.135199 132122617874240 async_checkpointer.py:177] [process=5][thread=MainThread] Using barrier_sync_fn: <function get_barrier_sync_fn.<locals>._fn at 0x780fa07cf920> timeout: 600 secs and primary_host=0 for async checkpoint writes
I0424 07:10:56.471759 132122617874240 checkpoint_manager.py:1788] Found 0 checkpoint steps in gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237/pt_sft_linen_xpk_main_20260424_070237_02_sft_linen_ckpt/checkpoints
I0424 07:10:56.483196 132122617874240 checkpoint_manager.py:921] [process=5][thread=MainThread] CheckpointManager created,  primary_host=0, CheckpointManagerOptions=CheckpointManagerOptions(save_interval_steps=10000, max_to_keep=None, keep_time_interval=None, keep_period=None, should_keep_fn=None, best_fn=None, best_mode='max', keep_checkpoints_without_metrics=True, step_prefix=None, step_format_fixed_length=None, step_name_format=None, create=True, cleanup_tmp_directories=False, save_on_steps=frozenset(), single_host_load_and_broadcast=False, todelete_subdir=None, todelete_full_path=None, enable_hns=False, enable_background_delete=False, read_only=False, enable_async_checkpointing=True, async_options=None, multiprocessing_options=MultiprocessingOptions(primary_host=0, active_processes=None, barrier_sync_key_prefix=None), should_save_fn=None, file_options=FileOptions(path_permission_mode=None), save_root_metadata=True, temporary_path_class=None, save_decision_policy=None, preservation_policy=None, prevent_write_metrics=False, enable_should_save_is_saving_in_progress_check=True, enable_per_process_directory_creation=False), root_directory=gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237/pt_sft_linen_xpk_main_20260424_070237_02_sft_linen_ckpt/checkpoints: <orbax.checkpoint.checkpoint_manager.CheckpointManager object at 0x781ea414cd40>
I0424 07:10:56.483464 132122617874240 peft_trainer.py:584] Training with mesh: Mesh('diloco': 1, 'data': 4, 'stage': 1, 'fsdp': 8, 'fsdp_transpose': 1, 'sequence': 1, 'context': 1, 'context_autoregressive': 1, 'tensor': 1, 'tensor_transpose': 1, 'tensor_sequence': 1, 'expert': 1, 'autoregressive': 1, axis_types=(Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto))
I0424 07:10:56.944943 132122617874240 peft_trainer.py:594] Compiled train_step cache size: 0
I0424 07:10:56.947030 132122617874240 metric_logger.py:301] number parameters: 0.000 billion
I0424 07:10:56.949223 131946427463424 grain_pool.py:367] Grain pool will use 1 processes.
I0424 07:10:56.999639 131946427463424 grain_pool.py:440] Grain pool will start child processes.
Per train step:
 Total TFLOPs: 0.00 
 split as 54.29% learnable weight flops and 45.71% attention flops
I0424 07:10:57.005520 131946427463424 grain_pool.py:448] Grain pool started all child processes.
2026-04-24 07:11:01.161669: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-04-24 07:11:01.206159: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-04-24 07:11:02.373361: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
`rope_parameters`'s factor field must be a float >= 1, got 40
`rope_parameters`'s beta_fast field must be a float, got 32
`rope_parameters`'s beta_slow field must be a float, got 1
DeepseekV32Config got `key=rope_scaling` in kwargs but hasn't set it as attribute. For RoPE standardization you need to set `self.rope_parameters` in model's config. 
2026-04-24 07:11:08.102586: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)
Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/deps/src/maxtext/trainers/post_train/sft/train_sft.py", line 217, in <module>
    app.run(main)
  File "/usr/local/lib/python3.12/site-packages/absl/app.py", line 316, in run
    _run_main(main, args)
  File "/usr/local/lib/python3.12/site-packages/absl/app.py", line 261, in _run_main
    sys.exit(main(argv))
             ^^^^^^^^^^
  File "/deps/src/maxtext/trainers/post_train/sft/train_sft.py", line 213, in main
    train(mt_config, goodput_recorder)
  File "/deps/src/maxtext/trainers/post_train/sft/train_sft.py", line 190, in train
    trainer = train_model(mt_config, trainer, mesh)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/deps/src/maxtext/trainers/post_train/sft/train_sft.py", line 176, in train_model
    trainer.train(trainer.data_hooks.train_data_iterator, trainer.data_hooks.eval_data_iterator)
  File "/usr/local/lib/python3.12/site-packages/tunix/sft/peft_trainer.py", line 652, in train
    train_example = sharding_utils.shard_input(
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/tunix/sft/sharding_utils.py", line 58, in shard_input
    return jax.tree.map(
           ^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/tree.py", line 156, in map
    return tree_util.tree_map(f, tree, *rest, is_leaf=is_leaf)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/tree_util.py", line 373, in tree_map
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/tree_util.py", line 373, in <genexpr>
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
                             ^^^^^^
  File "/usr/local/lib/python3.12/site-packages/tunix/sft/sharding_utils.py", line 59, in <lambda>
    lambda x: jax.make_array_from_process_local_data(
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/array.py", line 985, in make_array_from_process_local_data
    out = [_array_from_process_local_data(data, s, shape)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/array.py", line 1047, in _array_from_process_local_data
    return make_array_from_callback(global_shape, sharding, cb)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/array.py", line 844, in make_array_from_callback
    per_device_values = api.device_put(per_device_values, devices)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/api.py", line 2732, in device_put
    out_flat = dispatch._batched_device_put_impl(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/dispatch.py", line 602, in _batched_device_put_impl
    y = _device_put_impl(x, device=device, src=src, copy=cp, aval=aval)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/dispatch.py", line 582, in _device_put_impl
    return _device_put_sharding_impl(x, aval, device, copy)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/dispatch.py", line 512, in _device_put_sharding_impl
    raise ValueError(
ValueError: When the second argument to `device_put` is a Device, the first argument must be a fully addressable array or a non-addressable array with a single device sharding. Got value with devices {TpuDevice(id=23, process_index=5, coords=(3,5,0), core_on_chip=0), TpuDevice(id=20, process_index=4, coords=(0,5,0), core_on_chip=0), TpuDevice(id=19, process_index=5, coords=(3,4,0), core_on_chip=0), TpuDevice(id=11, process_index=3, coords=(3,2,0), core_on_chip=0), TpuDevice(id=17, process_index=4, coords=(1,4,0), core_on_chip=0), TpuDevice(id=27, process_index=7, coords=(3,6,0), core_on_chip=0), TpuDevice(id=28, process_index=6, coords=(0,7,0), core_on_chip=0), TpuDevice(id=30, process_index=7, coords=(2,7,0), core_on_chip=0), TpuDevice(id=24, process_index=6, coords=(0,6,0), core_on_chip=0), TpuDevice(id=7, process_index=1, coords=(3,1,0), core_on_chip=0), TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=10, process_index=3, coords=(2,2,0), core_on_chip=0), TpuDevice(id=9, process_index=2, coords=(1,2,0), core_on_chip=0), TpuDevice(id=3, process_index=1, coords=(3,0,0), core_on_chip=0), TpuDevice(id=29, process_index=6, coords=(1,7,0), core_on_chip=0), TpuDevice(id=14, process_index=3, coords=(2,3,0), core_on_chip=0), TpuDevice(id=31, process_index=7, coords=(3,7,0), core_on_chip=0), TpuDevice(id=8, process_index=2, coords=(0,2,0), core_on_chip=0), TpuDevice(id=2, process_index=1, coords=(2,0,0), core_on_chip=0), TpuDevice(id=15, process_index=3, coords=(3,3,0), core_on_chip=0), TpuDevice(id=6, process_index=1, coords=(2,1,0), core_on_chip=0), TpuDevice(id=25, process_index=6, coords=(1,6,0), core_on_chip=0), TpuDevice(id=22, process_index=5, coords=(2,5,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=12, process_index=2, coords=(0,3,0), core_on_chip=0), TpuDevice(id=26, process_index=7, coords=(2,6,0), core_on_chip=0), TpuDevice(id=21, process_index=4, coords=(1,5,0), core_on_chip=0), TpuDevice(id=16, process_index=4, coords=(0,4,0), core_on_chip=0), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=13, process_index=2, coords=(1,3,0), core_on_chip=0), TpuDevice(id=18, process_index=5, coords=(2,4,0), core_on_chip=0)}
I0424 07:11:13.445631 131946427463424 grain_pool.py:542] Grain pool is exiting.
I0424 07:11:13.445786 131946427463424 grain_pool.py:547] Shutting down multiprocessing system.
I0424 07:11:19.167035 131946427463424 grain_pool.py:547] Shutting down multiprocessing system.
/usr/local/lib/python3.12/multiprocessing/resource_tracker.py:279: UserWarning: resource_tracker: There appear to be 15 leaked shared_memory objects to clean up at shutdown
  warnings.warn('resource_tracker: There appear to be %d '
XPK End: Fri Apr 24 07:11:30 UTC 2026
EXIT_CODE=1