MaxView

Post-Training Test Results (NNX)

Environment

Log
Branch feat/nnx-post-train-fixes
Commit e27fc1e97
Date 2026-04-20 21:24
Run ID feat_nnx_post_train_fixes_20260420_205452
NNX flags pure_nnx=True enable_nnx=True pure_nnx_decoder=True
Hardware V6e-8 TPU (8 devices, 31.25 GiB/device)
Python Python 3.12.12
JAX 0.8.3
Flax 0.12.6

Summary

6/7 passed.

Test Result Log
01_sft_smoke PASS log
02_sft_nnx_ckpt PASS log
03_sft_linen_ckpt PASS log
07_distill_smoke FAIL log
04_rl_grpo_smoke PASS log
05_rl_gspo_smoke PASS log
06_rl_grpo_functional PASS log

Output Paths

Log
Logs ~/maxtext/venv_runs/feat_nnx_post_train_fixes_20260420_205452/nnx/logs/
GCS checkpoints gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260420_205452/
NNX seed ckpt (PT-02 NNX) gs://wanglance-maxtext/nnx_ckpt_feat_nnx_trainstate_and_training_loop_20260411_044231/nnx_feat_nnx_trainstate_and_training_loop_20260411_044231_08_checkpoint_async_true/checkpoints/9/items
Linen seed ckpt (PT-02/03) gs://wanglance-maxtext/pt_seed_ckpts/pt_seed_ckpt_gpt352k_linen/checkpoints/9/items
Teacher ckpt (PT-07) gs://wanglance-maxtext/pt_seed_ckpts/pt_seed_ckpt_gpt352k_v32k_linen/checkpoints/4/items
RL ckpt (PT-06) gs://wanglance-maxtext/rl_ckpt_llama31_8b/0/items

Reproduction Commands

source ~/maxtext/maxtext_pt_venv/bin/activate
export PYTHONPATH=src
export DECOUPLE_GCLOUD=TRUE
export VLLM_TARGET_DEVICE=tpu   # RL only
export HF_TOKEN=<your-token>
export RL_CKPT_PATH=gs://wanglance-maxtext/rl_ckpt_llama31_8b/0/items

01_sft_smoke

python3 -m maxtext.trainers.post_train.sft.train_sft src/maxtext/configs/post_train/sft.yml\
  model_name=gpt3-52k\
  per_device_batch_size=1\
  ici_fsdp_parallelism=8\
  max_target_length=1024\
  steps=5\
  eval_interval=-1\
  gradient_accumulation_steps=1\
  weight_dtype=float32\
  skip_jax_distributed_system=True\
  log_config=False\
  enable_goodput_recording=False\
  profiler=xplane\
  pure_nnx=True\
  enable_nnx=True\
  pure_nnx_decoder=True\
  tokenizer_path=meta-llama/Llama-2-7b-chat-hf\
  tokenizer_type=huggingface\
  hf_access_token=$HF_TOKEN\
  base_output_directory=gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260420_205452\
  run_name=pt_sft_nnx_feat_nnx_post_train_fixes_20260420_205452_01_sft_smoke 

02_sft_nnx_ckpt

python3 -m maxtext.trainers.post_train.sft.train_sft src/maxtext/configs/post_train/sft.yml\
  model_name=gpt3-52k\
  per_device_batch_size=1\
  ici_fsdp_parallelism=8\
  max_target_length=1024\
  steps=5\
  eval_interval=-1\
  gradient_accumulation_steps=1\
  weight_dtype=float32\
  skip_jax_distributed_system=True\
  log_config=False\
  enable_goodput_recording=False\
  profiler=xplane\
  pure_nnx=True\
  enable_nnx=True\
  pure_nnx_decoder=True\
  tokenizer_path=meta-llama/Llama-2-7b-chat-hf\
  tokenizer_type=huggingface\
  hf_access_token=$HF_TOKEN\
  base_output_directory=gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260420_205452\
  run_name=pt_sft_nnx_feat_nnx_post_train_fixes_20260420_205452_02_sft_nnx_ckpt\
  load_parameters_path=gs://wanglance-maxtext/nnx_ckpt_feat_nnx_trainstate_and_training_loop_20260411_044231/nnx_feat_nnx_trainstate_and_training_loop_20260411_044231_08_checkpoint_async_true/checkpoints/9/items

03_sft_linen_ckpt

