XPK Start: Tue Apr 21 12:01:14 UTC 2026 2026-04-21 12:01:43.436662: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303) I0421 12:01:43.650516 133160812083008 max_utils.py:273] Attempting to initialize the jax distributed system... INFO:2026-04-21 12:01:52,691:jax._src.distributed:149: Starting JAX distributed service on [::]:8482 I0421 12:01:52.691835 133160812083008 distributed.py:149] Starting JAX distributed service on [::]:8482 INFO:2026-04-21 12:01:52,694:jax._src.distributed:166: Connecting to JAX distributed service on mt-02-sft-nnx-ckpt-tj6gj-slice-job-0-0.mt-02-sft-nnx-ckpt-tj6gj:8482 I0421 12:01:52.694308 133160812083008 distributed.py:166] Connecting to JAX distributed service on mt-02-sft-nnx-ckpt-tj6gj-slice-job-0-0.mt-02-sft-nnx-ckpt-tj6gj:8482 I0421 12:01:55.608670 133160812083008 max_utils.py:284] Jax distributed system initialized! I0421 12:02:01.849682 133160812083008 max_utils.py:800] System Information: Jax Version: 0.8.3 I0421 12:02:01.849787 133160812083008 max_utils.py:801] System Information: Jaxlib Version: 0.8.3 I0421 12:02:01.849827 133160812083008 max_utils.py:802] System Information: Jax Backend: PJRT C API TFRT TPU v6 lite Built on Dec 15 2025 14:03:46 (1765836226) cl/844590465 I0421 12:02:01.853168 133160812083008 maxtext_utils.py:1718] Num_devices: 32, shape (1, 4, 1, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1) I0421 12:02:01.945938 133160812083008 maxtext_utils.py:1718] Num_devices: 32, shape (1, 4, 1, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1) I0421 12:02:02.044238 133160812083008 maxtext_utils.py:1718] Num_devices: 32, shape (1, 4, 1, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1) I0421 12:02:03.083804 133160812083008 pytree_checkpoint_handler.py:577] save_device_host_concurrent_bytes=None I0421 12:02:03.084247 133160812083008 base_pytree_checkpoint_handler.py:411] Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=True, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x791b351ac140>, enable_pinned_host_transfer=False, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB) I0421 12:02:03.084311 133160812083008 abstract_checkpointer.py:35] orbax-checkpoint version: 0.11.28 W0421 12:02:04.102452 133160812083008 checkpoint.py:202] Metadata file does not exist: gs://lance-maxtext/nnx_ckpt_feat_nnx_trainstate_and_training_loop_20260411_044231/nnx_ckpt_feat_nnx_trainstate_and_training_loop_20260411_044231/nnx_feat_nnx_trainstate_and_training_loop_20260411_044231_08_checkpoint_async_true/checkpoints/9/items/_CHECKPOINT_METADATA I0421 12:02:04.662363 1950 google_auth_provider.cc:181] Running on GCE, using service account 562977990677-compute@developer.gserviceaccount.com I0421 12:02:05.878541 133160812083008 checkpointer.py:304] Restoring checkpoint from gs://lance-maxtext/nnx_ckpt_feat_nnx_trainstate_and_training_loop_20260411_044231/nnx_ckpt_feat_nnx_trainstate_and_training_loop_20260411_044231/nnx_feat_nnx_trainstate_and_training_loop_20260411_044231_08_checkpoint_async_true/checkpoints/9/items. W0421 12:02:06.791819 133160812083008 transform_utils.py:230] The transformations API will eventually be replaced by an upgraded design. The current API will not be removed until this point, but it will no longer be actively worked on. I0421 12:02:06.792220 133160812083008 transform_utils.py:288] The following keys are not loaded from the original tree after applying specified transforms: decoder/decoder_norm/bias/value, decoder/decoder_norm/scale/value, decoder/dropout/rngs/aqt/count/value, decoder/dropout/rngs/aqt/key/value, decoder/dropout/rngs/dropout/count/value, decoder/dropout/rngs/dropout/key/value, decoder/dropout/rngs/params/count/value, decoder/dropout/rngs/params/key/value, decoder/layers/dropout/rngs/aqt/count/value, decoder/layers/dropout/rngs/aqt/key/value, decoder/layers/dropout/rngs/dropout/count/value, decoder/layers/dropout/rngs/dropout/key/value, decoder/layers/dropout/rngs/params/count/value, decoder/layers/dropout/rngs/params/key/value, decoder/layers/mlp/dropout/rngs/aqt/count/value, decoder/layers/mlp/dropout/rngs/aqt/key/value, decoder/layers/mlp/dropout/rngs/dropout/count/value, decoder/layers/mlp/dropout/rngs/dropout/key/value, decoder/layers/mlp/dropout/rngs/params/count/value, decoder/layers/mlp/dropout/rngs/params/key/value, decoder/layers/mlp/mlp_layer_norm/bias/value, decoder/layers/mlp/mlp_layer_norm/scale/value, decoder/layers/mlp/wi/bias/value, decoder/layers/mlp/wi/kernel/value, decoder/layers/mlp/wo/bias/value, decoder/layers/mlp/wo/kernel/value, decoder/layers/pre_self_attention_norm/bias/value, decoder/layers/pre_self_attention_norm/scale/value, decoder/layers/rngs/aqt/count/value, decoder/layers/rngs/aqt/key/value, decoder/layers/rngs/dropout/count/value, decoder/layers/rngs/dropout/key/value, decoder/layers/rngs/params/count/value, decoder/layers/rngs/params/key/value, decoder/layers/self_attention/out/bias/value, decoder/layers/self_attention/out/kernel/value, decoder/layers/self_attention/qkv_proj/bias/value, decoder/layers/self_attention/qkv_proj/kernel/value, decoder/position_embedder/embedding/value, decoder/rngs/aqt/count/value, decoder/rngs/aqt/key/value, decoder/rngs/dropout/count/value, decoder/rngs/dropout/key/value, decoder/rngs/params/count/value, decoder/rngs/params/key/value, token_embedder/embedding/value I0421 12:02:08.118224 133160812083008 checkpointer.py:318] Finished restoring checkpoint in 2.61 seconds from gs://lance-maxtext/nnx_ckpt_feat_nnx_trainstate_and_training_loop_20260411_044231/nnx_ckpt_feat_nnx_trainstate_and_training_loop_20260411_044231/nnx_feat_nnx_trainstate_and_training_loop_20260411_044231_08_checkpoint_async_true/checkpoints/9/items. I0421 12:02:08.187427 133160812083008 config.py:112] TensorFlow version 2.20.0 available. I0421 12:02:08.187973 133160812083008 config.py:125] JAX version 0.8.3 available. /deps/src/maxtext/input_pipeline/input_pipeline_utils.py:467: UserWarning: WARNING: Inefficient dataloading. Your train or eval dataset contains 3 shards, smaller than number of host loading data. This is known to lead to inefficient dataloading. Seegithub.com/google/maxtext/blob/main/getting_started/Data_Input_Pipeline.md#multihost-dataloading-best-practice warnings.warn( E0421 12:02:13.792789 133160812083008 packing.py:209] PackAndBatchOperation is deprecated. Please use lazy_dataset.FirstFitPackIterDataset instead. I0421 12:02:13.793010 133160812083008 data_loader.py:408] Adding CopyNumPyArrayToSharedMemory MapTransform. I0421 12:02:14.181520 133160812083008 pytree_checkpoint_handler.py:577] save_device_host_concurrent_bytes=None I0421 12:02:14.181685 133160812083008 base_pytree_checkpoint_handler.py:411] Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=False, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x791b351ac140>, enable_pinned_host_transfer=False, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB) I0421 12:02:14.181732 133160812083008 pytree_checkpoint_handler.py:577] save_device_host_concurrent_bytes=None I0421 12:02:14.181769 133160812083008 base_pytree_checkpoint_handler.py:411] Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=False, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x791b351ac140>, enable_pinned_host_transfer=False, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB) I0421 12:02:14.181812 133160812083008 checkpoint_manager.py:702] [process=4][thread=MainThread] CheckpointManager init: checkpointers=None, item_names=None, item_handlers={'model_params': <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x79032090d4c0>, 'optimizer_state': <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x79032090d7f0>, 'custom_metadata': <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x79032090d700>}, handler_registry=None I0421 12:02:14.182019 133160812083008 composite_checkpoint_handler.py:237] Deferred registration for item: "model_params". Adding handler `<orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x79032090d4c0>` for item "model_params" and save args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>` to `_handler_registry`. I0421 12:02:14.182062 133160812083008 composite_checkpoint_handler.py:237] Deferred registration for item: "optimizer_state". Adding handler `<orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x79032090d7f0>` for item "optimizer_state" and save args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>` to `_handler_registry`. I0421 12:02:14.182090 133160812083008 composite_checkpoint_handler.py:237] Deferred registration for item: "custom_metadata". Adding handler `<orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x79032090d700>` for item "custom_metadata" and save args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>` to `_handler_registry`. I0421 12:02:14.182114 133160812083008 composite_checkpoint_handler.py:237] Deferred registration for item: "metrics". Adding handler `<orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x790320914050>` for item "metrics" and save args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>` to `_handler_registry`. I0421 12:02:14.182142 133160812083008 composite_checkpoint_handler.py:505] Initialized registry DefaultCheckpointHandlerRegistry({('model_params', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x79032090d4c0>, ('model_params', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x79032090d4c0>, ('optimizer_state', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x79032090d7f0>, ('optimizer_state', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x79032090d7f0>, ('custom_metadata', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x79032090d700>, ('custom_metadata', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x79032090d700>, ('metrics', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x790320914050>, ('metrics', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x790320914050>}). I0421 12:02:14.183660 133160812083008 async_checkpointer.py:177] [process=4][thread=MainThread] Using barrier_sync_fn: <function get_barrier_sync_fn.<locals>._fn at 0x790320a282c0> timeout: 600 secs and primary_host=0 for async checkpoint writes I0421 12:02:16.583933 133160812083008 checkpoint_manager.py:1788] Found 0 checkpoint steps in gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_post_train_fixes_20260421_114106/pt_sft_nnx_xpk_feat_nnx_post_train_fixes_20260421_114106_02_sft_nnx_ckpt/checkpoints I0421 12:02:17.018330 133160812083008 checkpoint_manager.py:921] [process=4][thread=MainThread] CheckpointManager created, primary_host=0, CheckpointManagerOptions=CheckpointManagerOptions(save_interval_steps=10000, max_to_keep=None, keep_time_interval=None, keep_period=None, should_keep_fn=None, best_fn=None, best_mode='max', keep_checkpoints_without_metrics=True, step_prefix=None, step_format_fixed_length=None, step_name_format=None, create=True, cleanup_tmp_directories=False, save_on_steps=frozenset(), single_host_load_and_broadcast=False, todelete_subdir=None, todelete_full_path=None, enable_hns=False, enable_background_delete=False, read_only=False, enable_async_checkpointing=True, async_options=None, multiprocessing_options=MultiprocessingOptions(primary_host=0, active_processes=None, barrier_sync_key_prefix=None), should_save_fn=None, file_options=FileOptions(path_permission_mode=None), save_root_metadata=True, temporary_path_class=None, save_decision_policy=None, preservation_policy=None, prevent_write_metrics=False, enable_should_save_is_saving_in_progress_check=True, enable_per_process_directory_creation=False), root_directory=gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_post_train_fixes_20260421_114106/pt_sft_nnx_xpk_feat_nnx_post_train_fixes_20260421_114106_02_sft_nnx_ckpt/checkpoints: <orbax.checkpoint.checkpoint_manager.CheckpointManager object at 0x79032090da00> I0421 12:02:17.018725 133160812083008 peft_trainer.py:584] Training with mesh: Mesh('diloco': 1, 'data': 4, 'stage': 1, 'fsdp': 8, 'fsdp_transpose': 1, 'sequence': 1, 'context': 1, 'context_autoregressive': 1, 'tensor': 1, 'tensor_transpose': 1, 'tensor_sequence': 1, 'expert': 1, 'autoregressive': 1, axis_types=(Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto)) I0421 12:02:17.438088 133160812083008 peft_trainer.py:594] Compiled train_step cache size: 0 I0421 12:02:17.442322 133160812083008 metric_logger.py:301] number parameters: 0.000 billion I0421 12:02:17.444804 133006949873408 grain_pool.py:367] Grain pool will use 1 processes. I0421 12:02:17.470864 133006949873408 grain_pool.py:440] Grain pool will start child processes. Per train step: Total TFLOPs: 0.00 split as 54.29% learnable weight flops and 45.71% attention flops I0421 12:02:17.476335 133006949873408 grain_pool.py:448] Grain pool started all child processes. 2026-04-21 12:02:21.476458: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`. 2026-04-21 12:02:21.521394: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations. To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags. 2026-04-21 12:02:22.682172: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`. 2026-04-21 12:02:26.507153: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303) Traceback (most recent call last): File "<frozen runpy>", line 198, in _run_module_as_main File "<frozen runpy>", line 88, in _run_code File "/deps/src/maxtext/trainers/post_train/sft/train_sft.py", line 281, in <module> app.run(main) File "/usr/local/lib/python3.12/site-packages/absl/app.py", line 316, in run _run_main(main, args) File "/usr/local/lib/python3.12/site-packages/absl/app.py", line 261, in _run_main sys.exit(main(argv)) ^^^^^^^^^^ File "/deps/src/maxtext/trainers/post_train/sft/train_sft.py", line 277, in main train(mt_config, goodput_recorder) File "/deps/src/maxtext/trainers/post_train/sft/train_sft.py", line 254, in train trainer = train_model(mt_config, trainer, mesh) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/deps/src/maxtext/trainers/post_train/sft/train_sft.py", line 240, in train_model trainer.train(trainer.data_hooks.train_data_iterator, trainer.data_hooks.eval_data_iterator) File "/usr/local/lib/python3.12/site-packages/tunix/sft/peft_trainer.py", line 692, in train train_loss, aux, grad_norm = train_step(train_example) ^^^^^^^^^^^^^^^^^^^^^^^^^^ ValueError: not enough values to unpack (expected 3, got 2) I0421 12:02:35.521730 133006949873408 grain_pool.py:542] Grain pool is exiting. I0421 12:02:35.521841 133006949873408 grain_pool.py:547] Shutting down multiprocessing system. I0421 12:02:38.697963 133006949873408 grain_pool.py:547] Shutting down multiprocessing system. /usr/local/lib/python3.12/multiprocessing/resource_tracker.py:279: UserWarning: resource_tracker: There appear to be 15 leaked shared_memory objects to clean up at shutdown warnings.warn('resource_tracker: There appear to be %d ' XPK End: Tue Apr 21 12:02:50 UTC 2026 EXIT_CODE=1