MaxView

← Back to run

Log Summary

XPK Start: Thu Apr 23 07:26:48 UTC 2026
`rope_parameters`'s factor field must be a float >= 1, got 40
`rope_parameters`'s beta_fast field must be a float, got 32
`rope_parameters`'s beta_slow field must be a float, got 1
DeepseekV32Config got `key=rope_scaling` in kwargs but hasn't set it as attribute. For RoPE standardization you need to set `self.rope_parameters` in model's config. 
2026-04-23 07:27:47.988540: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)
I0423 07:27:48.188416 135604324120384 max_utils.py:273] Attempting to initialize the jax distributed system...
I0423 07:27:57.227900 135604324120384 distributed.py:149] Starting JAX distributed service on [::]:8482
I0423 07:27:57.230302 135604324120384 distributed.py:172] Connecting to JAX distributed service on mt-02-sft-linen-ckpt-fa340-slice-job-0-0.mt-02-sft-linen-ckpt-fa340:8482
I0423 07:28:14.351838 135604324120384 max_utils.py:284] Jax distributed system initialized!
I0423 07:28:20.550373 135604324120384 max_utils.py:800] System Information: Jax Version: 0.9.2
I0423 07:28:20.550479 135604324120384 max_utils.py:801] System Information: Jaxlib Version: 0.9.2
I0423 07:28:20.550517 135604324120384 max_utils.py:802] System Information: Jax Backend: PJRT C API
TFRT TPU v6 lite
Built on Apr 6 2026 20:48:10 (1775533690) cl/895581894
I0423 07:28:20.554246 135604324120384 maxtext_utils.py:1604] Num_devices: 32, shape (1, 4, 1, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1)
I0423 07:28:21.211314 135604324120384 maxtext_utils.py:1604] Num_devices: 32, shape (1, 4, 1, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1)
I0423 07:28:22.318448 135604324120384 pytree_checkpoint_handler.py:577] save_device_host_concurrent_bytes=None
I0423 07:28:22.318955 135604324120384 base_pytree_checkpoint_handler.py:411] Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=True, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x7b54213751f0>, enable_pinned_host_transfer=False, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB)
I0423 07:28:22.319013 135604324120384 abstract_checkpointer.py:35] orbax-checkpoint version: 0.11.28
W0423 07:28:22.824115 135604324120384 checkpoint.py:202] Metadata file does not exist: gs://lance-maxtext/pt_seed_ckpts/pt_seed_ckpts/pt_seed_ckpt_gpt352k_linen/checkpoints/9/items/_CHECKPOINT_METADATA
I0423 07:28:23.409421    1948 google_auth_provider.cc:181] Running on GCE, using service account 562977990677-compute@developer.gserviceaccount.com
I0423 07:28:24.314723 135604324120384 checkpointer.py:304] Restoring checkpoint from gs://lance-maxtext/pt_seed_ckpts/pt_seed_ckpts/pt_seed_ckpt_gpt352k_linen/checkpoints/9/items.
W0423 07:28:26.634971 135604324120384 transform_utils.py:230] The transformations API will eventually be replaced by an upgraded design. The current API will not be removed until this point, but it will no longer be actively worked on.
I0423 07:28:26.635327 135604324120384 transform_utils.py:288] The following keys are not loaded from the original tree after applying specified transforms: params/params/decoder/to_nnx__rngs/aqt/count, params/params/decoder/to_nnx__rngs/aqt/key, params/params/decoder/to_nnx__rngs/dropout/count, params/params/decoder/to_nnx__rngs/dropout/key, params/params/decoder/to_nnx__rngs/params/count, params/params/decoder/to_nnx__rngs/params/key
I0423 07:28:28.186471 135604324120384 checkpointer.py:318] Finished restoring checkpoint in 4.23 seconds from gs://lance-maxtext/pt_seed_ckpts/pt_seed_ckpts/pt_seed_ckpt_gpt352k_linen/checkpoints/9/items.
I0423 07:28:28.258221 135604324120384 config.py:112] TensorFlow version 2.20.0 available.
I0423 07:28:28.258726 135604324120384 config.py:125] JAX version 0.9.2 available.
I0423 07:28:28.691691 135604324120384 _client.py:1025] HTTP Request: HEAD https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k/resolve/main/README.md "HTTP/1.1 307 Temporary Redirect"
I0423 07:28:28.699918 135604324120384 _client.py:1025] HTTP Request: HEAD https://huggingface.co/api/resolve-cache/datasets/HuggingFaceH4/ultrachat_200k/8049631c405ae6576f93f445c6b8166f76f5505a/README.md "HTTP/1.1 200 OK"
I0423 07:28:28.708855 135604324120384 _client.py:1025] HTTP Request: GET https://huggingface.co/api/resolve-cache/datasets/HuggingFaceH4/ultrachat_200k/8049631c405ae6576f93f445c6b8166f76f5505a/README.md "HTTP/1.1 200 OK"
I0423 07:28:28.823413 135604324120384 _client.py:1025] HTTP Request: HEAD https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k/resolve/8049631c405ae6576f93f445c6b8166f76f5505a/ultrachat_200k.py "HTTP/1.1 404 Not Found"
I0423 07:28:29.137765 135604324120384 _client.py:1025] HTTP Request: HEAD https://s3.amazonaws.com/datasets.huggingface.co/datasets/datasets/HuggingFaceH4/ultrachat_200k/HuggingFaceH4/ultrachat_200k.py "HTTP/1.1 404 Not Found"
I0423 07:28:29.254178 135604324120384 _client.py:1025] HTTP Request: GET https://huggingface.co/api/datasets/HuggingFaceH4/ultrachat_200k/revision/8049631c405ae6576f93f445c6b8166f76f5505a "HTTP/1.1 200 OK"
I0423 07:28:29.357612 135604324120384 _client.py:1025] HTTP Request: HEAD https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k/resolve/8049631c405ae6576f93f445c6b8166f76f5505a/.huggingface.yaml "HTTP/1.1 404 Not Found"
I0423 07:28:29.512719 135604324120384 _client.py:1025] HTTP Request: GET https://datasets-server.huggingface.co/info?dataset=HuggingFaceH4/ultrachat_200k "HTTP/1.1 200 OK"
I0423 07:28:29.619787 135604324120384 _client.py:1025] HTTP Request: GET https://huggingface.co/api/datasets/HuggingFaceH4/ultrachat_200k/tree/8049631c405ae6576f93f445c6b8166f76f5505a/data?recursive=true&expand=false "HTTP/1.1 200 OK"
I0423 07:28:29.738644 135604324120384 _client.py:1025] HTTP Request: GET https://huggingface.co/api/datasets/HuggingFaceH4/ultrachat_200k/tree/8049631c405ae6576f93f445c6b8166f76f5505a?recursive=false&expand=false "HTTP/1.1 200 OK"
I0423 07:28:29.847123 135604324120384 _client.py:1025] HTTP Request: HEAD https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k/resolve/8049631c405ae6576f93f445c6b8166f76f5505a/dataset_infos.json "HTTP/1.1 404 Not Found"
I0423 07:28:30.048824 135604324120384 _client.py:1025] HTTP Request: HEAD https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/resolve/main/config.json "HTTP/1.1 200 OK"
I0423 07:28:30.158365 135604324120384 _client.py:1025] HTTP Request: GET https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/resolve/main/config.json "HTTP/1.1 200 OK"
I0423 07:28:30.279623 135604324120384 _client.py:1025] HTTP Request: HEAD https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/resolve/main/tokenizer_config.json "HTTP/1.1 200 OK"
I0423 07:28:30.387032 135604324120384 _client.py:1025] HTTP Request: GET https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/resolve/main/tokenizer_config.json "HTTP/1.1 200 OK"
I0423 07:28:30.511921 135604324120384 _client.py:1025] HTTP Request: GET https://huggingface.co/api/models/meta-llama/Llama-2-7b-chat-hf/tree/main/additional_chat_templates?recursive=false&expand=false "HTTP/1.1 404 Not Found"
I0423 07:28:30.612292 135604324120384 _client.py:1025] HTTP Request: GET https://huggingface.co/api/models/meta-llama/Llama-2-7b-chat-hf/tree/main?recursive=true&expand=false "HTTP/1.1 200 OK"
I0423 07:28:30.722516 135604324120384 _client.py:1025] HTTP Request: HEAD https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/resolve/main/tokenizer.model "HTTP/1.1 302 Found"
I0423 07:28:30.835324 135604324120384 _client.py:1025] HTTP Request: GET https://huggingface.co/api/models/meta-llama/Llama-2-7b-chat-hf/xet-read-token/f5db02db724555f92da89c216ac04704f23d4590 "HTTP/1.1 200 OK"
I0423 07:28:31.475312 135604324120384 _client.py:1025] HTTP Request: HEAD https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/resolve/main/tokenizer.json "HTTP/1.1 200 OK"
I0423 07:28:31.586866 135604324120384 _client.py:1025] HTTP Request: GET https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/resolve/main/tokenizer.json "HTTP/1.1 200 OK"
I0423 07:28:31.963258 135604324120384 _client.py:1025] HTTP Request: HEAD https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/resolve/main/added_tokens.json "HTTP/1.1 404 Not Found"
I0423 07:28:32.072515 135604324120384 _client.py:1025] HTTP Request: HEAD https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/resolve/main/special_tokens_map.json "HTTP/1.1 200 OK"
I0423 07:28:32.196173 135604324120384 _client.py:1025] HTTP Request: GET https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/resolve/main/special_tokens_map.json "HTTP/1.1 200 OK"
I0423 07:28:32.304965 135604324120384 _client.py:1025] HTTP Request: HEAD https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/resolve/main/chat_template.jinja "HTTP/1.1 404 Not Found"
/deps/src/maxtext/input_pipeline/input_pipeline_utils.py:467: UserWarning: WARNING: Inefficient dataloading. Your train or eval dataset contains 3 shards, smaller than number of host loading data. This is known to lead to inefficient dataloading. Seegithub.com/google/maxtext/blob/main/getting_started/Data_Input_Pipeline.md#multihost-dataloading-best-practice
  warnings.warn(
E0423 07:28:32.409493 135604324120384 packing.py:209] PackAndBatchOperation is deprecated. Please use lazy_dataset.FirstFitPackIterDataset instead.
I0423 07:28:32.409708 135604324120384 data_loader.py:408] Adding CopyNumPyArrayToSharedMemory MapTransform.
I0423 07:28:32.828326 135604324120384 pytree_checkpoint_handler.py:577] save_device_host_concurrent_bytes=None
I0423 07:28:32.828465 135604324120384 base_pytree_checkpoint_handler.py:411] Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=False, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x7b54213751f0>, enable_pinned_host_transfer=False, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB)
I0423 07:28:32.828512 135604324120384 pytree_checkpoint_handler.py:577] save_device_host_concurrent_bytes=None
I0423 07:28:32.828547 135604324120384 base_pytree_checkpoint_handler.py:411] Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=False, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x7b54213751f0>, enable_pinned_host_transfer=False, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB)
I0423 07:28:32.828592 135604324120384 checkpoint_manager.py:702] [process=5][thread=MainThread] CheckpointManager init: checkpointers=None, item_names=None, item_handlers={'model_params': <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7b52b1ea41a0>, 'optimizer_state': <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7b4923c7df40>, 'custom_metadata': <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7b4923c87800>}, handler_registry=None
I0423 07:28:32.828808 135604324120384 composite_checkpoint_handler.py:237] Deferred registration for item: "model_params". Adding handler `<orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7b52b1ea41a0>` for item "model_params" and save args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>` to `_handler_registry`.
I0423 07:28:32.828865 135604324120384 composite_checkpoint_handler.py:237] Deferred registration for item: "optimizer_state". Adding handler `<orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7b4923c7df40>` for item "optimizer_state" and save args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>` to `_handler_registry`.
I0423 07:28:32.828897 135604324120384 composite_checkpoint_handler.py:237] Deferred registration for item: "custom_metadata". Adding handler `<orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7b4923c87800>` for item "custom_metadata" and save args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>` to `_handler_registry`.
I0423 07:28:32.828924 135604324120384 composite_checkpoint_handler.py:237] Deferred registration for item: "metrics". Adding handler `<orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7b52b1e8e780>` for item "metrics" and save args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>` to `_handler_registry`.
I0423 07:28:32.829049 135604324120384 composite_checkpoint_handler.py:505] Initialized registry DefaultCheckpointHandlerRegistry({('model_params', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7b52b1ea41a0>, ('model_params', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7b52b1ea41a0>, ('optimizer_state', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7b4923c7df40>, ('optimizer_state', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7b4923c7df40>, ('custom_metadata', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7b4923c87800>, ('custom_metadata', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7b4923c87800>, ('metrics', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7b52b1e8e780>, ('metrics', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7b52b1e8e780>}).
I0423 07:28:32.830072 135604324120384 async_checkpointer.py:177] [process=5][thread=MainThread] Using barrier_sync_fn: <function get_barrier_sync_fn.<locals>._fn at 0x7b3ce558f920> timeout: 600 secs and primary_host=0 for async checkpoint writes
I0423 07:28:35.386541 135604324120384 checkpoint_manager.py:1788] Found 0 checkpoint steps in gs://lance-maxtext/pt_ckpt_xpk_main_20260423_071551/pt_sft_linen_xpk_main_20260423_071551_02_sft_linen_ckpt/checkpoints
I0423 07:28:35.799897 135604324120384 checkpoint_manager.py:921] [process=5][thread=MainThread] CheckpointManager created,  primary_host=0, CheckpointManagerOptions=CheckpointManagerOptions(save_interval_steps=10000, max_to_keep=None, keep_time_interval=None, keep_period=None, should_keep_fn=None, best_fn=None, best_mode='max', keep_checkpoints_without_metrics=True, step_prefix=None, step_format_fixed_length=None, step_name_format=None, create=True, cleanup_tmp_directories=False, save_on_steps=frozenset(), single_host_load_and_broadcast=False, todelete_subdir=None, todelete_full_path=None, enable_hns=False, enable_background_delete=False, read_only=False, enable_async_checkpointing=True, async_options=None, multiprocessing_options=MultiprocessingOptions(primary_host=0, active_processes=None, barrier_sync_key_prefix=None), should_save_fn=None, file_options=FileOptions(path_permission_mode=None), save_root_metadata=True, temporary_path_class=None, save_decision_policy=None, preservation_policy=None, prevent_write_metrics=False, enable_should_save_is_saving_in_progress_check=True, enable_per_process_directory_creation=False), root_directory=gs://lance-maxtext/pt_ckpt_xpk_main_20260423_071551/pt_sft_linen_xpk_main_20260423_071551_02_sft_linen_ckpt/checkpoints: <orbax.checkpoint.checkpoint_manager.CheckpointManager object at 0x7b4923c87650>
I0423 07:28:35.800261 135604324120384 peft_trainer.py:584] Training with mesh: Mesh('diloco': 1, 'data': 4, 'stage': 1, 'fsdp': 8, 'fsdp_transpose': 1, 'sequence': 1, 'context': 1, 'context_autoregressive': 1, 'tensor': 1, 'tensor_transpose': 1, 'tensor_sequence': 1, 'expert': 1, 'autoregressive': 1, axis_types=(Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto))
I0423 07:28:36.256852 135604324120384 peft_trainer.py:594] Compiled train_step cache size: 0
I0423 07:28:36.258992 135604324120384 metric_logger.py:301] number parameters: 0.000 billion
I0423 07:28:36.261217 135428125419264 grain_pool.py:367] Grain pool will use 1 processes.
I0423 07:28:36.311390 135428125419264 grain_pool.py:440] Grain pool will start child processes.
Per train step:
 Total TFLOPs: 0.00 
 split as 54.29% learnable weight flops and 45.71% attention flops
