MaxView

Pre-Training Tests

IDNameFrameworksPurposePass Criteria
01Base (TFDS c4)Linen + NNXBaseline training on real c4 dataset10 steps complete, loss decreasing
02Synthetic datasetLinen + NNXVerifies synthetic data path (no real dataset needed)10 steps, loss from ~10.9 to ~5.9
03DropoutLinen + NNXVerifies dropout is wired correctly in NNX modules10 steps, no NaN, loss decreasing
04int8 quantizationLinen + NNXVerifies int8 quantization path10 steps, loss close to baseline
05fp8 quantizationLinen + NNXVerifies fp8 quantization, numerically stable20 steps, no NaN
06Gradient accumulation ×4Linen + NNXVerifies NNX gradient accumulation via jax.lax.scan10 steps; total_weights=65536 (4× base) confirms 4-step accumulation
07Eval loopLinen + NNXVerifies p_eval_step runs without error alongside training10 training steps complete; eval runs at steps 5, 10 without crash
08Checkpointing async_checkpointing=True (save + resume)Linen + NNXSave checkpoint at steps 5 and 10, resume and continue to step 20 — default async pathSave exits 0; resume starts at step 10 with correct loss continuity
09per_device_batch_size < 1Linen + NNXFractional batch size (token packing)10 steps, total_weights varies per step (real token counts)
10shardy=FalseLinen + NNXLegacy GSPMD dialect (shardy=True is default and covered by 01_base)10 steps, loss matches 01_base
11optimizer_memory_host_offload=TrueLinen + NNXOptimizer state on host RAM (False is default and covered by 02_synthetic)10 steps, same loss as synthetic baseline
13scan_layers=FalseLinen + NNXStandard unrolled layers (scan_layers=True is default and covered by 02_synthetic)10 steps, clean convergence
14async_checkpointing=False (save + resume)Linen + NNXSynchronous checkpoint I/O pathSave exits 0; resume resumes from step 10
15checkpoint_storage_use_ocdbt (both values, save + resume)Linen + NNXOCDBT format (default)Save exits 0; resume from step 10
16checkpoint_storage_use_zarr3 (both values, save + resume)Linen + NNXZarr3 format (default)Save exits 0; resume from step 10

Post-Training Tests

IDNameFrameworksPurposePass Criteria
PT-01SFT smoke (random init)Linen + NNXVerify Tunix SFT pipeline, HF dataset, and loss logging end-to-end5 steps complete; learning/loss logged each step; no crash
PT-02SFT from checkpoint (same-format)Linen + NNXVerify SFT restores from NNX pre-train checkpointLoaded params from ... in log; 5 steps complete; loss not NaN
PT-03SFT from Linen checkpoint into NNX model (cross-format, NNX only)NNX only
NNX only — tests the Linen→NNX cross-format load in create_nnx_model
Verify create_nnx_model loads a Linen checkpoint into an NNX modelLoaded params from ... in log; 5 steps complete; loss not NaN
PT-04RL GRPO smoke (random init)NNX only
RL training uses vLLM which only supports NNX
Verify GRPO pipeline runs end-to-end on V6e-8 (vLLM sampler + trainer loop); does not verify learning quality2 batches complete; reward logged (will be ~0 with random weights); no crash
PT-05RL GSPO smoke (random init)NNX only
RL training uses vLLM which only supports NNX
Verify GSPO variant pipeline runs; mirrors PT-04 with rl.loss_algo=gspo-token2 batches complete; no crash
PT-06RL GRPO functional (Llama3.1-8B-Instruct)NNX only
RL training uses vLLM which only supports NNX
Verify RL produces meaningful reward improvement on GSM8K using a real instruct model2 batches complete; reward logged; policy loss logged; no crash
PT-07Distillation smoke (gpt3-52k teacher + student)Linen + NNXVerify dual-model distillation pipeline (ModelBundle, teacher forward pass frozen, student optimizer) runs end-to-end; does not verify knowledge transfer quality5 steps complete; _train_loss logged each step; no crash