| ID | Name | Frameworks | Purpose | Pass Criteria |
|---|---|---|---|---|
01 | Base (TFDS c4) | Linen + NNX | Baseline training on real c4 dataset | 10 steps complete, loss decreasing |
02 | Synthetic dataset | Linen + NNX | Verifies synthetic data path (no real dataset needed) | 10 steps, loss from ~10.9 to ~5.9 |
03 | Dropout | Linen + NNX | Verifies dropout is wired correctly in NNX modules | 10 steps, no NaN, loss decreasing |
04 | int8 quantization | Linen + NNX | Verifies int8 quantization path | 10 steps, loss close to baseline |
05 | fp8 quantization | Linen + NNX | Verifies fp8 quantization, numerically stable | 20 steps, no NaN |
06 | Gradient accumulation ×4 | Linen + NNX | Verifies NNX gradient accumulation via jax.lax.scan | 10 steps; total_weights=65536 (4× base) confirms 4-step accumulation |
07 | Eval loop | Linen + NNX | Verifies p_eval_step runs without error alongside training | 10 training steps complete; eval runs at steps 5, 10 without crash |
08 | Checkpointing async_checkpointing=True (save + resume) | Linen + NNX | Save checkpoint at steps 5 and 10, resume and continue to step 20 — default async path | Save exits 0; resume starts at step 10 with correct loss continuity |
09 | per_device_batch_size < 1 | Linen + NNX | Fractional batch size (token packing) | 10 steps, total_weights varies per step (real token counts) |
10 | shardy=False | Linen + NNX | Legacy GSPMD dialect (shardy=True is default and covered by 01_base) | 10 steps, loss matches 01_base |
11 | optimizer_memory_host_offload=True | Linen + NNX | Optimizer state on host RAM (False is default and covered by 02_synthetic) | 10 steps, same loss as synthetic baseline |
13 | scan_layers=False | Linen + NNX | Standard unrolled layers (scan_layers=True is default and covered by 02_synthetic) | 10 steps, clean convergence |
14 | async_checkpointing=False (save + resume) | Linen + NNX | Synchronous checkpoint I/O path | Save exits 0; resume resumes from step 10 |
15 | checkpoint_storage_use_ocdbt (both values, save + resume) | Linen + NNX | OCDBT format (default) | Save exits 0; resume from step 10 |
16 | checkpoint_storage_use_zarr3 (both values, save + resume) | Linen + NNX | Zarr3 format (default) | Save exits 0; resume from step 10 |
| ID | Name | Frameworks | Purpose | Pass Criteria |
|---|---|---|---|---|
PT-01 | SFT smoke (random init) | Linen + NNX | Verify Tunix SFT pipeline, HF dataset, and loss logging end-to-end | 5 steps complete; learning/loss logged each step; no crash |
PT-02 | SFT from checkpoint (same-format) | Linen + NNX | Verify SFT restores from NNX pre-train checkpoint | Loaded params from ... in log; 5 steps complete; loss not NaN |
PT-03 | SFT from Linen checkpoint into NNX model (cross-format, NNX only) | NNX only NNX only — tests the Linen→NNX cross-format load in create_nnx_model | Verify create_nnx_model loads a Linen checkpoint into an NNX model | Loaded params from ... in log; 5 steps complete; loss not NaN |
PT-04 | RL GRPO smoke (random init) | NNX only RL training uses vLLM which only supports NNX | Verify GRPO pipeline runs end-to-end on V6e-8 (vLLM sampler + trainer loop); does not verify learning quality | 2 batches complete; reward logged (will be ~0 with random weights); no crash |
PT-05 | RL GSPO smoke (random init) | NNX only RL training uses vLLM which only supports NNX | Verify GSPO variant pipeline runs; mirrors PT-04 with rl.loss_algo=gspo-token | 2 batches complete; no crash |
PT-06 | RL GRPO functional (Llama3.1-8B-Instruct) | NNX only RL training uses vLLM which only supports NNX | Verify RL produces meaningful reward improvement on GSM8K using a real instruct model | 2 batches complete; reward logged; policy loss logged; no crash |
PT-07 | Distillation smoke (gpt3-52k teacher + student) | Linen + NNX | Verify dual-model distillation pipeline (ModelBundle, teacher forward pass frozen, student optimizer) runs end-to-end; does not verify knowledge transfer quality | 5 steps complete; _train_loss logged each step; no crash |