I0423 07:28:36.317158 135428125419264 grain_pool.py:448] Grain pool started all child processes.
2026-04-23 07:28:40.488742: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-04-23 07:28:40.533013: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-04-23 07:28:41.703015: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
`rope_parameters`'s factor field must be a float >= 1, got 40
`rope_parameters`'s beta_fast field must be a float, got 32
`rope_parameters`'s beta_slow field must be a float, got 1
DeepseekV32Config got `key=rope_scaling` in kwargs but hasn't set it as attribute. For RoPE standardization you need to set `self.rope_parameters` in model's config. 
2026-04-23 07:28:47.384569: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)
Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/deps/src/maxtext/trainers/post_train/sft/train_sft.py", line 217, in <module>
    app.run(main)
  File "/usr/local/lib/python3.12/site-packages/absl/app.py", line 316, in run
    _run_main(main, args)
  File "/usr/local/lib/python3.12/site-packages/absl/app.py", line 261, in _run_main
    sys.exit(main(argv))
             ^^^^^^^^^^
  File "/deps/src/maxtext/trainers/post_train/sft/train_sft.py", line 213, in main
    train(mt_config, goodput_recorder)
  File "/deps/src/maxtext/trainers/post_train/sft/train_sft.py", line 190, in train
    trainer = train_model(mt_config, trainer, mesh)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/deps/src/maxtext/trainers/post_train/sft/train_sft.py", line 176, in train_model
    trainer.train(trainer.data_hooks.train_data_iterator, trainer.data_hooks.eval_data_iterator)
  File "/usr/local/lib/python3.12/site-packages/tunix/sft/peft_trainer.py", line 652, in train
    train_example = sharding_utils.shard_input(
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/tunix/sft/sharding_utils.py", line 58, in shard_input
    return jax.tree.map(
           ^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/tree.py", line 156, in map
    return tree_util.tree_map(f, tree, *rest, is_leaf=is_leaf)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/tree_util.py", line 373, in tree_map
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/tree_util.py", line 373, in <genexpr>
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
                             ^^^^^^
  File "/usr/local/lib/python3.12/site-packages/tunix/sft/sharding_utils.py", line 59, in <lambda>
    lambda x: jax.make_array_from_process_local_data(
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/array.py", line 985, in make_array_from_process_local_data
    out = [_array_from_process_local_data(data, s, shape)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/array.py", line 1047, in _array_from_process_local_data
    return make_array_from_callback(global_shape, sharding, cb)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/array.py", line 844, in make_array_from_callback
    per_device_values = api.device_put(per_device_values, devices)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/api.py", line 2732, in device_put
    out_flat = dispatch._batched_device_put_impl(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/dispatch.py", line 602, in _batched_device_put_impl
    y = _device_put_impl(x, device=device, src=src, copy=cp, aval=aval)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/dispatch.py", line 582, in _device_put_impl
    return _device_put_sharding_impl(x, aval, device, copy)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/jax/_src/dispatch.py", line 512, in _device_put_sharding_impl
    raise ValueError(
ValueError: When the second argument to `device_put` is a Device, the first argument must be a fully addressable array or a non-addressable array with a single device sharding. Got value with devices {TpuDevice(id=23, process_index=5, coords=(3,5,0), core_on_chip=0), TpuDevice(id=31, process_index=7, coords=(3,7,0), core_on_chip=0), TpuDevice(id=7, process_index=1, coords=(3,1,0), core_on_chip=0), TpuDevice(id=10, process_index=3, coords=(2,2,0), core_on_chip=0), TpuDevice(id=30, process_index=7, coords=(2,7,0), core_on_chip=0), TpuDevice(id=28, process_index=6, coords=(0,7,0), core_on_chip=0), TpuDevice(id=14, process_index=3, coords=(2,3,0), core_on_chip=0), TpuDevice(id=12, process_index=2, coords=(0,3,0), core_on_chip=0), TpuDevice(id=3, process_index=1, coords=(3,0,0), core_on_chip=0), TpuDevice(id=11, process_index=3, coords=(3,2,0), core_on_chip=0), TpuDevice(id=25, process_index=6, coords=(1,6,0), core_on_chip=0), TpuDevice(id=8, process_index=2, coords=(0,2,0), core_on_chip=0), TpuDevice(id=18, process_index=5, coords=(2,4,0), core_on_chip=0), TpuDevice(id=26, process_index=7, coords=(2,6,0), core_on_chip=0), TpuDevice(id=17, process_index=4, coords=(1,4,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=22, process_index=5, coords=(2,5,0), core_on_chip=0), TpuDevice(id=27, process_index=7, coords=(3,6,0), core_on_chip=0), TpuDevice(id=29, process_index=6, coords=(1,7,0), core_on_chip=0), TpuDevice(id=15, process_index=3, coords=(3,3,0), core_on_chip=0), TpuDevice(id=9, process_index=2, coords=(1,2,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=6, process_index=1, coords=(2,1,0), core_on_chip=0), TpuDevice(id=13, process_index=2, coords=(1,3,0), core_on_chip=0), TpuDevice(id=19, process_index=5, coords=(3,4,0), core_on_chip=0), TpuDevice(id=20, process_index=4, coords=(0,5,0), core_on_chip=0), TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=16, process_index=4, coords=(0,4,0), core_on_chip=0), TpuDevice(id=24, process_index=6, coords=(0,6,0), core_on_chip=0), TpuDevice(id=2, process_index=1, coords=(2,0,0), core_on_chip=0), TpuDevice(id=21, process_index=4, coords=(1,5,0), core_on_chip=0)}
I0423 07:28:52.735937 135428125419264 grain_pool.py:542] Grain pool is exiting.
I0423 07:28:52.736051 135428125419264 grain_pool.py:547] Shutting down multiprocessing system.
I0423 07:28:58.513347 135428125419264 grain_pool.py:547] Shutting down multiprocessing system.
/usr/local/lib/python3.12/multiprocessing/resource_tracker.py:279: UserWarning: resource_tracker: There appear to be 15 leaked shared_memory objects to clean up at shutdown
  warnings.warn('resource_tracker: There appear to be %d '
XPK End: Thu Apr 23 07:29:08 UTC 2026
EXIT_CODE=1