MaxView

← Back to run

Log Summary

XPK Start: Sun Apr 19 03:38:54 UTC 2026
2026-04-19 03:39:10.821087: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)
I0419 03:39:14.397938 133786137376576 max_utils.py:273] Attempting to initialize the jax distributed system...
INFO:2026-04-19 03:39:23,437:jax._src.distributed:149: Starting JAX distributed service on [::]:8482
I0419 03:39:23.437945 133786137376576 distributed.py:149] Starting JAX distributed service on [::]:8482
INFO:2026-04-19 03:39:23,440:jax._src.distributed:166: Connecting to JAX distributed service on mt-07-distill-smoke-jwpgn-slice-job-0-0.mt-07-distill-smoke-jwpgn:8482
I0419 03:39:23.440464 133786137376576 distributed.py:166] Connecting to JAX distributed service on mt-07-distill-smoke-jwpgn-slice-job-0-0.mt-07-distill-smoke-jwpgn:8482
I0419 03:39:24.779937 133786137376576 max_utils.py:284] Jax distributed system initialized!
I0419 03:39:30.171854 133786137376576 max_utils.py:244] Jax distributed system is already initialized.
I0419 03:39:30.651923 133786137376576 max_utils.py:244] Jax distributed system is already initialized.
I0419 03:39:30.653125 133786137376576 tokenizer.py:245] Tokenizer path: meta-llama/Llama-2-7b-chat-hf
I0419 03:39:30.653180 133786137376576 tokenizer.py:224] Loading HF tokenizer: meta-llama/Llama-2-7b-chat-hf
I0419 03:39:34.574856 133786137376576 _schedule.py:129] A polynomial schedule was set with a non-positive `transition_steps` value; this results in a constant schedule with value `init_value`.
I0419 03:39:34.577899 133786137376576 maxtext_utils.py:1631] Num_devices: 32, shape (1, 1, 1, 32, 1, 1, 1, 1, 1, 1, 1, 1, 1)
I0419 03:39:34.578013 133786137376576 train_distill.py:582] Applying logical axis rules for model initialization and training...
I0419 03:39:34.578086 133786137376576 train_distill.py:586] Loading Student from ...
I0419 03:39:34.578124 133786137376576 train_distill.py:169] --- Student Configuration ---
I0419 03:39:34.578146 133786137376576 train_distill.py:170]   Model Name:      gpt3-52k
I0419 03:39:34.578166 133786137376576 train_distill.py:171]   Dimensions:      1 Layers, 16 Emb Dim, 8 Head Dim
I0419 03:39:34.578186 133786137376576 train_distill.py:174]   Attention Heads: 2 Query, 2 KV
I0419 03:39:34.578203 133786137376576 train_distill.py:175]   Vocab Size:      32000
I0419 03:39:34.578221 133786137376576 train_distill.py:176]   Checkpoint:      
I0419 03:39:34.578371 133786137376576 train_distill.py:460] Initializing model: gpt3-52k...
Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/deps/src/maxtext/trainers/post_train/distillation/train_distill.py", line 747, in <module>
    app.run(main)
  File "/usr/local/lib/python3.12/site-packages/absl/app.py", line 316, in run
    _run_main(main, args)
  File "/usr/local/lib/python3.12/site-packages/absl/app.py", line 261, in _run_main
    sys.exit(main(argv))
             ^^^^^^^^^^
  File "/deps/src/maxtext/trainers/post_train/distillation/train_distill.py", line 743, in main
    train_distill(student_config, teacher_config, is_offline, global_config.offline_data_dir)
  File "/deps/src/maxtext/trainers/post_train/distillation/train_distill.py", line 588, in train_distill
    student_model = get_maxtext_model(student_config, mesh)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/deps/src/maxtext/trainers/post_train/distillation/train_distill.py", line 461, in get_maxtext_model
    model, _ = model_creation_utils.create_nnx_model(config, mesh=mesh)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/deps/src/maxtext/utils/model_creation_utils.py", line 335, in create_nnx_model
    model = maxtext_utils_nnx.create_nnx_sharded_model(abstract_model, _create_model, mesh=mesh)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/deps/src/maxtext/utils/maxtext_utils_nnx.py", line 171, in create_nnx_sharded_model
    sharded_state = create_sharded_state()
                    ^^^^^^^^^^^^^^^^^^^^^^
ValueError: One of pjit outputs was given the sharding of NamedSharding(mesh=Mesh('diloco': 1, 'data': 1, 'stage': 1, 'fsdp': 32, 'fsdp_transpose': 1, 'sequence': 1, 'context': 1, 'context_autoregressive': 1, 'tensor': 1, 'tensor_transpose': 1, 'tensor_sequence': 1, 'expert': 1, 'autoregressive': 1, axis_types=(Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto)), spec=PartitionSpec(('fsdp', 'sequence', 'context', 'expert'), 'stage', ('fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive')), memory_kind=device), which implies that the global size of its dimension 0 should be divisible by 32, but it is equal to 16 (full shape: (16, 1, 64))
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
XPK End: Sun Apr 19 03:39:44 UTC 2026
EXIT_CODE=1