mainXPK Start: Sat Apr 18 18:04:38 UTC 2026 2026-04-18 18:05:07.777123: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303) I0418 18:05:07.991232 139117642696512 max_utils.py:273] Attempting to initialize the jax distributed system... INFO:2026-04-18 18:05:17,031:jax._src.distributed:149: Starting JAX distributed service on [::]:8482 I0418 18:05:17.031873 139117642696512 distributed.py:149] Starting JAX distributed service on [::]:8482 INFO:2026-04-18 18:05:17,034:jax._src.distributed:166: Connecting to JAX distributed service on mt-01-sft-smoke-dyx7y-slice-job-0-0.mt-01-sft-smoke-dyx7y:8482 I0418 18:05:17.034419 139117642696512 distributed.py:166] Connecting to JAX distributed service on mt-01-sft-smoke-dyx7y-slice-job-0-0.mt-01-sft-smoke-dyx7y:8482 I0418 18:05:18.619306 139117642696512 max_utils.py:284] Jax distributed system initialized! I0418 18:05:24.845857 139117642696512 max_utils.py:800] System Information: Jax Version: 0.8.3 I0418 18:05:24.845964 139117642696512 max_utils.py:801] System Information: Jaxlib Version: 0.8.3 I0418 18:05:24.846005 139117642696512 max_utils.py:802] System Information: Jax Backend: PJRT C API TFRT TPU v6 lite Built on Dec 15 2025 14:03:46 (1765836226) cl/844590465 I0418 18:05:24.849439 139117642696512 maxtext_utils.py:1551] Num_devices: 32, shape (1, 1, 1, 32, 1, 1, 1, 1, 1, 1, 1, 1, 1) I0418 18:05:25.031054 139117642696512 maxtext_utils.py:1551] Num_devices: 32, shape (1, 1, 1, 32, 1, 1, 1, 1, 1, 1, 1, 1, 1) Traceback (most recent call last): File "<frozen runpy>", line 198, in _run_module_as_main File "<frozen runpy>", line 88, in _run_code File "/deps/src/maxtext/trainers/post_train/sft/train_sft.py", line 216, in <module> app.run(main) File "/usr/local/lib/python3.12/site-packages/absl/app.py", line 316, in run _run_main(main, args) File "/usr/local/lib/python3.12/site-packages/absl/app.py", line 261, in _run_main sys.exit(main(argv)) ^^^^^^^^^^ File "/deps/src/maxtext/trainers/post_train/sft/train_sft.py", line 212, in main train(mt_config, goodput_recorder) File "/deps/src/maxtext/trainers/post_train/sft/train_sft.py", line 186, in train trainer, mesh = setup_trainer_state(mt_config, goodput_recorder) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/deps/src/maxtext/trainers/post_train/sft/train_sft.py", line 149, in setup_trainer_state model, mesh = model_creation_utils.create_nnx_model(mt_config) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/deps/src/maxtext/utils/model_creation_utils.py", line 294, in create_nnx_model sharded_state = create_sharded_state() ^^^^^^^^^^^^^^^^^^^^^^ ValueError: One of pjit outputs was given the sharding of NamedSharding(mesh=Mesh('diloco': 1, 'data': 1, 'stage': 1, 'fsdp': 32, 'fsdp_transpose': 1, 'sequence': 1, 'context': 1, 'context_autoregressive': 1, 'tensor': 1, 'tensor_transpose': 1, 'tensor_sequence': 1, 'expert': 1, 'autoregressive': 1, axis_types=(Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto)), spec=PartitionSpec(('fsdp', 'sequence', 'tensor_transpose', 'context', 'expert'), 'stage', ('fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive')), memory_kind=device), which implies that the global size of its dimension 0 should be divisible by 32, but it is equal to 16 (full shape: (16, 1, 64)) -------------------- For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these. XPK End: Sat Apr 18 18:05:33 UTC 2026 EXIT_CODE=1
XPK Start: Sat Apr 18 18:26:59 UTC 2026 2026-04-18 18:27:27.604351: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303) I0418 18:27:27.818871 140145825343296 max_utils.py:273] Attempting to initialize the jax distributed system... INFO:2026-04-18 18:27:36,858:jax._src.distributed:149: Starting JAX distributed service on [::]:8482 I0418 18:27:36.858568 140145825343296 distributed.py:149] Starting JAX distributed service on [::]:8482 INFO:2026-04-18 18:27:36,860:jax._src.distributed:166: Connecting to JAX distributed service on mt-01-sft-smoke-4tl6h-slice-job-0-0.mt-01-sft-smoke-4tl6h:8482 I0418 18:27:36.860877 140145825343296 distributed.py:166] Connecting to JAX distributed service on mt-01-sft-smoke-4tl6h-slice-job-0-0.mt-01-sft-smoke-4tl6h:8482 I0418 18:27:38.492271 140145825343296 max_utils.py:284] Jax distributed system initialized! I0418 18:27:43.534725 140145825343296 max_utils.py:800] System Information: Jax Version: 0.8.3 I0418 18:27:43.534830 140145825343296 max_utils.py:801] System Information: Jaxlib Version: 0.8.3 I0418 18:27:43.534873 140145825343296 max_utils.py:802] System Information: Jax Backend: PJRT C API TFRT TPU v6 lite Built on Dec 15 2025 14:03:46 (1765836226) cl/844590465 I0418 18:27:43.538368 140145825343296 maxtext_utils.py:1551] Num_devices: 32, shape (1, 1, 1, 32, 1, 1, 1, 1, 1, 1, 1, 1, 1) I0418 18:27:43.638808 140145825343296 maxtext_utils.py:1551] Num_devices: 32, shape (1, 1, 1, 32, 1, 1, 1, 1, 1, 1, 1, 1, 1) Traceback (most recent call last): File "<frozen runpy>", line 198, in _run_module_as_main File "<frozen runpy>", line 88, in _run_code File "/deps/src/maxtext/trainers/post_train/sft/train_sft.py", line 216, in <module> app.run(main) File "/usr/local/lib/python3.12/site-packages/absl/app.py", line 316, in run _run_main(main, args) File "/usr/local/lib/python3.12/site-packages/absl/app.py", line 261, in _run_main sys.exit(main(argv)) ^^^^^^^^^^ File "/deps/src/maxtext/trainers/post_train/sft/train_sft.py", line 212, in main train(mt_config, goodput_recorder) File "/deps/src/maxtext/trainers/post_train/sft/train_sft.py", line 186, in train trainer, mesh = setup_trainer_state(mt_config, goodput_recorder) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/deps/src/maxtext/trainers/post_train/sft/train_sft.py", line 149, in setup_trainer_state model, mesh = model_creation_utils.create_nnx_model(mt_config) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/deps/src/maxtext/utils/model_creation_utils.py", line 294, in create_nnx_model sharded_state = create_sharded_state() ^^^^^^^^^^^^^^^^^^^^^^ ValueError: One of pjit outputs was given the sharding of NamedSharding(mesh=Mesh('diloco': 1, 'data': 1, 'stage': 1, 'fsdp': 32, 'fsdp_transpose': 1, 'sequence': 1, 'context': 1, 'context_autoregressive': 1, 'tensor': 1, 'tensor_transpose': 1, 'tensor_sequence': 1, 'expert': 1, 'autoregressive': 1, axis_types=(Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto)), spec=PartitionSpec(('fsdp', 'sequence', 'tensor_transpose', 'context', 'expert'), 'stage', ('fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive')), memory_kind=device), which implies that the global size of its dimension 0 should be divisible by 32, but it is equal to 16 (full shape: (16, 1, 64)) -------------------- For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these. XPK End: Sat Apr 18 18:27:49 UTC 2026 EXIT_CODE=1