MaxView

← Back to run

Log Summary

XPK Start: Sun Apr 19 21:49:28 UTC 2026
2026-04-19 21:50:26.036433: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)
I0419 21:50:26.250784 132119157753664 max_utils.py:273] Attempting to initialize the jax distributed system...
INFO:2026-04-19 21:50:35,292:jax._src.distributed:166: Connecting to JAX distributed service on mt-01-sft-smoke-0rb2o-slice-job-0-0.mt-01-sft-smoke-0rb2o:8482
I0419 21:50:35.292925 132119157753664 distributed.py:166] Connecting to JAX distributed service on mt-01-sft-smoke-0rb2o-slice-job-0-0.mt-01-sft-smoke-0rb2o:8482
I0419 21:50:50.470075 132119157753664 max_utils.py:284] Jax distributed system initialized!
I0419 21:50:56.556733 132119157753664 max_utils.py:800] System Information: Jax Version: 0.8.3
I0419 21:50:56.556847 132119157753664 max_utils.py:801] System Information: Jaxlib Version: 0.8.3
I0419 21:50:56.556886 132119157753664 max_utils.py:802] System Information: Jax Backend: PJRT C API
TFRT TPU v6 lite
Built on Dec 15 2025 14:03:46 (1765836226) cl/844590465
Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/deps/src/maxtext/trainers/post_train/sft/train_sft.py", line 216, in <module>
    app.run(main)
  File "/usr/local/lib/python3.12/site-packages/absl/app.py", line 316, in run
    _run_main(main, args)
  File "/usr/local/lib/python3.12/site-packages/absl/app.py", line 261, in _run_main
    sys.exit(main(argv))
             ^^^^^^^^^^
  File "/deps/src/maxtext/trainers/post_train/sft/train_sft.py", line 212, in main
    train(mt_config, goodput_recorder)
  File "/deps/src/maxtext/trainers/post_train/sft/train_sft.py", line 186, in train
    trainer, mesh = setup_trainer_state(mt_config, goodput_recorder)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/deps/src/maxtext/trainers/post_train/sft/train_sft.py", line 149, in setup_trainer_state
    model, mesh = model_creation_utils.create_nnx_model(mt_config)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/deps/src/maxtext/utils/model_creation_utils.py", line 337, in create_nnx_model
    model = create_nnx_sharded_model_hybrid(config, mesh, devices, model_mode, rng_key)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/deps/src/maxtext/utils/model_creation_utils.py", line 291, in create_nnx_sharded_model_hybrid
    abstract_model = nnx.eval_shape(_create_model_partial)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/flax/nnx/transforms/transforms.py", line 272, in eval_shape
    out = jax.eval_shape(_eval_shape_fn, *args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/flax/nnx/transforms/transforms.py", line 269, in _eval_shape_fn
    out = f_call(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^
  File "/deps/src/maxtext/utils/model_creation_utils.py", line 238, in _create_model
    return from_config(config, devices, mesh, rngs=rngs, model_mode=model_mode)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/deps/src/maxtext/utils/model_creation_utils.py", line 210, in from_config
    mesh = maxtext_utils.get_mesh_from_config(config, devices)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/deps/src/maxtext/utils/maxtext_utils.py", line 1870, in get_mesh_from_config
    devices_array = create_device_mesh(config, devices)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/deps/src/maxtext/utils/maxtext_utils.py", line 1677, in create_device_mesh
    ici_parallelism = max_utils.fill_unspecified_mesh_axes(config.ici_parallelism.copy(), num_devices_per_slice, "ICI")
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/deps/src/maxtext/utils/max_utils.py", line 450, in fill_unspecified_mesh_axes
    assert np.prod(parallelism_vals) == target_product, (
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: Number of devices per slice 32 does not match the product of the ICI parallelism 8
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
XPK End: Sun Apr 19 21:51:03 UTC 2026
EXIT_CODE=1