MaxView

‹ —Case: 01_sft_smoke07_distill_smoke ›

Metrics: Linen vs NNX  ·  feat/nnx-post-train-fixes

MetricLinen  574ad3fb9NNX  574ad3fb9Diff (NNX − Linen)
JAX0.8.30.8.3

Diff = NNX value − Linen value. Green = NNX improved. Red = NNX regressed.

Linen  ·  574ad3fb9  ·  feat_nnx_post_train_fixes_20260418_235042  ·  full log
XPK Start: Sat Apr 18 23:51:17 UTC 2026
2026-04-18 23:51:46.179851: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)
I0418 23:51:46.391475 137113772353344 max_utils.py:273] Attempting to initialize the jax distributed system...
INFO:2026-04-18 23:51:55,432:jax._src.distributed:149: Starting JAX distributed service on [::]:8482
I0418 23:51:55.432659 137113772353344 distributed.py:149] Starting JAX distributed service on [::]:8482
INFO:2026-04-18 23:51:55,435:jax._src.distributed:166: Connecting to JAX distributed service on mt-01-sft-smoke-74lxb-slice-job-0-0.mt-01-sft-smoke-74lxb:8482
I0418 23:51:55.435101 137113772353344 distributed.py:166] Connecting to JAX distributed service on mt-01-sft-smoke-74lxb-slice-job-0-0.mt-01-sft-smoke-74lxb:8482
I0418 23:51:58.485470 137113772353344 max_utils.py:284] Jax distributed system initialized!
I0418 23:52:04.899552 137113772353344 max_utils.py:800] System Information: Jax Version: 0.8.3
I0418 23:52:04.899655 137113772353344 max_utils.py:801] System Information: Jaxlib Version: 0.8.3
I0418 23:52:04.899696 137113772353344 max_utils.py:802] System Information: Jax Backend: PJRT C API
TFRT TPU v6 lite
Built on Dec 15 2025 14:03:46 (1765836226) cl/844590465
I0418 23:52:04.903053 137113772353344 maxtext_utils.py:1718] Num_devices: 32, shape (1, 1, 1, 32, 1, 1, 1, 1, 1, 1, 1, 1, 1)
I0418 23:52:05.095957 137113772353344 maxtext_utils.py:1718] Num_devices: 32, shape (1, 1, 1, 32, 1, 1, 1, 1, 1, 1, 1, 1, 1)
Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/deps/src/maxtext/trainers/post_train/sft/train_sft.py", line 280, in <module>
    app.run(main)
  File "/usr/local/lib/python3.12/site-packages/absl/app.py", line 316, in run
    _run_main(main, args)
  File "/usr/local/lib/python3.12/site-packages/absl/app.py", line 261, in _run_main
    sys.exit(main(argv))
             ^^^^^^^^^^
  File "/deps/src/maxtext/trainers/post_train/sft/train_sft.py", line 276, in main
    train(mt_config, goodput_recorder)
  File "/deps/src/maxtext/trainers/post_train/sft/train_sft.py", line 250, in train
    trainer, mesh = setup_trainer_state(mt_config, goodput_recorder)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/deps/src/maxtext/trainers/post_train/sft/train_sft.py", line 213, in setup_trainer_state
    model, mesh = model_creation_utils.create_nnx_model(mt_config)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/deps/src/maxtext/utils/model_creation_utils.py", line 337, in create_nnx_model
    model = create_nnx_sharded_model_hybrid(config, mesh, devices, model_mode, rng_key)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/deps/src/maxtext/utils/model_creation_utils.py", line 314, in create_nnx_sharded_model_hybrid
    sharded_state = create_sharded_state()
                    ^^^^^^^^^^^^^^^^^^^^^^
