XPK Start: Sun Apr 19 18:32:43 UTC 2026 2026-04-19 18:33:12.541717: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303) I0419 18:33:12.761007 137447105304384 max_utils.py:273] Attempting to initialize the jax distributed system... INFO:2026-04-19 18:33:21,802:jax._src.distributed:149: Starting JAX distributed service on [::]:8482 I0419 18:33:21.802567 137447105304384 distributed.py:149] Starting JAX distributed service on [::]:8482 INFO:2026-04-19 18:33:21,804:jax._src.distributed:166: Connecting to JAX distributed service on mt-02-sft-nnx-ckpt-2jpn5-slice-job-0-0.mt-02-sft-nnx-ckpt-2jpn5:8482 I0419 18:33:21.804981 137447105304384 distributed.py:166] Connecting to JAX distributed service on mt-02-sft-nnx-ckpt-2jpn5-slice-job-0-0.mt-02-sft-nnx-ckpt-2jpn5:8482 I0419 18:33:23.557123 137447105304384 max_utils.py:284] Jax distributed system initialized! I0419 18:33:29.496924 137447105304384 max_utils.py:800] System Information: Jax Version: 0.8.3 I0419 18:33:29.497028 137447105304384 max_utils.py:801] System Information: Jaxlib Version: 0.8.3 I0419 18:33:29.497070 137447105304384 max_utils.py:802] System Information: Jax Backend: PJRT C API TFRT TPU v6 lite Built on Dec 15 2025 14:03:46 (1765836226) cl/844590465 Traceback (most recent call last): File "<frozen runpy>", line 198, in _run_module_as_main File "<frozen runpy>", line 88, in _run_code File "/deps/src/maxtext/trainers/post_train/sft/train_sft.py", line 216, in <module> app.run(main) File "/usr/local/lib/python3.12/site-packages/absl/app.py", line 316, in run _run_main(main, args) File "/usr/local/lib/python3.12/site-packages/absl/app.py", line 261, in _run_main sys.exit(main(argv)) ^^^^^^^^^^ File "/deps/src/maxtext/trainers/post_train/sft/train_sft.py", line 212, in main train(mt_config, goodput_recorder) File "/deps/src/maxtext/trainers/post_train/sft/train_sft.py", line 186, in train trainer, mesh = setup_trainer_state(mt_config, goodput_recorder) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/deps/src/maxtext/trainers/post_train/sft/train_sft.py", line 149, in setup_trainer_state model, mesh = model_creation_utils.create_nnx_model(mt_config) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/deps/src/maxtext/utils/model_creation_utils.py", line 260, in create_nnx_model abstract_model = nnx.eval_shape(_create_model_partial) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/site-packages/flax/nnx/transforms/transforms.py", line 272, in eval_shape out = jax.eval_shape(_eval_shape_fn, *args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/site-packages/flax/nnx/transforms/transforms.py", line 269, in _eval_shape_fn out = f_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^ File "/deps/src/maxtext/utils/model_creation_utils.py", line 255, in _create_model return from_config(config, devices, mesh, rngs=rngs, model_mode=model_mode) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/deps/src/maxtext/utils/model_creation_utils.py", line 196, in from_config mesh = maxtext_utils.get_mesh_from_config(config, devices) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/deps/src/maxtext/utils/maxtext_utils.py", line 1703, in get_mesh_from_config devices_array = create_device_mesh(config, devices) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/deps/src/maxtext/utils/maxtext_utils.py", line 1510, in create_device_mesh ici_parallelism = max_utils.fill_unspecified_mesh_axes(config.ici_parallelism.copy(), num_devices_per_slice, "ICI") ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/deps/src/maxtext/utils/max_utils.py", line 450, in fill_unspecified_mesh_axes assert np.prod(parallelism_vals) == target_product, ( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ AssertionError: Number of devices per slice 32 does not match the product of the ICI parallelism 8 -------------------- For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these. XPK End: Sun Apr 19 18:33:39 UTC 2026 EXIT_CODE=1