XPK Start: Sat Apr 25 07:18:40 UTC 2026 `rope_parameters`'s factor field must be a float >= 1, got 40 `rope_parameters`'s beta_fast field must be a float, got 32 `rope_parameters`'s beta_slow field must be a float, got 1 DeepseekV32Config got `key=rope_scaling` in kwargs but hasn't set it as attribute. For RoPE standardization you need to set `self.rope_parameters` in model's config. 2026-04-25 07:19:11.933208: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303) I0425 07:19:12.133394 136528987920192 max_utils.py:273] Attempting to initialize the jax distributed system... I0425 07:19:21.170868 136528987920192 distributed.py:149] Starting JAX distributed service on [::]:8482 I0425 07:19:21.173221 136528987920192 distributed.py:172] Connecting to JAX distributed service on mt-02-sft-linen-ckpt-ycw6a-slice-job-0-0.mt-02-sft-linen-ckpt-ycw6a:8482 I0425 07:19:22.495081 136528987920192 max_utils.py:284] Jax distributed system initialized! I0425 07:19:27.585396 136528987920192 max_utils.py:800] System Information: Jax Version: 0.9.2 I0425 07:19:27.585501 136528987920192 max_utils.py:801] System Information: Jaxlib Version: 0.9.2 I0425 07:19:27.585542 136528987920192 max_utils.py:802] System Information: Jax Backend: PJRT C API TFRT TPU v6 lite Built on Apr 6 2026 20:48:10 (1775533690) cl/895581894 I0425 07:19:27.588973 136528987920192 maxtext_utils.py:1604] Num_devices: 32, shape (1, 4, 1, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1) I0425 07:19:28.259395 136528987920192 maxtext_utils.py:1604] Num_devices: 32, shape (1, 4, 1, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1) I0425 07:19:29.397443 136528987920192 pytree_checkpoint_handler.py:577] save_device_host_concurrent_bytes=None I0425 07:19:29.397956 136528987920192 base_pytree_checkpoint_handler.py:411] Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=True, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x7c2b6b59d460>, enable_pinned_host_transfer=False, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB) I0425 07:19:29.398036 136528987920192 abstract_checkpointer.py:35] orbax-checkpoint version: 0.11.28 W0425 07:19:29.897495 136528987920192 checkpoint.py:202] Metadata file does not exist: gs://lance-maxtext/pt_seed_ckpts/pt_seed_ckpts/pt_seed_ckpt_gpt352k_linen/checkpoints/9/items/_CHECKPOINT_METADATA I0425 07:19:30.557339 1906 google_auth_provider.cc:181] Running on GCE, using service account 562977990677-compute@developer.gserviceaccount.com I0425 07:19:31.841765 136528987920192 checkpointer.py:304] Restoring checkpoint from gs://lance-maxtext/pt_seed_ckpts/pt_seed_ckpts/pt_seed_ckpt_gpt352k_linen/checkpoints/9/items. W0425 07:19:34.105015 136528987920192 transform_utils.py:230] The transformations API will eventually be replaced by an upgraded design. The current API will not be removed until this point, but it will no longer be actively worked on. I0425 07:19:34.105387 136528987920192 transform_utils.py:288] The following keys are not loaded from the original tree after applying specified transforms: params/params/decoder/to_nnx__rngs/aqt/count, params/params/decoder/to_nnx__rngs/aqt/key, params/params/decoder/to_nnx__rngs/dropout/count, params/params/decoder/to_nnx__rngs/dropout/key, params/params/decoder/to_nnx__rngs/params/count, params/params/decoder/to_nnx__rngs/params/key I0425 07:19:34.448001 136528987920192 checkpointer.py:318] Finished restoring checkpoint in 2.99 seconds from gs://lance-maxtext/pt_seed_ckpts/pt_seed_ckpts/pt_seed_ckpt_gpt352k_linen/checkpoints/9/items. I0425 07:19:34.519785 136528987920192 config.py:112] TensorFlow version 2.20.0 available. I0425 07:19:34.520316 136528987920192 config.py:125] JAX version 0.9.2 available. I0425 07:19:34.958255 136528987920192 _client.py:1025] HTTP Request: HEAD https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k/resolve/main/README.md "HTTP/1.1 307 Temporary Redirect" I0425 07:19:34.967135 136528987920192 _client.py:1025] HTTP Request: HEAD https://huggingface.co/api/resolve-cache/datasets/HuggingFaceH4/ultrachat_200k/8049631c405ae6576f93f445c6b8166f76f5505a/README.md "HTTP/1.1 200 OK" I0425 07:19:34.985082 136528987920192 _client.py:1025] HTTP Request: GET https://huggingface.co/api/resolve-cache/datasets/HuggingFaceH4/ultrachat_200k/8049631c405ae6576f93f445c6b8166f76f5505a/README.md "HTTP/1.1 200 OK" I0425 07:19:35.092925 136528987920192 _client.py:1025] HTTP Request: HEAD https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k/resolve/8049631c405ae6576f93f445c6b8166f76f5505a/ultrachat_200k.py "HTTP/1.1 404 Not Found" I0425 07:19:35.392481 136528987920192 _client.py:1025] HTTP Request: HEAD https://s3.amazonaws.com/datasets.huggingface.co/datasets/datasets/HuggingFaceH4/ultrachat_200k/HuggingFaceH4/ultrachat_200k.py "HTTP/1.1 404 Not Found" I0425 07:19:35.504070 136528987920192 _client.py:1025] HTTP Request: GET https://huggingface.co/api/datasets/HuggingFaceH4/ultrachat_200k/revision/8049631c405ae6576f93f445c6b8166f76f5505a "HTTP/1.1 200 OK" I0425 07:19:35.613229 136528987920192 _client.py:1025] HTTP Request: HEAD https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k/resolve/8049631c405ae6576f93f445c6b8166f76f5505a/.huggingface.yaml "HTTP/1.1 404 Not Found" I0425 07:19:35.773933 136528987920192 _client.py:1025] HTTP Request: GET https://datasets-server.huggingface.co/info?dataset=HuggingFaceH4/ultrachat_200k "HTTP/1.1 200 OK" I0425 07:19:35.881526 136528987920192 _client.py:1025] HTTP Request: GET https://huggingface.co/api/datasets/HuggingFaceH4/ultrachat_200k/tree/8049631c405ae6576f93f445c6b8166f76f5505a/data?recursive=true&expand=false "HTTP/1.1 200 OK" I0425 07:19:35.986166 136528987920192 _client.py:1025] HTTP Request: GET https://huggingface.co/api/datasets/HuggingFaceH4/ultrachat_200k/tree/8049631c405ae6576f93f445c6b8166f76f5505a?recursive=false&expand=false "HTTP/1.1 200 OK" I0425 07:19:36.099577 136528987920192 _client.py:1025] HTTP Request: HEAD https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k/resolve/8049631c405ae6576f93f445c6b8166f76f5505a/dataset_infos.json "HTTP/1.1 404 Not Found" I0425 07:19:36.284096 136528987920192 _client.py:1025] HTTP Request: HEAD https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/resolve/main/config.json "HTTP/1.1 200 OK" I0425 07:19:36.391464 136528987920192 _client.py:1025] HTTP Request: GET https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/resolve/main/config.json "HTTP/1.1 200 OK" I0425 07:19:36.500261 136528987920192 _client.py:1025] HTTP Request: HEAD https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/resolve/main/tokenizer_config.json "HTTP/1.1 200 OK" I0425 07:19:36.613474 136528987920192 _client.py:1025] HTTP Request: GET https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/resolve/main/tokenizer_config.json "HTTP/1.1 200 OK" I0425 07:19:36.726759 136528987920192 _client.py:1025] HTTP Request: GET https://huggingface.co/api/models/meta-llama/Llama-2-7b-chat-hf/tree/main/additional_chat_templates?recursive=false&expand=false "HTTP/1.1 404 Not Found" I0425 07:19:36.832251 136528987920192 _client.py:1025] HTTP Request: GET https://huggingface.co/api/models/meta-llama/Llama-2-7b-chat-hf/tree/main?recursive=true&expand=false "HTTP/1.1 200 OK" I0425 07:19:36.946741 136528987920192 _client.py:1025] HTTP Request: HEAD https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/resolve/main/tokenizer.model "HTTP/1.1 302 Found" I0425 07:19:37.059099 136528987920192 _client.py:1025] HTTP Request: GET https://huggingface.co/api/models/meta-llama/Llama-2-7b-chat-hf/xet-read-token/f5db02db724555f92da89c216ac04704f23d4590 "HTTP/1.1 200 OK" I0425 07:19:37.699824 136528987920192 _client.py:1025] HTTP Request: HEAD https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/resolve/main/tokenizer.json "HTTP/1.1 200 OK" I0425 07:19:37.821932 136528987920192 _client.py:1025] HTTP Request: GET https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/resolve/main/tokenizer.json "HTTP/1.1 200 OK" I0425 07:19:38.237968 136528987920192 _client.py:1025] HTTP Request: HEAD https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/resolve/main/added_tokens.json "HTTP/1.1 404 Not Found" I0425 07:19:38.349588 136528987920192 _client.py:1025] HTTP Request: HEAD https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/resolve/main/special_tokens_map.json "HTTP/1.1 200 OK" I0425 07:19:38.459389 136528987920192 _client.py:1025] HTTP Request: GET https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/resolve/main/special_tokens_map.json "HTTP/1.1 200 OK" I0425 07:19:38.566802 136528987920192 _client.py:1025] HTTP Request: HEAD https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/resolve/main/chat_template.jinja "HTTP/1.1 404 Not Found" /deps/src/maxtext/input_pipeline/input_pipeline_utils.py:467: UserWarning: WARNING: Inefficient dataloading. Your train or eval dataset contains 3 shards, smaller than number of host loading data. This is known to lead to inefficient dataloading. Seegithub.com/google/maxtext/blob/main/getting_started/Data_Input_Pipeline.md#multihost-dataloading-best-practice warnings.warn( E0425 07:19:38.671681 136528987920192 packing.py:209] PackAndBatchOperation is deprecated. Please use lazy_dataset.FirstFitPackIterDataset instead. I0425 07:19:38.671885 136528987920192 data_loader.py:408] Adding CopyNumPyArrayToSharedMemory MapTransform. I0425 07:19:39.086899 136528987920192 pytree_checkpoint_handler.py:577] save_device_host_concurrent_bytes=None I0425 07:19:39.087044 136528987920192 base_pytree_checkpoint_handler.py:411] Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=False, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x7c2b6b59d460>, enable_pinned_host_transfer=False, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB) I0425 07:19:39.087090 136528987920192 pytree_checkpoint_handler.py:577] save_device_host_concurrent_bytes=None I0425 07:19:39.087125 136528987920192 base_pytree_checkpoint_handler.py:411] Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=False, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x7c2b6b59d460>, enable_pinned_host_transfer=False, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB) I0425 07:19:39.087167 136528987920192 checkpoint_manager.py:702] [process=5][thread=MainThread] CheckpointManager init: checkpointers=None, item_names=None, item_handlers={'model_params': <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7c29fc41a720>, 'optimizer_state': <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7c12802ba930>, 'custom_metadata': <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7c20921467b0>}, handler_registry=None I0425 07:19:39.087372 136528987920192 composite_checkpoint_handler.py:237] Deferred registration for item: "model_params". Adding handler `<orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7c29fc41a720>` for item "model_params" and save args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>` to `_handler_registry`. I0425 07:19:39.087414 136528987920192 composite_checkpoint_handler.py:237] Deferred registration for item: "optimizer_state". Adding handler `<orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7c12802ba930>` for item "optimizer_state" and save args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>` to `_handler_registry`. I0425 07:19:39.087441 136528987920192 composite_checkpoint_handler.py:237] Deferred registration for item: "custom_metadata". Adding handler `<orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7c20921467b0>` for item "custom_metadata" and save args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>` to `_handler_registry`. I0425 07:19:39.087469 136528987920192 composite_checkpoint_handler.py:237] Deferred registration for item: "metrics". Adding handler `<orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7c2092149d30>` for item "metrics" and save args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>` to `_handler_registry`. I0425 07:19:39.087498 136528987920192 composite_checkpoint_handler.py:505] Initialized registry DefaultCheckpointHandlerRegistry({('model_params', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7c29fc41a720>, ('model_params', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7c29fc41a720>, ('optimizer_state', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7c12802ba930>, ('optimizer_state', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7c12802ba930>, ('custom_metadata', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7c20921467b0>, ('custom_metadata', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7c20921467b0>, ('metrics', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7c2092149d30>, ('metrics', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7c2092149d30>}). I0425 07:19:39.088493 136528987920192 async_checkpointer.py:177] [process=5][thread=MainThread] Using barrier_sync_fn: <function get_barrier_sync_fn.<locals>._fn at 0x7c0dbc103920> timeout: 600 secs and primary_host=0 for async checkpoint writes I0425 07:19:41.844458 136528987920192 checkpoint_manager.py:1788] Found 0 checkpoint steps in gs://lance-maxtext/pt_ckpt_xpk_main_20260425_071506/pt_sft_linen_xpk_main_20260425_071506_02_sft_linen_ckpt/checkpoints I0425 07:19:42.328232 136528987920192 checkpoint_manager.py:921] [process=5][thread=MainThread] CheckpointManager created, primary_host=0, CheckpointManagerOptions=CheckpointManagerOptions(save_interval_steps=10000, max_to_keep=None, keep_time_interval=None, keep_period=None, should_keep_fn=None, best_fn=None, best_mode='max', keep_checkpoints_without_metrics=True, step_prefix=None, step_format_fixed_length=None, step_name_format=None, create=True, cleanup_tmp_directories=False, save_on_steps=frozenset(), single_host_load_and_broadcast=False, todelete_subdir=None, todelete_full_path=None, enable_hns=False, enable_background_delete=False, read_only=False, enable_async_checkpointing=True, async_options=None, multiprocessing_options=MultiprocessingOptions(primary_host=0, active_processes=None, barrier_sync_key_prefix=None), should_save_fn=None, file_options=FileOptions(path_permission_mode=None), save_root_metadata=True, temporary_path_class=None, save_decision_policy=None, preservation_policy=None, prevent_write_metrics=False, enable_should_save_is_saving_in_progress_check=True, enable_per_process_directory_creation=False), root_directory=gs://lance-maxtext/pt_ckpt_xpk_main_20260425_071506/pt_sft_linen_xpk_main_20260425_071506_02_sft_linen_ckpt/checkpoints: <orbax.checkpoint.checkpoint_manager.CheckpointManager object at 0x7c2092147740> I0425 07:19:42.328609 136528987920192 peft_trainer.py:584] Training with mesh: Mesh('diloco': 1, 'data': 4, 'stage': 1, 'fsdp': 8, 'fsdp_transpose': 1, 'sequence': 1, 'context': 1, 'context_autoregressive': 1, 'tensor': 1, 'tensor_transpose': 1, 'tensor_sequence': 1, 'expert': 1, 'autoregressive': 1, axis_types=(Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto)) I0425 07:19:42.786882 136528987920192 peft_trainer.py:594] Compiled train_step cache size: 0 I0425 07:19:42.788960 136528987920192 metric_logger.py:301] number parameters: 0.000 billion I0425 07:19:42.791187 136359884867328 grain_pool.py:367] Grain pool will use 1 processes. I0425 07:19:42.841737 136359884867328 grain_pool.py:440] Grain pool will start child processes. Per train step: Total TFLOPs: 0.00 split as 54.29% learnable weight flops and 45.71% attention flops I0425 07:19:42.847474 136359884867328 grain_pool.py:448] Grain pool started all child processes. 2026-04-25 07:19:47.020041: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`. 2026-04-25 07:19:47.064709: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations. To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags. 2026-04-25 07:19:48.236395: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`. `rope_parameters`'s factor field must be a float >= 1, got 40 `rope_parameters`'s beta_fast field must be a float, got 32 `rope_parameters`'s beta_slow field must be a float, got 1 DeepseekV32Config got `key=rope_scaling` in kwargs but hasn't set it as attribute. For RoPE standardization you need to set `self.rope_parameters` in model's config. 2026-04-25 07:19:53.987972: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303) Traceback (most recent call last): File "<frozen runpy>", line 198, in _run_module_as_main File "<frozen runpy>", line 88, in _run_code File "/deps/src/maxtext/trainers/post_train/sft/train_sft.py", line 217, in <module> app.run(main) File "/usr/local/lib/python3.12/site-packages/absl/app.py", line 316, in run _run_main(main, args) File "/usr/local/lib/python3.12/site-packages/absl/app.py", line 261, in _run_main sys.exit(main(argv)) ^^^^^^^^^^ File "/deps/src/maxtext/trainers/post_train/sft/train_sft.py", line 213, in main train(mt_config, goodput_recorder) File "/deps/src/maxtext/trainers/post_train/sft/train_sft.py", line 190, in train trainer = train_model(mt_config, trainer, mesh) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/deps/src/maxtext/trainers/post_train/sft/train_sft.py", line 176, in train_model trainer.train(trainer.data_hooks.train_data_iterator, trainer.data_hooks.eval_data_iterator) File "/usr/local/lib/python3.12/site-packages/tunix/sft/peft_trainer.py", line 652, in train train_example = sharding_utils.shard_input( ^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/site-packages/tunix/sft/sharding_utils.py", line 58, in shard_input return jax.tree.map( ^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/site-packages/jax/_src/tree.py", line 156, in map return tree_util.tree_map(f, tree, *rest, is_leaf=is_leaf) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/site-packages/jax/_src/tree_util.py", line 373, in tree_map return treedef.unflatten(f(*xs) for xs in zip(*all_leaves)) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/site-packages/jax/_src/tree_util.py", line 373, in <genexpr> return treedef.unflatten(f(*xs) for xs in zip(*all_leaves)) ^^^^^^ File "/usr/local/lib/python3.12/site-packages/tunix/sft/sharding_utils.py", line 59, in <lambda> lambda x: jax.make_array_from_process_local_data( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/site-packages/jax/_src/array.py", line 985, in make_array_from_process_local_data out = [_array_from_process_local_data(data, s, shape) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/site-packages/jax/_src/array.py", line 1047, in _array_from_process_local_data return make_array_from_callback(global_shape, sharding, cb) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/site-packages/jax/_src/array.py", line 844, in make_array_from_callback per_device_values = api.device_put(per_device_values, devices) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/site-packages/jax/_src/api.py", line 2732, in device_put out_flat = dispatch._batched_device_put_impl( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/site-packages/jax/_src/dispatch.py", line 602, in _batched_device_put_impl y = _device_put_impl(x, device=device, src=src, copy=cp, aval=aval) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/site-packages/jax/_src/dispatch.py", line 582, in _device_put_impl return _device_put_sharding_impl(x, aval, device, copy) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/site-packages/jax/_src/dispatch.py", line 512, in _device_put_sharding_impl raise ValueError( ValueError: When the second argument to `device_put` is a Device, the first argument must be a fully addressable array or a non-addressable array with a single device sharding. Got value with devices {TpuDevice(id=14, process_index=3, coords=(2,3,0), core_on_chip=0), TpuDevice(id=24, process_index=6, coords=(0,6,0), core_on_chip=0), TpuDevice(id=19, process_index=5, coords=(3,4,0), core_on_chip=0), TpuDevice(id=9, process_index=2, coords=(1,2,0), core_on_chip=0), TpuDevice(id=29, process_index=6, coords=(1,7,0), core_on_chip=0), TpuDevice(id=31, process_index=7, coords=(3,7,0), core_on_chip=0), TpuDevice(id=20, process_index=4, coords=(0,5,0), core_on_chip=0), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=8, process_index=2, coords=(0,2,0), core_on_chip=0), TpuDevice(id=15, process_index=3, coords=(3,3,0), core_on_chip=0), TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=3, process_index=1, coords=(3,0,0), core_on_chip=0), TpuDevice(id=12, process_index=2, coords=(0,3,0), core_on_chip=0), TpuDevice(id=26, process_index=7, coords=(2,6,0), core_on_chip=0), TpuDevice(id=21, process_index=4, coords=(1,5,0), core_on_chip=0), TpuDevice(id=11, process_index=3, coords=(3,2,0), core_on_chip=0), TpuDevice(id=22, process_index=5, coords=(2,5,0), core_on_chip=0), TpuDevice(id=17, process_index=4, coords=(1,4,0), core_on_chip=0), TpuDevice(id=18, process_index=5, coords=(2,4,0), core_on_chip=0), TpuDevice(id=7, process_index=1, coords=(3,1,0), core_on_chip=0), TpuDevice(id=28, process_index=6, coords=(0,7,0), core_on_chip=0), TpuDevice(id=10, process_index=3, coords=(2,2,0), core_on_chip=0), TpuDevice(id=30, process_index=7, coords=(2,7,0), core_on_chip=0), TpuDevice(id=16, process_index=4, coords=(0,4,0), core_on_chip=0), TpuDevice(id=2, process_index=1, coords=(2,0,0), core_on_chip=0), TpuDevice(id=25, process_index=6, coords=(1,6,0), core_on_chip=0), TpuDevice(id=27, process_index=7, coords=(3,6,0), core_on_chip=0), TpuDevice(id=23, process_index=5, coords=(3,5,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=13, process_index=2, coords=(1,3,0), core_on_chip=0), TpuDevice(id=6, process_index=1, coords=(2,1,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0)} I0425 07:19:59.318971 136359884867328 grain_pool.py:542] Grain pool is exiting. I0425 07:19:59.319071 136359884867328 grain_pool.py:547] Shutting down multiprocessing system. I0425 07:20:05.112967 136359884867328 grain_pool.py:547] Shutting down multiprocessing system. /usr/local/lib/python3.12/multiprocessing/resource_tracker.py:279: UserWarning: resource_tracker: There appear to be 15 leaked shared_memory objects to clean up at shutdown warnings.warn('resource_tracker: There appear to be %d ' XPK End: Sat Apr 25 07:20:15 UTC 2026 EXIT_CODE=1