ValueError: One of pjit outputs was given the sharding of NamedSharding(mesh=Mesh('diloco': 1, 'data': 1, 'stage': 1, 'fsdp': 32, 'fsdp_transpose': 1, 'sequence': 1, 'context': 1, 'context_autoregressive': 1, 'tensor': 1, 'tensor_transpose': 1, 'tensor_sequence': 1, 'expert': 1, 'autoregressive': 1, axis_types=(Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto)), spec=PartitionSpec(('fsdp', 'sequence', 'tensor_transpose', 'context', 'expert'), 'stage', ('fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive')), memory_kind=device), which implies that the global size of its dimension 0 should be divisible by 32, but it is equal to 16 (full shape: (16, 1, 64))
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
XPK End: Sat Apr 18 23:52:12 UTC 2026
EXIT_CODE=1
NNX  ·  574ad3fb9  ·  feat_nnx_post_train_fixes_20260418_235042  ·  full log
XPK Start: Sat Apr 18 23:58:28 UTC 2026
2026-04-18 23:58:57.071844: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)
I0418 23:58:57.283660 132437727110976 max_utils.py:273] Attempting to initialize the jax distributed system...
INFO:2026-04-18 23:59:06,326:jax._src.distributed:149: Starting JAX distributed service on [::]:8482
I0418 23:59:06.326119 132437727110976 distributed.py:149] Starting JAX distributed service on [::]:8482
INFO:2026-04-18 23:59:06,328:jax._src.distributed:166: Connecting to JAX distributed service on mt-01-sft-smoke-cmror-slice-job-0-0.mt-01-sft-smoke-cmror:8482
I0418 23:59:06.328450 132437727110976 distributed.py:166] Connecting to JAX distributed service on mt-01-sft-smoke-cmror-slice-job-0-0.mt-01-sft-smoke-cmror:8482
I0418 23:59:08.408328 132437727110976 max_utils.py:284] Jax distributed system initialized!
I0418 23:59:14.470565 132437727110976 max_utils.py:800] System Information: Jax Version: 0.8.3
I0418 23:59:14.470668 132437727110976 max_utils.py:801] System Information: Jaxlib Version: 0.8.3
I0418 23:59:14.470708 132437727110976 max_utils.py:802] System Information: Jax Backend: PJRT C API
TFRT TPU v6 lite
Built on Dec 15 2025 14:03:46 (1765836226) cl/844590465
I0418 23:59:14.474050 132437727110976 maxtext_utils.py:1718] Num_devices: 32, shape (1, 1, 1, 32, 1, 1, 1, 1, 1, 1, 1, 1, 1)
I0418 23:59:14.567893 132437727110976 maxtext_utils.py:1718] Num_devices: 32, shape (1, 1, 1, 32, 1, 1, 1, 1, 1, 1, 1, 1, 1)
I0418 23:59:14.667265 132437727110976 maxtext_utils.py:1718] Num_devices: 32, shape (1, 1, 1, 32, 1, 1, 1, 1, 1, 1, 1, 1, 1)
Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/deps/src/maxtext/trainers/post_train/sft/train_sft.py", line 280, in <module>
    app.run(main)
  File "/usr/local/lib/python3.12/site-packages/absl/app.py", line 316, in run
    _run_main(main, args)
  File "/usr/local/lib/python3.12/site-packages/absl/app.py", line 261, in _run_main
    sys.exit(main(argv))
             ^^^^^^^^^^
  File "/deps/src/maxtext/trainers/post_train/sft/train_sft.py", line 276, in main
    train(mt_config, goodput_recorder)
  File "/deps/src/maxtext/trainers/post_train/sft/train_sft.py", line 250, in train
    trainer, mesh = setup_trainer_state(mt_config, goodput_recorder)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/deps/src/maxtext/trainers/post_train/sft/train_sft.py", line 213, in setup_trainer_state
    model, mesh = model_creation_utils.create_nnx_model(mt_config)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/deps/src/maxtext/utils/model_creation_utils.py", line 334, in create_nnx_model
    model = maxtext_utils_nnx.create_nnx_sharded_model(abstract_model, _create_model, mesh=mesh)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/deps/src/maxtext/utils/maxtext_utils_nnx.py", line 174, in create_nnx_sharded_model
    sharded_state = create_sharded_state()
                    ^^^^^^^^^^^^^^^^^^^^^^
ValueError: One of pjit outputs was given the sharding of NamedSharding(mesh=Mesh('diloco': 1, 'data': 1, 'stage': 1, 'fsdp': 32, 'fsdp_transpose': 1, 'sequence': 1, 'context': 1, 'context_autoregressive': 1, 'tensor': 1, 'tensor_transpose': 1, 'tensor_sequence': 1, 'expert': 1, 'autoregressive': 1, axis_types=(Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto)), spec=PartitionSpec(('fsdp', 'sequence', 'context', 'expert'), 'stage', ('fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive')), memory_kind=device), which implies that the global size of its dimension 0 should be divisible by 32, but it is equal to 16 (full shape: (16, 1, 64))
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
XPK End: Sat Apr 18 23:59:24 UTC 2026
EXIT_CODE=1