test/pipeline-scan-nnxXPK Start: Sun Apr 19 09:06:46 UTC 2026 2026-04-19 09:07:28.447429: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303) I0419 09:07:32.145170 136522548135744 max_utils.py:273] Attempting to initialize the jax distributed system... INFO:2026-04-19 09:07:41,184:jax._src.distributed:149: Starting JAX distributed service on [::]:8482 I0419 09:07:41.184015 136522548135744 distributed.py:149] Starting JAX distributed service on [::]:8482 INFO:2026-04-19 09:07:41,186:jax._src.distributed:166: Connecting to JAX distributed service on mt-07-distill-smoke-6mtf5-slice-job-0-0.mt-07-distill-smoke-6mtf5:8482 I0419 09:07:41.186305 136522548135744 distributed.py:166] Connecting to JAX distributed service on mt-07-distill-smoke-6mtf5-slice-job-0-0.mt-07-distill-smoke-6mtf5:8482 I0419 09:07:49.637182 136522548135744 max_utils.py:284] Jax distributed system initialized! I0419 09:07:55.905689 136522548135744 max_utils.py:244] Jax distributed system is already initialized. I0419 09:07:56.373677 136522548135744 max_utils.py:244] Jax distributed system is already initialized. I0419 09:07:56.375303 136522548135744 tokenizer.py:245] Tokenizer path: meta-llama/Llama-2-7b-chat-hf I0419 09:07:56.375357 136522548135744 tokenizer.py:224] Loading HF tokenizer: meta-llama/Llama-2-7b-chat-hf I0419 09:08:00.274246 136522548135744 _schedule.py:129] A polynomial schedule was set with a non-positive `transition_steps` value; this results in a constant schedule with value `init_value`. I0419 09:08:00.277145 136522548135744 maxtext_utils.py:1398] Num_devices: 32, shape (1, 1, 1, 32, 1, 1, 1, 1, 1, 1, 1, 1, 1) I0419 09:08:00.277267 136522548135744 train_distill.py:582] Applying logical axis rules for model initialization and training... I0419 09:08:00.277339 136522548135744 train_distill.py:586] Loading Student from ... I0419 09:08:00.277369 136522548135744 train_distill.py:169] --- Student Configuration --- I0419 09:08:00.277392 136522548135744 train_distill.py:170] Model Name: gpt3-52k I0419 09:08:00.277414 136522548135744 train_distill.py:171] Dimensions: 1 Layers, 16 Emb Dim, 8 Head Dim I0419 09:08:00.277434 136522548135744 train_distill.py:174] Attention Heads: 2 Query, 2 KV I0419 09:08:00.277452 136522548135744 train_distill.py:175] Vocab Size: 32000 I0419 09:08:00.277471 136522548135744 train_distill.py:176] Checkpoint: I0419 09:08:00.277490 136522548135744 train_distill.py:460] Initializing model: gpt3-52k... Traceback (most recent call last): File "<frozen runpy>", line 198, in _run_module_as_main File "<frozen runpy>", line 88, in _run_code File "/deps/src/maxtext/trainers/post_train/distillation/train_distill.py", line 747, in <module> app.run(main) File "/usr/local/lib/python3.12/site-packages/absl/app.py", line 316, in run _run_main(main, args) File "/usr/local/lib/python3.12/site-packages/absl/app.py", line 261, in _run_main sys.exit(main(argv)) ^^^^^^^^^^ File "/deps/src/maxtext/trainers/post_train/distillation/train_distill.py", line 743, in main train_distill(student_config, teacher_config, is_offline, global_config.offline_data_dir) File "/deps/src/maxtext/trainers/post_train/distillation/train_distill.py", line 588, in train_distill student_model = get_maxtext_model(student_config, mesh) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/deps/src/maxtext/trainers/post_train/distillation/train_distill.py", line 461, in get_maxtext_model model, _ = model_creation_utils.create_nnx_model(config, mesh=mesh) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/deps/src/maxtext/utils/model_creation_utils.py", line 301, in create_nnx_model sharded_state = create_sharded_state() ^^^^^^^^^^^^^^^^^^^^^^ ValueError: One of pjit outputs was given the sharding of NamedSharding(mesh=Mesh('diloco': 1, 'data': 1, 'stage': 1, 'fsdp': 32, 'fsdp_transpose': 1, 'sequence': 1, 'context': 1, 'context_autoregressive': 1, 'tensor': 1, 'tensor_transpose': 1, 'tensor_sequence': 1, 'expert': 1, 'autoregressive': 1, axis_types=(Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto)), spec=PartitionSpec(('fsdp', 'sequence', 'context', 'expert'), None, ('fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive')), memory_kind=device), which implies that the global size of its dimension 0 should be divisible by 32, but it is equal to 16 (full shape: (16, 1, 64)) -------------------- For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these. XPK End: Sun Apr 19 09:08:06 UTC 2026 EXIT_CODE=1
XPK Start: Sun Apr 19 09:16:14 UTC 2026 2026-04-19 09:16:31.075858: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303) I0419 09:16:34.632635 133016035387200 max_utils.py:273] Attempting to initialize the jax distributed system... INFO:2026-04-19 09:16:43,669:jax._src.distributed:149: Starting JAX distributed service on [::]:8482 I0419 09:16:43.669782 133016035387200 distributed.py:149] Starting JAX distributed service on [::]:8482 INFO:2026-04-19 09:16:43,672:jax._src.distributed:166: Connecting to JAX distributed service on mt-07-distill-smoke-wwin2-slice-job-0-0.mt-07-distill-smoke-wwin2:8482 I0419 09:16:43.672029 133016035387200 distributed.py:166] Connecting to JAX distributed service on mt-07-distill-smoke-wwin2-slice-job-0-0.mt-07-distill-smoke-wwin2:8482 I0419 09:16:45.011039 133016035387200 max_utils.py:284] Jax distributed system initialized! I0419 09:16:51.497201 133016035387200 max_utils.py:244] Jax distributed system is already initialized. I0419 09:16:51.968220 133016035387200 max_utils.py:244] Jax distributed system is already initialized. I0419 09:16:51.969876 133016035387200 tokenizer.py:245] Tokenizer path: meta-llama/Llama-2-7b-chat-hf I0419 09:16:51.969932 133016035387200 tokenizer.py:224] Loading HF tokenizer: meta-llama/Llama-2-7b-chat-hf I0419 09:16:55.908151 133016035387200 _schedule.py:129] A polynomial schedule was set with a non-positive `transition_steps` value; this results in a constant schedule with value `init_value`. I0419 09:16:55.911183 133016035387200 maxtext_utils.py:1398] Num_devices: 32, shape (1, 1, 1, 32, 1, 1, 1, 1, 1, 1, 1, 1, 1) I0419 09:16:55.911320 133016035387200 train_distill.py:582] Applying logical axis rules for model initialization and training... I0419 09:16:55.911395 133016035387200 train_distill.py:586] Loading Student from ... I0419 09:16:55.911425 133016035387200 train_distill.py:169] --- Student Configuration --- I0419 09:16:55.911449 133016035387200 train_distill.py:170] Model Name: gpt3-52k I0419 09:16:55.911472 133016035387200 train_distill.py:171] Dimensions: 1 Layers, 16 Emb Dim, 8 Head Dim I0419 09:16:55.911491 133016035387200 train_distill.py:174] Attention Heads: 2 Query, 2 KV I0419 09:16:55.911510 133016035387200 train_distill.py:175] Vocab Size: 32000 I0419 09:16:55.911529 133016035387200 train_distill.py:176] Checkpoint: I0419 09:16:55.911548 133016035387200 train_distill.py:460] Initializing model: gpt3-52k... Traceback (most recent call last): File "<frozen runpy>", line 198, in _run_module_as_main File "<frozen runpy>", line 88, in _run_code File "/deps/src/maxtext/trainers/post_train/distillation/train_distill.py", line 747, in <module> app.run(main) File "/usr/local/lib/python3.12/site-packages/absl/app.py", line 316, in run _run_main(main, args) File "/usr/local/lib/python3.12/site-packages/absl/app.py", line 261, in _run_main sys.exit(main(argv)) ^^^^^^^^^^ File "/deps/src/maxtext/trainers/post_train/distillation/train_distill.py", line 743, in main train_distill(student_config, teacher_config, is_offline, global_config.offline_data_dir) File "/deps/src/maxtext/trainers/post_train/distillation/train_distill.py", line 588, in train_distill student_model = get_maxtext_model(student_config, mesh) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/deps/src/maxtext/trainers/post_train/distillation/train_distill.py", line 461, in get_maxtext_model model, _ = model_creation_utils.create_nnx_model(config, mesh=mesh) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/deps/src/maxtext/utils/model_creation_utils.py", line 301, in create_nnx_model sharded_state = create_sharded_state() ^^^^^^^^^^^^^^^^^^^^^^ ValueError: One of pjit outputs was given the sharding of NamedSharding(mesh=Mesh('diloco': 1, 'data': 1, 'stage': 1, 'fsdp': 32, 'fsdp_transpose': 1, 'sequence': 1, 'context': 1, 'context_autoregressive': 1, 'tensor': 1, 'tensor_transpose': 1, 'tensor_sequence': 1, 'expert': 1, 'autoregressive': 1, axis_types=(Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto)), spec=PartitionSpec(('fsdp', 'sequence', 'context', 'expert'), None, ('fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive')), memory_kind=device), which implies that the global size of its dimension 0 should be divisible by 32, but it is equal to 16 (full shape: (16, 1, 64)) -------------------- For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these. XPK End: Sun Apr 19 09:17:05 UTC 2026 EXIT_CODE=1