python3 -m maxtext.trainers.post_train.sft.train_sft src/maxtext/configs/post_train/sft.yml\
  model_name=gpt3-52k\
  per_device_batch_size=1\
  ici_fsdp_parallelism=8\
  max_target_length=1024\
  steps=5\
  eval_interval=-1\
  gradient_accumulation_steps=1\
  weight_dtype=float32\
  skip_jax_distributed_system=True\
  log_config=False\
  enable_goodput_recording=False\
  profiler=xplane\
  pure_nnx=True\
  enable_nnx=True\
  pure_nnx_decoder=True\
  tokenizer_path=meta-llama/Llama-2-7b-chat-hf\
  tokenizer_type=huggingface\
  hf_access_token=$HF_TOKEN\
  base_output_directory=gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260420_205452\
  run_name=pt_sft_nnx_feat_nnx_post_train_fixes_20260420_205452_03_sft_linen_ckpt\
  load_parameters_path=gs://wanglance-maxtext/pt_seed_ckpts/pt_seed_ckpt_gpt352k_linen/checkpoints/9/items

07_distill_smoke

python3 -m maxtext.trainers.post_train.distillation.train_distill src/maxtext/configs/post_train/distillation.yml student_overrides.model_name=gpt3-52k student_overrides.vocab_size=32000 teacher_overrides.model_name=gpt3-52k teacher_overrides.vocab_size=32000 teacher_overrides.load_parameters_path=gs://wanglance-maxtext/pt_seed_ckpts/pt_seed_ckpt_gpt352k_v32k_linen/checkpoints/4/items teacher_overrides.skip_jax_distributed_system=True\
  tokenizer_path=meta-llama/Llama-2-7b-chat-hf\
  tokenizer_type=huggingface\
  hf_access_token=$HF_TOKEN\
  steps=5\
  per_device_batch_size=1\
  ici_fsdp_parallelism=8\
  weight_dtype=float32\
  gradient_accumulation_steps=1\
  skip_jax_distributed_system=True\
  profiler=xplane\
  log_config=False\
  enable_goodput_recording=False\
  pure_nnx=True\
  enable_nnx=True\
  pure_nnx_decoder=True\
  base_output_directory=gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260420_205452\
  run_name=pt_distill_nnx_feat_nnx_post_train_fixes_20260420_205452_07_distill_smoke 

04_rl_grpo_smoke

VLLM_TARGET_DEVICE=tpu python3 -m maxtext.trainers.post_train.rl.train_rl src/maxtext/configs/post_train/rl.yml\
  model_name=qwen3-0.6b\
  tokenizer_path=Qwen/Qwen3-0.6B\
  use_pathways=False\
  chips_per_vm=8\
  num_batches=2\
  rollout_data_parallelism=1\
  async_scheduling=False\
  skip_jax_distributed_system=True\
  log_config=False\
  enable_goodput_recording=False\
  profiler=xplane\
  pure_nnx=True\
  enable_nnx=True\
  pure_nnx_decoder=True\
  hf_access_token=$HF_TOKEN\
  base_output_directory=gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260420_205452\
  run_name=pt_rl_nnx_feat_nnx_post_train_fixes_20260420_205452_04_rl_grpo_smoke rl.loss_algo=grpo

05_rl_gspo_smoke

VLLM_TARGET_DEVICE=tpu python3 -m maxtext.trainers.post_train.rl.train_rl src/maxtext/configs/post_train/rl.yml\
  model_name=qwen3-0.6b\
  tokenizer_path=Qwen/Qwen3-0.6B\
  use_pathways=False\
  chips_per_vm=8\
  num_batches=2\
  rollout_data_parallelism=1\
  async_scheduling=False\
  skip_jax_distributed_system=True\
  log_config=False\
  enable_goodput_recording=False\
  profiler=xplane\
  pure_nnx=True\
  enable_nnx=True\
  pure_nnx_decoder=True\
  hf_access_token=$HF_TOKEN\
  base_output_directory=gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260420_205452\
  run_name=pt_rl_nnx_feat_nnx_post_train_fixes_20260420_205452_05_rl_gspo_smoke rl.loss_algo=gspo-token

06_rl_grpo_functional

VLLM_TARGET_DEVICE=tpu python3 -m maxtext.trainers.post_train.rl.train_rl src/maxtext/configs/post_train/rl.yml\
  model_name=qwen3-0.6b\
  tokenizer_path=Qwen/Qwen3-0.6B\
  use_pathways=False\
  chips_per_vm=8\
  num_batches=2\
  rollout_data_parallelism=1\
  async_scheduling=False\
  skip_jax_distributed_system=True\
  log_config=False\
  enable_goodput_recording=False\
  profiler=xplane\
  pure_nnx=True\
  enable_nnx=True\
  pure_nnx_decoder=True\
  hf_access_token=$HF_TOKEN\
  base_output_directory=gs://wanglance-maxtext/pt_ckpt_feat_nnx_post_train_fixes_20260420_205452\
  run_name=pt_rl_nnx_feat_nnx_post_train_fixes_20260420_205452_06_rl_grpo_functional\
  model_name=llama3.1-8b-Instruct\
  tokenizer_path=meta-llama/Llama-3.1-8B-Instruct\
  load_parameters_path=gs://wanglance-maxtext/rl_ckpt_llama31_8b/0/items\
  num_batches=2 rl.loss_algo=grpo