Post-Training Test Results (NNX)
Environment
|
|
Log |
| Branch |
feat/nnx-post-train-fixes |
| Commit |
d8cde296b |
| Date |
2026-04-16 21:28 |
| Run ID |
feat_nnx_post_train_fixes_20260416_210550 |
| NNX flags |
pure_nnx=True enable_nnx=True pure_nnx_decoder=True |
| Hardware |
V6e-8 TPU (8 devices, 31.25 GiB/device) |
| Python |
Python 3.12.12 |
| JAX |
0.8.3 |
| Flax |
0.12.6 |
Summary
5/7 passed.
| Test |
Result |
Log |
01_sft_smoke |
PASS |
log |
02_sft_nnx_ckpt |
PASS |
log |
03_sft_linen_ckpt |
PASS |
log |
07_distill_smoke |
FAIL |
log |
04_rl_grpo_smoke |
PASS |
log |
05_rl_gspo_smoke |
PASS |
log |
06_rl_grpo_functional |
FAIL |
log |
Output Paths
|
|
Log |
| Logs |
~/maxtext/venv_runs/feat_nnx_post_train_fixes_20260416_210550/nnx/logs/ |
| GCS checkpoints |
gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260416_210550/ |
| NNX seed ckpt (PT-02 NNX) |
gs://wanglance-maxtext/nnx_ckpt_feat_nnx_trainstate_and_training_loop_20260411_044231/nnx_feat_nnx_trainstate_and_training_loop_20260411_044231_08_checkpoint_async_true/checkpoints/9/items |
| Linen seed ckpt (PT-02/03) |
gs://wanglance-maxtext/pt_seed_ckpts/pt_seed_ckpt_gpt352k_linen/checkpoints/9/items |
| Teacher ckpt (PT-07) |
gs://wanglance-maxtext/pt_seed_ckpts/pt_seed_ckpt_gpt352k_v32k_linen/checkpoints/4/items |
| RL ckpt (PT-06) |
gs://wanglance-maxtext/rl_ckpt_llama31_8b/0/items |
Reproduction Commands
source ~/maxtext/maxtext_pt_venv/bin/activate
export PYTHONPATH=src
export DECOUPLE_GCLOUD=TRUE
export VLLM_TARGET_DEVICE=tpu # RL only
export HF_TOKEN=<your-token>
export RL_CKPT_PATH=gs://wanglance-maxtext/rl_ckpt_llama31_8b/0/items
01_sft_smoke
python3 -m maxtext.trainers.post_train.sft.train_sft src/maxtext/configs/post_train/sft.yml\
model_name=gpt3-52k\
per_device_batch_size=1\
ici_fsdp_parallelism=8\
max_target_length=1024\
steps=5\
eval_interval=-1\
gradient_accumulation_steps=1\
weight_dtype=float32\
skip_jax_distributed_system=True\
log_config=False\
enable_goodput_recording=False\
profiler=xplane\
pure_nnx=True\
enable_nnx=True\
pure_nnx_decoder=True\
tokenizer_path=meta-llama/Llama-2-7b-chat-hf\
tokenizer_type=huggingface\
hf_access_token=$HF_TOKEN\
base_output_directory=gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260416_210550\
run_name=pt_sft_nnx_feat_nnx_post_train_fixes_20260416_210550_01_sft_smoke
02_sft_nnx_ckpt
python3 -m maxtext.trainers.post_train.sft.train_sft src/maxtext/configs/post_train/sft.yml\
model_name=gpt3-52k\
per_device_batch_size=1\
ici_fsdp_parallelism=8\
max_target_length=1024\
steps=5\
eval_interval=-1\
gradient_accumulation_steps=1\
weight_dtype=float32\
skip_jax_distributed_system=True\
log_config=False\
enable_goodput_recording=False\
profiler=xplane\
pure_nnx=True\
enable_nnx=True\
pure_nnx_decoder=True\
tokenizer_path=meta-llama/Llama-2-7b-chat-hf\
tokenizer_type=huggingface\
hf_access_token=$HF_TOKEN\
base_output_directory=gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260416_210550\
run_name=pt_sft_nnx_feat_nnx_post_train_fixes_20260416_210550_02_sft_nnx_ckpt\
load_parameters_path=gs://wanglance-maxtext/nnx_ckpt_feat_nnx_trainstate_and_training_loop_20260411_044231/nnx_feat_nnx_trainstate_and_training_loop_20260411_044231_08_checkpoint_async_true/checkpoints/9/items
03_sft_linen_ckpt
python3 -m maxtext.trainers.post_train.sft.train_sft src/maxtext/configs/post_train/sft.yml\
model_name=gpt3-52k\
per_device_batch_size=1\
ici_fsdp_parallelism=8\
max_target_length=1024\
steps=5\
eval_interval=-1\
gradient_accumulation_steps=1\
weight_dtype=float32\
skip_jax_distributed_system=True\
log_config=False\
enable_goodput_recording=False\
profiler=xplane\
pure_nnx=True\
enable_nnx=True\
pure_nnx_decoder=True\
tokenizer_path=meta-llama/Llama-2-7b-chat-hf\
tokenizer_type=huggingface\
hf_access_token=$HF_TOKEN\
base_output_directory=gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260416_210550\
run_name=pt_sft_nnx_feat_nnx_post_train_fixes_20260416_210550_03_sft_linen_ckpt\
load_parameters_path=gs://wanglance-maxtext/pt_seed_ckpts/pt_seed_ckpt_gpt352k_linen/checkpoints/9/items
07_distill_smoke
python3 -m maxtext.trainers.post_train.distillation.train_distill src/maxtext/configs/post_train/distillation.yml student_overrides.model_name=gpt3-52k student_overrides.vocab_size=32000 teacher_overrides.model_name=gpt3-52k teacher_overrides.vocab_size=32000 teacher_overrides.load_parameters_path=gs://wanglance-maxtext/pt_seed_ckpts/pt_seed_ckpt_gpt352k_v32k_linen/checkpoints/4/items teacher_overrides.skip_jax_distributed_system=True\
tokenizer_path=meta-llama/Llama-2-7b-chat-hf\
tokenizer_type=huggingface\
hf_access_token=$HF_TOKEN\
steps=5\
per_device_batch_size=1\
ici_fsdp_parallelism=8\
weight_dtype=float32\
gradient_accumulation_steps=1\
skip_jax_distributed_system=True\
profiler=xplane\
log_config=False\
enable_goodput_recording=False\
pure_nnx=True\
enable_nnx=True\
pure_nnx_decoder=True\
base_output_directory=gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260416_210550\
run_name=pt_distill_nnx_feat_nnx_post_train_fixes_20260416_210550_07_distill_smoke
04_rl_grpo_smoke
VLLM_TARGET_DEVICE=tpu python3 -m maxtext.trainers.post_train.rl.train_rl src/maxtext/configs/post_train/rl.yml\
model_name=qwen3-0.6b\
tokenizer_path=Qwen/Qwen3-0.6B\
use_pathways=False\
chips_per_vm=8\
num_batches=2\
rollout_data_parallelism=1\
async_scheduling=False\
skip_jax_distributed_system=True\
log_config=False\
enable_goodput_recording=False\
profiler=xplane\
pure_nnx=True\
enable_nnx=True\
pure_nnx_decoder=True\
hf_access_token=$HF_TOKEN\
base_output_directory=gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260416_210550\
run_name=pt_rl_nnx_feat_nnx_post_train_fixes_20260416_210550_04_rl_grpo_smoke rl.loss_algo=grpo
05_rl_gspo_smoke
VLLM_TARGET_DEVICE=tpu python3 -m maxtext.trainers.post_train.rl.train_rl src/maxtext/configs/post_train/rl.yml\
model_name=qwen3-0.6b\
tokenizer_path=Qwen/Qwen3-0.6B\
use_pathways=False\
chips_per_vm=8\
num_batches=2\
rollout_data_parallelism=1\
async_scheduling=False\
skip_jax_distributed_system=True\
log_config=False\
enable_goodput_recording=False\
profiler=xplane\
pure_nnx=True\
enable_nnx=True\
pure_nnx_decoder=True\
hf_access_token=$HF_TOKEN\
base_output_directory=gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260416_210550\
run_name=pt_rl_nnx_feat_nnx_post_train_fixes_20260416_210550_05_rl_gspo_smoke rl.loss_algo=gspo-token
06_rl_grpo_functional
VLLM_TARGET_DEVICE=tpu python3 -m maxtext.trainers.post_train.rl.train_rl src/maxtext/configs/post_train/rl.yml\
model_name=qwen3-0.6b\
tokenizer_path=Qwen/Qwen3-0.6B\
use_pathways=False\
chips_per_vm=8\
num_batches=2\
rollout_data_parallelism=1\
async_scheduling=False\
skip_jax_distributed_system=True\
log_config=False\
enable_goodput_recording=False\
profiler=xplane\
pure_nnx=True\
enable_nnx=True\
pure_nnx_decoder=True\
hf_access_token=$HF_TOKEN\
base_output_directory=gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260416_210550\
run_name=pt_rl_nnx_feat_nnx_post_train_fixes_20260416_210550_06_rl_grpo_functional\
model_name=llama3.1-8b-Instruct\
tokenizer_path=meta-llama/Llama-3.1-8B-Instruct\
load_parameters_path=gs://wanglance-maxtext/rl_ckpt_llama31_8b/0/items\
num_batches=2 rl.loss_algo=grpo