XPK Start: Wed Apr 22 12:57:23 UTC 2026 2026-04-22 12:57:51.780689: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303) I0422 12:57:51.998446 133422147196736 max_utils.py:273] Attempting to initialize the jax distributed system... INFO:2026-04-22 12:58:01,039:jax._src.distributed:149: Starting JAX distributed service on [::]:8482 I0422 12:58:01.039395 133422147196736 distributed.py:149] Starting JAX distributed service on [::]:8482 INFO:2026-04-22 12:58:01,041:jax._src.distributed:166: Connecting to JAX distributed service on mt-01-sft-smoke-031k7-slice-job-0-0.mt-01-sft-smoke-031k7:8482 I0422 12:58:01.041793 133422147196736 distributed.py:166] Connecting to JAX distributed service on mt-01-sft-smoke-031k7-slice-job-0-0.mt-01-sft-smoke-031k7:8482 I0422 12:58:02.788225 133422147196736 max_utils.py:284] Jax distributed system initialized! I0422 12:58:09.040838 133422147196736 max_utils.py:800] System Information: Jax Version: 0.8.3 I0422 12:58:09.040946 133422147196736 max_utils.py:801] System Information: Jaxlib Version: 0.8.3 I0422 12:58:09.040986 133422147196736 max_utils.py:802] System Information: Jax Backend: PJRT C API TFRT TPU v6 lite Built on Dec 15 2025 14:03:46 (1765836226) cl/844590465 I0422 12:58:09.044381 133422147196736 maxtext_utils.py:1718] Num_devices: 32, shape (1, 4, 1, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1) I0422 12:58:09.138288 133422147196736 maxtext_utils.py:1718] Num_devices: 32, shape (1, 4, 1, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1) I0422 12:58:09.238423 133422147196736 maxtext_utils.py:1718] Num_devices: 32, shape (1, 4, 1, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1) I0422 12:58:10.349864 133422147196736 config.py:112] TensorFlow version 2.20.0 available. I0422 12:58:10.350386 133422147196736 config.py:125] JAX version 0.8.3 available. /deps/src/maxtext/input_pipeline/input_pipeline_utils.py:467: UserWarning: WARNING: Inefficient dataloading. Your train or eval dataset contains 3 shards, smaller than number of host loading data. This is known to lead to inefficient dataloading. Seegithub.com/google/maxtext/blob/main/getting_started/Data_Input_Pipeline.md#multihost-dataloading-best-practice warnings.warn( E0422 12:58:15.889910 133422147196736 packing.py:209] PackAndBatchOperation is deprecated. Please use lazy_dataset.FirstFitPackIterDataset instead. I0422 12:58:15.890145 133422147196736 data_loader.py:408] Adding CopyNumPyArrayToSharedMemory MapTransform. I0422 12:58:16.286330 133422147196736 pytree_checkpoint_handler.py:577] save_device_host_concurrent_bytes=None I0422 12:58:16.286839 133422147196736 base_pytree_checkpoint_handler.py:411] Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=False, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x79580df1c200>, enable_pinned_host_transfer=False, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB) I0422 12:58:16.286889 133422147196736 pytree_checkpoint_handler.py:577] save_device_host_concurrent_bytes=None I0422 12:58:16.286928 133422147196736 base_pytree_checkpoint_handler.py:411] Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=False, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x79580df1c200>, enable_pinned_host_transfer=False, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB) I0422 12:58:16.286975 133422147196736 checkpoint_manager.py:702] [process=6][thread=MainThread] CheckpointManager init: checkpointers=None, item_names=None, item_handlers={'model_params': <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x793ff9b8e7e0>, 'optimizer_state': <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x793ff9b6fbc0>, 'custom_metadata': <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x793a64223380>}, handler_registry=None I0422 12:58:16.287200 133422147196736 composite_checkpoint_handler.py:237] Deferred registration for item: "model_params". Adding handler `<orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x793ff9b8e7e0>` for item "model_params" and save args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>` to `_handler_registry`. I0422 12:58:16.287246 133422147196736 composite_checkpoint_handler.py:237] Deferred registration for item: "optimizer_state". Adding handler `<orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x793ff9b6fbc0>` for item "optimizer_state" and save args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>` to `_handler_registry`. I0422 12:58:16.287275 133422147196736 composite_checkpoint_handler.py:237] Deferred registration for item: "custom_metadata". Adding handler `<orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x793a64223380>` for item "custom_metadata" and save args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>` to `_handler_registry`. I0422 12:58:16.287301 133422147196736 composite_checkpoint_handler.py:237] Deferred registration for item: "metrics". Adding handler `<orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x793a64028e00>` for item "metrics" and save args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>` to `_handler_registry`. I0422 12:58:16.287329 133422147196736 composite_checkpoint_handler.py:505] Initialized registry DefaultCheckpointHandlerRegistry({('model_params', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x793ff9b8e7e0>, ('model_params', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x793ff9b8e7e0>, ('optimizer_state', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x793ff9b6fbc0>, ('optimizer_state', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x793ff9b6fbc0>, ('custom_metadata', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x793a64223380>, ('custom_metadata', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x793a64223380>, ('metrics', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x793a64028e00>, ('metrics', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x793a64028e00>}). I0422 12:58:16.287533 133422147196736 abstract_checkpointer.py:35] orbax-checkpoint version: 0.11.28 I0422 12:58:16.287588 133422147196736 async_checkpointer.py:177] [process=6][thread=MainThread] Using barrier_sync_fn: <function get_barrier_sync_fn.<locals>._fn at 0x793a6418d300> timeout: 600 secs and primary_host=0 for async checkpoint writes I0422 12:58:18.633499 133422147196736 checkpoint_manager.py:1788] Found 0 checkpoint steps in gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_post_train_fixes_20260422_123915/pt_sft_nnx_xpk_feat_nnx_post_train_fixes_20260422_123915_01_sft_smoke/checkpoints I0422 12:58:19.052539 133422147196736 checkpoint_manager.py:921] [process=6][thread=MainThread] CheckpointManager created, primary_host=0, CheckpointManagerOptions=CheckpointManagerOptions(save_interval_steps=10000, max_to_keep=None, keep_time_interval=None, keep_period=None, should_keep_fn=None, best_fn=None, best_mode='max', keep_checkpoints_without_metrics=True, step_prefix=None, step_format_fixed_length=None, step_name_format=None, create=True, cleanup_tmp_directories=False, save_on_steps=frozenset(), single_host_load_and_broadcast=False, todelete_subdir=None, todelete_full_path=None, enable_hns=False, enable_background_delete=False, read_only=False, enable_async_checkpointing=True, async_options=None, multiprocessing_options=MultiprocessingOptions(primary_host=0, active_processes=None, barrier_sync_key_prefix=None), should_save_fn=None, file_options=FileOptions(path_permission_mode=None), save_root_metadata=True, temporary_path_class=None, save_decision_policy=None, preservation_policy=None, prevent_write_metrics=False, enable_should_save_is_saving_in_progress_check=True, enable_per_process_directory_creation=False), root_directory=gs://lance-maxtext/pt_ckpt_xpk_feat_nnx_post_train_fixes_20260422_123915/pt_sft_nnx_xpk_feat_nnx_post_train_fixes_20260422_123915_01_sft_smoke/checkpoints: <orbax.checkpoint.checkpoint_manager.CheckpointManager object at 0x793a64223980> I0422 12:58:19.052906 133422147196736 peft_trainer.py:584] Training with mesh: Mesh('diloco': 1, 'data': 4, 'stage': 1, 'fsdp': 8, 'fsdp_transpose': 1, 'sequence': 1, 'context': 1, 'context_autoregressive': 1, 'tensor': 1, 'tensor_transpose': 1, 'tensor_sequence': 1, 'expert': 1, 'autoregressive': 1, axis_types=(Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto)) I0422 12:58:19.480984 133422147196736 peft_trainer.py:594] Compiled train_step cache size: 0 I0422 12:58:19.485069 133422147196736 metric_logger.py:301] number parameters: 0.000 billion I0422 12:58:19.487508 133268733159168 grain_pool.py:367] Grain pool will use 1 processes. I0422 12:58:19.514011 133268733159168 grain_pool.py:440] Grain pool will start child processes. Per train step: Total TFLOPs: 0.00 split as 54.29% learnable weight flops and 45.71% attention flops I0422 12:58:19.519266 133268733159168 grain_pool.py:448] Grain pool started all child processes. 2026-04-22 12:58:23.535922: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`. 2026-04-22 12:58:23.581019: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations. To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags. 2026-04-22 12:58:24.750117: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`. 2026-04-22 12:58:28.617521: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303) Traceback (most recent call last): File "<frozen runpy>", line 198, in _run_module_as_main File "<frozen runpy>", line 88, in _run_code File "/deps/src/maxtext/trainers/post_train/sft/train_sft.py", line 281, in <module> app.run(main) File "/usr/local/lib/python3.12/site-packages/absl/app.py", line 316, in run _run_main(main, args) File "/usr/local/lib/python3.12/site-packages/absl/app.py", line 261, in _run_main sys.exit(main(argv)) ^^^^^^^^^^ File "/deps/src/maxtext/trainers/post_train/sft/train_sft.py", line 277, in main train(mt_config, goodput_recorder) File "/deps/src/maxtext/trainers/post_train/sft/train_sft.py", line 254, in train trainer = train_model(mt_config, trainer, mesh) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/deps/src/maxtext/trainers/post_train/sft/train_sft.py", line 240, in train_model trainer.train(trainer.data_hooks.train_data_iterator, trainer.data_hooks.eval_data_iterator) File "/usr/local/lib/python3.12/site-packages/tunix/sft/peft_trainer.py", line 692, in train train_loss, aux, grad_norm = train_step(train_example) ^^^^^^^^^^^^^^^^^^^^^^^^^^ ValueError: not enough values to unpack (expected 3, got 2) I0422 12:58:37.634279 133268733159168 grain_pool.py:542] Grain pool is exiting. I0422 12:58:37.634390 133268733159168 grain_pool.py:547] Shutting down multiprocessing system. /usr/local/lib/python3.12/multiprocessing/resource_tracker.py:279: UserWarning: resource_tracker: There appear to be 27 leaked semaphore objects to clean up at shutdown warnings.warn('resource_tracker: There appear to be %d ' /usr/local/lib/python3.12/multiprocessing/resource_tracker.py:279: UserWarning: resource_tracker: There appear to be 20 leaked shared_memory objects to clean up at shutdown warnings.warn('resource_tracker: There appear to be %d ' XPK End: Wed Apr 22 12:58:44 UTC 2026 EXIT_CODE=1