XPK Start: Fri Apr 24 07:35:16 UTC 2026 `rope_parameters`'s factor field must be a float >= 1, got 40 `rope_parameters`'s beta_fast field must be a float, got 32 `rope_parameters`'s beta_slow field must be a float, got 1 DeepseekV32Config got `key=rope_scaling` in kwargs but hasn't set it as attribute. For RoPE standardization you need to set `self.rope_parameters` in model's config. 2026-04-24 07:35:41.170050: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303) I0424 07:35:41.170173 138242845357888 max_utils.py:800] System Information: Jax Version: 0.9.2 I0424 07:35:41.170268 138242845357888 max_utils.py:801] System Information: Jaxlib Version: 0.9.2 I0424 07:35:47.175641 138242845357888 max_utils.py:802] System Information: Jax Backend: PJRT C API TFRT TPU v6 lite Built on Apr 6 2026 20:48:10 (1775533690) cl/895581894 I0424 07:35:47.377693 138242845357888 max_utils.py:238] Skipping jax distributed system due to skip_jax_distributed_system=True flag. I0424 07:35:47.379275 138242845357888 model_creation_utils.py:269] Running on a single slice I0424 07:35:47.379331 138242845357888 model_creation_utils.py:356] Creating reference model and also meshes for reference and rollout I0424 07:35:47.382698 138242845357888 maxtext_utils.py:1604] Num_devices: 32, shape (1, 1, 1, 32, 1, 1, 1, 1, 1, 1, 1, 1) I0424 07:35:47.502054 138242845357888 maxtext_utils.py:1604] Num_devices: 32, shape (1, 1, 1, 32, 1, 1, 1, 1, 1, 1, 1, 1) I0424 07:36:16.056151 138242845357888 pytree_checkpoint_handler.py:577] save_device_host_concurrent_bytes=None I0424 07:36:16.056630 138242845357888 base_pytree_checkpoint_handler.py:411] Created BasePyTreeCheckpointHandler: use_ocdbt=False, use_zarr3=False, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x7dba26500770>, enable_pinned_host_transfer=False, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB) I0424 07:36:16.056689 138242845357888 abstract_checkpointer.py:35] orbax-checkpoint version: 0.11.28 W0424 07:36:16.581766 138242845357888 checkpoint.py:202] Metadata file does not exist: gs://lance-maxtext/rl_ckpt_llama31_8b/rl_ckpt_llama31_8b/0/items/_CHECKPOINT_METADATA I0424 07:36:16.978165 1715 google_auth_provider.cc:181] Running on GCE, using service account 562977990677-compute@developer.gserviceaccount.com I0424 07:36:18.048366 138242845357888 checkpointer.py:304] Restoring checkpoint from gs://lance-maxtext/rl_ckpt_llama31_8b/rl_ckpt_llama31_8b/0/items. W0424 07:36:47.304528 138242845357888 transform_utils.py:230] The transformations API will eventually be replaced by an upgraded design. The current API will not be removed until this point, but it will no longer be actively worked on. I0424 07:36:47.304836 138242845357888 transform_utils.py:288] The following keys are not loaded from the original tree after applying specified transforms: params/params/decoder/dropout/rngs/aqt/count, params/params/decoder/dropout/rngs/aqt/key, params/params/decoder/dropout/rngs/dropout/count, params/params/decoder/dropout/rngs/dropout/key, params/params/decoder/dropout/rngs/params/count, params/params/decoder/dropout/rngs/params/key, params/params/decoder/layers/dropout/rngs/aqt/count, params/params/decoder/layers/dropout/rngs/aqt/key, params/params/decoder/layers/dropout/rngs/dropout/count, params/params/decoder/layers/dropout/rngs/dropout/key, params/params/decoder/layers/dropout/rngs/params/count, params/params/decoder/layers/dropout/rngs/params/key, params/params/decoder/layers/mlp/dropout/rngs/aqt/count, params/params/decoder/layers/mlp/dropout/rngs/aqt/key, params/params/decoder/layers/mlp/dropout/rngs/dropout/count, params/params/decoder/layers/mlp/dropout/rngs/dropout/key, params/params/decoder/layers/mlp/dropout/rngs/params/count, params/params/decoder/layers/mlp/dropout/rngs/params/key, params/params/decoder/layers/self_attention/attention_op/rngs/aqt/count, params/params/decoder/layers/self_attention/attention_op/rngs/aqt/key, params/params/decoder/layers/self_attention/attention_op/rngs/dropout/count, params/params/decoder/layers/self_attention/attention_op/rngs/dropout/key, params/params/decoder/layers/self_attention/attention_op/rngs/params/count, params/params/decoder/layers/self_attention/attention_op/rngs/params/key, params/params/decoder/rngs/aqt/count, params/params/decoder/rngs/aqt/key, params/params/decoder/rngs/dropout/count, params/params/decoder/rngs/dropout/key, params/params/decoder/rngs/params/count, params/params/decoder/rngs/params/key I0424 07:37:09.634540 138242845357888 checkpointer.py:318] Finished restoring checkpoint in 51.95 seconds from gs://lance-maxtext/rl_ckpt_llama31_8b/rl_ckpt_llama31_8b/0/items. I0424 07:37:09.662802 138242845357888 maxtext_utils.py:1604] Num_devices: 32, shape (1, 1, 1, 32, 1, 1, 1, 1, 1, 1, 1, 1) I0424 07:37:09.663257 138242845357888 model_creation_utils.py:373] Creating policy model with same config as reference model on trainer mesh I0424 07:37:09.665741 138242845357888 maxtext_utils.py:1604] Num_devices: 32, shape (1, 1, 1, 32, 1, 1, 1, 1, 1, 1, 1, 1) I0424 07:37:09.730589 138242845357888 maxtext_utils.py:1604] Num_devices: 32, shape (1, 1, 1, 32, 1, 1, 1, 1, 1, 1, 1, 1) W0424 07:37:09.956352 11 pjrt_executable.cc:642] Assume version compatibility. PjRt-IFRT does not track XLA executable versions. I0424 07:37:10.037447 138242845357888 pytree_checkpoint_handler.py:577] save_device_host_concurrent_bytes=None I0424 07:37:10.037610 138242845357888 base_pytree_checkpoint_handler.py:411] Created BasePyTreeCheckpointHandler: use_ocdbt=False, use_zarr3=False, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x7dba26500770>, enable_pinned_host_transfer=False, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB) W0424 07:37:10.548193 138242845357888 checkpoint.py:202] Metadata file does not exist: gs://lance-maxtext/rl_ckpt_llama31_8b/rl_ckpt_llama31_8b/0/items/_CHECKPOINT_METADATA I0424 07:37:12.044370 138242845357888 checkpointer.py:304] Restoring checkpoint from gs://lance-maxtext/rl_ckpt_llama31_8b/rl_ckpt_llama31_8b/0/items. W0424 07:37:41.201411 138242845357888 transform_utils.py:230] The transformations API will eventually be replaced by an upgraded design. The current API will not be removed until this point, but it will no longer be actively worked on. I0424 07:37:41.201729 138242845357888 transform_utils.py:288] The following keys are not loaded from the original tree after applying specified transforms: params/params/decoder/dropout/rngs/aqt/count, params/params/decoder/dropout/rngs/aqt/key, params/params/decoder/dropout/rngs/dropout/count, params/params/decoder/dropout/rngs/dropout/key, params/params/decoder/dropout/rngs/params/count, params/params/decoder/dropout/rngs/params/key, params/params/decoder/layers/dropout/rngs/aqt/count, params/params/decoder/layers/dropout/rngs/aqt/key, params/params/decoder/layers/dropout/rngs/dropout/count, params/params/decoder/layers/dropout/rngs/dropout/key, params/params/decoder/layers/dropout/rngs/params/count, params/params/decoder/layers/dropout/rngs/params/key, params/params/decoder/layers/mlp/dropout/rngs/aqt/count, params/params/decoder/layers/mlp/dropout/rngs/aqt/key, params/params/decoder/layers/mlp/dropout/rngs/dropout/count, params/params/decoder/layers/mlp/dropout/rngs/dropout/key, params/params/decoder/layers/mlp/dropout/rngs/params/count, params/params/decoder/layers/mlp/dropout/rngs/params/key, params/params/decoder/layers/self_attention/attention_op/rngs/aqt/count, params/params/decoder/layers/self_attention/attention_op/rngs/aqt/key, params/params/decoder/layers/self_attention/attention_op/rngs/dropout/count, params/params/decoder/layers/self_attention/attention_op/rngs/dropout/key, params/params/decoder/layers/self_attention/attention_op/rngs/params/count, params/params/decoder/layers/self_attention/attention_op/rngs/params/key, params/params/decoder/rngs/aqt/count, params/params/decoder/rngs/aqt/key, params/params/decoder/rngs/dropout/count, params/params/decoder/rngs/dropout/key, params/params/decoder/rngs/params/count, params/params/decoder/rngs/params/key I0424 07:37:45.345584 138242845357888 checkpointer.py:318] Finished restoring checkpoint in 33.67 seconds from gs://lance-maxtext/rl_ckpt_llama31_8b/rl_ckpt_llama31_8b/0/items. I0424 07:37:48.710641 138242845357888 _client.py:1025] HTTP Request: HEAD https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct/resolve/main/config.json "HTTP/1.1 200 OK" I0424 07:37:48.824737 138242845357888 _client.py:1025] HTTP Request: GET https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct/resolve/main/config.json "HTTP/1.1 200 OK" I0424 07:37:48.932451 138242845357888 _client.py:1025] HTTP Request: HEAD https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct/resolve/main/tokenizer_config.json "HTTP/1.1 200 OK" I0424 07:37:49.045423 138242845357888 _client.py:1025] HTTP Request: GET https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct/resolve/main/tokenizer_config.json "HTTP/1.1 200 OK" I0424 07:37:49.166863 138242845357888 _client.py:1025] HTTP Request: GET https://huggingface.co/api/models/meta-llama/Llama-3.1-8B-Instruct/tree/main/additional_chat_templates?recursive=false&expand=false "HTTP/1.1 404 Not Found" I0424 07:37:49.274016 138242845357888 _client.py:1025] HTTP Request: GET https://huggingface.co/api/models/meta-llama/Llama-3.1-8B-Instruct/tree/main?recursive=true&expand=false "HTTP/1.1 200 OK" I0424 07:37:49.382920 138242845357888 _client.py:1025] HTTP Request: HEAD https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct/resolve/main/tokenizer.json "HTTP/1.1 200 OK" I0424 07:37:49.515920 138242845357888 _client.py:1025] HTTP Request: GET https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct/resolve/main/tokenizer.json "HTTP/1.1 200 OK" I0424 07:37:50.092268 138242845357888 _client.py:1025] HTTP Request: HEAD https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct/resolve/main/tokenizer.model "HTTP/1.1 404 Not Found" I0424 07:37:50.214736 138242845357888 _client.py:1025] HTTP Request: HEAD https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct/resolve/main/added_tokens.json "HTTP/1.1 404 Not Found" I0424 07:37:50.321582 138242845357888 _client.py:1025] HTTP Request: HEAD https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct/resolve/main/special_tokens_map.json "HTTP/1.1 200 OK" I0424 07:37:50.431528 138242845357888 _client.py:1025] HTTP Request: GET https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct/resolve/main/special_tokens_map.json "HTTP/1.1 200 OK" I0424 07:37:50.542529 138242845357888 _client.py:1025] HTTP Request: HEAD https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct/resolve/main/chat_template.jinja "HTTP/1.1 404 Not Found" I0424 07:37:51.300463 138242845357888 _client.py:1025] HTTP Request: GET https://huggingface.co/api/models/meta-llama/Llama-3.1-8B-Instruct "HTTP/1.1 200 OK" I0424 07:37:51.419019 138242845357888 _client.py:1025] HTTP Request: HEAD https://huggingface.co/datasets/openai/gsm8k/resolve/main/README.md "HTTP/1.1 307 Temporary Redirect" I0424 07:37:51.428622 138242845357888 _client.py:1025] HTTP Request: HEAD https://huggingface.co/api/resolve-cache/datasets/openai/gsm8k/740312add88f781978c0658806c59bc2815b9866/README.md "HTTP/1.1 200 OK" I0424 07:37:51.437468 138242845357888 _client.py:1025] HTTP Request: GET https://huggingface.co/api/resolve-cache/datasets/openai/gsm8k/740312add88f781978c0658806c59bc2815b9866/README.md "HTTP/1.1 200 OK" I0424 07:37:51.543738 138242845357888 _client.py:1025] HTTP Request: HEAD https://huggingface.co/datasets/openai/gsm8k/resolve/740312add88f781978c0658806c59bc2815b9866/gsm8k.py "HTTP/1.1 404 Not Found" I0424 07:37:51.850031 138242845357888 _client.py:1025] HTTP Request: HEAD https://s3.amazonaws.com/datasets.huggingface.co/datasets/datasets/openai/gsm8k/openai/gsm8k.py "HTTP/1.1 404 Not Found" I0424 07:37:52.098366 138242845357888 _client.py:1025] HTTP Request: GET https://huggingface.co/api/datasets/openai/gsm8k/revision/740312add88f781978c0658806c59bc2815b9866 "HTTP/1.1 200 OK" I0424 07:37:52.235486 138242845357888 _client.py:1025] HTTP Request: HEAD https://huggingface.co/datasets/openai/gsm8k/resolve/740312add88f781978c0658806c59bc2815b9866/.huggingface.yaml "HTTP/1.1 404 Not Found" I0424 07:37:52.395419 138242845357888 _client.py:1025] HTTP Request: GET https://datasets-server.huggingface.co/info?dataset=openai/gsm8k "HTTP/1.1 200 OK" I0424 07:37:52.545275 138242845357888 _client.py:1025] HTTP Request: GET https://huggingface.co/api/datasets/openai/gsm8k/tree/740312add88f781978c0658806c59bc2815b9866/main?recursive=true&expand=false "HTTP/1.1 200 OK" I0424 07:37:52.660396 138242845357888 _client.py:1025] HTTP Request: GET https://huggingface.co/api/datasets/openai/gsm8k/tree/740312add88f781978c0658806c59bc2815b9866?recursive=false&expand=false "HTTP/1.1 200 OK" I0424 07:37:52.768895 138242845357888 _client.py:1025] HTTP Request: HEAD https://huggingface.co/datasets/openai/gsm8k/resolve/740312add88f781978c0658806c59bc2815b9866/dataset_infos.json "HTTP/1.1 404 Not Found" I0424 07:37:52.928611 138242845357888 _client.py:1025] HTTP Request: HEAD https://huggingface.co/datasets/openai/gsm8k/resolve/740312add88f781978c0658806c59bc2815b9866/main/train-00000-of-00001.parquet "HTTP/1.1 302 Found" I0424 07:37:53.042311 138242845357888 _client.py:1025] HTTP Request: GET https://huggingface.co/api/datasets/openai/gsm8k/xet-read-token/740312add88f781978c0658806c59bc2815b9866 "HTTP/1.1 200 OK" I0424 07:37:53.707674 138242845357888 _client.py:1025] HTTP Request: HEAD https://huggingface.co/datasets/openai/gsm8k/resolve/740312add88f781978c0658806c59bc2815b9866/main/test-00000-of-00001.parquet "HTTP/1.1 302 Found" Generating train split: 0%| | 0/7473 [00:00<?, ? examples/s] Generating train split: 100%|██████████| 7473/7473 [00:00<00:00, 701287.25 examples/s] Generating test split: 0%| | 0/1319 [00:00<?, ? examples/s] Generating test split: 100%|██████████| 1319/1319 [00:00<00:00, 471635.72 examples/s] I0424 07:37:53.949980 138242845357888 train_rl.py:96] Loaded Hugging Face dataset openai/gsm8k with split train. Size: 7473 I0424 07:37:54.058989 138242845357888 _client.py:1025] HTTP Request: HEAD https://huggingface.co/datasets/openai/gsm8k/resolve/main/README.md "HTTP/1.1 307 Temporary Redirect" I0424 07:37:54.068182 138242845357888 _client.py:1025] HTTP Request: HEAD https://huggingface.co/api/resolve-cache/datasets/openai/gsm8k/740312add88f781978c0658806c59bc2815b9866/README.md "HTTP/1.1 200 OK" I0424 07:37:54.172403 138242845357888 _client.py:1025] HTTP Request: HEAD https://huggingface.co/datasets/openai/gsm8k/resolve/740312add88f781978c0658806c59bc2815b9866/gsm8k.py "HTTP/1.1 404 Not Found" I0424 07:37:54.263882 138242845357888 _client.py:1025] HTTP Request: HEAD https://s3.amazonaws.com/datasets.huggingface.co/datasets/datasets/openai/gsm8k/openai/gsm8k.py "HTTP/1.1 404 Not Found" I0424 07:37:54.371627 138242845357888 _client.py:1025] HTTP Request: HEAD https://huggingface.co/datasets/openai/gsm8k/resolve/740312add88f781978c0658806c59bc2815b9866/.huggingface.yaml "HTTP/1.1 404 Not Found" I0424 07:37:54.485863 138242845357888 _client.py:1025] HTTP Request: GET https://datasets-server.huggingface.co/info?dataset=openai/gsm8k "HTTP/1.1 200 OK" I0424 07:37:54.589873 138242845357888 _client.py:1025] HTTP Request: HEAD https://huggingface.co/datasets/openai/gsm8k/resolve/740312add88f781978c0658806c59bc2815b9866/dataset_infos.json "HTTP/1.1 404 Not Found" I0424 07:37:54.594204 138242845357888 train_rl.py:96] Loaded Hugging Face dataset openai/gsm8k with split test. Size: 1319 I0424 07:37:54.595404 138242845357888 train_rl.py:562] Train dataset samples: I0424 07:37:54.626063 138242845357888 train_rl.py:568] Test dataset samples: {'answer': array(['["3", "3"]'], dtype='<U10'), 'prompts': array(['<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\n<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n<start_of_turn>user\nYou are given a problem. Think about the problem and provide your reasoning. Place it between <reasoning> and </reasoning>. Then, provide the final answer (i.e., just one numerical value) between <answer> and </answer>.\n\nMaria has 4 dimes, 4 quarters, and 7 nickels in her piggy bank. Her mom gives her 5 quarters. How much money, in dollars, does Maria have now?<end_of_turn>\n<start_of_turn>model<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n'], dtype='<U650'), 'question': array(['Maria has 4 dimes, 4 quarters, and 7 nickels in her piggy bank. Her mom gives her 5 quarters. How much money, in dollars, does Maria have now?'], dtype='<U142')} {'answer': array(['["34", "34"]'], dtype='<U12'), 'prompts': array(['<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\n<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n<start_of_turn>user\nYou are given a problem. Think about the problem and provide your reasoning. Place it between <reasoning> and </reasoning>. Then, provide the final answer (i.e., just one numerical value) between <answer> and </answer>.\n\nA wildlife team is monitoring the number of birds in a park. There are 3 blackbirds in each of the park’s 7 trees. There are also 13 magpies roaming around the park. How many birds are in the park in total?<end_of_turn>\n<start_of_turn>model<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n'], dtype='<U714'), 'question': array(['A wildlife team is monitoring the number of birds in a park. There are 3 blackbirds in each of the park’s 7 trees. There are also 13 magpies roaming around the park. How many birds are in the park in total?'], dtype='<U206')} {'answer': array(['["300", "300"]'], dtype='<U14'), 'prompts': array(['<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\n<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n<start_of_turn>user\nYou are given a problem. Think about the problem and provide your reasoning. Place it between <reasoning> and </reasoning>. Then, provide the final answer (i.e., just one numerical value) between <answer> and </answer>.\n\nMr Hezekiah had 20 trucks from his store supplying fertiliser to different farmers in his hometown dispatched for delivery on a particular day. Each truck was carrying 20 tons of fertiliser packed in bags. Two hours after the trucks had departed for delivery, Mr Hezekiah got the news that a quarter of the number of lorries dispatched for delivery had mechanical failures on the road and could not deliver the fertilisers to the farmers. Calculate the total number of tons of fertiliser that reached the farmers that day?<end_of_turn>\n<start_of_turn>model<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n'], dtype='<U1030'), 'question': array(['Mr Hezekiah had 20 trucks from his store supplying fertiliser to different farmers in his hometown dispatched for delivery on a particular day. Each truck was carrying 20 tons of fertiliser packed in bags. Two hours after the trucks had departed for delivery, Mr Hezekiah got the news that a quarter of the number of lorries dispatched for delivery had mechanical failures on the road and could not deliver the fertilisers to the farmers. Calculate the total number of tons of fertiliser that reached the farmers that day?'], dtype='<U522')} {'answer': array(['["450", "450"]'], dtype='<U14'), 'prompts': array(['<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\n<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n<start_of_turn>user\nYou are given a problem. Think about the problem and provide your reasoning. Place it between <reasoning> and </reasoning>. Then, provide the final answer (i.e., just one numerical value) between <answer> and </answer>.\n\nGrandpa loves to eat jelly beans, but how many jelly beans he can eat depends on the size of the beans. It takes 75 large jelly beans to fill Grandpa up. He can eat twice as many medium-sized beans as large beans. And eating 3 small beans is the same as eating 1 medium-sized bean. How many small beans can Grandpa eat?<end_of_turn>\n<start_of_turn>model<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n'], dtype='<U830'), 'question': array(['Grandpa loves to eat jelly beans, but how many jelly beans he can eat depends on the size of the beans. It takes 75 large jelly beans to fill Grandpa up. He can eat twice as many medium-sized beans as large beans. And eating 3 small beans is the same as eating 1 medium-sized bean. How many small beans can Grandpa eat?'], dtype='<U322')} {'answer': array(['["320", "320"]'], dtype='<U14'), 'prompts': array(['<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\n<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n<start_of_turn>user\nYou are given a problem. Think about the problem and provide your reasoning. Place it between <reasoning> and </reasoning>. Then, provide the final answer (i.e., just one numerical value) between <answer> and </answer>.\n\nMr. Maxim works at The Best Cookeries Around restaurant. On a particular day, 50 people entered the restaurant in the morning to eat. At around 10:00, 40 more people entered the restaurant and ordered the same amount of food as the first people. After a while, twice the number of people who entered the restaurant at 10:00 came in and ordered lunch. By evening, an additional 3 times as many people as the number that came in first had entered the restaurant. Calculate the total number of people that entered the restaurant on that day.<end_of_turn>\n<start_of_turn>model<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n'], dtype='<U1046'), 'question': array(['Mr. Maxim works at The Best Cookeries Around restaurant. On a particular day, 50 people entered the restaurant in the morning to eat. At around 10:00, 40 more people entered the restaurant and ordered the same amount of food as the first people. After a while, twice the number of people who entered the restaurant at 10:00 came in and ordered lunch. By evening, an additional 3 times as many people as the number that came in first had entered the restaurant. Calculate the total number of people that entered the restaurant on that day.'], I0424 07:37:54.630827 138242845357888 train_rl.py:575] Reference Model initialized successfully dtype='<U538')} {'answer': array(['["9", "9"]'], dtype='<U10'), 'prompts': array(['<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\n<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n<start_of_turn>user\nYou are given a problem. Think about the problem and provide your reasoning. Place it between <reasoning> and </reasoning>. Then, provide the final answer (i.e., just one numerical value) between <answer> and </answer>.\n\nJackson is planting tulips. He can fit 6 red tulips in a row and 8 blue tulips in a row. If Jackson buys 36 red tulips and 24 blue tulips, how many rows of flowers will he plant?<end_of_turn>\n<start_of_turn>model<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n'], dtype='<U686'), 'question': array(['Jackson is planting tulips. He can fit 6 red tulips in a row and 8 blue tulips in a row. If Jackson buys 36 red tulips and 24 blue tulips, how many rows of flowers will he plant?'], dtype='<U178')} {'answer': array(['["60", "60"]'], dtype='<U12'), 'prompts': array(['<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\n<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n<start_of_turn>user\nYou are given a problem. Think about the problem and provide your reasoning. Place it between <reasoning> and </reasoning>. Then, provide the final answer (i.e., just one numerical value) between <answer> and </answer>.\n\nThere are five phones on a phone plan. The main phone costs twice as much as each additional phone. If the main phone plan costs $20, how much does the whole phone plan cost?<end_of_turn>\n<start_of_turn>model<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n'], dtype='<U682'), 'question': array(['There are five phones on a phone plan. The main phone costs twice as much as each additional phone. If the main phone plan costs $20, how much does the whole phone plan cost?'], dtype='<U174')} TunixMaxTextAdapter( # Param: 8,030,261,248 (16.1 GB), RngState: 588 (3.5 KB), Total: 8,030,261,836 (16.1 GB) base=Transformer( # Param: 8,030,261,248 (16.1 GB), RngState: 588 (3.5 KB), Total: 8,030,261,836 (16.1 GB) audio_encoder=None, config=MaxTextConfig(emb_dim=4096, mlp_dim=14336, moe_mlp_dim=-1, num_decoder_layers=32, num_kv_heads=8, num_query_heads=32, num_diloco_replicas=1, ici_parallelism=[1, 1, 1, -1, 1, 1, 1, 1, 1, 1, 1, 1], dcn_parallelism=[1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], using_pipeline_parallelism=False, context_parallel_size=1, num_target_devices=32, global_batch_size_to_train_on=384, global_batch_size_to_eval_on=384, global_batch_size_to_load=384, global_batch_size_to_load_eval=384, micro_batch_size_to_train_on=384, micro_batch_size_to_eval_on=384, checkpoint_dir='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237/pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional/checkpoints/', convert_checkpoint_if_possible=True, metrics_dir='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237/pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional/metrics/', tensorboard_dir='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237/pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional/tensorboard/', managed_mldiagnostics_dir='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237/pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional/managed-mldiagnostics', rampup_end_step=0, tensors_on_device=[], tensors_to_offload=['decoder_layer_input', 'query_proj', 'key_proj', 'value_proj'], global_batch_size_to_load_start=None, global_batch_size_to_load_increment=None, rampup_samples_per_increment_to_load=None, d_model_for_audio=256, encoder_attention_heads_for_audio=4, encoder_ffn_dim_for_audio=512, encoder_layers_for_audio=2, attention_dropout_for_audio=0.0, activation_dropout_for_audio=0.0, activation_function_for_audio='gelu', num_mel_bins_for_audio=128, max_source_positions_for_audio=1500, scale_embedding_for_audio=True, n_window_for_audio=50, n_window_infer_for_audio=800, conv_chunksize_for_audio=500, downsample_hidden_size_for_audio=256, output_dim_for_audio=512, num_conv_layers_for_audio=3, max_timescale_for_audio=10000.0, max_sample_len_for_audio=10000, projector_input_dim_for_vit=4096, projector_output_dim_for_vit=4096, pixel_shuffle_ratio_for_vit=0.5, projector_dropout_for_vit=0.0, hidden_size_for_vit=1408, intermediate_size_for_vit=5632, num_attention_heads_for_vit=16, num_channels_for_vit=3, tile_size_for_vit=336, patch_size_for_vit=14, conv_stride_for_vit=14, num_hidden_layers_for_vit=34, rope_theta_for_vit=10000, vision_output_dim_for_vit=4096, spatial_merge_size_for_vit=2, out_hidden_size_for_vit=512, temporal_patch_size_for_vit=2, num_position_embeddings_for_vit=1024, deepstack_visual_indexes_for_vit=[], vision_output_length=-1, use_multimodal=False, freeze_vision_encoder_params=True, freeze_audio_encoder_params=True, use_audio=False, image_size_for_vit=896, image_path='', image_placeholder='<|image|>', posemb_type_for_vit='learn', max_num_images_per_example=-1, video_path='', audio_path='', video_placeholder='<|video|>', audio_placeholder='<|audio|>', use_audio_in_video=False, use_mrope=False, mrope_section=[24, 20, 20], position_id_per_seconds=25, managed_mldiagnostics=False, managed_mldiagnostics_run_group='', enable_tensorboard=True, use_vertex_tensorboard=False, vertex_tensorboard_project='', vertex_tensorboard_region='', report_heartbeat_metric_for_gcp_monitoring=False, heartbeat_reporting_interval_in_seconds=5, report_performance_metric_for_gcp_monitoring=False, enable_goodput_recording=False, monitor_goodput=False, goodput_upload_interval_seconds=30, enable_pathways_goodput=False, monitor_step_time_deviation=True, step_deviation_interval_seconds=30, enable_gcp_goodput_metrics=True, enable_gcp_step_deviation_metrics=True, metrics_file='', gcs_metrics=False, save_config_to_gcs=False, record_internal_nn_metrics=0, prometheus_port=0, enable_checkpoint_cloud_logger=False, enable_tunix_perf_metrics=False, collect_stack_trace=False, stack_trace_to_cloud=False, stack_trace_interval_seconds=600, dump_hlo=False, dump_step=-1, dump_hlo_local_dir='/tmp/xla_dump/', dump_hlo_delete_local_after=True, dump_hlo_gcs_dir='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237/pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional/xla_dump', dump_hlo_module_name='jit_train_step', dump_hlo_local_module_name='jit_train_step', dump_hlo_xla_flags='--xla_dump_to=/tmp/xla_dump/ --xla_dump_large_constants --xla_dump_hlo_module_re=jit_train_step', dump_hlo_upload_all=False, dump_jaxpr=False, dump_jaxpr_local_dir='/tmp/jaxpr_dump/', dump_jaxpr_delete_local_after=True, dump_jaxpr_gcs_dir='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237/pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional/jaxpr_dump', profiler=<ProfilerType.XPLANE: 'xplane'>, upload_all_profiler_results=False, skip_first_n_steps_for_profiler=1, profiler_steps=5, profile_cleanly=True, profile_periodically_period=-1, hide_profiler_step_metric=False, enable_jax_profiler=False, jax_profiler_port=9999, xprof_tpu_power_trace_level=<XProfTPUPowerTraceMode.POWER_TRACE_NONE: 0>, xprof_e2e_enable_fw_throttle_event=False, xprof_e2e_enable_fw_power_level_event=False, xprof_e2e_enable_fw_thermal_event=False, profile_power_events=False, constant_bound_config=[], jax_cache_dir='~/jax_cache', jax_distributed_initialization_timeout=300, jax_debug_log_modules='', skip_jax_distributed_system=True, enable_single_controller=False, subslice_shape='', max_checkify=False, compiled_trainstep_file='', compile_topology='', compile_topology_num_slices=-1, enable_prefix_caching=False, prefix_caching_hbm_byte=10000000000, prefix_caching_dram_byte=100000000000, inference_microbenchmark_prefill_lengths='64,128,256,512,1024', inference_microbenchmark_stages='prefill,generate', inference_microbenchmark_loop_iters=10, inference_microbenchmark_log_file_path='', inference_microbenchmark_num_samples=[1, 2, 3, 4, 5], inference_metadata_file='', inference_benchmark_test=False, inference_server='MaxtextInterleavedServer', prefill_slice='v5e-16', generate_slice='v5e-16', stack_prefill_result_cache=False, prefill_cache_axis_order='1,2,0,3', ar_cache_axis_order='1,2,0,3', compute_axis_order='0,1,2,3', reshape_q=False, decode_sampling_strategy=<SamplingStrategy.GREEDY: 'greedy'>, decode_sampling_nucleus_p=1.0, decode_sampling_top_k=50, decode_sampling_temperature=0.9, max_target_length=1024, max_prefill_predict_length=256, prompt='I love to', load_from_prefill_dir=False, prefill_cache_dir='', autoregressive_decode_assert='', model_call_mode='', use_chunked_prefill=False, prefill_chunk_size=256, enable_model_warmup=False, enable_llm_inference_pool=False, multi_sampling=False, return_log_prob=False, vocab_size=128256, tokenizer_path='meta-llama/Llama-3.1-8B-Instruct', tokenizer_type=<TokenizerType.HUGGINGFACE: 'huggingface'>, use_chat_template=False, chat_template_path='maxtext/examples/chat_templates/gsm8k_rl.json', chat_template='', tokenize_train_data=True, tokenize_eval_data=True, add_bos=True, add_eos=True, use_truncation=True, num_vocab_tiling=1, grain_train_files='', grain_eval_files='', grain_train_mixture_config_path='', grain_file_type='arrayrecord', grain_worker_count=1, grain_per_worker_buffer_size=1, grain_worker_count_eval=1, grain_per_worker_buffer_size_eval=1, grain_ram_budget_mb=1024, grain_num_threads=16, grain_prefetch_buffer_size=500, grain_num_threads_eval=16, grain_prefetch_buffer_size_eval=500, grain_data_source_max_workers=16, grain_shuffle_buffer_size=100, hf_path='', hf_name='main', hf_data_dir=None, hf_train_files=None, hf_eval_split=None, hf_eval_files=None, hf_access_token='<redacted>', dataset_path='', dataset_name='openai/gsm8k', eval_dataset_name='openai/gsm8k', train_split='train', eval_split='test', dataset_type=<DatasetType.TFDS: 'tfds'>, per_device_batch_size=12.0, eval_per_device_batch_size=12.0, max_corpus_chars=10000000, train_data_columns=['text'], train_image_column='image', eval_data_columns=['text'], eval_image_column='image', packing=True, grain_packing_type='first_fit', max_segments_per_seq=-1, num_epoch=1, expansion_factor_real_data=-1.0, reuse_example_batch=0, generate_padding_batch_train=False, generate_padding_batch_eval=False, enable_rampup_batch_size=False, per_device_batch_size_start=4.0, per_device_batch_size_increment=2.0, global_rampup_samples=500, colocated_python_data_input=False, max_position_embeddings=163840, original_max_position_embeddings=4096, rope_factor=40, beta_fast=32, beta_slow=1, mscale=1.0, rope_interleave=True, rope_truncate=True, rope_attention_scaling=False, rope_type=<RopeType.DEFAULT: 'default'>, rope_use_scale=True, rope_min_timescale=1, rope_max_timescale=500000, rope_linear_scaling_factor=1.0, local_rope_max_timescale=-1, global_rope_max_timescale=-1, global_rope_proportion=0.25, local_rope_proportion=1.0, use_iota_embed=False, use_untrainable_positional_embedding=False, trainable_position_size=-1, nope_layer_interval=-1, reasoning_start_token='<reasoning>', reasoning_end_token='</reasoning>', solution_start_token='<answer>', solution_end_token='</answer>', reward_exact_answer=1.0, reward_exact_format_match=0.1, reward_white_space_format_match=1.0, reward_partial_format_match=0.0, reward_ratio_guess_to_answer_high=0.0, reward_ratio_guess_to_answer_low=0.0, penalty_incorrect_format=0.0, penalty_incorrect_answer=0.0, math_verify_timeout=120, math_verify_num_procs=None, eval_sampling_strategy='greedy', generation_configs={'greedy': {'eval_temperature': 0.01, 'eval_top_k': 1, 'eval_top_p': 1.0}, 'standard': {'eval_temperature': 0.7, 'eval_top_k': 50, 'eval_top_p': 0.95}, 'liberal': {'eval_temperature': 0.85, 'eval_top_k': 2000, 'eval_top_p': 1.0}}, num_eval_passes=1, eval_corr_lst=False, eval_make_lst=False, eval_mode='pass', batch_size=1, num_batches=2, num_test_batches=5, test_batch_start_index=0, train_fraction=1.0, train_micro_batch_size=-1, rollout_micro_batch_size=-1, num_generations=2, num_iterations=1, grpo_beta=0.08, grpo_epsilon=0.2, loss_algo='grpo', use_agentic_rollout=False, max_concurrency=256, off_policy_steps=0, system_prompt='', degenerate_group_masking=True, epsilon_high=None, kv_cache_buffer=256, hbm_utilization_vllm=0.72, swap_space_vllm_gb=2, enable_dp_attention=False, enable_expert_parallel=False, async_scheduling=False, max_num_batched_tokens=None, max_num_seqs=None, stop_strings=None, vllm_additional_config={}, vllm_hf_overrides={}, vllm_hf_config_path='', trainer_devices_fraction=0.5, sampler_devices_fraction=0.5, chips_per_vm=4, use_pathways=False, num_trainer_slices=-1, num_samplers_slices=-1, rollout_data_parallelism=2, rollout_tensor_parallelism=-1, rollout_expert_parallelism=1, student_overrides={}, teacher_overrides={}, offline_data_dir=None, distill_alpha=0.5, distill_temperature=1.0, distill_beta=0.0, distill_feature_loss_type='cosine', distill_layer_indices=None, distill_alpha_end=None, distill_alpha_schedule='constant', distill_temperature_end=None, distill_temperature_schedule='constant', distill_beta_end=None, distill_beta_schedule='constant', student_params_to_update=None, use_dpo=False, dpo_label_smoothing=0.0, dpo_beta=0.1, use_sft=False, sft_train_on_completion_only=False, formatting_func_path='', formatting_func_kwargs={}, use_grpo=True, muon_beta=0.95, muon_weight_decay=0.0, muon_consistent_rms=None, adam_b1=0.9, adam_b2=0.99, adam_eps=1e-08, adam_eps_root=0.0, adam_weight_decay=0.1, adamw_mask=[], mu_dtype=<DType.BFLOAT16: 'bfloat16'>, opt_type=<OptimizerType.ADAMW: 'adamw'>, skip_step_on_spikes=False, skip_step_interval=128, skip_step_scaling_factor=6.0, gradient_accumulation_steps=1, use_tunix_gradient_accumulation=False, gradient_clipping_threshold=0.1, learning_rate=3e-06, lr_schedule_type=<LearningRateScheduleType.COSINE: 'cosine'>, learning_rate_final_fraction=0.1, wsd_decay_steps_fraction=0.1, wsd_decay_style=<WsdDecayStyle.LINEAR: 'linear'>, warmup_steps_fraction=0.1, learning_rate_schedule_steps=150001, trainable_parameters_mask=[], enable_diloco=False, diloco_sync_period=36, diloco_outer_lr=0.3, diloco_outer_momentum=0.9, mhc_expansion_rate=1, sinkhorn_iterations=20, steps=150001, log_period=20, eval_interval=10, eval_steps=-1, target_eval_loss=0.0, abort_on_nan_loss=True, abort_on_inf_loss=True, enable_dropout=False, dropout_rate=0.0, enable_data_shuffling=True, data_shuffle_seed=42, init_weights_seed=0, remat_policy='custom', remat_policy_for_vit='minimal', decoder_layer_input=<RematLocation.OFFLOAD: 'offload'>, context=<RematLocation.REMAT: 'remat'>, mlpwi=<RematLocation.REMAT: 'remat'>, mlpwi_0=<RematLocation.REMAT: 'remat'>, mlpwi_1=<RematLocation.REMAT: 'remat'>, mlpwo=<RematLocation.REMAT: 'remat'>, moe_mlpwi_0=<RematLocation.REMAT: 'remat'>, moe_mlpwi_1=<RematLocation.REMAT: 'remat'>, moe_mlpwo=<RematLocation.REMAT: 'remat'>, query_proj=<RematLocation.OFFLOAD: 'offload'>, key_proj=<RematLocation.OFFLOAD: 'offload'>, value_proj=<RematLocation.OFFLOAD: 'offload'>, query_wa_proj=<RematLocation.REMAT: 'remat'>, kv_wa_proj=<RematLocation.REMAT: 'remat'>, qkv_proj=<RematLocation.REMAT: 'remat'>, out_proj=<RematLocation.REMAT: 'remat'>, mla_q=<RematLocation.REMAT: 'remat'>, mla_kv=<RematLocation.REMAT: 'remat'>, attention_out=<RematLocation.REMAT: 'remat'>, engram=<RematLocation.REMAT: 'remat'>, optimizer_memory_host_offload=False, parameter_memory_host_offload=False, pipeline_fsdp_ag_per_repeat=False, num_layers_per_pipeline_stage=1, num_pipeline_repeats=-1, pipeline_parallel_layers=32, num_pipeline_microbatches=-1, pipeline_delay_activation_forwarding=False, pipeline_fsdp_ag_once=False, scan_pipeline_iterations=True, scan_pipeline_repeats=False, scan_layers_per_stage=False, set_remat_policy_on_pipeline_iterations=True, set_remat_policy_on_layers_per_stage=False, ici_diloco_parallelism=1, ici_data_parallelism=1, ici_fsdp_parallelism=-1, ici_fsdp_transpose_parallelism=1, ici_sequence_parallelism=1, ici_context_parallelism=1, ici_context_autoregressive_parallelism=1, ici_tensor_parallelism=1, ici_tensor_transpose_parallelism=1, ici_tensor_sequence_parallelism=1, ici_autoregressive_parallelism=1, ici_pipeline_parallelism=1, ici_expert_parallelism=1, dcn_diloco_parallelism=1, dcn_data_parallelism=-1, dcn_fsdp_parallelism=1, dcn_fsdp_transpose_parallelism=1, dcn_sequence_parallelism=1, dcn_context_parallelism=1, dcn_context_autoregressive_parallelism=1, dcn_tensor_parallelism=1, dcn_tensor_transpose_parallelism=1, dcn_tensor_sequence_parallelism=1, dcn_pipeline_parallelism=1, dcn_expert_parallelism=1, dcn_autoregressive_parallelism=1, logical_axis_rules=[['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'context', 'expert']], ['activation_vocab', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_vocab', ['tensor', 'tensor_transpose']], ['activation_vocab', 'tensor_sequence'], ['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['embed_vocab', ['fsdp', 'fsdp_transpose', 'context', 'expert']], ['activation_batch_attn', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['activation_kv_heads', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_length_attn', ['context']], ['activation_q_length', ['context']], ['activation_kv_length', []], ['activation_embed_attn', ['tensor', 'tensor_transpose']], ['activation_kv', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_kv_head_dim', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['q_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['kv_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['qkv', []], ['kv', []], ['kv_head_dim', []], ['q_lora', ['fsdp', 'fsdp_transpose', 'context', 'tensor_transpose', 'expert']], ['q_lora', ['fsdp', 'context', 'tensor_transpose', 'expert']], ['q_lora', ['fsdp', 'fsdp_transpose', 'context', 'expert']], ['q_lora', ['fsdp', 'context', 'expert']], ['q_lora_up_proj', []], ['kv_lora', ['fsdp', 'fsdp_transpose', 'context', 'tensor_transpose', 'expert']], ['kv_lora', ['fsdp', 'context', 'tensor_transpose', 'expert']], ['kv_lora', ['fsdp', 'fsdp_transpose', 'context', 'expert']], ['kv_lora', ['fsdp', 'context', 'expert']], ['kv_lora_up_proj', []], ['activation_batch_moe', ['data', 'fsdp', 'fsdp_transpose']], ['activation_length_moe', ['context']], ['activation_norm_length_moe', ['tensor_sequence', 'context']], ['activation_embed_moe', ['tensor', 'tensor_transpose']], ['activation_mlp_moe', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_exp', ['expert']], ['exp', 'expert'], ['mlp_moe', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']], ['embed_moe', ['fsdp', 'fsdp_transpose', 'tensor_transpose', 'context']], ['embed_moe', ['fsdp', 'tensor_transpose', 'context']], ['embed_moe', ['fsdp', 'fsdp_transpose', 'context']], ['embed_moe', ['fsdp', 'context']], ['activation_mlp', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_length', ['context']], ['activation_norm_length', ['tensor_sequence', 'context']], ['activation_embed', ['tensor', 'tensor_transpose']], ['activation_stage', 'stage'], ['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']], ['embed', ['fsdp', 'fsdp_transpose', 'tensor_transpose', 'context', 'expert']], ['embed', ['fsdp', 'tensor_transpose', 'context', 'expert']], ['embed', ['fsdp', 'fsdp_transpose', 'context', 'expert']], ['embed', ['fsdp', 'context', 'expert']], ['norm', ['tensor', 'tensor_transpose']], ['layers', 'stage'], ['diloco', 'diloco'], ['engram_dim', ['tensor']], ['dense_layers', []], ['moe_layers', []], ['mhc', []], ['prefill_activation_length', ['context']], ['prefill_activation_norm_length', ['tensor_sequence', 'context']], ['activation_prefill_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['decode_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['decode_length', []], ['cache_heads', ['autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence']], ['cache_heads', ['autoregressive', 'tensor', 'tensor_sequence']], ['paged_kv_heads', ['tensor']], ['cache_batch_prefill', []], ['cache_batch', []], ['cache_heads_none', []], ['cache_kv', []], ['cache_sequence', []], ['num_pages', []], ['tokens_per_page', []], ['paged_kv_head_dim_size', []], ['mlp_no_fsdp', ['tensor', 'tensor_sequence', 'autoregressive']], ['embed_tensor_transpose', ['tensor_transpose']], ['exp_with_fsdp', 'fsdp']], data_sharding=[['data', 'stage', 'fsdp', 'fsdp_transpose', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']], context_sharding='context', input_data_sharding_logical_axes=['activation_embed_and_logits_batch', 'activation_norm_length'], sharding_tolerance=0.02, shard_optimizer_over_data=False, internal_compile=False, internal_compile_num_devices=-1, compile_xla_flags='', hardware='tpu', num_slices=1, mesh_axes=['diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive'], shard_mode=<ShardMode.AUTO: 'auto'>, inhomogeneous_layer_cycle_interval=1, scan_layers=True, param_scan_axis=1, context_parallel_load_balance=True, context_parallel_strategy='all_gather', context_parallel_reorder_strategy=<ReorderStrategy.AUTO: 'auto'>, custom_mesh='', custom_mesh_and_rule='', allow_split_physical_axes=False, enable_nnx=True, optimize_mesh_for_tpu_v6e=False, shardy=True, pure_nnx_decoder=True, pure_nnx=True, remove_size_one_mesh_axis_from_type=True, gdn_conv_kernel_dim=4, gdn_key_head_dim=128, gdn_value_head_dim=128, gdn_num_key_heads=16, gdn_num_value_heads=32, gdn_chunk_size=64, use_qk_norm_in_gdn=True, partial_rotary_factor=1.0, first_num_dense_layers=0, shared_experts=0, routed_scaling_factor=1.0, routed_score_func='', routed_bias=False, routed_bias_update_rate=0.0, mlp_bias=False, n_routing_groups=-1, topk_routing_group=-1, use_batch_split_schedule=False, batch_split_factor=1, megablox=True, sparse_matmul=True, wi_tile_fwd_batch_seq=512, wi_tile_fwd_embed_dim=1024, wi_tile_fwd_mlp_dim=1024, wi_tile_dlhs_batch_seq=512, wi_tile_dlhs_embed_dim=1024, wi_tile_dlhs_mlp_dim=1024, wi_tile_drhs_batch_seq=512, wi_tile_drhs_embed_dim=1024, wi_tile_drhs_mlp_dim=1024, wo_tile_fwd_batch_seq=512, wo_tile_fwd_embed_dim=1024, wo_tile_fwd_mlp_dim=1024, wo_tile_dlhs_batch_seq=512, wo_tile_dlhs_embed_dim=1024, wo_tile_dlhs_mlp_dim=1024, wo_tile_drhs_batch_seq=512, wo_tile_drhs_embed_dim=1024, wo_tile_drhs_mlp_dim=1024, merge_gating_gmm=False, num_experts=1, num_experts_per_tok=1, capacity_factor=-1.0, ragged_buffer_factor=-1.0, moe_expert_input_dim=-1, base_moe_mlp_dim=-1, load_balance_loss_weight=0.0, use_custom_sort_vjp=True, use_ring_of_experts=False, use_gather_mosaic_kernel=False, use_random_routing=False, interleave_moe_layer_step=1, moe_fsdp_use_two_stage_all_gather=False, shard_exp_on_fsdp=False, use_2d_fsdp_sharding=False, norm_topk_prob=False, float32_weight_sum=True, float32_gate_logits=False, prefuse_moe_weights=False, pagedattn_num_pages=64, pagedattn_tokens_per_page=32, pagedattn_pages_per_compute_block=4, pagedattn_max_pages_per_group=-1, pagedattn_head_dim_alignment=128, sa_block_q=512, sa_block_kv=512, sa_block_kv_compute=512, sa_block_q_dkv=512, sa_block_kv_dkv=512, sa_block_kv_dkv_compute=512, sa_block_q_dq=512, sa_block_kv_dq=512, sa_use_fused_bwd_kernel=False, sa_q_layout='HEAD_DIM_MINOR', sa_k_layout='HEAD_DIM_MINOR', sa_v_layout='HEAD_DIM_MINOR', use_max_logit_estimate=-1, cost_estimate_flops_fwd=-1, cost_estimate_flops_bwd=-1, dq_reduction_steps=0, use_splash_scheduler=False, use_qk_norm=False, temperature_tuning=False, use_indexer=False, indexer_head_dim=128, indexer_n_heads=64, indexer_topk=2048, indexer_sparse_training=False, indexer_loss_scaling_factor=0.0, moba=False, moba_chunk_size=1024, moba_topk=8, mla_naive_kvcache=True, q_lora_rank=0, kv_lora_rank=512, qk_nope_head_dim=128, qk_rope_head_dim=64, v_head_dim=128, attention='dot_product', attention_type='global', share_kv_projections=False, global_num_kv_heads=0, attention_sink=False, float32_qk_product=False, float32_logits=False, sliding_window_size=0, chunk_attn_window_size=0, attn_logits_soft_cap=None, use_post_attn_norm=False, use_post_ffw_norm=False, use_ragged_attention=False, use_tokamax_gmm=False, ragged_block_size=256, enable_padding_causal_mask=True, use_tokamax_splash=False, use_jax_splash=False, force_q_layout=False, use_qk_clip=False, qk_clip_threshold=100.0, logits_via_embedding=False, normalize_embedding_logits=True, logits_dot_in_fp32=False, cast_logits_to_fp32=True, final_logits_soft_cap=None, z_loss_multiplier=0.0, mtp_num_layers=0, mtp_loss_scaling_factor=0.1, mtp_eval_target_module=0, engram_layers=[], engram_num_heads=8, engram_head_dim=1280, engram_vocab_bases=[], engram_max_ngram_size=3, engram_kernel_size=4, engram_seed=0, decoder_block=<DecoderBlockType.LLAMA2: 'llama2'>, global_parameter_scale=1, base_emb_dim=4096, base_num_query_heads=32, base_num_kv_heads=8, base_mlp_dim=14336, dense_init_scale=1.0, base_num_decoder_layers=32, head_dim=128, attention_output_dim=-1, global_head_dim=0, mlp_activations=['silu', 'linear'], mlp_activations_limit=-1.0, normalization_layer_epsilon=1e-05, fused_qkv=False, attention_bias=False, fused_mlp=False, qk_norm_with_scale=True, v_norm_with_scale=True, quantization=<QuantizationType.NONE: ''>, replicate_quant_scale=False, quant_cfg_path='', quantize_kvcache=False, kv_quant_axis=<KvQuantAxis.HEADS_AND_DKV: 'heads_and_dkv'>, kv_quant_dtype='int8', quantization_local_shard_count=4, use_qwix_quantization=False, weight_quantization_calibration_method='absmax', act_quantization_calibration_method='absmax', bwd_quantization_calibration_method='absmax', dtype=<DType.BFLOAT16: 'bfloat16'>, grad_dtype=<DType.FLOAT32: 'float32'>, weight_dtype=<DType.BFLOAT16: 'bfloat16'>, matmul_precision=<MatmulPrecision.DEFAULT: 'default'>, activations_in_float32=False, dtype_mm='float32', elastic_enabled=False, elastic_timeout_seconds=300, elastic_max_retries=10, enable_multi_tier_checkpointing=False, local_checkpoint_directory='', local_checkpoint_period=0, multi_tier_checkpointing_backup_interval_minutes=0, mtc_data_parallelism=0, enable_emergency_checkpoint=False, use_replicator_service=False, replicator_backup_interval_minutes=0, checkpoint_storage_target_data_file_size_bytes=2147483648, checkpoint_storage_use_ocdbt=False, checkpoint_storage_use_zarr3=False, checkpoint_storage_concurrent_gb=96, load_parameters_path='gs://lance-maxtext/rl_ckpt_llama31_8b/rl_ckpt_llama31_8b/0/items', lora_input_adapters_path='', load_full_state_path='', enable_checkpointing=True, load_checkpoint_only_once=False, async_checkpointing=False, checkpoint_period=50, max_num_checkpoints_to_keep=10, enable_single_replica_ckpt_restoring=False, checkpoint_todelete_subdir=None, checkpoint_todelete_full_path=None, force_unroll=False, checkpoint_is_quantized=False, save_quantized_params_path='', enable_orbax_v1=False, checkpoint_conversion_fn=None, source_checkpoint_layout='orbax', save_checkpoint_on_completion=True, enable_continuous_checkpointing=False, colocated_python_checkpointing=False, enable_autocheckpoint=False, base_config='../base.yml', run_name='pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional', model_name='llama3.1-8b-Instruct', override_model_config=False, override_logical_axis_rules=False, log_config=False, debug_sharding=False, base_output_directory='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237', sharding_strategy=None, debug=Debug(rl=True), rl=RL(num_generations=2, num_iterations=1, grpo_beta=0.08, grpo_epsilon=0.2, loss_algo='grpo', use_agentic_rollout=False, max_concurrency=256, off_policy_steps=0, system_prompt='', degenerate_group_masking=True, epsilon_high=None)), decoder=NNXDecoder( # Param: 7,504,924,672 (15.0 GB), RngState: 588 (3.5 KB), Total: 7,504,925,260 (15.0 GB) config=MaxTextConfig(emb_dim=4096, mlp_dim=14336, moe_mlp_dim=-1, num_decoder_layers=32, num_kv_heads=8, num_query_heads=32, num_diloco_replicas=1, ici_parallelism=[1, 1, 1, -1, 1, 1, 1, 1, 1, 1, 1, 1], dcn_parallelism=[1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], using_pipeline_parallelism=False, context_parallel_size=1, num_target_devices=32, global_batch_size_to_train_on=384, global_batch_size_to_eval_on=384, global_batch_size_to_load=384, global_batch_size_to_load_eval=384, micro_batch_size_to_train_on=384, micro_batch_size_to_eval_on=384, checkpoint_dir='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237/pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional/checkpoints/', convert_checkpoint_if_possible=True, metrics_dir='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237/pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional/metrics/', tensorboard_dir='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237/pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional/tensorboard/', managed_mldiagnostics_dir='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237/pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional/managed-mldiagnostics', rampup_end_step=0, tensors_on_device=[], tensors_to_offload=['decoder_layer_input', 'query_proj', 'key_proj', 'value_proj'], global_batch_size_to_load_start=None, global_batch_size_to_load_increment=None, rampup_samples_per_increment_to_load=None, d_model_for_audio=256, encoder_attention_heads_for_audio=4, encoder_ffn_dim_for_audio=512, encoder_layers_for_audio=2, attention_dropout_for_audio=0.0, activation_dropout_for_audio=0.0, activation_function_for_audio='gelu', num_mel_bins_for_audio=128, max_source_positions_for_audio=1500, scale_embedding_for_audio=True, n_window_for_audio=50, n_window_infer_for_audio=800, conv_chunksize_for_audio=500, downsample_hidden_size_for_audio=256, output_dim_for_audio=512, num_conv_layers_for_audio=3, max_timescale_for_audio=10000.0, max_sample_len_for_audio=10000, projector_input_dim_for_vit=4096, projector_output_dim_for_vit=4096, pixel_shuffle_ratio_for_vit=0.5, projector_dropout_for_vit=0.0, hidden_size_for_vit=1408, intermediate_size_for_vit=5632, num_attention_heads_for_vit=16, num_channels_for_vit=3, tile_size_for_vit=336, patch_size_for_vit=14, conv_stride_for_vit=14, num_hidden_layers_for_vit=34, rope_theta_for_vit=10000, vision_output_dim_for_vit=4096, spatial_merge_size_for_vit=2, out_hidden_size_for_vit=512, temporal_patch_size_for_vit=2, num_position_embeddings_for_vit=1024, deepstack_visual_indexes_for_vit=[], vision_output_length=-1, use_multimodal=False, freeze_vision_encoder_params=True, freeze_audio_encoder_params=True, use_audio=False, image_size_for_vit=896, image_path='', image_placeholder='<|image|>', posemb_type_for_vit='learn', max_num_images_per_example=-1, video_path='', audio_path='', video_placeholder='<|video|>', audio_placeholder='<|audio|>', use_audio_in_video=False, use_mrope=False, mrope_section=[24, 20, 20], position_id_per_seconds=25, managed_mldiagnostics=False, managed_mldiagnostics_run_group='', enable_tensorboard=True, use_vertex_tensorboard=False, vertex_tensorboard_project='', vertex_tensorboard_region='', report_heartbeat_metric_for_gcp_monitoring=False, heartbeat_reporting_interval_in_seconds=5, report_performance_metric_for_gcp_monitoring=False, enable_goodput_recording=False, monitor_goodput=False, goodput_upload_interval_seconds=30, enable_pathways_goodput=False, monitor_step_time_deviation=True, step_deviation_interval_seconds=30, enable_gcp_goodput_metrics=True, enable_gcp_step_deviation_metrics=True, metrics_file='', gcs_metrics=False, save_config_to_gcs=False, record_internal_nn_metrics=0, prometheus_port=0, enable_checkpoint_cloud_logger=False, enable_tunix_perf_metrics=False, collect_stack_trace=False, stack_trace_to_cloud=False, stack_trace_interval_seconds=600, dump_hlo=False, dump_step=-1, dump_hlo_local_dir='/tmp/xla_dump/', dump_hlo_delete_local_after=True, dump_hlo_gcs_dir='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237/pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional/xla_dump', dump_hlo_module_name='jit_train_step', dump_hlo_local_module_name='jit_train_step', dump_hlo_xla_flags='--xla_dump_to=/tmp/xla_dump/ --xla_dump_large_constants --xla_dump_hlo_module_re=jit_train_step', dump_hlo_upload_all=False, dump_jaxpr=False, dump_jaxpr_local_dir='/tmp/jaxpr_dump/', dump_jaxpr_delete_local_after=True, dump_jaxpr_gcs_dir='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237/pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional/jaxpr_dump', profiler=<ProfilerType.XPLANE: 'xplane'>, upload_all_profiler_results=False, skip_first_n_steps_for_profiler=1, profiler_steps=5, profile_cleanly=True, profile_periodically_period=-1, hide_profiler_step_metric=False, enable_jax_profiler=False, jax_profiler_port=9999, xprof_tpu_power_trace_level=<XProfTPUPowerTraceMode.POWER_TRACE_NONE: 0>, xprof_e2e_enable_fw_throttle_event=False, xprof_e2e_enable_fw_power_level_event=False, xprof_e2e_enable_fw_thermal_event=False, profile_power_events=False, constant_bound_config=[], jax_cache_dir='~/jax_cache', jax_distributed_initialization_timeout=300, jax_debug_log_modules='', skip_jax_distributed_system=True, enable_single_controller=False, subslice_shape='', max_checkify=False, compiled_trainstep_file='', compile_topology='', compile_topology_num_slices=-1, enable_prefix_caching=False, prefix_caching_hbm_byte=10000000000, prefix_caching_dram_byte=100000000000, inference_microbenchmark_prefill_lengths='64,128,256,512,1024', inference_microbenchmark_stages='prefill,generate', inference_microbenchmark_loop_iters=10, inference_microbenchmark_log_file_path='', inference_microbenchmark_num_samples=[1, 2, 3, 4, 5], inference_metadata_file='', inference_benchmark_test=False, inference_server='MaxtextInterleavedServer', prefill_slice='v5e-16', generate_slice='v5e-16', stack_prefill_result_cache=False, prefill_cache_axis_order='1,2,0,3', ar_cache_axis_order='1,2,0,3', compute_axis_order='0,1,2,3', reshape_q=False, decode_sampling_strategy=<SamplingStrategy.GREEDY: 'greedy'>, decode_sampling_nucleus_p=1.0, decode_sampling_top_k=50, decode_sampling_temperature=0.9, max_target_length=1024, max_prefill_predict_length=256, prompt='I love to', load_from_prefill_dir=False, prefill_cache_dir='', autoregressive_decode_assert='', model_call_mode='', use_chunked_prefill=False, prefill_chunk_size=256, enable_model_warmup=False, enable_llm_inference_pool=False, multi_sampling=False, return_log_prob=False, vocab_size=128256, tokenizer_path='meta-llama/Llama-3.1-8B-Instruct', tokenizer_type=<TokenizerType.HUGGINGFACE: 'huggingface'>, use_chat_template=False, chat_template_path='maxtext/examples/chat_templates/gsm8k_rl.json', chat_template='', tokenize_train_data=True, tokenize_eval_data=True, add_bos=True, add_eos=True, use_truncation=True, num_vocab_tiling=1, grain_train_files='', grain_eval_files='', grain_train_mixture_config_path='', grain_file_type='arrayrecord', grain_worker_count=1, grain_per_worker_buffer_size=1, grain_worker_count_eval=1, grain_per_worker_buffer_size_eval=1, grain_ram_budget_mb=1024, grain_num_threads=16, grain_prefetch_buffer_size=500, grain_num_threads_eval=16, grain_prefetch_buffer_size_eval=500, grain_data_source_max_workers=16, grain_shuffle_buffer_size=100, hf_path='', hf_name='main', hf_data_dir=None, hf_train_files=None, hf_eval_split=None, hf_eval_files=None, hf_access_token='<redacted>', dataset_path='', dataset_name='openai/gsm8k', eval_dataset_name='openai/gsm8k', train_split='train', eval_split='test', dataset_type=<DatasetType.TFDS: 'tfds'>, per_device_batch_size=12.0, eval_per_device_batch_size=12.0, max_corpus_chars=10000000, train_data_columns=['text'], train_image_column='image', eval_data_columns=['text'], eval_image_column='image', packing=True, grain_packing_type='first_fit', max_segments_per_seq=-1, num_epoch=1, expansion_factor_real_data=-1.0, reuse_example_batch=0, generate_padding_batch_train=False, generate_padding_batch_eval=False, enable_rampup_batch_size=False, per_device_batch_size_start=4.0, per_device_batch_size_increment=2.0, global_rampup_samples=500, colocated_python_data_input=False, max_position_embeddings=163840, original_max_position_embeddings=4096, rope_factor=40, beta_fast=32, beta_slow=1, mscale=1.0, rope_interleave=True, rope_truncate=True, rope_attention_scaling=False, rope_type=<RopeType.DEFAULT: 'default'>, rope_use_scale=True, rope_min_timescale=1, rope_max_timescale=500000, rope_linear_scaling_factor=1.0, local_rope_max_timescale=-1, global_rope_max_timescale=-1, global_rope_proportion=0.25, local_rope_proportion=1.0, use_iota_embed=False, use_untrainable_positional_embedding=False, trainable_position_size=-1, nope_layer_interval=-1, reasoning_start_token='<reasoning>', reasoning_end_token='</reasoning>', solution_start_token='<answer>', solution_end_token='</answer>', reward_exact_answer=1.0, reward_exact_format_match=0.1, reward_white_space_format_match=1.0, reward_partial_format_match=0.0, reward_ratio_guess_to_answer_high=0.0, reward_ratio_guess_to_answer_low=0.0, penalty_incorrect_format=0.0, penalty_incorrect_answer=0.0, math_verify_timeout=120, math_verify_num_procs=None, eval_sampling_strategy='greedy', generation_configs={'greedy': {'eval_temperature': 0.01, 'eval_top_k': 1, 'eval_top_p': 1.0}, 'standard': {'eval_temperature': 0.7, 'eval_top_k': 50, 'eval_top_p': 0.95}, 'liberal': {'eval_temperature': 0.85, 'eval_top_k': 2000, 'eval_top_p': 1.0}}, num_eval_passes=1, eval_corr_lst=False, eval_make_lst=False, eval_mode='pass', batch_size=1, num_batches=2, num_test_batches=5, test_batch_start_index=0, train_fraction=1.0, train_micro_batch_size=-1, rollout_micro_batch_size=-1, num_generations=2, num_iterations=1, grpo_beta=0.08, grpo_epsilon=0.2, loss_algo='grpo', use_agentic_rollout=False, max_concurrency=256, off_policy_steps=0, system_prompt='', degenerate_group_masking=True, epsilon_high=None, kv_cache_buffer=256, hbm_utilization_vllm=0.72, swap_space_vllm_gb=2, enable_dp_attention=False, enable_expert_parallel=False, async_scheduling=False, max_num_batched_tokens=None, max_num_seqs=None, stop_strings=None, vllm_additional_config={}, vllm_hf_overrides={}, vllm_hf_config_path='', trainer_devices_fraction=0.5, sampler_devices_fraction=0.5, chips_per_vm=4, use_pathways=False, num_trainer_slices=-1, num_samplers_slices=-1, rollout_data_parallelism=2, rollout_tensor_parallelism=-1, rollout_expert_parallelism=1, student_overrides={}, teacher_overrides={}, offline_data_dir=None, distill_alpha=0.5, distill_temperature=1.0, distill_beta=0.0, distill_feature_loss_type='cosine', distill_layer_indices=None, distill_alpha_end=None, distill_alpha_schedule='constant', distill_temperature_end=None, distill_temperature_schedule='constant', distill_beta_end=None, distill_beta_schedule='constant', student_params_to_update=None, use_dpo=False, dpo_label_smoothing=0.0, dpo_beta=0.1, use_sft=False, sft_train_on_completion_only=False, formatting_func_path='', formatting_func_kwargs={}, use_grpo=True, muon_beta=0.95, muon_weight_decay=0.0, muon_consistent_rms=None, adam_b1=0.9, adam_b2=0.99, adam_eps=1e-08, adam_eps_root=0.0, adam_weight_decay=0.1, adamw_mask=[], mu_dtype=<DType.BFLOAT16: 'bfloat16'>, opt_type=<OptimizerType.ADAMW: 'adamw'>, skip_step_on_spikes=False, skip_step_interval=128, skip_step_scaling_factor=6.0, gradient_accumulation_steps=1, use_tunix_gradient_accumulation=False, gradient_clipping_threshold=0.1, learning_rate=3e-06, lr_schedule_type=<LearningRateScheduleType.COSINE: 'cosine'>, learning_rate_final_fraction=0.1, wsd_decay_steps_fraction=0.1, wsd_decay_style=<WsdDecayStyle.LINEAR: 'linear'>, warmup_steps_fraction=0.1, learning_rate_schedule_steps=150001, trainable_parameters_mask=[], enable_diloco=False, diloco_sync_period=36, diloco_outer_lr=0.3, diloco_outer_momentum=0.9, mhc_expansion_rate=1, sinkhorn_iterations=20, steps=150001, log_period=20, eval_interval=10, eval_steps=-1, target_eval_loss=0.0, abort_on_nan_loss=True, abort_on_inf_loss=True, enable_dropout=False, dropout_rate=0.0, enable_data_shuffling=True, data_shuffle_seed=42, init_weights_seed=0, remat_policy='custom', remat_policy_for_vit='minimal', decoder_layer_input=<RematLocation.OFFLOAD: 'offload'>, context=<RematLocation.REMAT: 'remat'>, mlpwi=<RematLocation.REMAT: 'remat'>, mlpwi_0=<RematLocation.REMAT: 'remat'>, mlpwi_1=<RematLocation.REMAT: 'remat'>, mlpwo=<RematLocation.REMAT: 'remat'>, moe_mlpwi_0=<RematLocation.REMAT: 'remat'>, moe_mlpwi_1=<RematLocation.REMAT: 'remat'>, moe_mlpwo=<RematLocation.REMAT: 'remat'>, query_proj=<RematLocation.OFFLOAD: 'offload'>, key_proj=<RematLocation.OFFLOAD: 'offload'>, value_proj=<RematLocation.OFFLOAD: 'offload'>, query_wa_proj=<RematLocation.REMAT: 'remat'>, kv_wa_proj=<RematLocation.REMAT: 'remat'>, qkv_proj=<RematLocation.REMAT: 'remat'>, out_proj=<RematLocation.REMAT: 'remat'>, mla_q=<RematLocation.REMAT: 'remat'>, mla_kv=<RematLocation.REMAT: 'remat'>, attention_out=<RematLocation.REMAT: 'remat'>, engram=<RematLocation.REMAT: 'remat'>, optimizer_memory_host_offload=False, parameter_memory_host_offload=False, pipeline_fsdp_ag_per_repeat=False, num_layers_per_pipeline_stage=1, num_pipeline_repeats=-1, pipeline_parallel_layers=32, num_pipeline_microbatches=-1, pipeline_delay_activation_forwarding=False, pipeline_fsdp_ag_once=False, scan_pipeline_iterations=True, scan_pipeline_repeats=False, scan_layers_per_stage=False, set_remat_policy_on_pipeline_iterations=True, set_remat_policy_on_layers_per_stage=False, ici_diloco_parallelism=1, ici_data_parallelism=1, ici_fsdp_parallelism=-1, ici_fsdp_transpose_parallelism=1, ici_sequence_parallelism=1, ici_context_parallelism=1, ici_context_autoregressive_parallelism=1, ici_tensor_parallelism=1, ici_tensor_transpose_parallelism=1, ici_tensor_sequence_parallelism=1, ici_autoregressive_parallelism=1, ici_pipeline_parallelism=1, ici_expert_parallelism=1, dcn_diloco_parallelism=1, dcn_data_parallelism=-1, dcn_fsdp_parallelism=1, dcn_fsdp_transpose_parallelism=1, dcn_sequence_parallelism=1, dcn_context_parallelism=1, dcn_context_autoregressive_parallelism=1, dcn_tensor_parallelism=1, dcn_tensor_transpose_parallelism=1, dcn_tensor_sequence_parallelism=1, dcn_pipeline_parallelism=1, dcn_expert_parallelism=1, dcn_autoregressive_parallelism=1, logical_axis_rules=[['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'context', 'expert']], ['activation_vocab', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_vocab', ['tensor', 'tensor_transpose']], ['activation_vocab', 'tensor_sequence'], ['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['embed_vocab', ['fsdp', 'fsdp_transpose', 'context', 'expert']], ['activation_batch_attn', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['activation_kv_heads', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_length_attn', ['context']], ['activation_q_length', ['context']], ['activation_kv_length', []], ['activation_embed_attn', ['tensor', 'tensor_transpose']], ['activation_kv', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_kv_head_dim', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['q_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['kv_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['qkv', []], ['kv', []], ['kv_head_dim', []], ['q_lora', ['fsdp', 'fsdp_transpose', 'context', 'tensor_transpose', 'expert']], ['q_lora', ['fsdp', 'context', 'tensor_transpose', 'expert']], ['q_lora', ['fsdp', 'fsdp_transpose', 'context', 'expert']], ['q_lora', ['fsdp', 'context', 'expert']], ['q_lora_up_proj', []], ['kv_lora', ['fsdp', 'fsdp_transpose', 'context', 'tensor_transpose', 'expert']], ['kv_lora', ['fsdp', 'context', 'tensor_transpose', 'expert']], ['kv_lora', ['fsdp', 'fsdp_transpose', 'context', 'expert']], ['kv_lora', ['fsdp', 'context', 'expert']], ['kv_lora_up_proj', []], ['activation_batch_moe', ['data', 'fsdp', 'fsdp_transpose']], ['activation_length_moe', ['context']], ['activation_norm_length_moe', ['tensor_sequence', 'context']], ['activation_embed_moe', ['tensor', 'tensor_transpose']], ['activation_mlp_moe', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_exp', ['expert']], ['exp', 'expert'], ['mlp_moe', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']], ['embed_moe', ['fsdp', 'fsdp_transpose', 'tensor_transpose', 'context']], ['embed_moe', ['fsdp', 'tensor_transpose', 'context']], ['embed_moe', ['fsdp', 'fsdp_transpose', 'context']], ['embed_moe', ['fsdp', 'context']], ['activation_mlp', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_length', ['context']], ['activation_norm_length', ['tensor_sequence', 'context']], ['activation_embed', ['tensor', 'tensor_transpose']], ['activation_stage', 'stage'], ['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']], ['embed', ['fsdp', 'fsdp_transpose', 'tensor_transpose', 'context', 'expert']], ['embed', ['fsdp', 'tensor_transpose', 'context', 'expert']], ['embed', ['fsdp', 'fsdp_transpose', 'context', 'expert']], ['embed', ['fsdp', 'context', 'expert']], ['norm', ['tensor', 'tensor_transpose']], ['layers', 'stage'], ['diloco', 'diloco'], ['engram_dim', ['tensor']], ['dense_layers', []], ['moe_layers', []], ['mhc', []], ['prefill_activation_length', ['context']], ['prefill_activation_norm_length', ['tensor_sequence', 'context']], ['activation_prefill_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['decode_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['decode_length', []], ['cache_heads', ['autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence']], ['cache_heads', ['autoregressive', 'tensor', 'tensor_sequence']], ['paged_kv_heads', ['tensor']], ['cache_batch_prefill', []], ['cache_batch', []], ['cache_heads_none', []], ['cache_kv', []], ['cache_sequence', []], ['num_pages', []], ['tokens_per_page', []], ['paged_kv_head_dim_size', []], ['mlp_no_fsdp', ['tensor', 'tensor_sequence', 'autoregressive']], ['embed_tensor_transpose', ['tensor_transpose']], ['exp_with_fsdp', 'fsdp']], data_sharding=[['data', 'stage', 'fsdp', 'fsdp_transpose', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']], context_sharding='context', input_data_sharding_logical_axes=['activation_embed_and_logits_batch', 'activation_norm_length'], sharding_tolerance=0.02, shard_optimizer_over_data=False, internal_compile=False, internal_compile_num_devices=-1, compile_xla_flags='', hardware='tpu', num_slices=1, mesh_axes=['diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive'], shard_mode=<ShardMode.AUTO: 'auto'>, inhomogeneous_layer_cycle_interval=1, scan_layers=True, param_scan_axis=1, context_parallel_load_balance=True, context_parallel_strategy='all_gather', context_parallel_reorder_strategy=<ReorderStrategy.AUTO: 'auto'>, custom_mesh='', custom_mesh_and_rule='', allow_split_physical_axes=False, enable_nnx=True, optimize_mesh_for_tpu_v6e=False, shardy=True, pure_nnx_decoder=True, pure_nnx=True, remove_size_one_mesh_axis_from_type=True, gdn_conv_kernel_dim=4, gdn_key_head_dim=128, gdn_value_head_dim=128, gdn_num_key_heads=16, gdn_num_value_heads=32, gdn_chunk_size=64, use_qk_norm_in_gdn=True, partial_rotary_factor=1.0, first_num_dense_layers=0, shared_experts=0, routed_scaling_factor=1.0, routed_score_func='', routed_bias=False, routed_bias_update_rate=0.0, mlp_bias=False, n_routing_groups=-1, topk_routing_group=-1, use_batch_split_schedule=False, batch_split_factor=1, megablox=True, sparse_matmul=True, wi_tile_fwd_batch_seq=512, wi_tile_fwd_embed_dim=1024, wi_tile_fwd_mlp_dim=1024, wi_tile_dlhs_batch_seq=512, wi_tile_dlhs_embed_dim=1024, wi_tile_dlhs_mlp_dim=1024, wi_tile_drhs_batch_seq=512, wi_tile_drhs_embed_dim=1024, wi_tile_drhs_mlp_dim=1024, wo_tile_fwd_batch_seq=512, wo_tile_fwd_embed_dim=1024, wo_tile_fwd_mlp_dim=1024, wo_tile_dlhs_batch_seq=512, wo_tile_dlhs_embed_dim=1024, wo_tile_dlhs_mlp_dim=1024, wo_tile_drhs_batch_seq=512, wo_tile_drhs_embed_dim=1024, wo_tile_drhs_mlp_dim=1024, merge_gating_gmm=False, num_experts=1, num_experts_per_tok=1, capacity_factor=-1.0, ragged_buffer_factor=-1.0, moe_expert_input_dim=-1, base_moe_mlp_dim=-1, load_balance_loss_weight=0.0, use_custom_sort_vjp=True, use_ring_of_experts=False, use_gather_mosaic_kernel=False, use_random_routing=False, interleave_moe_layer_step=1, moe_fsdp_use_two_stage_all_gather=False, shard_exp_on_fsdp=False, use_2d_fsdp_sharding=False, norm_topk_prob=False, float32_weight_sum=True, float32_gate_logits=False, prefuse_moe_weights=False, pagedattn_num_pages=64, pagedattn_tokens_per_page=32, pagedattn_pages_per_compute_block=4, pagedattn_max_pages_per_group=-1, pagedattn_head_dim_alignment=128, sa_block_q=512, sa_block_kv=512, sa_block_kv_compute=512, sa_block_q_dkv=512, sa_block_kv_dkv=512, sa_block_kv_dkv_compute=512, sa_block_q_dq=512, sa_block_kv_dq=512, sa_use_fused_bwd_kernel=False, sa_q_layout='HEAD_DIM_MINOR', sa_k_layout='HEAD_DIM_MINOR', sa_v_layout='HEAD_DIM_MINOR', use_max_logit_estimate=-1, cost_estimate_flops_fwd=-1, cost_estimate_flops_bwd=-1, dq_reduction_steps=0, use_splash_scheduler=False, use_qk_norm=False, temperature_tuning=False, use_indexer=False, indexer_head_dim=128, indexer_n_heads=64, indexer_topk=2048, indexer_sparse_training=False, indexer_loss_scaling_factor=0.0, moba=False, moba_chunk_size=1024, moba_topk=8, mla_naive_kvcache=True, q_lora_rank=0, kv_lora_rank=512, qk_nope_head_dim=128, qk_rope_head_dim=64, v_head_dim=128, attention='dot_product', attention_type='global', share_kv_projections=False, global_num_kv_heads=0, attention_sink=False, float32_qk_product=False, float32_logits=False, sliding_window_size=0, chunk_attn_window_size=0, attn_logits_soft_cap=None, use_post_attn_norm=False, use_post_ffw_norm=False, use_ragged_attention=False, use_tokamax_gmm=False, ragged_block_size=256, enable_padding_causal_mask=True, use_tokamax_splash=False, use_jax_splash=False, force_q_layout=False, use_qk_clip=False, qk_clip_threshold=100.0, logits_via_embedding=False, normalize_embedding_logits=True, logits_dot_in_fp32=False, cast_logits_to_fp32=True, final_logits_soft_cap=None, z_loss_multiplier=0.0, mtp_num_layers=0, mtp_loss_scaling_factor=0.1, mtp_eval_target_module=0, engram_layers=[], engram_num_heads=8, engram_head_dim=1280, engram_vocab_bases=[], engram_max_ngram_size=3, engram_kernel_size=4, engram_seed=0, decoder_block=<DecoderBlockType.LLAMA2: 'llama2'>, global_parameter_scale=1, base_emb_dim=4096, base_num_query_heads=32, base_num_kv_heads=8, base_mlp_dim=14336, dense_init_scale=1.0, base_num_decoder_layers=32, head_dim=128, attention_output_dim=-1, global_head_dim=0, mlp_activations=['silu', 'linear'], mlp_activations_limit=-1.0, normalization_layer_epsilon=1e-05, fused_qkv=False, attention_bias=False, fused_mlp=False, qk_norm_with_scale=True, v_norm_with_scale=True, quantization=<QuantizationType.NONE: ''>, replicate_quant_scale=False, quant_cfg_path='', quantize_kvcache=False, kv_quant_axis=<KvQuantAxis.HEADS_AND_DKV: 'heads_and_dkv'>, kv_quant_dtype='int8', quantization_local_shard_count=4, use_qwix_quantization=False, weight_quantization_calibration_method='absmax', act_quantization_calibration_method='absmax', bwd_quantization_calibration_method='absmax', dtype=<DType.BFLOAT16: 'bfloat16'>, grad_dtype=<DType.FLOAT32: 'float32'>, weight_dtype=<DType.BFLOAT16: 'bfloat16'>, matmul_precision=<MatmulPrecision.DEFAULT: 'default'>, activations_in_float32=False, dtype_mm='float32', elastic_enabled=False, elastic_timeout_seconds=300, elastic_max_retries=10, enable_multi_tier_checkpointing=False, local_checkpoint_directory='', local_checkpoint_period=0, multi_tier_checkpointing_backup_interval_minutes=0, mtc_data_parallelism=0, enable_emergency_checkpoint=False, use_replicator_service=False, replicator_backup_interval_minutes=0, checkpoint_storage_target_data_file_size_bytes=2147483648, checkpoint_storage_use_ocdbt=False, checkpoint_storage_use_zarr3=False, checkpoint_storage_concurrent_gb=96, load_parameters_path='gs://lance-maxtext/rl_ckpt_llama31_8b/rl_ckpt_llama31_8b/0/items', lora_input_adapters_path='', load_full_state_path='', enable_checkpointing=True, load_checkpoint_only_once=False, async_checkpointing=False, checkpoint_period=50, max_num_checkpoints_to_keep=10, enable_single_replica_ckpt_restoring=False, checkpoint_todelete_subdir=None, checkpoint_todelete_full_path=None, force_unroll=False, checkpoint_is_quantized=False, save_quantized_params_path='', enable_orbax_v1=False, checkpoint_conversion_fn=None, source_checkpoint_layout='orbax', save_checkpoint_on_completion=True, enable_continuous_checkpointing=False, colocated_python_checkpointing=False, enable_autocheckpoint=False, base_config='../base.yml', run_name='pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional', model_name='llama3.1-8b-Instruct', override_model_config=False, override_logical_axis_rules=False, log_config=False, debug_sharding=False, base_output_directory='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237', sharding_strategy=None, debug=Debug(rl=True), rl=RL(num_generations=2, num_iterations=1, grpo_beta=0.08, grpo_epsilon=0.2, loss_algo='grpo', use_agentic_rollout=False, max_concurrency=256, off_policy_steps=0, system_prompt='', degenerate_group_masking=True, epsilon_high=None)), decoder_norm=RMSNorm( # Param: 4,096 (8.2 KB) dtype=<DType.BFLOAT16: 'bfloat16'>, epsilon=1e-05, kernel_axes=('norm',), num_features=4096, parameter_memory_host_offload=False, scale=Param( # 4,096 (8.2 KB) value=Array(shape=(4096,), dtype=dtype(bfloat16)), eager_sharding=False, out_sharding=('norm',) ), scale_init=<function ones at 0x7dba2a2d6980>, scale_offset=0.0, shard_mode=<ShardMode.AUTO: 'auto'>, weight_dtype=<DType.BFLOAT16: 'bfloat16'>, with_scale=True ), dropout=Dropout( # RngState: 6 (36 B) broadcast_dims=(-2,), deterministic=False, rate=0.0, rng_collection='dropout', rngs=Rngs( # RngState: 6 (36 B) aqt=RngStream( # RngState: 2 (12 B) count=RngCount( # 1 (4 B) value=Array(0, dtype=uint32), eager_sharding=False, tag='aqt' ), key=RngKey( # 1 (8 B) value=Array((), dtype=key<fry>) overlaying: [2799984767 1105366846], eager_sharding=False, tag='aqt' ), tag='aqt' ), dropout=RngStream( # RngState: 2 (12 B) count=RngCount( # 1 (4 B) value=Array(0, dtype=uint32), eager_sharding=False, tag='dropout' ), key=RngKey( # 1 (8 B) value=Array((), dtype=key<fry>) overlaying: [346279018 360566543], eager_sharding=False, tag='dropout' ), tag='dropout' ), params=RngStream( # RngState: 2 (12 B) count=RngCount( # 1 (4 B) value=Array(0, dtype=uint32), I0424 07:37:54.648406 138242845357888 train_rl.py:577] Reference mesh shape: OrderedDict({'diloco': 1, 'data': 1, 'stage': 1, 'fsdp': 32, 'fsdp_transpose': 1, 'context': 1, 'context_autoregressive': 1, 'tensor': 1, 'tensor_transpose': 1, 'tensor_sequence': 1, 'expert': 1, 'autoregressive': 1}) I0424 07:37:54.648467 138242845357888 train_rl.py:578] Policy Model initialized successfully eager_sharding=False, tag='params' ), key=RngKey( # 1 (8 B) value=Array((), dtype=key<fry>) overlaying: [2839387376 2467677468], eager_sharding=False, tag='params' ), tag='params' ) ) ), is_deepseek=False, is_gemma3=False, layers=LlamaDecoderLayer( # RngState: 576 (3.5 KB), Param: 6,979,584,000 (14.0 GB), Total: 6,979,584,576 (14.0 GB) activation_axis_names=('activation_batch', 'activation_norm_length', 'activation_embed'), config=MaxTextConfig(emb_dim=4096, mlp_dim=14336, moe_mlp_dim=-1, num_decoder_layers=32, num_kv_heads=8, num_query_heads=32, num_diloco_replicas=1, ici_parallelism=[1, 1, 1, -1, 1, 1, 1, 1, 1, 1, 1, 1], dcn_parallelism=[1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], using_pipeline_parallelism=False, context_parallel_size=1, num_target_devices=32, global_batch_size_to_train_on=384, global_batch_size_to_eval_on=384, global_batch_size_to_load=384, global_batch_size_to_load_eval=384, micro_batch_size_to_train_on=384, micro_batch_size_to_eval_on=384, checkpoint_dir='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237/pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional/checkpoints/', convert_checkpoint_if_possible=True, metrics_dir='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237/pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional/metrics/', tensorboard_dir='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237/pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional/tensorboard/', managed_mldiagnostics_dir='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237/pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional/managed-mldiagnostics', rampup_end_step=0, tensors_on_device=[], tensors_to_offload=['decoder_layer_input', 'query_proj', 'key_proj', 'value_proj'], global_batch_size_to_load_start=None, global_batch_size_to_load_increment=None, rampup_samples_per_increment_to_load=None, d_model_for_audio=256, encoder_attention_heads_for_audio=4, encoder_ffn_dim_for_audio=512, encoder_layers_for_audio=2, attention_dropout_for_audio=0.0, activation_dropout_for_audio=0.0, activation_function_for_audio='gelu', num_mel_bins_for_audio=128, max_source_positions_for_audio=1500, scale_embedding_for_audio=True, n_window_for_audio=50, n_window_infer_for_audio=800, conv_chunksize_for_audio=500, downsample_hidden_size_for_audio=256, output_dim_for_audio=512, num_conv_layers_for_audio=3, max_timescale_for_audio=10000.0, max_sample_len_for_audio=10000, projector_input_dim_for_vit=4096, projector_output_dim_for_vit=4096, pixel_shuffle_ratio_for_vit=0.5, projector_dropout_for_vit=0.0, hidden_size_for_vit=1408, intermediate_size_for_vit=5632, num_attention_heads_for_vit=16, num_channels_for_vit=3, tile_size_for_vit=336, patch_size_for_vit=14, conv_stride_for_vit=14, num_hidden_layers_for_vit=34, rope_theta_for_vit=10000, vision_output_dim_for_vit=4096, spatial_merge_size_for_vit=2, out_hidden_size_for_vit=512, temporal_patch_size_for_vit=2, num_position_embeddings_for_vit=1024, deepstack_visual_indexes_for_vit=[], vision_output_length=-1, use_multimodal=False, freeze_vision_encoder_params=True, freeze_audio_encoder_params=True, use_audio=False, image_size_for_vit=896, image_path='', image_placeholder='<|image|>', posemb_type_for_vit='learn', max_num_images_per_example=-1, video_path='', audio_path='', video_placeholder='<|video|>', audio_placeholder='<|audio|>', use_audio_in_video=False, use_mrope=False, mrope_section=[24, 20, 20], position_id_per_seconds=25, managed_mldiagnostics=False, managed_mldiagnostics_run_group='', enable_tensorboard=True, use_vertex_tensorboard=False, vertex_tensorboard_project='', vertex_tensorboard_region='', report_heartbeat_metric_for_gcp_monitoring=False, heartbeat_reporting_interval_in_seconds=5, report_performance_metric_for_gcp_monitoring=False, enable_goodput_recording=False, monitor_goodput=False, goodput_upload_interval_seconds=30, enable_pathways_goodput=False, monitor_step_time_deviation=True, step_deviation_interval_seconds=30, enable_gcp_goodput_metrics=True, enable_gcp_step_deviation_metrics=True, metrics_file='', gcs_metrics=False, save_config_to_gcs=False, record_internal_nn_metrics=0, prometheus_port=0, enable_checkpoint_cloud_logger=False, enable_tunix_perf_metrics=False, collect_stack_trace=False, stack_trace_to_cloud=False, stack_trace_interval_seconds=600, dump_hlo=False, dump_step=-1, dump_hlo_local_dir='/tmp/xla_dump/', dump_hlo_delete_local_after=True, dump_hlo_gcs_dir='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237/pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional/xla_dump', dump_hlo_module_name='jit_train_step', dump_hlo_local_module_name='jit_train_step', dump_hlo_xla_flags='--xla_dump_to=/tmp/xla_dump/ --xla_dump_large_constants --xla_dump_hlo_module_re=jit_train_step', dump_hlo_upload_all=False, dump_jaxpr=False, dump_jaxpr_local_dir='/tmp/jaxpr_dump/', dump_jaxpr_delete_local_after=True, dump_jaxpr_gcs_dir='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237/pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional/jaxpr_dump', profiler=<ProfilerType.XPLANE: 'xplane'>, upload_all_profiler_results=False, skip_first_n_steps_for_profiler=1, profiler_steps=5, profile_cleanly=True, profile_periodically_period=-1, hide_profiler_step_metric=False, enable_jax_profiler=False, jax_profiler_port=9999, xprof_tpu_power_trace_level=<XProfTPUPowerTraceMode.POWER_TRACE_NONE: 0>, xprof_e2e_enable_fw_throttle_event=False, xprof_e2e_enable_fw_power_level_event=False, xprof_e2e_enable_fw_thermal_event=False, profile_power_events=False, constant_bound_config=[], jax_cache_dir='~/jax_cache', jax_distributed_initialization_timeout=300, jax_debug_log_modules='', skip_jax_distributed_system=True, enable_single_controller=False, subslice_shape='', max_checkify=False, compiled_trainstep_file='', compile_topology='', compile_topology_num_slices=-1, enable_prefix_caching=False, prefix_caching_hbm_byte=10000000000, prefix_caching_dram_byte=100000000000, inference_microbenchmark_prefill_lengths='64,128,256,512,1024', inference_microbenchmark_stages='prefill,generate', inference_microbenchmark_loop_iters=10, inference_microbenchmark_log_file_path='', inference_microbenchmark_num_samples=[1, 2, 3, 4, 5], inference_metadata_file='', inference_benchmark_test=False, inference_server='MaxtextInterleavedServer', prefill_slice='v5e-16', generate_slice='v5e-16', stack_prefill_result_cache=False, prefill_cache_axis_order='1,2,0,3', ar_cache_axis_order='1,2,0,3', compute_axis_order='0,1,2,3', reshape_q=False, decode_sampling_strategy=<SamplingStrategy.GREEDY: 'greedy'>, decode_sampling_nucleus_p=1.0, decode_sampling_top_k=50, decode_sampling_temperature=0.9, max_target_length=1024, max_prefill_predict_length=256, prompt='I love to', load_from_prefill_dir=False, prefill_cache_dir='', autoregressive_decode_assert='', model_call_mode='', use_chunked_prefill=False, prefill_chunk_size=256, enable_model_warmup=False, enable_llm_inference_pool=False, multi_sampling=False, return_log_prob=False, vocab_size=128256, tokenizer_path='meta-llama/Llama-3.1-8B-Instruct', tokenizer_type=<TokenizerType.HUGGINGFACE: 'huggingface'>, use_chat_template=False, chat_template_path='maxtext/examples/chat_templates/gsm8k_rl.json', chat_template='', tokenize_train_data=True, tokenize_eval_data=True, add_bos=True, add_eos=True, use_truncation=True, num_vocab_tiling=1, grain_train_files='', grain_eval_files='', grain_train_mixture_config_path='', grain_file_type='arrayrecord', grain_worker_count=1, grain_per_worker_buffer_size=1, grain_worker_count_eval=1, grain_per_worker_buffer_size_eval=1, grain_ram_budget_mb=1024, grain_num_threads=16, grain_prefetch_buffer_size=500, grain_num_threads_eval=16, grain_prefetch_buffer_size_eval=500, grain_data_source_max_workers=16, grain_shuffle_buffer_size=100, hf_path='', hf_name='main', hf_data_dir=None, hf_train_files=None, hf_eval_split=None, hf_eval_files=None, hf_access_token='<redacted>', dataset_path='', dataset_name='openai/gsm8k', eval_dataset_name='openai/gsm8k', train_split='train', eval_split='test', dataset_type=<DatasetType.TFDS: 'tfds'>, per_device_batch_size=12.0, eval_per_device_batch_size=12.0, max_corpus_chars=10000000, train_data_columns=['text'], train_image_column='image', eval_data_columns=['text'], eval_image_column='image', packing=True, grain_packing_type='first_fit', max_segments_per_seq=-1, num_epoch=1, expansion_factor_real_data=-1.0, reuse_example_batch=0, generate_padding_batch_train=False, generate_padding_batch_eval=False, enable_rampup_batch_size=False, per_device_batch_size_start=4.0, per_device_batch_size_increment=2.0, global_rampup_samples=500, colocated_python_data_input=False, max_position_embeddings=163840, original_max_position_embeddings=4096, rope_factor=40, beta_fast=32, beta_slow=1, mscale=1.0, rope_interleave=True, rope_truncate=True, rope_attention_scaling=False, rope_type=<RopeType.DEFAULT: 'default'>, rope_use_scale=True, rope_min_timescale=1, rope_max_timescale=500000, rope_linear_scaling_factor=1.0, local_rope_max_timescale=-1, global_rope_max_timescale=-1, global_rope_proportion=0.25, local_rope_proportion=1.0, use_iota_embed=False, use_untrainable_positional_embedding=False, trainable_position_size=-1, nope_layer_interval=-1, reasoning_start_token='<reasoning>', reasoning_end_token='</reasoning>', solution_start_token='<answer>', solution_end_token='</answer>', reward_exact_answer=1.0, reward_exact_format_match=0.1, reward_white_space_format_match=1.0, reward_partial_format_match=0.0, reward_ratio_guess_to_answer_high=0.0, reward_ratio_guess_to_answer_low=0.0, penalty_incorrect_format=0.0, penalty_incorrect_answer=0.0, math_verify_timeout=120, math_verify_num_procs=None, eval_sampling_strategy='greedy', generation_configs={'greedy': {'eval_temperature': 0.01, 'eval_top_k': 1, 'eval_top_p': 1.0}, 'standard': {'eval_temperature': 0.7, 'eval_top_k': 50, 'eval_top_p': 0.95}, 'liberal': {'eval_temperature': 0.85, 'eval_top_k': 2000, 'eval_top_p': 1.0}}, num_eval_passes=1, eval_corr_lst=False, eval_make_lst=False, eval_mode='pass', batch_size=1, num_batches=2, num_test_batches=5, test_batch_start_index=0, train_fraction=1.0, train_micro_batch_size=-1, rollout_micro_batch_size=-1, num_generations=2, num_iterations=1, grpo_beta=0.08, grpo_epsilon=0.2, loss_algo='grpo', use_agentic_rollout=False, max_concurrency=256, off_policy_steps=0, system_prompt='', degenerate_group_masking=True, epsilon_high=None, kv_cache_buffer=256, hbm_utilization_vllm=0.72, swap_space_vllm_gb=2, enable_dp_attention=False, enable_expert_parallel=False, async_scheduling=False, max_num_batched_tokens=None, max_num_seqs=None, stop_strings=None, vllm_additional_config={}, vllm_hf_overrides={}, vllm_hf_config_path='', trainer_devices_fraction=0.5, sampler_devices_fraction=0.5, chips_per_vm=4, use_pathways=False, num_trainer_slices=-1, num_samplers_slices=-1, rollout_data_parallelism=2, rollout_tensor_parallelism=-1, rollout_expert_parallelism=1, student_overrides={}, teacher_overrides={}, offline_data_dir=None, distill_alpha=0.5, distill_temperature=1.0, distill_beta=0.0, distill_feature_loss_type='cosine', distill_layer_indices=None, distill_alpha_end=None, distill_alpha_schedule='constant', distill_temperature_end=None, distill_temperature_schedule='constant', distill_beta_end=None, distill_beta_schedule='constant', student_params_to_update=None, use_dpo=False, dpo_label_smoothing=0.0, dpo_beta=0.1, use_sft=False, sft_train_on_completion_only=False, formatting_func_path='', formatting_func_kwargs={}, use_grpo=True, muon_beta=0.95, muon_weight_decay=0.0, muon_consistent_rms=None, adam_b1=0.9, adam_b2=0.99, adam_eps=1e-08, adam_eps_root=0.0, adam_weight_decay=0.1, adamw_mask=[], mu_dtype=<DType.BFLOAT16: 'bfloat16'>, opt_type=<OptimizerType.ADAMW: 'adamw'>, skip_step_on_spikes=False, skip_step_interval=128, skip_step_scaling_factor=6.0, gradient_accumulation_steps=1, use_tunix_gradient_accumulation=False, gradient_clipping_threshold=0.1, learning_rate=3e-06, lr_schedule_type=<LearningRateScheduleType.COSINE: 'cosine'>, learning_rate_final_fraction=0.1, wsd_decay_steps_fraction=0.1, wsd_decay_style=<WsdDecayStyle.LINEAR: 'linear'>, warmup_steps_fraction=0.1, learning_rate_schedule_steps=150001, trainable_parameters_mask=[], enable_diloco=False, diloco_sync_period=36, diloco_outer_lr=0.3, diloco_outer_momentum=0.9, mhc_expansion_rate=1, sinkhorn_iterations=20, steps=150001, log_period=20, eval_interval=10, eval_steps=-1, target_eval_loss=0.0, abort_on_nan_loss=True, abort_on_inf_loss=True, enable_dropout=False, dropout_rate=0.0, enable_data_shuffling=True, data_shuffle_seed=42, init_weights_seed=0, remat_policy='custom', remat_policy_for_vit='minimal', decoder_layer_input=<RematLocation.OFFLOAD: 'offload'>, context=<RematLocation.REMAT: 'remat'>, mlpwi=<RematLocation.REMAT: 'remat'>, mlpwi_0=<RematLocation.REMAT: 'remat'>, mlpwi_1=<RematLocation.REMAT: 'remat'>, mlpwo=<RematLocation.REMAT: 'remat'>, moe_mlpwi_0=<RematLocation.REMAT: 'remat'>, moe_mlpwi_1=<RematLocation.REMAT: 'remat'>, moe_mlpwo=<RematLocation.REMAT: 'remat'>, query_proj=<RematLocation.OFFLOAD: 'offload'>, key_proj=<RematLocation.OFFLOAD: 'offload'>, value_proj=<RematLocation.OFFLOAD: 'offload'>, query_wa_proj=<RematLocation.REMAT: 'remat'>, kv_wa_proj=<RematLocation.REMAT: 'remat'>, qkv_proj=<RematLocation.REMAT: 'remat'>, out_proj=<RematLocation.REMAT: 'remat'>, mla_q=<RematLocation.REMAT: 'remat'>, mla_kv=<RematLocation.REMAT: 'remat'>, attention_out=<RematLocation.REMAT: 'remat'>, engram=<RematLocation.REMAT: 'remat'>, optimizer_memory_host_offload=False, parameter_memory_host_offload=False, pipeline_fsdp_ag_per_repeat=False, num_layers_per_pipeline_stage=1, num_pipeline_repeats=-1, pipeline_parallel_layers=32, num_pipeline_microbatches=-1, pipeline_delay_activation_forwarding=False, pipeline_fsdp_ag_once=False, scan_pipeline_iterations=True, scan_pipeline_repeats=False, scan_layers_per_stage=False, set_remat_policy_on_pipeline_iterations=True, set_remat_policy_on_layers_per_stage=False, ici_diloco_parallelism=1, ici_data_parallelism=1, ici_fsdp_parallelism=-1, ici_fsdp_transpose_parallelism=1, ici_sequence_parallelism=1, ici_context_parallelism=1, ici_context_autoregressive_parallelism=1, ici_tensor_parallelism=1, ici_tensor_transpose_parallelism=1, ici_tensor_sequence_parallelism=1, ici_autoregressive_parallelism=1, ici_pipeline_parallelism=1, ici_expert_parallelism=1, dcn_diloco_parallelism=1, dcn_data_parallelism=-1, dcn_fsdp_parallelism=1, dcn_fsdp_transpose_parallelism=1, dcn_sequence_parallelism=1, dcn_context_parallelism=1, dcn_context_autoregressive_parallelism=1, dcn_tensor_parallelism=1, dcn_tensor_transpose_parallelism=1, dcn_tensor_sequence_parallelism=1, dcn_pipeline_parallelism=1, dcn_expert_parallelism=1, dcn_autoregressive_parallelism=1, logical_axis_rules=[['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'context', 'expert']], ['activation_vocab', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_vocab', ['tensor', 'tensor_transpose']], ['activation_vocab', 'tensor_sequence'], ['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['embed_vocab', ['fsdp', 'fsdp_transpose', 'context', 'expert']], ['activation_batch_attn', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['activation_kv_heads', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_length_attn', ['context']], ['activation_q_length', ['context']], ['activation_kv_length', []], ['activation_embed_attn', ['tensor', 'tensor_transpose']], ['activation_kv', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_kv_head_dim', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['q_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['kv_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['qkv', []], ['kv', []], ['kv_head_dim', []], ['q_lora', ['fsdp', 'fsdp_transpose', 'context', 'tensor_transpose', 'expert']], ['q_lora', ['fsdp', 'context', 'tensor_transpose', 'expert']], ['q_lora', ['fsdp', 'fsdp_transpose', 'context', 'expert']], ['q_lora', ['fsdp', 'context', 'expert']], ['q_lora_up_proj', []], ['kv_lora', ['fsdp', 'fsdp_transpose', 'context', 'tensor_transpose', 'expert']], ['kv_lora', ['fsdp', 'context', 'tensor_transpose', 'expert']], ['kv_lora', ['fsdp', 'fsdp_transpose', 'context', 'expert']], ['kv_lora', ['fsdp', 'context', 'expert']], ['kv_lora_up_proj', []], ['activation_batch_moe', ['data', 'fsdp', 'fsdp_transpose']], ['activation_length_moe', ['context']], ['activation_norm_length_moe', ['tensor_sequence', 'context']], ['activation_embed_moe', ['tensor', 'tensor_transpose']], ['activation_mlp_moe', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_exp', ['expert']], ['exp', 'expert'], ['mlp_moe', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']], ['embed_moe', ['fsdp', 'fsdp_transpose', 'tensor_transpose', 'context']], ['embed_moe', ['fsdp', 'tensor_transpose', 'context']], ['embed_moe', ['fsdp', 'fsdp_transpose', 'context']], ['embed_moe', ['fsdp', 'context']], ['activation_mlp', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_length', ['context']], ['activation_norm_length', ['tensor_sequence', 'context']], ['activation_embed', ['tensor', 'tensor_transpose']], ['activation_stage', 'stage'], ['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']], ['embed', ['fsdp', 'fsdp_transpose', 'tensor_transpose', 'context', 'expert']], ['embed', ['fsdp', 'tensor_transpose', 'context', 'expert']], ['embed', ['fsdp', 'fsdp_transpose', 'context', 'expert']], ['embed', ['fsdp', 'context', 'expert']], ['norm', ['tensor', 'tensor_transpose']], ['layers', 'stage'], ['diloco', 'diloco'], ['engram_dim', ['tensor']], ['dense_layers', []], ['moe_layers', []], ['mhc', []], ['prefill_activation_length', ['context']], ['prefill_activation_norm_length', ['tensor_sequence', 'context']], ['activation_prefill_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['decode_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['decode_length', []], ['cache_heads', ['autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence']], ['cache_heads', ['autoregressive', 'tensor', 'tensor_sequence']], ['paged_kv_heads', ['tensor']], ['cache_batch_prefill', []], ['cache_batch', []], ['cache_heads_none', []], ['cache_kv', []], ['cache_sequence', []], ['num_pages', []], ['tokens_per_page', []], ['paged_kv_head_dim_size', []], ['mlp_no_fsdp', ['tensor', 'tensor_sequence', 'autoregressive']], ['embed_tensor_transpose', ['tensor_transpose']], ['exp_with_fsdp', 'fsdp']], data_sharding=[['data', 'stage', 'fsdp', 'fsdp_transpose', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']], context_sharding='context', input_data_sharding_logical_axes=['activation_embed_and_logits_batch', 'activation_norm_length'], sharding_tolerance=0.02, shard_optimizer_over_data=False, internal_compile=False, internal_compile_num_devices=-1, compile_xla_flags='', hardware='tpu', num_slices=1, mesh_axes=['diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive'], shard_mode=<ShardMode.AUTO: 'auto'>, inhomogeneous_layer_cycle_interval=1, scan_layers=True, param_scan_axis=1, context_parallel_load_balance=True, context_parallel_strategy='all_gather', context_parallel_reorder_strategy=<ReorderStrategy.AUTO: 'auto'>, custom_mesh='', custom_mesh_and_rule='', allow_split_physical_axes=False, enable_nnx=True, optimize_mesh_for_tpu_v6e=False, shardy=True, pure_nnx_decoder=True, pure_nnx=True, remove_size_one_mesh_axis_from_type=True, gdn_conv_kernel_dim=4, gdn_key_head_dim=128, gdn_value_head_dim=128, gdn_num_key_heads=16, gdn_num_value_heads=32, gdn_chunk_size=64, use_qk_norm_in_gdn=True, partial_rotary_factor=1.0, first_num_dense_layers=0, shared_experts=0, routed_scaling_factor=1.0, routed_score_func='', routed_bias=False, routed_bias_update_rate=0.0, mlp_bias=False, n_routing_groups=-1, topk_routing_group=-1, use_batch_split_schedule=False, batch_split_factor=1, megablox=True, sparse_matmul=True, wi_tile_fwd_batch_seq=512, wi_tile_fwd_embed_dim=1024, wi_tile_fwd_mlp_dim=1024, wi_tile_dlhs_batch_seq=512, wi_tile_dlhs_embed_dim=1024, wi_tile_dlhs_mlp_dim=1024, wi_tile_drhs_batch_seq=512, wi_tile_drhs_embed_dim=1024, wi_tile_drhs_mlp_dim=1024, wo_tile_fwd_batch_seq=512, wo_tile_fwd_embed_dim=1024, wo_tile_fwd_mlp_dim=1024, wo_tile_dlhs_batch_seq=512, wo_tile_dlhs_embed_dim=1024, wo_tile_dlhs_mlp_dim=1024, wo_tile_drhs_batch_seq=512, wo_tile_drhs_embed_dim=1024, wo_tile_drhs_mlp_dim=1024, merge_gating_gmm=False, num_experts=1, num_experts_per_tok=1, capacity_factor=-1.0, ragged_buffer_factor=-1.0, moe_expert_input_dim=-1, base_moe_mlp_dim=-1, load_balance_loss_weight=0.0, use_custom_sort_vjp=True, use_ring_of_experts=False, use_gather_mosaic_kernel=False, use_random_routing=False, interleave_moe_layer_step=1, moe_fsdp_use_two_stage_all_gather=False, shard_exp_on_fsdp=False, use_2d_fsdp_sharding=False, norm_topk_prob=False, float32_weight_sum=True, float32_gate_logits=False, prefuse_moe_weights=False, pagedattn_num_pages=64, pagedattn_tokens_per_page=32, pagedattn_pages_per_compute_block=4, pagedattn_max_pages_per_group=-1, pagedattn_head_dim_alignment=128, sa_block_q=512, sa_block_kv=512, sa_block_kv_compute=512, sa_block_q_dkv=512, sa_block_kv_dkv=512, sa_block_kv_dkv_compute=512, sa_block_q_dq=512, sa_block_kv_dq=512, sa_use_fused_bwd_kernel=False, sa_q_layout='HEAD_DIM_MINOR', sa_k_layout='HEAD_DIM_MINOR', sa_v_layout='HEAD_DIM_MINOR', use_max_logit_estimate=-1, cost_estimate_flops_fwd=-1, cost_estimate_flops_bwd=-1, dq_reduction_steps=0, use_splash_scheduler=False, use_qk_norm=False, temperature_tuning=False, use_indexer=False, indexer_head_dim=128, indexer_n_heads=64, indexer_topk=2048, indexer_sparse_training=False, indexer_loss_scaling_factor=0.0, moba=False, moba_chunk_size=1024, moba_topk=8, mla_naive_kvcache=True, q_lora_rank=0, kv_lora_rank=512, qk_nope_head_dim=128, qk_rope_head_dim=64, v_head_dim=128, attention='dot_product', attention_type='global', share_kv_projections=False, global_num_kv_heads=0, attention_sink=False, float32_qk_product=False, float32_logits=False, sliding_window_size=0, chunk_attn_window_size=0, attn_logits_soft_cap=None, use_post_attn_norm=False, use_post_ffw_norm=False, use_ragged_attention=False, use_tokamax_gmm=False, ragged_block_size=256, enable_padding_causal_mask=True, use_tokamax_splash=False, use_jax_splash=False, force_q_layout=False, use_qk_clip=False, qk_clip_threshold=100.0, logits_via_embedding=False, normalize_embedding_logits=True, logits_dot_in_fp32=False, cast_logits_to_fp32=True, final_logits_soft_cap=None, z_loss_multiplier=0.0, mtp_num_layers=0, mtp_loss_scaling_factor=0.1, mtp_eval_target_module=0, engram_layers=[], engram_num_heads=8, engram_head_dim=1280, engram_vocab_bases=[], engram_max_ngram_size=3, engram_kernel_size=4, engram_seed=0, decoder_block=<DecoderBlockType.LLAMA2: 'llama2'>, global_parameter_scale=1, base_emb_dim=4096, base_num_query_heads=32, base_num_kv_heads=8, base_mlp_dim=14336, dense_init_scale=1.0, base_num_decoder_layers=32, head_dim=128, attention_output_dim=-1, global_head_dim=0, mlp_activations=['silu', 'linear'], mlp_activations_limit=-1.0, normalization_layer_epsilon=1e-05, fused_qkv=False, attention_bias=False, fused_mlp=False, qk_norm_with_scale=True, v_norm_with_scale=True, quantization=<QuantizationType.NONE: ''>, replicate_quant_scale=False, quant_cfg_path='', quantize_kvcache=False, kv_quant_axis=<KvQuantAxis.HEADS_AND_DKV: 'heads_and_dkv'>, kv_quant_dtype='int8', quantization_local_shard_count=4, use_qwix_quantization=False, weight_quantization_calibration_method='absmax', act_quantization_calibration_method='absmax', bwd_quantization_calibration_method='absmax', dtype=<DType.BFLOAT16: 'bfloat16'>, grad_dtype=<DType.FLOAT32: 'float32'>, weight_dtype=<DType.BFLOAT16: 'bfloat16'>, matmul_precision=<MatmulPrecision.DEFAULT: 'default'>, activations_in_float32=False, dtype_mm='float32', elastic_enabled=False, elastic_timeout_seconds=300, elastic_max_retries=10, enable_multi_tier_checkpointing=False, local_checkpoint_directory='', local_checkpoint_period=0, multi_tier_checkpointing_backup_interval_minutes=0, mtc_data_parallelism=0, enable_emergency_checkpoint=False, use_replicator_service=False, replicator_backup_interval_minutes=0, checkpoint_storage_target_data_file_size_bytes=2147483648, checkpoint_storage_use_ocdbt=False, checkpoint_storage_use_zarr3=False, checkpoint_storage_concurrent_gb=96, load_parameters_path='gs://lance-maxtext/rl_ckpt_llama31_8b/rl_ckpt_llama31_8b/0/items', lora_input_adapters_path='', load_full_state_path='', enable_checkpointing=True, load_checkpoint_only_once=False, async_checkpointing=False, checkpoint_period=50, max_num_checkpoints_to_keep=10, enable_single_replica_ckpt_restoring=False, checkpoint_todelete_subdir=None, checkpoint_todelete_full_path=None, force_unroll=False, checkpoint_is_quantized=False, save_quantized_params_path='', enable_orbax_v1=False, checkpoint_conversion_fn=None, source_checkpoint_layout='orbax', save_checkpoint_on_completion=True, enable_continuous_checkpointing=False, colocated_python_checkpointing=False, enable_autocheckpoint=False, base_config='../base.yml', run_name='pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional', model_name='llama3.1-8b-Instruct', override_model_config=False, override_logical_axis_rules=False, log_config=False, debug_sharding=False, base_output_directory='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237', sharding_strategy=None, debug=Debug(rl=True), rl=RL(num_generations=2, num_iterations=1, grpo_beta=0.08, grpo_epsilon=0.2, loss_algo='grpo', use_agentic_rollout=False, max_concurrency=256, off_policy_steps=0, system_prompt='', degenerate_group_masking=True, epsilon_high=None)), dropout=Dropout( # RngState: 192 (1.2 KB) broadcast_dims=(-2,), deterministic=False, rate=0.0, rng_collection='dropout', rngs=Rngs( # RngState: 192 (1.2 KB) aqt=RngStream( # RngState: 64 (384 B) count=RngCount( # 32 (128 B) value=Array(shape=(32,), dtype=dtype('uint32')), eager_sharding=False, tag='aqt' ), key=RngKey( # 32 (256 B) value=Array(shape=(32,), dtype=key<fry>), eager_sharding=False, tag='aqt' ), tag='aqt' ), dropout=RngStream( # RngState: 64 (384 B) count=RngCount( # 32 (128 B) value=Array(shape=(32,), dtype=dtype('uint32')), eager_sharding=False, tag='dropout' ), key=RngKey( # 32 (256 B) value=Array(shape=(32,), dtype=key<fry>), eager_sharding=False, tag='dropout' ), tag='dropout' ), params=RngStream( # RngState: 64 (384 B) count=RngCount( # 32 (128 B) value=Array(shape=(32,), dtype=dtype('uint32')), eager_sharding=False, tag='params' ), key=RngKey( # 32 (256 B) value=Array(shape=(32,), dtype=key<fry>), eager_sharding=False, tag='params' ), tag='params' ) ) ), mesh=Mesh(axis_sizes=(1, 1, 1, 32, 1, 1, 1, 1, 1, 1, 1, 1), axis_names=('diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive'), axis_types=(Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto)), mlp=MlpBlock( # RngState: 192 (1.2 KB), Param: 5,637,144,576 (11.3 GB), Total: 5,637,144,768 (11.3 GB) activations=['silu', 'linear'], config=MaxTextConfig(emb_dim=4096, mlp_dim=14336, moe_mlp_dim=-1, num_decoder_layers=32, num_kv_heads=8, num_query_heads=32, num_diloco_replicas=1, ici_parallelism=[1, 1, 1, -1, 1, 1, 1, 1, 1, 1, 1, 1], dcn_parallelism=[1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], using_pipeline_parallelism=False, context_parallel_size=1, num_target_devices=32, global_batch_size_to_train_on=384, global_batch_size_to_eval_on=384, global_batch_size_to_load=384, global_batch_size_to_load_eval=384, micro_batch_size_to_train_on=384, micro_batch_size_to_eval_on=384, checkpoint_dir='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237/pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional/checkpoints/', convert_checkpoint_if_possible=True, metrics_dir='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237/pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional/metrics/', tensorboard_dir='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237/pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional/tensorboard/', managed_mldiagnostics_dir='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237/pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional/managed-mldiagnostics', rampup_end_step=0, tensors_on_device=[], tensors_to_offload=['decoder_layer_input', 'query_proj', 'key_proj', 'value_proj'], global_batch_size_to_load_start=None, global_batch_size_to_load_increment=None, rampup_samples_per_increment_to_load=None, d_model_for_audio=256, encoder_attention_heads_for_audio=4, encoder_ffn_dim_for_audio=512, encoder_layers_for_audio=2, attention_dropout_for_audio=0.0, activation_dropout_for_audio=0.0, activation_function_for_audio='gelu', num_mel_bins_for_audio=128, max_source_positions_for_audio=1500, scale_embedding_for_audio=True, n_window_for_audio=50, n_window_infer_for_audio=800, conv_chunksize_for_audio=500, downsample_hidden_size_for_audio=256, output_dim_for_audio=512, num_conv_layers_for_audio=3, max_timescale_for_audio=10000.0, max_sample_len_for_audio=10000, projector_input_dim_for_vit=4096, projector_output_dim_for_vit=4096, pixel_shuffle_ratio_for_vit=0.5, projector_dropout_for_vit=0.0, hidden_size_for_vit=1408, intermediate_size_for_vit=5632, num_attention_heads_for_vit=16, num_channels_for_vit=3, tile_size_for_vit=336, patch_size_for_vit=14, conv_stride_for_vit=14, num_hidden_layers_for_vit=34, rope_theta_for_vit=10000, vision_output_dim_for_vit=4096, spatial_merge_size_for_vit=2, out_hidden_size_for_vit=512, temporal_patch_size_for_vit=2, num_position_embeddings_for_vit=1024, deepstack_visual_indexes_for_vit=[], vision_output_length=-1, use_multimodal=False, freeze_vision_encoder_params=True, freeze_audio_encoder_params=True, use_audio=False, image_size_for_vit=896, image_path='', image_placeholder='<|image|>', posemb_type_for_vit='learn', max_num_images_per_example=-1, video_path='', audio_path='', video_placeholder='<|video|>', audio_placeholder='<|audio|>', use_audio_in_video=False, use_mrope=False, mrope_section=[24, 20, 20], position_id_per_seconds=25, managed_mldiagnostics=False, managed_mldiagnostics_run_group='', enable_tensorboard=True, use_vertex_tensorboard=False, vertex_tensorboard_project='', vertex_tensorboard_region='', report_heartbeat_metric_for_gcp_monitoring=False, heartbeat_reporting_interval_in_seconds=5, report_performance_metric_for_gcp_monitoring=False, enable_goodput_recording=False, monitor_goodput=False, goodput_upload_interval_seconds=30, enable_pathways_goodput=False, monitor_step_time_deviation=True, step_deviation_interval_seconds=30, enable_gcp_goodput_metrics=True, enable_gcp_step_deviation_metrics=True, metrics_file='', gcs_metrics=False, save_config_to_gcs=False, record_internal_nn_metrics=0, prometheus_port=0, enable_checkpoint_cloud_logger=False, enable_tunix_perf_metrics=False, collect_stack_trace=False, stack_trace_to_cloud=False, stack_trace_interval_seconds=600, dump_hlo=False, dump_step=-1, dump_hlo_local_dir='/tmp/xla_dump/', dump_hlo_delete_local_after=True, dump_hlo_gcs_dir='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237/pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional/xla_dump', dump_hlo_module_name='jit_train_step', dump_hlo_local_module_name='jit_train_step', dump_hlo_xla_flags='--xla_dump_to=/tmp/xla_dump/ --xla_dump_large_constants --xla_dump_hlo_module_re=jit_train_step', dump_hlo_upload_all=False, dump_jaxpr=False, dump_jaxpr_local_dir='/tmp/jaxpr_dump/', dump_jaxpr_delete_local_after=True, dump_jaxpr_gcs_dir='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237/pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional/jaxpr_dump', profiler=<ProfilerType.XPLANE: 'xplane'>, upload_all_profiler_results=False, skip_first_n_steps_for_profiler=1, profiler_steps=5, profile_cleanly=True, profile_periodically_period=-1, hide_profiler_step_metric=False, enable_jax_profiler=False, jax_profiler_port=9999, xprof_tpu_power_trace_level=<XProfTPUPowerTraceMode.POWER_TRACE_NONE: 0>, xprof_e2e_enable_fw_throttle_event=False, xprof_e2e_enable_fw_power_level_event=False, xprof_e2e_enable_fw_thermal_event=False, profile_power_events=False, constant_bound_config=[], jax_cache_dir='~/jax_cache', jax_distributed_initialization_timeout=300, jax_debug_log_modules='', skip_jax_distributed_system=True, enable_single_controller=False, subslice_shape='', max_checkify=False, compiled_trainstep_file='', compile_topology='', compile_topology_num_slices=-1, enable_prefix_caching=False, prefix_caching_hbm_byte=10000000000, prefix_caching_dram_byte=100000000000, inference_microbenchmark_prefill_lengths='64,128,256,512,1024', inference_microbenchmark_stages='prefill,generate', inference_microbenchmark_loop_iters=10, inference_microbenchmark_log_file_path='', inference_microbenchmark_num_samples=[1, 2, 3, 4, 5], inference_metadata_file='', inference_benchmark_test=False, inference_server='MaxtextInterleavedServer', prefill_slice='v5e-16', generate_slice='v5e-16', stack_prefill_result_cache=False, prefill_cache_axis_order='1,2,0,3', ar_cache_axis_order='1,2,0,3', compute_axis_order='0,1,2,3', reshape_q=False, decode_sampling_strategy=<SamplingStrategy.GREEDY: 'greedy'>, decode_sampling_nucleus_p=1.0, decode_sampling_top_k=50, decode_sampling_temperature=0.9, max_target_length=1024, max_prefill_predict_length=256, prompt='I love to', load_from_prefill_dir=False, prefill_cache_dir='', autoregressive_decode_assert='', model_call_mode='', use_chunked_prefill=False, prefill_chunk_size=256, enable_model_warmup=False, enable_llm_inference_pool=False, multi_sampling=False, return_log_prob=False, vocab_size=128256, tokenizer_path='meta-llama/Llama-3.1-8B-Instruct', tokenizer_type=<TokenizerType.HUGGINGFACE: 'huggingface'>, use_chat_template=False, chat_template_path='maxtext/examples/chat_templates/gsm8k_rl.json', chat_template='', tokenize_train_data=True, tokenize_eval_data=True, add_bos=True, add_eos=True, use_truncation=True, num_vocab_tiling=1, grain_train_files='', grain_eval_files='', grain_train_mixture_config_path='', grain_file_type='arrayrecord', grain_worker_count=1, grain_per_worker_buffer_size=1, grain_worker_count_eval=1, grain_per_worker_buffer_size_eval=1, grain_ram_budget_mb=1024, grain_num_threads=16, grain_prefetch_buffer_size=500, grain_num_threads_eval=16, grain_prefetch_buffer_size_eval=500, grain_data_source_max_workers=16, grain_shuffle_buffer_size=100, hf_path='', hf_name='main', hf_data_dir=None, hf_train_files=None, hf_eval_split=None, hf_eval_files=None, hf_access_token='<redacted>', dataset_path='', dataset_name='openai/gsm8k', eval_dataset_name='openai/gsm8k', train_split='train', eval_split='test', dataset_type=<DatasetType.TFDS: 'tfds'>, per_device_batch_size=12.0, eval_per_device_batch_size=12.0, max_corpus_chars=10000000, train_data_columns=['text'], train_image_column='image', eval_data_columns=['text'], eval_image_column='image', packing=True, grain_packing_type='first_fit', max_segments_per_seq=-1, num_epoch=1, expansion_factor_real_data=-1.0, reuse_example_batch=0, generate_padding_batch_train=False, generate_padding_batch_eval=False, enable_rampup_batch_size=False, per_device_batch_size_start=4.0, per_device_batch_size_increment=2.0, global_rampup_samples=500, colocated_python_data_input=False, max_position_embeddings=163840, original_max_position_embeddings=4096, rope_factor=40, beta_fast=32, beta_slow=1, mscale=1.0, rope_interleave=True, rope_truncate=True, rope_attention_scaling=False, rope_type=<RopeType.DEFAULT: 'default'>, rope_use_scale=True, rope_min_timescale=1, rope_max_timescale=500000, rope_linear_scaling_factor=1.0, local_rope_max_timescale=-1, global_rope_max_timescale=-1, global_rope_proportion=0.25, local_rope_proportion=1.0, use_iota_embed=False, use_untrainable_positional_embedding=False, trainable_position_size=-1, nope_layer_interval=-1, reasoning_start_token='<reasoning>', reasoning_end_token='</reasoning>', solution_start_token='<answer>', solution_end_token='</answer>', reward_exact_answer=1.0, reward_exact_format_match=0.1, reward_white_space_format_match=1.0, reward_partial_format_match=0.0, reward_ratio_guess_to_answer_high=0.0, reward_ratio_guess_to_answer_low=0.0, penalty_incorrect_format=0.0, penalty_incorrect_answer=0.0, math_verify_timeout=120, math_verify_num_procs=None, eval_sampling_strategy='greedy', generation_configs={'greedy': {'eval_temperature': 0.01, 'eval_top_k': 1, 'eval_top_p': 1.0}, 'standard': {'eval_temperature': 0.7, 'eval_top_k': 50, 'eval_top_p': 0.95}, 'liberal': {'eval_temperature': 0.85, 'eval_top_k': 2000, 'eval_top_p': 1.0}}, num_eval_passes=1, eval_corr_lst=False, eval_make_lst=False, eval_mode='pass', batch_size=1, num_batches=2, num_test_batches=5, test_batch_start_index=0, train_fraction=1.0, train_micro_batch_size=-1, rollout_micro_batch_size=-1, num_generations=2, num_iterations=1, grpo_beta=0.08, grpo_epsilon=0.2, loss_algo='grpo', use_agentic_rollout=False, max_concurrency=256, off_policy_steps=0, system_prompt='', degenerate_group_masking=True, epsilon_high=None, kv_cache_buffer=256, hbm_utilization_vllm=0.72, swap_space_vllm_gb=2, enable_dp_attention=False, enable_expert_parallel=False, async_scheduling=False, max_num_batched_tokens=None, max_num_seqs=None, stop_strings=None, vllm_additional_config={}, vllm_hf_overrides={}, vllm_hf_config_path='', trainer_devices_fraction=0.5, sampler_devices_fraction=0.5, chips_per_vm=4, use_pathways=False, num_trainer_slices=-1, num_samplers_slices=-1, rollout_data_parallelism=2, rollout_tensor_parallelism=-1, rollout_expert_parallelism=1, student_overrides={}, teacher_overrides={}, offline_data_dir=None, distill_alpha=0.5, distill_temperature=1.0, distill_beta=0.0, distill_feature_loss_type='cosine', distill_layer_indices=None, distill_alpha_end=None, distill_alpha_schedule='constant', distill_temperature_end=None, distill_temperature_schedule='constant', distill_beta_end=None, distill_beta_schedule='constant', student_params_to_update=None, use_dpo=False, dpo_label_smoothing=0.0, dpo_beta=0.1, use_sft=False, sft_train_on_completion_only=False, formatting_func_path='', formatting_func_kwargs={}, use_grpo=True, muon_beta=0.95, muon_weight_decay=0.0, muon_consistent_rms=None, adam_b1=0.9, adam_b2=0.99, adam_eps=1e-08, adam_eps_root=0.0, adam_weight_decay=0.1, adamw_mask=[], mu_dtype=<DType.BFLOAT16: 'bfloat16'>, opt_type=<OptimizerType.ADAMW: 'adamw'>, skip_step_on_spikes=False, skip_step_interval=128, skip_step_scaling_factor=6.0, gradient_accumulation_steps=1, use_tunix_gradient_accumulation=False, gradient_clipping_threshold=0.1, learning_rate=3e-06, lr_schedule_type=<LearningRateScheduleType.COSINE: 'cosine'>, learning_rate_final_fraction=0.1, wsd_decay_steps_fraction=0.1, wsd_decay_style=<WsdDecayStyle.LINEAR: 'linear'>, warmup_steps_fraction=0.1, learning_rate_schedule_steps=150001, trainable_parameters_mask=[], enable_diloco=False, diloco_sync_period=36, diloco_outer_lr=0.3, diloco_outer_momentum=0.9, mhc_expansion_rate=1, sinkhorn_iterations=20, steps=150001, log_period=20, eval_interval=10, eval_steps=-1, target_eval_loss=0.0, abort_on_nan_loss=True, abort_on_inf_loss=True, enable_dropout=False, dropout_rate=0.0, enable_data_shuffling=True, data_shuffle_seed=42, init_weights_seed=0, remat_policy='custom', remat_policy_for_vit='minimal', decoder_layer_input=<RematLocation.OFFLOAD: 'offload'>, context=<RematLocation.REMAT: 'remat'>, mlpwi=<RematLocation.REMAT: 'remat'>, mlpwi_0=<RematLocation.REMAT: 'remat'>, mlpwi_1=<RematLocation.REMAT: 'remat'>, mlpwo=<RematLocation.REMAT: 'remat'>, moe_mlpwi_0=<RematLocation.REMAT: 'remat'>, moe_mlpwi_1=<RematLocation.REMAT: 'remat'>, moe_mlpwo=<RematLocation.REMAT: 'remat'>, query_proj=<RematLocation.OFFLOAD: 'offload'>, key_proj=<RematLocation.OFFLOAD: 'offload'>, value_proj=<RematLocation.OFFLOAD: 'offload'>, query_wa_proj=<RematLocation.REMAT: 'remat'>, kv_wa_proj=<RematLocation.REMAT: 'remat'>, qkv_proj=<RematLocation.REMAT: 'remat'>, out_proj=<RematLocation.REMAT: 'remat'>, mla_q=<RematLocation.REMAT: 'remat'>, mla_kv=<RematLocation.REMAT: 'remat'>, attention_out=<RematLocation.REMAT: 'remat'>, engram=<RematLocation.REMAT: 'remat'>, optimizer_memory_host_offload=False, parameter_memory_host_offload=False, pipeline_fsdp_ag_per_repeat=False, num_layers_per_pipeline_stage=1, num_pipeline_repeats=-1, pipeline_parallel_layers=32, num_pipeline_microbatches=-1, pipeline_delay_activation_forwarding=False, pipeline_fsdp_ag_once=False, scan_pipeline_iterations=True, scan_pipeline_repeats=False, scan_layers_per_stage=False, set_remat_policy_on_pipeline_iterations=True, set_remat_policy_on_layers_per_stage=False, ici_diloco_parallelism=1, ici_data_parallelism=1, ici_fsdp_parallelism=-1, ici_fsdp_transpose_parallelism=1, ici_sequence_parallelism=1, ici_context_parallelism=1, ici_context_autoregressive_parallelism=1, ici_tensor_parallelism=1, ici_tensor_transpose_parallelism=1, ici_tensor_sequence_parallelism=1, ici_autoregressive_parallelism=1, ici_pipeline_parallelism=1, ici_expert_parallelism=1, dcn_diloco_parallelism=1, dcn_data_parallelism=-1, dcn_fsdp_parallelism=1, dcn_fsdp_transpose_parallelism=1, dcn_sequence_parallelism=1, dcn_context_parallelism=1, dcn_context_autoregressive_parallelism=1, dcn_tensor_parallelism=1, dcn_tensor_transpose_parallelism=1, dcn_tensor_sequence_parallelism=1, dcn_pipeline_parallelism=1, dcn_expert_parallelism=1, dcn_autoregressive_parallelism=1, logical_axis_rules=[['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'context', 'expert']], ['activation_vocab', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_vocab', ['tensor', 'tensor_transpose']], ['activation_vocab', 'tensor_sequence'], ['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['embed_vocab', ['fsdp', 'fsdp_transpose', 'context', 'expert']], ['activation_batch_attn', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['activation_kv_heads', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_length_attn', ['context']], ['activation_q_length', ['context']], ['activation_kv_length', []], ['activation_embed_attn', ['tensor', 'tensor_transpose']], ['activation_kv', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_kv_head_dim', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['q_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['kv_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['qkv', []], ['kv', []], ['kv_head_dim', []], ['q_lora', ['fsdp', 'fsdp_transpose', 'context', 'tensor_transpose', 'expert']], ['q_lora', ['fsdp', 'context', 'tensor_transpose', 'expert']], ['q_lora', ['fsdp', 'fsdp_transpose', 'context', 'expert']], ['q_lora', ['fsdp', 'context', 'expert']], ['q_lora_up_proj', []], ['kv_lora', ['fsdp', 'fsdp_transpose', 'context', 'tensor_transpose', 'expert']], ['kv_lora', ['fsdp', 'context', 'tensor_transpose', 'expert']], ['kv_lora', ['fsdp', 'fsdp_transpose', 'context', 'expert']], ['kv_lora', ['fsdp', 'context', 'expert']], ['kv_lora_up_proj', []], ['activation_batch_moe', ['data', 'fsdp', 'fsdp_transpose']], ['activation_length_moe', ['context']], ['activation_norm_length_moe', ['tensor_sequence', 'context']], ['activation_embed_moe', ['tensor', 'tensor_transpose']], ['activation_mlp_moe', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_exp', ['expert']], ['exp', 'expert'], ['mlp_moe', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']], ['embed_moe', ['fsdp', 'fsdp_transpose', 'tensor_transpose', 'context']], ['embed_moe', ['fsdp', 'tensor_transpose', 'context']], ['embed_moe', ['fsdp', 'fsdp_transpose', 'context']], ['embed_moe', ['fsdp', 'context']], ['activation_mlp', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_length', ['context']], ['activation_norm_length', ['tensor_sequence', 'context']], ['activation_embed', ['tensor', 'tensor_transpose']], ['activation_stage', 'stage'], ['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']], ['embed', ['fsdp', 'fsdp_transpose', 'tensor_transpose', 'context', 'expert']], ['embed', ['fsdp', 'tensor_transpose', 'context', 'expert']], ['embed', ['fsdp', 'fsdp_transpose', 'context', 'expert']], ['embed', ['fsdp', 'context', 'expert']], ['norm', ['tensor', 'tensor_transpose']], ['layers', 'stage'], ['diloco', 'diloco'], ['engram_dim', ['tensor']], ['dense_layers', []], ['moe_layers', []], ['mhc', []], ['prefill_activation_length', ['context']], ['prefill_activation_norm_length', ['tensor_sequence', 'context']], ['activation_prefill_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['decode_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['decode_length', []], ['cache_heads', ['autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence']], ['cache_heads', ['autoregressive', 'tensor', 'tensor_sequence']], ['paged_kv_heads', ['tensor']], ['cache_batch_prefill', []], ['cache_batch', []], ['cache_heads_none', []], ['cache_kv', []], ['cache_sequence', []], ['num_pages', []], ['tokens_per_page', []], ['paged_kv_head_dim_size', []], ['mlp_no_fsdp', ['tensor', 'tensor_sequence', 'autoregressive']], ['embed_tensor_transpose', ['tensor_transpose']], ['exp_with_fsdp', 'fsdp']], data_sharding=[['data', 'stage', 'fsdp', 'fsdp_transpose', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']], context_sharding='context', input_data_sharding_logical_axes=['activation_embed_and_logits_batch', 'activation_norm_length'], sharding_tolerance=0.02, shard_optimizer_over_data=False, internal_compile=False, internal_compile_num_devices=-1, compile_xla_flags='', hardware='tpu', num_slices=1, mesh_axes=['diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive'], shard_mode=<ShardMode.AUTO: 'auto'>, inhomogeneous_layer_cycle_interval=1, scan_layers=True, param_scan_axis=1, context_parallel_load_balance=True, context_parallel_strategy='all_gather', context_parallel_reorder_strategy=<ReorderStrategy.AUTO: 'auto'>, custom_mesh='', custom_mesh_and_rule='', allow_split_physical_axes=False, enable_nnx=True, optimize_mesh_for_tpu_v6e=False, shardy=True, pure_nnx_decoder=True, pure_nnx=True, remove_size_one_mesh_axis_from_type=True, gdn_conv_kernel_dim=4, gdn_key_head_dim=128, gdn_value_head_dim=128, gdn_num_key_heads=16, gdn_num_value_heads=32, gdn_chunk_size=64, use_qk_norm_in_gdn=True, partial_rotary_factor=1.0, first_num_dense_layers=0, shared_experts=0, routed_scaling_factor=1.0, routed_score_func='', routed_bias=False, routed_bias_update_rate=0.0, mlp_bias=False, n_routing_groups=-1, topk_routing_group=-1, use_batch_split_schedule=False, batch_split_factor=1, megablox=True, sparse_matmul=True, wi_tile_fwd_batch_seq=512, wi_tile_fwd_embed_dim=1024, wi_tile_fwd_mlp_dim=1024, wi_tile_dlhs_batch_seq=512, wi_tile_dlhs_embed_dim=1024, wi_tile_dlhs_mlp_dim=1024, wi_tile_drhs_batch_seq=512, wi_tile_drhs_embed_dim=1024, wi_tile_drhs_mlp_dim=1024, wo_tile_fwd_batch_seq=512, wo_tile_fwd_embed_dim=1024, wo_tile_fwd_mlp_dim=1024, wo_tile_dlhs_batch_seq=512, wo_tile_dlhs_embed_dim=1024, wo_tile_dlhs_mlp_dim=1024, wo_tile_drhs_batch_seq=512, wo_tile_drhs_embed_dim=1024, wo_tile_drhs_mlp_dim=1024, merge_gating_gmm=False, num_experts=1, num_experts_per_tok=1, capacity_factor=-1.0, ragged_buffer_factor=-1.0, moe_expert_input_dim=-1, base_moe_mlp_dim=-1, load_balance_loss_weight=0.0, use_custom_sort_vjp=True, use_ring_of_experts=False, use_gather_mosaic_kernel=False, use_random_routing=False, interleave_moe_layer_step=1, moe_fsdp_use_two_stage_all_gather=False, shard_exp_on_fsdp=False, use_2d_fsdp_sharding=False, norm_topk_prob=False, float32_weight_sum=True, float32_gate_logits=False, prefuse_moe_weights=False, pagedattn_num_pages=64, pagedattn_tokens_per_page=32, pagedattn_pages_per_compute_block=4, pagedattn_max_pages_per_group=-1, pagedattn_head_dim_alignment=128, sa_block_q=512, sa_block_kv=512, sa_block_kv_compute=512, sa_block_q_dkv=512, sa_block_kv_dkv=512, sa_block_kv_dkv_compute=512, sa_block_q_dq=512, sa_block_kv_dq=512, sa_use_fused_bwd_kernel=False, sa_q_layout='HEAD_DIM_MINOR', sa_k_layout='HEAD_DIM_MINOR', sa_v_layout='HEAD_DIM_MINOR', use_max_logit_estimate=-1, cost_estimate_flops_fwd=-1, cost_estimate_flops_bwd=-1, dq_reduction_steps=0, use_splash_scheduler=False, use_qk_norm=False, temperature_tuning=False, use_indexer=False, indexer_head_dim=128, indexer_n_heads=64, indexer_topk=2048, indexer_sparse_training=False, indexer_loss_scaling_factor=0.0, moba=False, moba_chunk_size=1024, moba_topk=8, mla_naive_kvcache=True, q_lora_rank=0, kv_lora_rank=512, qk_nope_head_dim=128, qk_rope_head_dim=64, v_head_dim=128, attention='dot_product', attention_type='global', share_kv_projections=False, global_num_kv_heads=0, attention_sink=False, float32_qk_product=False, float32_logits=False, sliding_window_size=0, chunk_attn_window_size=0, attn_logits_soft_cap=None, use_post_attn_norm=False, use_post_ffw_norm=False, use_ragged_attention=False, use_tokamax_gmm=False, ragged_block_size=256, enable_padding_causal_mask=True, use_tokamax_splash=False, use_jax_splash=False, force_q_layout=False, use_qk_clip=False, qk_clip_threshold=100.0, logits_via_embedding=False, normalize_embedding_logits=True, logits_dot_in_fp32=False, cast_logits_to_fp32=True, final_logits_soft_cap=None, z_loss_multiplier=0.0, mtp_num_layers=0, mtp_loss_scaling_factor=0.1, mtp_eval_target_module=0, engram_layers=[], engram_num_heads=8, engram_head_dim=1280, engram_vocab_bases=[], engram_max_ngram_size=3, engram_kernel_size=4, engram_seed=0, decoder_block=<DecoderBlockType.LLAMA2: 'llama2'>, global_parameter_scale=1, base_emb_dim=4096, base_num_query_heads=32, base_num_kv_heads=8, base_mlp_dim=14336, dense_init_scale=1.0, base_num_decoder_layers=32, head_dim=128, attention_output_dim=-1, global_head_dim=0, mlp_activations=['silu', 'linear'], mlp_activations_limit=-1.0, normalization_layer_epsilon=1e-05, fused_qkv=False, attention_bias=False, fused_mlp=False, qk_norm_with_scale=True, v_norm_with_scale=True, quantization=<QuantizationType.NONE: ''>, replicate_quant_scale=False, quant_cfg_path='', quantize_kvcache=False, kv_quant_axis=<KvQuantAxis.HEADS_AND_DKV: 'heads_and_dkv'>, kv_quant_dtype='int8', quantization_local_shard_count=4, use_qwix_quantization=False, weight_quantization_calibration_method='absmax', act_quantization_calibration_method='absmax', bwd_quantization_calibration_method='absmax', dtype=<DType.BFLOAT16: 'bfloat16'>, grad_dtype=<DType.FLOAT32: 'float32'>, weight_dtype=<DType.BFLOAT16: 'bfloat16'>, matmul_precision=<MatmulPrecision.DEFAULT: 'default'>, activations_in_float32=False, dtype_mm='float32', elastic_enabled=False, elastic_timeout_seconds=300, elastic_max_retries=10, enable_multi_tier_checkpointing=False, local_checkpoint_directory='', local_checkpoint_period=0, multi_tier_checkpointing_backup_interval_minutes=0, mtc_data_parallelism=0, enable_emergency_checkpoint=False, use_replicator_service=False, replicator_backup_interval_minutes=0, checkpoint_storage_target_data_file_size_bytes=2147483648, checkpoint_storage_use_ocdbt=False, checkpoint_storage_use_zarr3=False, checkpoint_storage_concurrent_gb=96, load_parameters_path='gs://lance-maxtext/rl_ckpt_llama31_8b/rl_ckpt_llama31_8b/0/items', lora_input_adapters_path='', load_full_state_path='', enable_checkpointing=True, load_checkpoint_only_once=False, async_checkpointing=False, checkpoint_period=50, max_num_checkpoints_to_keep=10, enable_single_replica_ckpt_restoring=False, checkpoint_todelete_subdir=None, checkpoint_todelete_full_path=None, force_unroll=False, checkpoint_is_quantized=False, save_quantized_params_path='', enable_orbax_v1=False, checkpoint_conversion_fn=None, source_checkpoint_layout='orbax', save_checkpoint_on_completion=True, enable_continuous_checkpointing=False, colocated_python_checkpointing=False, enable_autocheckpoint=False, base_config='../base.yml', run_name='pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional', model_name='llama3.1-8b-Instruct', override_model_config=False, override_logical_axis_rules=False, log_config=False, debug_sharding=False, base_output_directory='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237', sharding_strategy=None, debug=Debug(rl=True), rl=RL(num_generations=2, num_iterations=1, grpo_beta=0.08, grpo_epsilon=0.2, loss_algo='grpo', use_agentic_rollout=False, max_concurrency=256, off_policy_steps=0, system_prompt='', degenerate_group_masking=True, epsilon_high=None)), dropout=Dropout( # RngState: 192 (1.2 KB) broadcast_dims=(-2,), deterministic=False, rate=0.0, rng_collection='dropout', rngs=Rngs( # RngState: 192 (1.2 KB) aqt=RngStream( # RngState: 64 (384 B) count=RngCount( # 32 (128 B) value=Array(shape=(32,), dtype=dtype('uint32')), eager_sharding=False, tag='aqt' ), key=RngKey( # 32 (256 B) value=Array(shape=(32,), dtype=key<fry>), eager_sharding=False, tag='aqt' ), tag='aqt' ), dropout=RngStream( # RngState: 64 (384 B) count=RngCount( # 32 (128 B) value=Array(shape=(32,), dtype=dtype('uint32')), eager_sharding=False, tag='dropout' ), key=RngKey( # 32 (256 B) value=Array(shape=(32,), dtype=key<fry>), eager_sharding=False, tag='dropout' ), tag='dropout' ), params=RngStream( # RngState: 64 (384 B) count=RngCount( # 32 (128 B) value=Array(shape=(32,), dtype=dtype('uint32')), eager_sharding=False, tag='params' ), key=RngKey( # 32 (256 B) value=Array(shape=(32,), dtype=key<fry>), eager_sharding=False, tag='params' ), tag='params' ) ) ), dtype=<DType.BFLOAT16: 'bfloat16'>, in_features=4096, intermediate_dim=14336, intermediate_dropout_rate=0.0, intermediate_logical=('activation_batch', 'activation_length', 'activation_mlp'), kernel_init=<function nd_dense_init.<locals>.init_fn at 0x7db8f65d5e40>, mesh=Mesh(axis_sizes=(1, 1, 1, 32, 1, 1, 1, 1, 1, 1, 1, 1), axis_names=('diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive'), axis_types=(Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto)), mlp_layer_norm=None, model_mode='train', quant=None, use_bias=False, use_pre_norm=False, weight_dtype=<DType.BFLOAT16: 'bfloat16'>, wi_0=DenseGeneral( # Param: 1,879,048,192 (3.8 GB) axis=(-1,), bias=None, dtype=<DType.BFLOAT16: 'bfloat16'>, in_features_shape=(4096,), kernel=Param( # 1,879,048,192 (3.8 GB) value=Array(shape=(4096, 32, 14336), dtype=dtype(bfloat16)), eager_sharding=False, out_sharding=('embed', 'layers', 'mlp') ), kernel_axes=('embed', 'mlp'), kernel_init=<function nd_dense_init.<locals>.init_fn at 0x7db8f65d5e40>, matmul_precision=<MatmulPrecision.DEFAULT: 'default'>, out_features_shape=(14336,), parameter_memory_host_offload=False, quant=None, shard_mode=<ShardMode.AUTO: 'auto'>, use_bias=False, weight_dtype=<DType.BFLOAT16: 'bfloat16'> ), wi_1=DenseGeneral( # Param: 1,879,048,192 (3.8 GB) axis=(-1,), bias=None, dtype=<DType.BFLOAT16: 'bfloat16'>, in_features_shape=(4096,), kernel=Param( # 1,879,048,192 (3.8 GB) value=Array(shape=(4096, 32, 14336), dtype=dtype(bfloat16)), eager_sharding=False, out_sharding=('embed', 'layers', 'mlp') ), kernel_axes=('embed', 'mlp'), kernel_init=<function nd_dense_init.<locals>.init_fn at 0x7db8f65d5e40>, matmul_precision=<MatmulPrecision.DEFAULT: 'default'>, out_features_shape=(14336,), parameter_memory_host_offload=False, quant=None, shard_mode=<ShardMode.AUTO: 'auto'>, use_bias=False, weight_dtype=<DType.BFLOAT16: 'bfloat16'> ), wo=DenseGeneral( # Param: 1,879,048,192 (3.8 GB) axis=(-1,), bias=None, dtype=<DType.BFLOAT16: 'bfloat16'>, in_features_shape=(14336,), kernel=Param( # 1,879,048,192 (3.8 GB) value=Array(shape=(14336, 32, 4096), dtype=dtype(bfloat16)), eager_sharding=False, out_sharding=('mlp', 'layers', 'embed') ), kernel_axes=('mlp', 'embed'), kernel_init=<function nd_dense_init.<locals>.init_fn at 0x7db8f65d5e40>, matmul_precision=<MatmulPrecision.DEFAULT: 'default'>, out_features_shape=(4096,), parameter_memory_host_offload=False, quant=None, shard_mode=<ShardMode.AUTO: 'auto'>, use_bias=False, weight_dtype=<DType.BFLOAT16: 'bfloat16'> ) ), post_self_attention_layer_norm=RMSNorm( # Param: 131,072 (262.1 KB) dtype=<DType.BFLOAT16: 'bfloat16'>, epsilon=1e-05, kernel_axes=('norm',), num_features=4096, parameter_memory_host_offload=False, scale=Param( # 131,072 (262.1 KB) value=Array(shape=(4096, 32), dtype=dtype(bfloat16)), eager_sharding=False, out_sharding=('norm', 'layers') ), scale_init=<function ones at 0x7dba2a2d6980>, scale_offset=0.0, shard_mode=<ShardMode.AUTO: 'auto'>, weight_dtype=<DType.BFLOAT16: 'bfloat16'>, with_scale=True ), pre_self_attention_layer_norm=RMSNorm( # Param: 131,072 (262.1 KB) dtype=<DType.BFLOAT16: 'bfloat16'>, epsilon=1e-05, kernel_axes=('norm',), num_features=4096, parameter_memory_host_offload=False, scale=Param( # 131,072 (262.1 KB) value=Array(shape=(4096, 32), dtype=dtype(bfloat16)), eager_sharding=False, out_sharding=('norm', 'layers') ), scale_init=<function ones at 0x7dba2a2d6980>, scale_offset=0.0, shard_mode=<ShardMode.AUTO: 'auto'>, weight_dtype=<DType.BFLOAT16: 'bfloat16'>, with_scale=True ), quant=None, self_attention=Attention( # RngState: 192 (1.2 KB), Param: 1,342,177,280 (2.7 GB), Total: 1,342,177,472 (2.7 GB) KVCache_0=None, ar_cache_axis_order=(1, 2, 0, 3), attention_kernel='dot_product', attention_op=AttentionOp( # RngState: 192 (1.2 KB) AqtEinsum_0=<function einsum at 0x7dba2a842480>, AqtEinsum_1=<function einsum at 0x7dba2a842480>, AqtEinsum_2=<function einsum at 0x7dba2a842480>, AqtEinsum_3=<function einsum at 0x7dba2a842480>, attention_kernel='dot_product', attention_type=<AttentionType.GLOBAL: 'global'>, attn_logits_soft_cap=None, cache_logical_axis_names=('cache_batch', 'cache_sequence', 'cache_heads', 'cache_kv'), cache_scale_logical_axis_names=('cache_scale_batch', 'cache_scale_sequence', 'cache_scale_heads', 'cache_scale_kv'), chunk_attn_window_size=0, compute_axis_order=(0, 1, 2, 3), config=MaxTextConfig(emb_dim=4096, mlp_dim=14336, moe_mlp_dim=-1, num_decoder_layers=32, num_kv_heads=8, num_query_heads=32, num_diloco_replicas=1, ici_parallelism=[1, 1, 1, -1, 1, 1, 1, 1, 1, 1, 1, 1], dcn_parallelism=[1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], using_pipeline_parallelism=False, context_parallel_size=1, num_target_devices=32, global_batch_size_to_train_on=384, global_batch_size_to_eval_on=384, global_batch_size_to_load=384, global_batch_size_to_load_eval=384, micro_batch_size_to_train_on=384, micro_batch_size_to_eval_on=384, checkpoint_dir='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237/pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional/checkpoints/', convert_checkpoint_if_possible=True, metrics_dir='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237/pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional/metrics/', tensorboard_dir='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237/pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional/tensorboard/', managed_mldiagnostics_dir='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237/pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional/managed-mldiagnostics', rampup_end_step=0, tensors_on_device=[], tensors_to_offload=['decoder_layer_input', 'query_proj', 'key_proj', 'value_proj'], global_batch_size_to_load_start=None, global_batch_size_to_load_increment=None, rampup_samples_per_increment_to_load=None, d_model_for_audio=256, encoder_attention_heads_for_audio=4, encoder_ffn_dim_for_audio=512, encoder_layers_for_audio=2, attention_dropout_for_audio=0.0, activation_dropout_for_audio=0.0, activation_function_for_audio='gelu', num_mel_bins_for_audio=128, max_source_positions_for_audio=1500, scale_embedding_for_audio=True, n_window_for_audio=50, n_window_infer_for_audio=800, conv_chunksize_for_audio=500, downsample_hidden_size_for_audio=256, output_dim_for_audio=512, num_conv_layers_for_audio=3, max_timescale_for_audio=10000.0, max_sample_len_for_audio=10000, projector_input_dim_for_vit=4096, projector_output_dim_for_vit=4096, pixel_shuffle_ratio_for_vit=0.5, projector_dropout_for_vit=0.0, hidden_size_for_vit=1408, intermediate_size_for_vit=5632, num_attention_heads_for_vit=16, num_channels_for_vit=3, tile_size_for_vit=336, patch_size_for_vit=14, conv_stride_for_vit=14, num_hidden_layers_for_vit=34, rope_theta_for_vit=10000, vision_output_dim_for_vit=4096, spatial_merge_size_for_vit=2, out_hidden_size_for_vit=512, temporal_patch_size_for_vit=2, num_position_embeddings_for_vit=1024, deepstack_visual_indexes_for_vit=[], vision_output_length=-1, use_multimodal=False, freeze_vision_encoder_params=True, freeze_audio_encoder_params=True, use_audio=False, image_size_for_vit=896, image_path='', image_placeholder='<|image|>', posemb_type_for_vit='learn', max_num_images_per_example=-1, video_path='', audio_path='', video_placeholder='<|video|>', audio_placeholder='<|audio|>', use_audio_in_video=False, use_mrope=False, mrope_section=[24, 20, 20], position_id_per_seconds=25, managed_mldiagnostics=False, managed_mldiagnostics_run_group='', enable_tensorboard=True, use_vertex_tensorboard=False, vertex_tensorboard_project='', vertex_tensorboard_region='', report_heartbeat_metric_for_gcp_monitoring=False, heartbeat_reporting_interval_in_seconds=5, report_performance_metric_for_gcp_monitoring=False, enable_goodput_recording=False, monitor_goodput=False, goodput_upload_interval_seconds=30, enable_pathways_goodput=False, monitor_step_time_deviation=True, step_deviation_interval_seconds=30, enable_gcp_goodput_metrics=True, enable_gcp_step_deviation_metrics=True, metrics_file='', gcs_metrics=False, save_config_to_gcs=False, record_internal_nn_metrics=0, prometheus_port=0, enable_checkpoint_cloud_logger=False, enable_tunix_perf_metrics=False, collect_stack_trace=False, stack_trace_to_cloud=False, stack_trace_interval_seconds=600, dump_hlo=False, dump_step=-1, dump_hlo_local_dir='/tmp/xla_dump/', dump_hlo_delete_local_after=True, dump_hlo_gcs_dir='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237/pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional/xla_dump', dump_hlo_module_name='jit_train_step', dump_hlo_local_module_name='jit_train_step', dump_hlo_xla_flags='--xla_dump_to=/tmp/xla_dump/ --xla_dump_large_constants --xla_dump_hlo_module_re=jit_train_step', dump_hlo_upload_all=False, dump_jaxpr=False, dump_jaxpr_local_dir='/tmp/jaxpr_dump/', dump_jaxpr_delete_local_after=True, dump_jaxpr_gcs_dir='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237/pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional/jaxpr_dump', profiler=<ProfilerType.XPLANE: 'xplane'>, upload_all_profiler_results=False, skip_first_n_steps_for_profiler=1, profiler_steps=5, profile_cleanly=True, profile_periodically_period=-1, hide_profiler_step_metric=False, enable_jax_profiler=False, jax_profiler_port=9999, xprof_tpu_power_trace_level=<XProfTPUPowerTraceMode.POWER_TRACE_NONE: 0>, xprof_e2e_enable_fw_throttle_event=False, xprof_e2e_enable_fw_power_level_event=False, xprof_e2e_enable_fw_thermal_event=False, profile_power_events=False, constant_bound_config=[], jax_cache_dir='~/jax_cache', jax_distributed_initialization_timeout=300, jax_debug_log_modules='', skip_jax_distributed_system=True, enable_single_controller=False, subslice_shape='', max_checkify=False, compiled_trainstep_file='', compile_topology='', compile_topology_num_slices=-1, enable_prefix_caching=False, prefix_caching_hbm_byte=10000000000, prefix_caching_dram_byte=100000000000, inference_microbenchmark_prefill_lengths='64,128,256,512,1024', inference_microbenchmark_stages='prefill,generate', inference_microbenchmark_loop_iters=10, inference_microbenchmark_log_file_path='', inference_microbenchmark_num_samples=[1, 2, 3, 4, 5], inference_metadata_file='', inference_benchmark_test=False, inference_server='MaxtextInterleavedServer', prefill_slice='v5e-16', generate_slice='v5e-16', stack_prefill_result_cache=False, prefill_cache_axis_order='1,2,0,3', ar_cache_axis_order='1,2,0,3', compute_axis_order='0,1,2,3', reshape_q=False, decode_sampling_strategy=<SamplingStrategy.GREEDY: 'greedy'>, decode_sampling_nucleus_p=1.0, decode_sampling_top_k=50, decode_sampling_temperature=0.9, max_target_length=1024, max_prefill_predict_length=256, prompt='I love to', load_from_prefill_dir=False, prefill_cache_dir='', autoregressive_decode_assert='', model_call_mode='', use_chunked_prefill=False, prefill_chunk_size=256, enable_model_warmup=False, enable_llm_inference_pool=False, multi_sampling=False, return_log_prob=False, vocab_size=128256, tokenizer_path='meta-llama/Llama-3.1-8B-Instruct', tokenizer_type=<TokenizerType.HUGGINGFACE: 'huggingface'>, use_chat_template=False, chat_template_path='maxtext/examples/chat_templates/gsm8k_rl.json', chat_template='', tokenize_train_data=True, tokenize_eval_data=True, add_bos=True, add_eos=True, use_truncation=True, num_vocab_tiling=1, grain_train_files='', grain_eval_files='', grain_train_mixture_config_path='', grain_file_type='arrayrecord', grain_worker_count=1, grain_per_worker_buffer_size=1, grain_worker_count_eval=1, grain_per_worker_buffer_size_eval=1, grain_ram_budget_mb=1024, grain_num_threads=16, grain_prefetch_buffer_size=500, grain_num_threads_eval=16, grain_prefetch_buffer_size_eval=500, grain_data_source_max_workers=16, grain_shuffle_buffer_size=100, hf_path='', hf_name='main', hf_data_dir=None, hf_train_files=None, hf_eval_split=None, hf_eval_files=None, hf_access_token='<redacted>', dataset_path='', dataset_name='openai/gsm8k', eval_dataset_name='openai/gsm8k', train_split='train', eval_split='test', dataset_type=<DatasetType.TFDS: 'tfds'>, per_device_batch_size=12.0, eval_per_device_batch_size=12.0, max_corpus_chars=10000000, train_data_columns=['text'], train_image_column='image', eval_data_columns=['text'], eval_image_column='image', packing=True, grain_packing_type='first_fit', max_segments_per_seq=-1, num_epoch=1, expansion_factor_real_data=-1.0, reuse_example_batch=0, generate_padding_batch_train=False, generate_padding_batch_eval=False, enable_rampup_batch_size=False, per_device_batch_size_start=4.0, per_device_batch_size_increment=2.0, global_rampup_samples=500, colocated_python_data_input=False, max_position_embeddings=163840, original_max_position_embeddings=4096, rope_factor=40, beta_fast=32, beta_slow=1, mscale=1.0, rope_interleave=True, rope_truncate=True, rope_attention_scaling=False, rope_type=<RopeType.DEFAULT: 'default'>, rope_use_scale=True, rope_min_timescale=1, rope_max_timescale=500000, rope_linear_scaling_factor=1.0, local_rope_max_timescale=-1, global_rope_max_timescale=-1, global_rope_proportion=0.25, local_rope_proportion=1.0, use_iota_embed=False, use_untrainable_positional_embedding=False, trainable_position_size=-1, nope_layer_interval=-1, reasoning_start_token='<reasoning>', reasoning_end_token='</reasoning>', solution_start_token='<answer>', solution_end_token='</answer>', reward_exact_answer=1.0, reward_exact_format_match=0.1, reward_white_space_format_match=1.0, reward_partial_format_match=0.0, reward_ratio_guess_to_answer_high=0.0, reward_ratio_guess_to_answer_low=0.0, penalty_incorrect_format=0.0, penalty_incorrect_answer=0.0, math_verify_timeout=120, math_verify_num_procs=None, eval_sampling_strategy='greedy', generation_configs={'greedy': {'eval_temperature': 0.01, 'eval_top_k': 1, 'eval_top_p': 1.0}, 'standard': {'eval_temperature': 0.7, 'eval_top_k': 50, 'eval_top_p': 0.95}, 'liberal': {'eval_temperature': 0.85, 'eval_top_k': 2000, 'eval_top_p': 1.0}}, num_eval_passes=1, eval_corr_lst=False, eval_make_lst=False, eval_mode='pass', batch_size=1, num_batches=2, num_test_batches=5, test_batch_start_index=0, train_fraction=1.0, train_micro_batch_size=-1, rollout_micro_batch_size=-1, num_generations=2, num_iterations=1, grpo_beta=0.08, grpo_epsilon=0.2, loss_algo='grpo', use_agentic_rollout=False, max_concurrency=256, off_policy_steps=0, system_prompt='', degenerate_group_masking=True, epsilon_high=None, kv_cache_buffer=256, hbm_utilization_vllm=0.72, swap_space_vllm_gb=2, enable_dp_attention=False, enable_expert_parallel=False, async_scheduling=False, max_num_batched_tokens=None, max_num_seqs=None, stop_strings=None, vllm_additional_config={}, vllm_hf_overrides={}, vllm_hf_config_path='', trainer_devices_fraction=0.5, sampler_devices_fraction=0.5, chips_per_vm=4, use_pathways=False, num_trainer_slices=-1, num_samplers_slices=-1, rollout_data_parallelism=2, rollout_tensor_parallelism=-1, rollout_expert_parallelism=1, student_overrides={}, teacher_overrides={}, offline_data_dir=None, distill_alpha=0.5, distill_temperature=1.0, distill_beta=0.0, distill_feature_loss_type='cosine', distill_layer_indices=None, distill_alpha_end=None, distill_alpha_schedule='constant', distill_temperature_end=None, distill_temperature_schedule='constant', distill_beta_end=None, distill_beta_schedule='constant', student_params_to_update=None, use_dpo=False, dpo_label_smoothing=0.0, dpo_beta=0.1, use_sft=False, sft_train_on_completion_only=False, formatting_func_path='', formatting_func_kwargs={}, use_grpo=True, muon_beta=0.95, muon_weight_decay=0.0, muon_consistent_rms=None, adam_b1=0.9, adam_b2=0.99, adam_eps=1e-08, adam_eps_root=0.0, adam_weight_decay=0.1, adamw_mask=[], mu_dtype=<DType.BFLOAT16: 'bfloat16'>, opt_type=<OptimizerType.ADAMW: 'adamw'>, skip_step_on_spikes=False, skip_step_interval=128, skip_step_scaling_factor=6.0, gradient_accumulation_steps=1, use_tunix_gradient_accumulation=False, gradient_clipping_threshold=0.1, learning_rate=3e-06, lr_schedule_type=<LearningRateScheduleType.COSINE: 'cosine'>, learning_rate_final_fraction=0.1, wsd_decay_steps_fraction=0.1, wsd_decay_style=<WsdDecayStyle.LINEAR: 'linear'>, warmup_steps_fraction=0.1, learning_rate_schedule_steps=150001, trainable_parameters_mask=[], enable_diloco=False, diloco_sync_period=36, diloco_outer_lr=0.3, diloco_outer_momentum=0.9, mhc_expansion_rate=1, sinkhorn_iterations=20, steps=150001, log_period=20, eval_interval=10, eval_steps=-1, target_eval_loss=0.0, abort_on_nan_loss=True, abort_on_inf_loss=True, enable_dropout=False, dropout_rate=0.0, enable_data_shuffling=True, data_shuffle_seed=42, init_weights_seed=0, remat_policy='custom', remat_policy_for_vit='minimal', decoder_layer_input=<RematLocation.OFFLOAD: 'offload'>, context=<RematLocation.REMAT: 'remat'>, mlpwi=<RematLocation.REMAT: 'remat'>, mlpwi_0=<RematLocation.REMAT: 'remat'>, mlpwi_1=<RematLocation.REMAT: 'remat'>, mlpwo=<RematLocation.REMAT: 'remat'>, moe_mlpwi_0=<RematLocation.REMAT: 'remat'>, moe_mlpwi_1=<RematLocation.REMAT: 'remat'>, moe_mlpwo=<RematLocation.REMAT: 'remat'>, query_proj=<RematLocation.OFFLOAD: 'offload'>, key_proj=<RematLocation.OFFLOAD: 'offload'>, value_proj=<RematLocation.OFFLOAD: 'offload'>, query_wa_proj=<RematLocation.REMAT: 'remat'>, kv_wa_proj=<RematLocation.REMAT: 'remat'>, qkv_proj=<RematLocation.REMAT: 'remat'>, out_proj=<RematLocation.REMAT: 'remat'>, mla_q=<RematLocation.REMAT: 'remat'>, mla_kv=<RematLocation.REMAT: 'remat'>, attention_out=<RematLocation.REMAT: 'remat'>, engram=<RematLocation.REMAT: 'remat'>, optimizer_memory_host_offload=False, parameter_memory_host_offload=False, pipeline_fsdp_ag_per_repeat=False, num_layers_per_pipeline_stage=1, num_pipeline_repeats=-1, pipeline_parallel_layers=32, num_pipeline_microbatches=-1, pipeline_delay_activation_forwarding=False, pipeline_fsdp_ag_once=False, scan_pipeline_iterations=True, scan_pipeline_repeats=False, scan_layers_per_stage=False, set_remat_policy_on_pipeline_iterations=True, set_remat_policy_on_layers_per_stage=False, ici_diloco_parallelism=1, ici_data_parallelism=1, ici_fsdp_parallelism=-1, ici_fsdp_transpose_parallelism=1, ici_sequence_parallelism=1, ici_context_parallelism=1, ici_context_autoregressive_parallelism=1, ici_tensor_parallelism=1, ici_tensor_transpose_parallelism=1, ici_tensor_sequence_parallelism=1, ici_autoregressive_parallelism=1, ici_pipeline_parallelism=1, ici_expert_parallelism=1, dcn_diloco_parallelism=1, dcn_data_parallelism=-1, dcn_fsdp_parallelism=1, dcn_fsdp_transpose_parallelism=1, dcn_sequence_parallelism=1, dcn_context_parallelism=1, dcn_context_autoregressive_parallelism=1, dcn_tensor_parallelism=1, dcn_tensor_transpose_parallelism=1, dcn_tensor_sequence_parallelism=1, dcn_pipeline_parallelism=1, dcn_expert_parallelism=1, dcn_autoregressive_parallelism=1, logical_axis_rules=[['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'context', 'expert']], ['activation_vocab', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_vocab', ['tensor', 'tensor_transpose']], ['activation_vocab', 'tensor_sequence'], ['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['embed_vocab', ['fsdp', 'fsdp_transpose', 'context', 'expert']], ['activation_batch_attn', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['activation_kv_heads', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_length_attn', ['context']], ['activation_q_length', ['context']], ['activation_kv_length', []], ['activation_embed_attn', ['tensor', 'tensor_transpose']], ['activation_kv', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_kv_head_dim', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['q_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['kv_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['qkv', []], ['kv', []], ['kv_head_dim', []], ['q_lora', ['fsdp', 'fsdp_transpose', 'context', 'tensor_transpose', 'expert']], ['q_lora', ['fsdp', 'context', 'tensor_transpose', 'expert']], ['q_lora', ['fsdp', 'fsdp_transpose', 'context', 'expert']], ['q_lora', ['fsdp', 'context', 'expert']], ['q_lora_up_proj', []], ['kv_lora', ['fsdp', 'fsdp_transpose', 'context', 'tensor_transpose', 'expert']], ['kv_lora', ['fsdp', 'context', 'tensor_transpose', 'expert']], ['kv_lora', ['fsdp', 'fsdp_transpose', 'context', 'expert']], ['kv_lora', ['fsdp', 'context', 'expert']], ['kv_lora_up_proj', []], ['activation_batch_moe', ['data', 'fsdp', 'fsdp_transpose']], ['activation_length_moe', ['context']], ['activation_norm_length_moe', ['tensor_sequence', 'context']], ['activation_embed_moe', ['tensor', 'tensor_transpose']], ['activation_mlp_moe', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_exp', ['expert']], ['exp', 'expert'], ['mlp_moe', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']], ['embed_moe', ['fsdp', 'fsdp_transpose', 'tensor_transpose', 'context']], ['embed_moe', ['fsdp', 'tensor_transpose', 'context']], ['embed_moe', ['fsdp', 'fsdp_transpose', 'context']], ['embed_moe', ['fsdp', 'context']], ['activation_mlp', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_length', ['context']], ['activation_norm_length', ['tensor_sequence', 'context']], ['activation_embed', ['tensor', 'tensor_transpose']], ['activation_stage', 'stage'], ['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']], ['embed', ['fsdp', 'fsdp_transpose', 'tensor_transpose', 'context', 'expert']], ['embed', ['fsdp', 'tensor_transpose', 'context', 'expert']], ['embed', ['fsdp', 'fsdp_transpose', 'context', 'expert']], ['embed', ['fsdp', 'context', 'expert']], ['norm', ['tensor', 'tensor_transpose']], ['layers', 'stage'], ['diloco', 'diloco'], ['engram_dim', ['tensor']], ['dense_layers', []], ['moe_layers', []], ['mhc', []], ['prefill_activation_length', ['context']], ['prefill_activation_norm_length', ['tensor_sequence', 'context']], ['activation_prefill_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['decode_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['decode_length', []], ['cache_heads', ['autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence']], ['cache_heads', ['autoregressive', 'tensor', 'tensor_sequence']], ['paged_kv_heads', ['tensor']], ['cache_batch_prefill', []], ['cache_batch', []], ['cache_heads_none', []], ['cache_kv', []], ['cache_sequence', []], ['num_pages', []], ['tokens_per_page', []], ['paged_kv_head_dim_size', []], ['mlp_no_fsdp', ['tensor', 'tensor_sequence', 'autoregressive']], ['embed_tensor_transpose', ['tensor_transpose']], ['exp_with_fsdp', 'fsdp']], data_sharding=[['data', 'stage', 'fsdp', 'fsdp_transpose', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']], context_sharding='context', input_data_sharding_logical_axes=['activation_embed_and_logits_batch', 'activation_norm_length'], sharding_tolerance=0.02, shard_optimizer_over_data=False, internal_compile=False, internal_compile_num_devices=-1, compile_xla_flags='', hardware='tpu', num_slices=1, mesh_axes=['diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive'], shard_mode=<ShardMode.AUTO: 'auto'>, inhomogeneous_layer_cycle_interval=1, scan_layers=True, param_scan_axis=1, context_parallel_load_balance=True, context_parallel_strategy='all_gather', context_parallel_reorder_strategy=<ReorderStrategy.AUTO: 'auto'>, custom_mesh='', custom_mesh_and_rule='', allow_split_physical_axes=False, enable_nnx=True, optimize_mesh_for_tpu_v6e=False, shardy=True, pure_nnx_decoder=True, pure_nnx=True, remove_size_one_mesh_axis_from_type=True, gdn_conv_kernel_dim=4, gdn_key_head_dim=128, gdn_value_head_dim=128, gdn_num_key_heads=16, gdn_num_value_heads=32, gdn_chunk_size=64, use_qk_norm_in_gdn=True, partial_rotary_factor=1.0, first_num_dense_layers=0, shared_experts=0, routed_scaling_factor=1.0, routed_score_func='', routed_bias=False, routed_bias_update_rate=0.0, mlp_bias=False, n_routing_groups=-1, topk_routing_group=-1, use_batch_split_schedule=False, batch_split_factor=1, megablox=True, sparse_matmul=True, wi_tile_fwd_batch_seq=512, wi_tile_fwd_embed_dim=1024, wi_tile_fwd_mlp_dim=1024, wi_tile_dlhs_batch_seq=512, wi_tile_dlhs_embed_dim=1024, wi_tile_dlhs_mlp_dim=1024, wi_tile_drhs_batch_seq=512, wi_tile_drhs_embed_dim=1024, wi_tile_drhs_mlp_dim=1024, wo_tile_fwd_batch_seq=512, wo_tile_fwd_embed_dim=1024, wo_tile_fwd_mlp_dim=1024, wo_tile_dlhs_batch_seq=512, wo_tile_dlhs_embed_dim=1024, wo_tile_dlhs_mlp_dim=1024, wo_tile_drhs_batch_seq=512, wo_tile_drhs_embed_dim=1024, wo_tile_drhs_mlp_dim=1024, merge_gating_gmm=False, num_experts=1, num_experts_per_tok=1, capacity_factor=-1.0, ragged_buffer_factor=-1.0, moe_expert_input_dim=-1, base_moe_mlp_dim=-1, load_balance_loss_weight=0.0, use_custom_sort_vjp=True, use_ring_of_experts=False, use_gather_mosaic_kernel=False, use_random_routing=False, interleave_moe_layer_step=1, moe_fsdp_use_two_stage_all_gather=False, shard_exp_on_fsdp=False, use_2d_fsdp_sharding=False, norm_topk_prob=False, float32_weight_sum=True, float32_gate_logits=False, prefuse_moe_weights=False, pagedattn_num_pages=64, pagedattn_tokens_per_page=32, pagedattn_pages_per_compute_block=4, pagedattn_max_pages_per_group=-1, pagedattn_head_dim_alignment=128, sa_block_q=512, sa_block_kv=512, sa_block_kv_compute=512, sa_block_q_dkv=512, sa_block_kv_dkv=512, sa_block_kv_dkv_compute=512, sa_block_q_dq=512, sa_block_kv_dq=512, sa_use_fused_bwd_kernel=False, sa_q_layout='HEAD_DIM_MINOR', sa_k_layout='HEAD_DIM_MINOR', sa_v_layout='HEAD_DIM_MINOR', use_max_logit_estimate=-1, cost_estimate_flops_fwd=-1, cost_estimate_flops_bwd=-1, dq_reduction_steps=0, use_splash_scheduler=False, use_qk_norm=False, temperature_tuning=False, use_indexer=False, indexer_head_dim=128, indexer_n_heads=64, indexer_topk=2048, indexer_sparse_training=False, indexer_loss_scaling_factor=0.0, moba=False, moba_chunk_size=1024, moba_topk=8, mla_naive_kvcache=True, q_lora_rank=0, kv_lora_rank=512, qk_nope_head_dim=128, qk_rope_head_dim=64, v_head_dim=128, attention='dot_product', attention_type='global', share_kv_projections=False, global_num_kv_heads=0, attention_sink=False, float32_qk_product=False, float32_logits=False, sliding_window_size=0, chunk_attn_window_size=0, attn_logits_soft_cap=None, use_post_attn_norm=False, use_post_ffw_norm=False, use_ragged_attention=False, use_tokamax_gmm=False, ragged_block_size=256, enable_padding_causal_mask=True, use_tokamax_splash=False, use_jax_splash=False, force_q_layout=False, use_qk_clip=False, qk_clip_threshold=100.0, logits_via_embedding=False, normalize_embedding_logits=True, logits_dot_in_fp32=False, cast_logits_to_fp32=True, final_logits_soft_cap=None, z_loss_multiplier=0.0, mtp_num_layers=0, mtp_loss_scaling_factor=0.1, mtp_eval_target_module=0, engram_layers=[], engram_num_heads=8, engram_head_dim=1280, engram_vocab_bases=[], engram_max_ngram_size=3, engram_kernel_size=4, engram_seed=0, decoder_block=<DecoderBlockType.LLAMA2: 'llama2'>, global_parameter_scale=1, base_emb_dim=4096, base_num_query_heads=32, base_num_kv_heads=8, base_mlp_dim=14336, dense_init_scale=1.0, base_num_decoder_layers=32, head_dim=128, attention_output_dim=-1, global_head_dim=0, mlp_activations=['silu', 'linear'], mlp_activations_limit=-1.0, normalization_layer_epsilon=1e-05, fused_qkv=False, attention_bias=False, fused_mlp=False, qk_norm_with_scale=True, v_norm_with_scale=True, quantization=<QuantizationType.NONE: ''>, replicate_quant_scale=False, quant_cfg_path='', quantize_kvcache=False, kv_quant_axis=<KvQuantAxis.HEADS_AND_DKV: 'heads_and_dkv'>, kv_quant_dtype='int8', quantization_local_shard_count=4, use_qwix_quantization=False, weight_quantization_calibration_method='absmax', act_quantization_calibration_method='absmax', bwd_quantization_calibration_method='absmax', dtype=<DType.BFLOAT16: 'bfloat16'>, grad_dtype=<DType.FLOAT32: 'float32'>, weight_dtype=<DType.BFLOAT16: 'bfloat16'>, matmul_precision=<MatmulPrecision.DEFAULT: 'default'>, activations_in_float32=False, dtype_mm='float32', elastic_enabled=False, elastic_timeout_seconds=300, elastic_max_retries=10, enable_multi_tier_checkpointing=False, local_checkpoint_directory='', local_checkpoint_period=0, multi_tier_checkpointing_backup_interval_minutes=0, mtc_data_parallelism=0, enable_emergency_checkpoint=False, use_replicator_service=False, replicator_backup_interval_minutes=0, checkpoint_storage_target_data_file_size_bytes=2147483648, checkpoint_storage_use_ocdbt=False, checkpoint_storage_use_zarr3=False, checkpoint_storage_concurrent_gb=96, load_parameters_path='gs://lance-maxtext/rl_ckpt_llama31_8b/rl_ckpt_llama31_8b/0/items', lora_input_adapters_path='', load_full_state_path='', enable_checkpointing=True, load_checkpoint_only_once=False, async_checkpointing=False, checkpoint_period=50, max_num_checkpoints_to_keep=10, enable_single_replica_ckpt_restoring=False, checkpoint_todelete_subdir=None, checkpoint_todelete_full_path=None, force_unroll=False, checkpoint_is_quantized=False, save_quantized_params_path='', enable_orbax_v1=False, checkpoint_conversion_fn=None, source_checkpoint_layout='orbax', save_checkpoint_on_completion=True, enable_continuous_checkpointing=False, colocated_python_checkpointing=False, enable_autocheckpoint=False, base_config='../base.yml', run_name='pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional', model_name='llama3.1-8b-Instruct', override_model_config=False, override_logical_axis_rules=False, log_config=False, debug_sharding=False, base_output_directory='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237', sharding_strategy=None, debug=Debug(rl=True), rl=RL(num_generations=2, num_iterations=1, grpo_beta=0.08, grpo_epsilon=0.2, loss_algo='grpo', use_agentic_rollout=False, max_concurrency=256, off_policy_steps=0, system_prompt='', degenerate_group_masking=True, epsilon_high=None)), dropout_rate=0.0, dtype=<DType.BFLOAT16: 'bfloat16'>, flash_axis_names_kv=('activation_batch_attn', 'activation_heads', 'activation_kv_length', 'activation_kv'), flash_axis_names_q=('activation_batch_attn', 'activation_heads', 'activation_length', 'activation_kv'), flash_axis_names_splash_kernel=('activation_heads', 'activation_length'), float32_logits=False, float32_qk_product=False, key_axis_order=(2, 0, 1, 3), kv_quant=None, max_prefill_predict_length=256, max_target_length=1024, mesh=Mesh(axis_sizes=(1, 1, 1, 32, 1, 1, 1, 1, 1, 1, 1, 1), axis_names=('diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive'), axis_types=(Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto)), num_kv_heads=8, num_query_heads=32, prefill_cache_logical_axis_names=('cache_batch_prefill', 'cache_sequence', 'cache_heads', 'cache_kv'), quant=None, ragged_block_size=256, ragged_lengths_names=('cache_batch',), ragged_qkv_axis_names=('cache_batch', 'cache_heads', 'cache_sequence', 'cache_kv'), reshape_q=False, rngs=Rngs( # RngState: 192 (1.2 KB) aqt=RngStream( # RngState: 64 (384 B) count=RngCount( # 32 (128 B) value=Array(shape=(32,), dtype=dtype('uint32')), eager_sharding=False, tag='aqt' ), key=RngKey( # 32 (256 B) value=Array(shape=(32,), dtype=key<fry>), eager_sharding=False, tag='aqt' ), tag='aqt' ), dropout=RngStream( # RngState: 64 (384 B) count=RngCount( # 32 (128 B) value=Array(shape=(32,), dtype=dtype('uint32')), eager_sharding=False, tag='dropout' ), key=RngKey( # 32 (256 B) value=Array(shape=(32,), dtype=key<fry>), eager_sharding=False, tag='dropout' ), tag='dropout' ), params=RngStream( # RngState: 64 (384 B) count=RngCount( # 32 (128 B) value=Array(shape=(32,), dtype=dtype('uint32')), eager_sharding=False, tag='params' ), key=RngKey( # 32 (256 B) value=Array(shape=(32,), dtype=key<fry>), eager_sharding=False, tag='params' ), tag='params' ) ), sliding_window_size=None, use_ragged_attention=False ), attention_type=<AttentionType.GLOBAL: 'global'>, attn_logits_soft_cap=None, compute_axis_order=(0, 1, 2, 3), config=MaxTextConfig(emb_dim=4096, mlp_dim=14336, moe_mlp_dim=-1, num_decoder_layers=32, num_kv_heads=8, num_query_heads=32, num_diloco_replicas=1, ici_parallelism=[1, 1, 1, -1, 1, 1, 1, 1, 1, 1, 1, 1], dcn_parallelism=[1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], using_pipeline_parallelism=False, context_parallel_size=1, num_target_devices=32, global_batch_size_to_train_on=384, global_batch_size_to_eval_on=384, global_batch_size_to_load=384, global_batch_size_to_load_eval=384, micro_batch_size_to_train_on=384, micro_batch_size_to_eval_on=384, checkpoint_dir='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237/pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional/checkpoints/', convert_checkpoint_if_possible=True, metrics_dir='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237/pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional/metrics/', tensorboard_dir='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237/pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional/tensorboard/', managed_mldiagnostics_dir='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237/pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional/managed-mldiagnostics', rampup_end_step=0, tensors_on_device=[], tensors_to_offload=['decoder_layer_input', 'query_proj', 'key_proj', 'value_proj'], global_batch_size_to_load_start=None, global_batch_size_to_load_increment=None, rampup_samples_per_increment_to_load=None, d_model_for_audio=256, encoder_attention_heads_for_audio=4, encoder_ffn_dim_for_audio=512, encoder_layers_for_audio=2, attention_dropout_for_audio=0.0, activation_dropout_for_audio=0.0, activation_function_for_audio='gelu', num_mel_bins_for_audio=128, max_source_positions_for_audio=1500, scale_embedding_for_audio=True, n_window_for_audio=50, n_window_infer_for_audio=800, conv_chunksize_for_audio=500, downsample_hidden_size_for_audio=256, output_dim_for_audio=512, num_conv_layers_for_audio=3, max_timescale_for_audio=10000.0, max_sample_len_for_audio=10000, projector_input_dim_for_vit=4096, projector_output_dim_for_vit=4096, pixel_shuffle_ratio_for_vit=0.5, projector_dropout_for_vit=0.0, hidden_size_for_vit=1408, intermediate_size_for_vit=5632, num_attention_heads_for_vit=16, num_channels_for_vit=3, tile_size_for_vit=336, patch_size_for_vit=14, conv_stride_for_vit=14, num_hidden_layers_for_vit=34, rope_theta_for_vit=10000, vision_output_dim_for_vit=4096, spatial_merge_size_for_vit=2, out_hidden_size_for_vit=512, temporal_patch_size_for_vit=2, num_position_embeddings_for_vit=1024, deepstack_visual_indexes_for_vit=[], vision_output_length=-1, use_multimodal=False, freeze_vision_encoder_params=True, freeze_audio_encoder_params=True, use_audio=False, image_size_for_vit=896, image_path='', image_placeholder='<|image|>', posemb_type_for_vit='learn', max_num_images_per_example=-1, video_path='', audio_path='', video_placeholder='<|video|>', audio_placeholder='<|audio|>', use_audio_in_video=False, use_mrope=False, mrope_section=[24, 20, 20], position_id_per_seconds=25, managed_mldiagnostics=False, managed_mldiagnostics_run_group='', enable_tensorboard=True, use_vertex_tensorboard=False, vertex_tensorboard_project='', vertex_tensorboard_region='', report_heartbeat_metric_for_gcp_monitoring=False, heartbeat_reporting_interval_in_seconds=5, report_performance_metric_for_gcp_monitoring=False, enable_goodput_recording=False, monitor_goodput=False, goodput_upload_interval_seconds=30, enable_pathways_goodput=False, monitor_step_time_deviation=True, step_deviation_interval_seconds=30, enable_gcp_goodput_metrics=True, enable_gcp_step_deviation_metrics=True, metrics_file='', gcs_metrics=False, save_config_to_gcs=False, record_internal_nn_metrics=0, prometheus_port=0, enable_checkpoint_cloud_logger=False, enable_tunix_perf_metrics=False, collect_stack_trace=False, stack_trace_to_cloud=False, stack_trace_interval_seconds=600, dump_hlo=False, dump_step=-1, dump_hlo_local_dir='/tmp/xla_dump/', dump_hlo_delete_local_after=True, dump_hlo_gcs_dir='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237/pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional/xla_dump', dump_hlo_module_name='jit_train_step', dump_hlo_local_module_name='jit_train_step', dump_hlo_xla_flags='--xla_dump_to=/tmp/xla_dump/ --xla_dump_large_constants --xla_dump_hlo_module_re=jit_train_step', dump_hlo_upload_all=False, dump_jaxpr=False, dump_jaxpr_local_dir='/tmp/jaxpr_dump/', dump_jaxpr_delete_local_after=True, dump_jaxpr_gcs_dir='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237/pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional/jaxpr_dump', profiler=<ProfilerType.XPLANE: 'xplane'>, upload_all_profiler_results=False, skip_first_n_steps_for_profiler=1, profiler_steps=5, profile_cleanly=True, profile_periodically_period=-1, hide_profiler_step_metric=False, enable_jax_profiler=False, jax_profiler_port=9999, xprof_tpu_power_trace_level=<XProfTPUPowerTraceMode.POWER_TRACE_NONE: 0>, xprof_e2e_enable_fw_throttle_event=False, xprof_e2e_enable_fw_power_level_event=False, xprof_e2e_enable_fw_thermal_event=False, profile_power_events=False, constant_bound_config=[], jax_cache_dir='~/jax_cache', jax_distributed_initialization_timeout=300, jax_debug_log_modules='', skip_jax_distributed_system=True, enable_single_controller=False, subslice_shape='', max_checkify=False, compiled_trainstep_file='', compile_topology='', compile_topology_num_slices=-1, enable_prefix_caching=False, prefix_caching_hbm_byte=10000000000, prefix_caching_dram_byte=100000000000, inference_microbenchmark_prefill_lengths='64,128,256,512,1024', inference_microbenchmark_stages='prefill,generate', inference_microbenchmark_loop_iters=10, inference_microbenchmark_log_file_path='', inference_microbenchmark_num_samples=[1, 2, 3, 4, 5], inference_metadata_file='', inference_benchmark_test=False, inference_server='MaxtextInterleavedServer', prefill_slice='v5e-16', generate_slice='v5e-16', stack_prefill_result_cache=False, prefill_cache_axis_order='1,2,0,3', ar_cache_axis_order='1,2,0,3', compute_axis_order='0,1,2,3', reshape_q=False, decode_sampling_strategy=<SamplingStrategy.GREEDY: 'greedy'>, decode_sampling_nucleus_p=1.0, decode_sampling_top_k=50, decode_sampling_temperature=0.9, max_target_length=1024, max_prefill_predict_length=256, prompt='I love to', load_from_prefill_dir=False, prefill_cache_dir='', autoregressive_decode_assert='', model_call_mode='', use_chunked_prefill=False, prefill_chunk_size=256, enable_model_warmup=False, enable_llm_inference_pool=False, multi_sampling=False, return_log_prob=False, vocab_size=128256, tokenizer_path='meta-llama/Llama-3.1-8B-Instruct', tokenizer_type=<TokenizerType.HUGGINGFACE: 'huggingface'>, use_chat_template=False, chat_template_path='maxtext/examples/chat_templates/gsm8k_rl.json', chat_template='', tokenize_train_data=True, tokenize_eval_data=True, add_bos=True, add_eos=True, use_truncation=True, num_vocab_tiling=1, grain_train_files='', grain_eval_files='', grain_train_mixture_config_path='', grain_file_type='arrayrecord', grain_worker_count=1, grain_per_worker_buffer_size=1, grain_worker_count_eval=1, grain_per_worker_buffer_size_eval=1, grain_ram_budget_mb=1024, grain_num_threads=16, grain_prefetch_buffer_size=500, grain_num_threads_eval=16, grain_prefetch_buffer_size_eval=500, grain_data_source_max_workers=16, grain_shuffle_buffer_size=100, hf_path='', hf_name='main', hf_data_dir=None, hf_train_files=None, hf_eval_split=None, hf_eval_files=None, hf_access_token='<redacted>', dataset_path='', dataset_name='openai/gsm8k', eval_dataset_name='openai/gsm8k', train_split='train', eval_split='test', dataset_type=<DatasetType.TFDS: 'tfds'>, per_device_batch_size=12.0, eval_per_device_batch_size=12.0, max_corpus_chars=10000000, train_data_columns=['text'], train_image_column='image', eval_data_columns=['text'], eval_image_column='image', packing=True, grain_packing_type='first_fit', max_segments_per_seq=-1, num_epoch=1, expansion_factor_real_data=-1.0, reuse_example_batch=0, generate_padding_batch_train=False, generate_padding_batch_eval=False, enable_rampup_batch_size=False, per_device_batch_size_start=4.0, per_device_batch_size_increment=2.0, global_rampup_samples=500, colocated_python_data_input=False, max_position_embeddings=163840, original_max_position_embeddings=4096, rope_factor=40, beta_fast=32, beta_slow=1, mscale=1.0, rope_interleave=True, rope_truncate=True, rope_attention_scaling=False, rope_type=<RopeType.DEFAULT: 'default'>, rope_use_scale=True, rope_min_timescale=1, rope_max_timescale=500000, rope_linear_scaling_factor=1.0, local_rope_max_timescale=-1, global_rope_max_timescale=-1, global_rope_proportion=0.25, local_rope_proportion=1.0, use_iota_embed=False, use_untrainable_positional_embedding=False, trainable_position_size=-1, nope_layer_interval=-1, reasoning_start_token='<reasoning>', reasoning_end_token='</reasoning>', solution_start_token='<answer>', solution_end_token='</answer>', reward_exact_answer=1.0, reward_exact_format_match=0.1, reward_white_space_format_match=1.0, reward_partial_format_match=0.0, reward_ratio_guess_to_answer_high=0.0, reward_ratio_guess_to_answer_low=0.0, penalty_incorrect_format=0.0, penalty_incorrect_answer=0.0, math_verify_timeout=120, math_verify_num_procs=None, eval_sampling_strategy='greedy', generation_configs={'greedy': {'eval_temperature': 0.01, 'eval_top_k': 1, 'eval_top_p': 1.0}, 'standard': {'eval_temperature': 0.7, 'eval_top_k': 50, 'eval_top_p': 0.95}, 'liberal': {'eval_temperature': 0.85, 'eval_top_k': 2000, 'eval_top_p': 1.0}}, num_eval_passes=1, eval_corr_lst=False, eval_make_lst=False, eval_mode='pass', batch_size=1, num_batches=2, num_test_batches=5, test_batch_start_index=0, train_fraction=1.0, train_micro_batch_size=-1, rollout_micro_batch_size=-1, num_generations=2, num_iterations=1, grpo_beta=0.08, grpo_epsilon=0.2, loss_algo='grpo', use_agentic_rollout=False, max_concurrency=256, off_policy_steps=0, system_prompt='', degenerate_group_masking=True, epsilon_high=None, kv_cache_buffer=256, hbm_utilization_vllm=0.72, swap_space_vllm_gb=2, enable_dp_attention=False, enable_expert_parallel=False, async_scheduling=False, max_num_batched_tokens=None, max_num_seqs=None, stop_strings=None, vllm_additional_config={}, vllm_hf_overrides={}, vllm_hf_config_path='', trainer_devices_fraction=0.5, sampler_devices_fraction=0.5, chips_per_vm=4, use_pathways=False, num_trainer_slices=-1, num_samplers_slices=-1, rollout_data_parallelism=2, rollout_tensor_parallelism=-1, rollout_expert_parallelism=1, student_overrides={}, teacher_overrides={}, offline_data_dir=None, distill_alpha=0.5, distill_temperature=1.0, distill_beta=0.0, distill_feature_loss_type='cosine', distill_layer_indices=None, distill_alpha_end=None, distill_alpha_schedule='constant', distill_temperature_end=None, distill_temperature_schedule='constant', distill_beta_end=None, distill_beta_schedule='constant', student_params_to_update=None, use_dpo=False, dpo_label_smoothing=0.0, dpo_beta=0.1, use_sft=False, sft_train_on_completion_only=False, formatting_func_path='', formatting_func_kwargs={}, use_grpo=True, muon_beta=0.95, muon_weight_decay=0.0, muon_consistent_rms=None, adam_b1=0.9, adam_b2=0.99, adam_eps=1e-08, adam_eps_root=0.0, adam_weight_decay=0.1, adamw_mask=[], mu_dtype=<DType.BFLOAT16: 'bfloat16'>, opt_type=<OptimizerType.ADAMW: 'adamw'>, skip_step_on_spikes=False, skip_step_interval=128, skip_step_scaling_factor=6.0, gradient_accumulation_steps=1, use_tunix_gradient_accumulation=False, gradient_clipping_threshold=0.1, learning_rate=3e-06, lr_schedule_type=<LearningRateScheduleType.COSINE: 'cosine'>, learning_rate_final_fraction=0.1, wsd_decay_steps_fraction=0.1, wsd_decay_style=<WsdDecayStyle.LINEAR: 'linear'>, warmup_steps_fraction=0.1, learning_rate_schedule_steps=150001, trainable_parameters_mask=[], enable_diloco=False, diloco_sync_period=36, diloco_outer_lr=0.3, diloco_outer_momentum=0.9, mhc_expansion_rate=1, sinkhorn_iterations=20, steps=150001, log_period=20, eval_interval=10, eval_steps=-1, target_eval_loss=0.0, abort_on_nan_loss=True, abort_on_inf_loss=True, enable_dropout=False, dropout_rate=0.0, enable_data_shuffling=True, data_shuffle_seed=42, init_weights_seed=0, remat_policy='custom', remat_policy_for_vit='minimal', decoder_layer_input=<RematLocation.OFFLOAD: 'offload'>, context=<RematLocation.REMAT: 'remat'>, mlpwi=<RematLocation.REMAT: 'remat'>, mlpwi_0=<RematLocation.REMAT: 'remat'>, mlpwi_1=<RematLocation.REMAT: 'remat'>, mlpwo=<RematLocation.REMAT: 'remat'>, moe_mlpwi_0=<RematLocation.REMAT: 'remat'>, moe_mlpwi_1=<RematLocation.REMAT: 'remat'>, moe_mlpwo=<RematLocation.REMAT: 'remat'>, query_proj=<RematLocation.OFFLOAD: 'offload'>, key_proj=<RematLocation.OFFLOAD: 'offload'>, value_proj=<RematLocation.OFFLOAD: 'offload'>, query_wa_proj=<RematLocation.REMAT: 'remat'>, kv_wa_proj=<RematLocation.REMAT: 'remat'>, qkv_proj=<RematLocation.REMAT: 'remat'>, out_proj=<RematLocation.REMAT: 'remat'>, mla_q=<RematLocation.REMAT: 'remat'>, mla_kv=<RematLocation.REMAT: 'remat'>, attention_out=<RematLocation.REMAT: 'remat'>, engram=<RematLocation.REMAT: 'remat'>, optimizer_memory_host_offload=False, parameter_memory_host_offload=False, pipeline_fsdp_ag_per_repeat=False, num_layers_per_pipeline_stage=1, num_pipeline_repeats=-1, pipeline_parallel_layers=32, num_pipeline_microbatches=-1, pipeline_delay_activation_forwarding=False, pipeline_fsdp_ag_once=False, scan_pipeline_iterations=True, scan_pipeline_repeats=False, scan_layers_per_stage=False, set_remat_policy_on_pipeline_iterations=True, set_remat_policy_on_layers_per_stage=False, ici_diloco_parallelism=1, ici_data_parallelism=1, ici_fsdp_parallelism=-1, ici_fsdp_transpose_parallelism=1, ici_sequence_parallelism=1, ici_context_parallelism=1, ici_context_autoregressive_parallelism=1, ici_tensor_parallelism=1, ici_tensor_transpose_parallelism=1, ici_tensor_sequence_parallelism=1, ici_autoregressive_parallelism=1, ici_pipeline_parallelism=1, ici_expert_parallelism=1, dcn_diloco_parallelism=1, dcn_data_parallelism=-1, dcn_fsdp_parallelism=1, dcn_fsdp_transpose_parallelism=1, dcn_sequence_parallelism=1, dcn_context_parallelism=1, dcn_context_autoregressive_parallelism=1, dcn_tensor_parallelism=1, dcn_tensor_transpose_parallelism=1, dcn_tensor_sequence_parallelism=1, dcn_pipeline_parallelism=1, dcn_expert_parallelism=1, dcn_autoregressive_parallelism=1, logical_axis_rules=[['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'context', 'expert']], ['activation_vocab', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_vocab', ['tensor', 'tensor_transpose']], ['activation_vocab', 'tensor_sequence'], ['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['embed_vocab', ['fsdp', 'fsdp_transpose', 'context', 'expert']], ['activation_batch_attn', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['activation_kv_heads', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_length_attn', ['context']], ['activation_q_length', ['context']], ['activation_kv_length', []], ['activation_embed_attn', ['tensor', 'tensor_transpose']], ['activation_kv', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_kv_head_dim', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['q_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['kv_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['qkv', []], ['kv', []], ['kv_head_dim', []], ['q_lora', ['fsdp', 'fsdp_transpose', 'context', 'tensor_transpose', 'expert']], ['q_lora', ['fsdp', 'context', 'tensor_transpose', 'expert']], ['q_lora', ['fsdp', 'fsdp_transpose', 'context', 'expert']], ['q_lora', ['fsdp', 'context', 'expert']], ['q_lora_up_proj', []], ['kv_lora', ['fsdp', 'fsdp_transpose', 'context', 'tensor_transpose', 'expert']], ['kv_lora', ['fsdp', 'context', 'tensor_transpose', 'expert']], ['kv_lora', ['fsdp', 'fsdp_transpose', 'context', 'expert']], ['kv_lora', ['fsdp', 'context', 'expert']], ['kv_lora_up_proj', []], ['activation_batch_moe', ['data', 'fsdp', 'fsdp_transpose']], ['activation_length_moe', ['context']], ['activation_norm_length_moe', ['tensor_sequence', 'context']], ['activation_embed_moe', ['tensor', 'tensor_transpose']], ['activation_mlp_moe', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_exp', ['expert']], ['exp', 'expert'], ['mlp_moe', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']], ['embed_moe', ['fsdp', 'fsdp_transpose', 'tensor_transpose', 'context']], ['embed_moe', ['fsdp', 'tensor_transpose', 'context']], ['embed_moe', ['fsdp', 'fsdp_transpose', 'context']], ['embed_moe', ['fsdp', 'context']], ['activation_mlp', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_length', ['context']], ['activation_norm_length', ['tensor_sequence', 'context']], ['activation_embed', ['tensor', 'tensor_transpose']], ['activation_stage', 'stage'], ['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']], ['embed', ['fsdp', 'fsdp_transpose', 'tensor_transpose', 'context', 'expert']], ['embed', ['fsdp', 'tensor_transpose', 'context', 'expert']], ['embed', ['fsdp', 'fsdp_transpose', 'context', 'expert']], ['embed', ['fsdp', 'context', 'expert']], ['norm', ['tensor', 'tensor_transpose']], ['layers', 'stage'], ['diloco', 'diloco'], ['engram_dim', ['tensor']], ['dense_layers', []], ['moe_layers', []], ['mhc', []], ['prefill_activation_length', ['context']], ['prefill_activation_norm_length', ['tensor_sequence', 'context']], ['activation_prefill_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['decode_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['decode_length', []], ['cache_heads', ['autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence']], ['cache_heads', ['autoregressive', 'tensor', 'tensor_sequence']], ['paged_kv_heads', ['tensor']], ['cache_batch_prefill', []], ['cache_batch', []], ['cache_heads_none', []], ['cache_kv', []], ['cache_sequence', []], ['num_pages', []], ['tokens_per_page', []], ['paged_kv_head_dim_size', []], ['mlp_no_fsdp', ['tensor', 'tensor_sequence', 'autoregressive']], ['embed_tensor_transpose', ['tensor_transpose']], ['exp_with_fsdp', 'fsdp']], data_sharding=[['data', 'stage', 'fsdp', 'fsdp_transpose', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']], context_sharding='context', input_data_sharding_logical_axes=['activation_embed_and_logits_batch', 'activation_norm_length'], sharding_tolerance=0.02, shard_optimizer_over_data=False, internal_compile=False, internal_compile_num_devices=-1, compile_xla_flags='', hardware='tpu', num_slices=1, mesh_axes=['diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive'], shard_mode=<ShardMode.AUTO: 'auto'>, inhomogeneous_layer_cycle_interval=1, scan_layers=True, param_scan_axis=1, context_parallel_load_balance=True, context_parallel_strategy='all_gather', context_parallel_reorder_strategy=<ReorderStrategy.AUTO: 'auto'>, custom_mesh='', custom_mesh_and_rule='', allow_split_physical_axes=False, enable_nnx=True, optimize_mesh_for_tpu_v6e=False, shardy=True, pure_nnx_decoder=True, pure_nnx=True, remove_size_one_mesh_axis_from_type=True, gdn_conv_kernel_dim=4, gdn_key_head_dim=128, gdn_value_head_dim=128, gdn_num_key_heads=16, gdn_num_value_heads=32, gdn_chunk_size=64, use_qk_norm_in_gdn=True, partial_rotary_factor=1.0, first_num_dense_layers=0, shared_experts=0, routed_scaling_factor=1.0, routed_score_func='', routed_bias=False, routed_bias_update_rate=0.0, mlp_bias=False, n_routing_groups=-1, topk_routing_group=-1, use_batch_split_schedule=False, batch_split_factor=1, megablox=True, sparse_matmul=True, wi_tile_fwd_batch_seq=512, wi_tile_fwd_embed_dim=1024, wi_tile_fwd_mlp_dim=1024, wi_tile_dlhs_batch_seq=512, wi_tile_dlhs_embed_dim=1024, wi_tile_dlhs_mlp_dim=1024, wi_tile_drhs_batch_seq=512, wi_tile_drhs_embed_dim=1024, wi_tile_drhs_mlp_dim=1024, wo_tile_fwd_batch_seq=512, wo_tile_fwd_embed_dim=1024, wo_tile_fwd_mlp_dim=1024, wo_tile_dlhs_batch_seq=512, wo_tile_dlhs_embed_dim=1024, wo_tile_dlhs_mlp_dim=1024, wo_tile_drhs_batch_seq=512, wo_tile_drhs_embed_dim=1024, wo_tile_drhs_mlp_dim=1024, merge_gating_gmm=False, num_experts=1, num_experts_per_tok=1, capacity_factor=-1.0, ragged_buffer_factor=-1.0, moe_expert_input_dim=-1, base_moe_mlp_dim=-1, load_balance_loss_weight=0.0, use_custom_sort_vjp=True, use_ring_of_experts=False, use_gather_mosaic_kernel=False, use_random_routing=False, interleave_moe_layer_step=1, moe_fsdp_use_two_stage_all_gather=False, shard_exp_on_fsdp=False, use_2d_fsdp_sharding=False, norm_topk_prob=False, float32_weight_sum=True, float32_gate_logits=False, prefuse_moe_weights=False, pagedattn_num_pages=64, pagedattn_tokens_per_page=32, pagedattn_pages_per_compute_block=4, pagedattn_max_pages_per_group=-1, pagedattn_head_dim_alignment=128, sa_block_q=512, sa_block_kv=512, sa_block_kv_compute=512, sa_block_q_dkv=512, sa_block_kv_dkv=512, sa_block_kv_dkv_compute=512, sa_block_q_dq=512, sa_block_kv_dq=512, sa_use_fused_bwd_kernel=False, sa_q_layout='HEAD_DIM_MINOR', sa_k_layout='HEAD_DIM_MINOR', sa_v_layout='HEAD_DIM_MINOR', use_max_logit_estimate=-1, cost_estimate_flops_fwd=-1, cost_estimate_flops_bwd=-1, dq_reduction_steps=0, use_splash_scheduler=False, use_qk_norm=False, temperature_tuning=False, use_indexer=False, indexer_head_dim=128, indexer_n_heads=64, indexer_topk=2048, indexer_sparse_training=False, indexer_loss_scaling_factor=0.0, moba=False, moba_chunk_size=1024, moba_topk=8, mla_naive_kvcache=True, q_lora_rank=0, kv_lora_rank=512, qk_nope_head_dim=128, qk_rope_head_dim=64, v_head_dim=128, attention='dot_product', attention_type='global', share_kv_projections=False, global_num_kv_heads=0, attention_sink=False, float32_qk_product=False, float32_logits=False, sliding_window_size=0, chunk_attn_window_size=0, attn_logits_soft_cap=None, use_post_attn_norm=False, use_post_ffw_norm=False, use_ragged_attention=False, use_tokamax_gmm=False, ragged_block_size=256, enable_padding_causal_mask=True, use_tokamax_splash=False, use_jax_splash=False, force_q_layout=False, use_qk_clip=False, qk_clip_threshold=100.0, logits_via_embedding=False, normalize_embedding_logits=True, logits_dot_in_fp32=False, cast_logits_to_fp32=True, final_logits_soft_cap=None, z_loss_multiplier=0.0, mtp_num_layers=0, mtp_loss_scaling_factor=0.1, mtp_eval_target_module=0, engram_layers=[], engram_num_heads=8, engram_head_dim=1280, engram_vocab_bases=[], engram_max_ngram_size=3, engram_kernel_size=4, engram_seed=0, decoder_block=<DecoderBlockType.LLAMA2: 'llama2'>, global_parameter_scale=1, base_emb_dim=4096, base_num_query_heads=32, base_num_kv_heads=8, base_mlp_dim=14336, dense_init_scale=1.0, base_num_decoder_layers=32, head_dim=128, attention_output_dim=-1, global_head_dim=0, mlp_activations=['silu', 'linear'], mlp_activations_limit=-1.0, normalization_layer_epsilon=1e-05, fused_qkv=False, attention_bias=False, fused_mlp=False, qk_norm_with_scale=True, v_norm_with_scale=True, quantization=<QuantizationType.NONE: ''>, replicate_quant_scale=False, quant_cfg_path='', quantize_kvcache=False, kv_quant_axis=<KvQuantAxis.HEADS_AND_DKV: 'heads_and_dkv'>, kv_quant_dtype='int8', quantization_local_shard_count=4, use_qwix_quantization=False, weight_quantization_calibration_method='absmax', act_quantization_calibration_method='absmax', bwd_quantization_calibration_method='absmax', dtype=<DType.BFLOAT16: 'bfloat16'>, grad_dtype=<DType.FLOAT32: 'float32'>, weight_dtype=<DType.BFLOAT16: 'bfloat16'>, matmul_precision=<MatmulPrecision.DEFAULT: 'default'>, activations_in_float32=False, dtype_mm='float32', elastic_enabled=False, elastic_timeout_seconds=300, elastic_max_retries=10, enable_multi_tier_checkpointing=False, local_checkpoint_directory='', local_checkpoint_period=0, multi_tier_checkpointing_backup_interval_minutes=0, mtc_data_parallelism=0, enable_emergency_checkpoint=False, use_replicator_service=False, replicator_backup_interval_minutes=0, checkpoint_storage_target_data_file_size_bytes=2147483648, checkpoint_storage_use_ocdbt=False, checkpoint_storage_use_zarr3=False, checkpoint_storage_concurrent_gb=96, load_parameters_path='gs://lance-maxtext/rl_ckpt_llama31_8b/rl_ckpt_llama31_8b/0/items', lora_input_adapters_path='', load_full_state_path='', enable_checkpointing=True, load_checkpoint_only_once=False, async_checkpointing=False, checkpoint_period=50, max_num_checkpoints_to_keep=10, enable_single_replica_ckpt_restoring=False, checkpoint_todelete_subdir=None, checkpoint_todelete_full_path=None, force_unroll=False, checkpoint_is_quantized=False, save_quantized_params_path='', enable_orbax_v1=False, checkpoint_conversion_fn=None, source_checkpoint_layout='orbax', save_checkpoint_on_completion=True, enable_continuous_checkpointing=False, colocated_python_checkpointing=False, enable_autocheckpoint=False, base_config='../base.yml', run_name='pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional', model_name='llama3.1-8b-Instruct', override_model_config=False, override_logical_axis_rules=False, log_config=False, debug_sharding=False, base_output_directory='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237', sharding_strategy=None, debug=Debug(rl=True), rl=RL(num_generations=2, num_iterations=1, grpo_beta=0.08, grpo_epsilon=0.2, loss_algo='grpo', use_agentic_rollout=False, max_concurrency=256, off_policy_steps=0, system_prompt='', degenerate_group_masking=True, epsilon_high=None)), decode_input_axis_names=('decode_batch', 'decode_length', 'activation_embed_attn'), decode_out_axis_names=('decode_batch', 'decode_length', 'activation_heads', 'activation_kv'), dropout_rate=0.0, dtype=<DType.BFLOAT16: 'bfloat16'>, float32_logits=False, float32_qk_product=False, head_dim=128, input_axis_names=('activation_batch_attn', 'activation_length_attn', 'activation_embed_attn'), is_nope_layer=False, is_qwen2=False, is_qwen3_next=False, is_vision=False, kernel_init=<function nd_dense_init.<locals>.init_fn at 0x7db8f6039940>, key=DenseGeneral( # Param: 134,217,728 (268.4 MB) axis=(-1,), bias=None, dtype=<DType.BFLOAT16: 'bfloat16'>, in_features_shape=(4096,), kernel=Param( # 134,217,728 (268.4 MB) value=Array(shape=(4096, 32, 8, 128), dtype=dtype(bfloat16)), eager_sharding=False, out_sharding=('embed', 'layers', 'kv_heads', 'kv_head_dim') ), kernel_axes=('embed', 'kv_heads', 'kv_head_dim'), kernel_init=<function nd_dense_init.<locals>.init_fn at 0x7db8f6039940>, matmul_precision=<MatmulPrecision.DEFAULT: 'default'>, out_features_shape=(8, 128), parameter_memory_host_offload=False, quant=None, shard_mode=<ShardMode.AUTO: 'auto'>, use_bias=False, weight_dtype=<DType.BFLOAT16: 'bfloat16'> ), key_axis_names=('activation_kv_batch', 'activation_length_attn', 'activation_kv_heads', 'activation_kv_head_dim'), key_norm=None, kv_quant=None, max_prefill_predict_length=256, max_target_length=1024, mesh=Mesh(axis_sizes=(1, 1, 1, 32, 1, 1, 1, 1, 1, 1, 1, 1), axis_names=('diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive'), axis_types=(Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto)), model_mode='train', mrope_section=None, num_kv_heads=8, num_query_heads=32, out=DenseGeneral( # Param: 536,870,912 (1.1 GB) axis=(-2, -1), bias=None, dtype=<DType.BFLOAT16: 'bfloat16'>, in_features_shape=(32, 128), kernel=Param( # 536,870,912 (1.1 GB) value=Array(shape=(32, 32, 128, 4096), dtype=dtype(bfloat16)), eager_sharding=False, out_sharding=('heads', 'layers', 'kv', 'embed') ), kernel_axes=('heads', 'kv', 'embed'), kernel_init=<function nd_dense_init.<locals>.init_fn at 0x7db8f6039940>, matmul_precision=<MatmulPrecision.DEFAULT: 'default'>, out_features_shape=(4096,), parameter_memory_host_offload=False, quant=None, shard_mode=<ShardMode.AUTO: 'auto'>, use_bias=False, weight_dtype=<DType.BFLOAT16: 'bfloat16'> ), out_axis_names=('activation_batch_attn', 'activation_length_attn', 'activation_heads', 'activation_kv'), partial_rotary_factor=None, prefill_cache_axis_order=(1, 2, 0, 3), prefill_input_axis_names=('activation_prefill_kv_batch', 'prefill_activation_length', 'activation_embed_attn'), prefill_key_axis_names=('activation_prefill_kv_batch', 'prefill_activation_length', 'activation_kv_heads', 'activation_kv_head_dim'), prefill_out_axis_names=('activation_prefill_kv_batch', 'prefill_activation_length', 'activation_heads', 'activation_kv'), prefill_query_axis_names=('activation_prefill_kv_batch', 'prefill_activation_length', 'activation_kv_heads', 'activation_kv_head_dim'), prefill_value_axis_names=('activation_prefill_kv_batch', 'prefill_activation_length', 'activation_kv_heads', 'activation_kv_head_dim'), quant=None, query=DenseGeneral( # Param: 536,870,912 (1.1 GB) axis=(-1,), bias=None, dtype=<DType.BFLOAT16: 'bfloat16'>, in_features_shape=(4096,), kernel=Param( # 536,870,912 (1.1 GB) value=Array(shape=(4096, 32, 32, 128), dtype=dtype(bfloat16)), eager_sharding=False, out_sharding=('embed', 'layers', 'q_heads', 'kv') ), kernel_axes=('embed', 'q_heads', 'kv'), kernel_init=<function Attention.init_query_w.<locals>.query_init at 0x7db77cd3cb80>, matmul_precision=<MatmulPrecision.DEFAULT: 'default'>, out_features_shape=(32, 128), parameter_memory_host_offload=False, quant=None, shard_mode=<ShardMode.AUTO: 'auto'>, use_bias=False, weight_dtype=<DType.BFLOAT16: 'bfloat16'> ), query_axis_names=('activation_kv_batch', 'activation_length_attn', 'activation_kv_heads', 'activation_kv_head_dim'), query_norm=None, query_pre_attn_scalar=None, ragged_block_size=256, reshape_q=False, rngs=Rngs(...), rope_max_timescale=500000, rope_type='default', rotary_embedding=LLaMARotaryEmbedding( cast_as_fprop_dtype=True, embedding_dims=128, fprop_dtype=<DType.BFLOAT16: 'bfloat16'>, max_timescale=500000, mesh=Mesh(axis_sizes=(1, 1, 1, 32, 1, 1, 1, 1, 1, 1, 1, 1), axis_names=('diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive'), axis_types=(Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto)), min_timescale=1, rope_linear_scaling_factor=1.0, shard_mode=<ShardMode.AUTO: 'auto'>, use_scale=True ), share_kv_projections=False, sinks=None, sliding_window_size=None, temperature_tuning=False, temperature_tuning_floor_scale=8192.0, temperature_tuning_scale=0.1, use_bias_in_projections=False, use_mrope=False, use_qk_norm=False, use_ragged_attention=False, use_v_norm=False, value=DenseGeneral( # Param: 134,217,728 (268.4 MB) axis=(-1,), bias=None, dtype=<DType.BFLOAT16: 'bfloat16'>, in_features_shape=(4096,), kernel=Param( # 134,217,728 (268.4 MB) value=Array(shape=(4096, 32, 8, 128), dtype=dtype(bfloat16)), eager_sharding=False, out_sharding=('embed', 'layers', 'kv_heads', 'kv_head_dim') ), kernel_axes=('embed', 'kv_heads', 'kv_head_dim'), kernel_init=<function nd_dense_init.<locals>.init_fn at 0x7db8f6039940>, matmul_precision=<MatmulPrecision.DEFAULT: 'default'>, out_features_shape=(8, 128), parameter_memory_host_offload=False, quant=None, shard_mode=<ShardMode.AUTO: 'auto'>, use_bias=False, weight_dtype=<DType.BFLOAT16: 'bfloat16'> ), value_axis_names=('activation_kv_batch', 'activation_length_attn', 'activation_kv_heads', 'activation_kv_head_dim'), value_norm=None, weight_dtype=<DType.BFLOAT16: 'bfloat16'> ) ), logits_dense=DenseGeneral( # Param: 525,336,576 (1.1 GB) axis=(-1,), bias=None, dtype=<DType.BFLOAT16: 'bfloat16'>, in_features_shape=(4096,), kernel=Param( # 525,336,576 (1.1 GB) value=Array(shape=(4096, 128256), dtype=dtype(bfloat16)), eager_sharding=False, out_sharding=('embed_vocab', 'vocab') ), kernel_axes=('embed_vocab', 'vocab'), kernel_init=<function nd_dense_init.<locals>.init_fn at 0x7db8f65d58a0>, matmul_precision=<MatmulPrecision.DEFAULT: 'default'>, out_features_shape=(128256,), parameter_memory_host_offload=False, quant=None, shard_mode=<ShardMode.AUTO: 'auto'>, use_bias=False, weight_dtype=<DType.BFLOAT16: 'bfloat16'> ), mesh=Mesh(axis_sizes=(1, 1, 1, 32, 1, 1, 1, 1, 1, 1, 1, 1), axis_names=('diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive'), axis_types=(Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto)), model_mode='train', positional_embedding=PositionalEmbedding( cast_as_fprop_dtype=False, embedding_dims=4096, fprop_dtype=bfloat16, max_wavelength=10000, rngs=None ), quant=None, rngs=Rngs( # RngState: 6 (36 B) aqt=RngStream( # RngState: 2 (12 B) count=RngCount( # 1 (4 B) value=Array(2, dtype=uint32), eager_sharding=False, tag='aqt' ), key=RngKey( # 1 (8 B) value=Array((), dtype=key<fry>) overlaying: [4146024105 2718843009], eager_sharding=False, tag='aqt' ), tag='aqt' ), dropout=RngStream( # RngState: 2 (12 B) count=RngCount( # 1 (4 B) value=Array(2, dtype=uint32), eager_sharding=False, tag='dropout' ), key=RngKey( # 1 (8 B) value=Array((), dtype=key<fry>) overlaying: [ 928981903 3453687069], eager_sharding=False, tag='dropout' ), tag='dropout' ), params=RngStream( # RngState: 2 (12 B) count=RngCount( # 1 (4 B) value=Array(5, dtype=uint32), eager_sharding=False, tag='params' ), key=RngKey( # 1 (8 B) value=Array((), dtype=key<fry>) overlaying: [1797259609 2579123966], eager_sharding=False, tag='params' ), tag='params' ) ), scanned_layers=None ), hidden_states=None, mesh=Mesh(axis_sizes=(1, 1, 1, 32, 1, 1, 1, 1, 1, 1, 1, 1), axis_names=('diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive'), axis_types=(Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto)), model_mode='train', quant=None, token_embedder=Embed( # Param: 525,336,576 (1.1 GB) attend_dtype=<DType.BFLOAT16: 'bfloat16'>, cast_input_dtype=None, config=MaxTextConfig(emb_dim=4096, mlp_dim=14336, moe_mlp_dim=-1, num_decoder_layers=32, num_kv_heads=8, num_query_heads=32, num_diloco_replicas=1, ici_parallelism=[1, 1, 1, -1, 1, 1, 1, 1, 1, 1, 1, 1], dcn_parallelism=[1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], using_pipeline_parallelism=False, context_parallel_size=1, num_target_devices=32, global_batch_size_to_train_on=384, global_batch_size_to_eval_on=384, global_batch_size_to_load=384, global_batch_size_to_load_eval=384, micro_batch_size_to_train_on=384, micro_batch_size_to_eval_on=384, checkpoint_dir='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237/pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional/checkpoints/', convert_checkpoint_if_possible=True, metrics_dir='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237/pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional/metrics/', tensorboard_dir='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237/pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional/tensorboard/', managed_mldiagnostics_dir='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237/pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional/managed-mldiagnostics', rampup_end_step=0, tensors_on_device=[], tensors_to_offload=['decoder_layer_input', 'query_proj', 'key_proj', 'value_proj'], global_batch_size_to_load_start=None, global_batch_size_to_load_increment=None, rampup_samples_per_increment_to_load=None, d_model_for_audio=256, encoder_attention_heads_for_audio=4, encoder_ffn_dim_for_audio=512, encoder_layers_for_audio=2, attention_dropout_for_audio=0.0, activation_dropout_for_audio=0.0, activation_function_for_audio='gelu', num_mel_bins_for_audio=128, max_source_positions_for_audio=1500, scale_embedding_for_audio=True, n_window_for_audio=50, n_window_infer_for_audio=800, conv_chunksize_for_audio=500, downsample_hidden_size_for_audio=256, output_dim_for_audio=512, num_conv_layers_for_audio=3, max_timescale_for_audio=10000.0, max_sample_len_for_audio=10000, projector_input_dim_for_vit=4096, projector_output_dim_for_vit=4096, pixel_shuffle_ratio_for_vit=0.5, projector_dropout_for_vit=0.0, hidden_size_for_vit=1408, intermediate_size_for_vit=5632, num_attention_heads_for_vit=16, num_channels_for_vit=3, tile_size_for_vit=336, patch_size_for_vit=14, conv_stride_for_vit=14, num_hidden_layers_for_vit=34, rope_theta_for_vit=10000, vision_output_dim_for_vit=4096, spatial_merge_size_for_vit=2, out_hidden_size_for_vit=512, temporal_patch_size_for_vit=2, num_position_embeddings_for_vit=1024, deepstack_visual_indexes_for_vit=[], vision_output_length=-1, use_multimodal=False, freeze_vision_encoder_params=True, freeze_audio_encoder_params=True, use_audio=False, image_size_for_vit=896, image_path='', image_placeholder='<|image|>', posemb_type_for_vit='learn', max_num_images_per_example=-1, video_path='', audio_path='', video_placeholder='<|video|>', audio_placeholder='<|audio|>', use_audio_in_video=False, use_mrope=False, mrope_section=[24, 20, 20], position_id_per_seconds=25, managed_mldiagnostics=False, managed_mldiagnostics_run_group='', enable_tensorboard=True, use_vertex_tensorboard=False, vertex_tensorboard_project='', vertex_tensorboard_region='', report_heartbeat_metric_for_gcp_monitoring=False, heartbeat_reporting_interval_in_seconds=5, report_performance_metric_for_gcp_monitoring=False, enable_goodput_recording=False, monitor_goodput=False, goodput_upload_interval_seconds=30, enable_pathways_goodput=False, monitor_step_time_deviation=True, step_deviation_interval_seconds=30, enable_gcp_goodput_metrics=True, enable_gcp_step_deviation_metrics=True, metrics_file='', gcs_metrics=False, save_config_to_gcs=False, record_internal_nn_metrics=0, prometheus_port=0, enable_checkpoint_cloud_logger=False, enable_tunix_perf_metrics=False, collect_stack_trace=False, stack_trace_to_cloud=False, stack_trace_interval_seconds=600, dump_hlo=False, dump_step=-1, dump_hlo_local_dir='/tmp/xla_dump/', dump_hlo_delete_local_after=True, dump_hlo_gcs_dir='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237/pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional/xla_dump', dump_hlo_module_name='jit_train_step', dump_hlo_local_module_name='jit_train_step', dump_hlo_xla_flags='--xla_dump_to=/tmp/xla_dump/ --xla_dump_large_constants --xla_dump_hlo_module_re=jit_train_step', dump_hlo_upload_all=False, dump_jaxpr=False, dump_jaxpr_local_dir='/tmp/jaxpr_dump/', dump_jaxpr_delete_local_after=True, dump_jaxpr_gcs_dir='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237/pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional/jaxpr_dump', profiler=<ProfilerType.XPLANE: 'xplane'>, upload_all_profiler_results=False, skip_first_n_steps_for_profiler=1, profiler_steps=5, profile_cleanly=True, profile_periodically_period=-1, hide_profiler_step_metric=False, enable_jax_profiler=False, jax_profiler_port=9999, xprof_tpu_power_trace_level=<XProfTPUPowerTraceMode.POWER_TRACE_NONE: 0>, xprof_e2e_enable_fw_throttle_event=False, xprof_e2e_enable_fw_power_level_event=False, xprof_e2e_enable_fw_thermal_event=False, profile_power_events=False, constant_bound_config=[], jax_cache_dir='~/jax_cache', jax_distributed_initialization_timeout=300, jax_debug_log_modules='', skip_jax_distributed_system=True, enable_single_controller=False, subslice_shape='', max_checkify=False, compiled_trainstep_file='', compile_topology='', compile_topology_num_slices=-1, enable_prefix_caching=False, prefix_caching_hbm_byte=10000000000, prefix_caching_dram_byte=100000000000, inference_microbenchmark_prefill_lengths='64,128,256,512,1024', inference_microbenchmark_stages='prefill,generate', inference_microbenchmark_loop_iters=10, inference_microbenchmark_log_file_path='', inference_microbenchmark_num_samples=[1, 2, 3, 4, 5], inference_metadata_file='', inference_benchmark_test=False, inference_server='MaxtextInterleavedServer', prefill_slice='v5e-16', generate_slice='v5e-16', stack_prefill_result_cache=False, prefill_cache_axis_order='1,2,0,3', ar_cache_axis_order='1,2,0,3', compute_axis_order='0,1,2,3', reshape_q=False, decode_sampling_strategy=<SamplingStrategy.GREEDY: 'greedy'>, decode_sampling_nucleus_p=1.0, decode_sampling_top_k=50, decode_sampling_temperature=0.9, max_target_length=1024, max_prefill_predict_length=256, prompt='I love to', load_from_prefill_dir=False, prefill_cache_dir='', autoregressive_decode_assert='', model_call_mode='', use_chunked_prefill=False, prefill_chunk_size=256, enable_model_warmup=False, enable_llm_inference_pool=False, multi_sampling=False, return_log_prob=False, vocab_size=128256, tokenizer_path='meta-llama/Llama-3.1-8B-Instruct', tokenizer_type=<TokenizerType.HUGGINGFACE: 'huggingface'>, use_chat_template=False, chat_template_path='maxtext/examples/chat_templates/gsm8k_rl.json', chat_template='', tokenize_train_data=True, tokenize_eval_data=True, add_bos=True, add_eos=True, use_truncation=True, num_vocab_tiling=1, grain_train_files='', grain_eval_files='', grain_train_mixture_config_path='', grain_file_type='arrayrecord', grain_worker_count=1, grain_per_worker_buffer_size=1, grain_worker_count_eval=1, grain_per_worker_buffer_size_eval=1, grain_ram_budget_mb=1024, grain_num_threads=16, grain_prefetch_buffer_size=500, grain_num_threads_eval=16, grain_prefetch_buffer_size_eval=500, grain_data_source_max_workers=16, grain_shuffle_buffer_size=100, hf_path='', hf_name='main', hf_data_dir=None, hf_train_files=None, hf_eval_split=None, hf_eval_files=None, hf_access_token='<redacted>', dataset_path='', dataset_name='openai/gsm8k', eval_dataset_name='openai/gsm8k', train_split='train', eval_split='test', dataset_type=<DatasetType.TFDS: 'tfds'>, per_device_batch_size=12.0, eval_per_device_batch_size=12.0, max_corpus_chars=10000000, train_data_columns=['text'], train_image_column='image', eval_data_columns=['text'], eval_image_column='image', packing=True, grain_packing_type='first_fit', max_segments_per_seq=-1, num_epoch=1, expansion_factor_real_data=-1.0, reuse_example_batch=0, generate_padding_batch_train=False, generate_padding_batch_eval=False, enable_rampup_batch_size=False, per_device_batch_size_start=4.0, per_device_batch_size_increment=2.0, global_rampup_samples=500, colocated_python_data_input=False, max_position_embeddings=163840, original_max_position_embeddings=4096, rope_factor=40, beta_fast=32, beta_slow=1, mscale=1.0, rope_interleave=True, rope_truncate=True, rope_attention_scaling=False, rope_type=<RopeType.DEFAULT: 'default'>, rope_use_scale=True, rope_min_timescale=1, rope_max_timescale=500000, rope_linear_scaling_factor=1.0, local_rope_max_timescale=-1, global_rope_max_timescale=-1, global_rope_proportion=0.25, local_rope_proportion=1.0, use_iota_embed=False, use_untrainable_positional_embedding=False, trainable_position_size=-1, nope_layer_interval=-1, reasoning_start_token='<reasoning>', reasoning_end_token='</reasoning>', solution_start_token='<answer>', solution_end_token='</answer>', reward_exact_answer=1.0, reward_exact_format_match=0.1, reward_white_space_format_match=1.0, reward_partial_format_match=0.0, reward_ratio_guess_to_answer_high=0.0, reward_ratio_guess_to_answer_low=0.0, penalty_incorrect_format=0.0, penalty_incorrect_answer=0.0, math_verify_timeout=120, math_verify_num_procs=None, eval_sampling_strategy='greedy', generation_configs={'greedy': {'eval_temperature': 0.01, 'eval_top_k': 1, 'eval_top_p': 1.0}, 'standard': {'eval_temperature': 0.7, 'eval_top_k': 50, 'eval_top_p': 0.95}, 'liberal': {'eval_temperature': 0.85, 'eval_top_k': 2000, 'eval_top_p': 1.0}}, num_eval_passes=1, eval_corr_lst=False, eval_make_lst=False, eval_mode='pass', batch_size=1, num_batches=2, num_test_batches=5, test_batch_start_index=0, train_fraction=1.0, train_micro_batch_size=-1, rollout_micro_batch_size=-1, num_generations=2, num_iterations=1, grpo_beta=0.08, grpo_epsilon=0.2, loss_algo='grpo', use_agentic_rollout=False, max_concurrency=256, off_policy_steps=0, system_prompt='', degenerate_group_masking=True, epsilon_high=None, kv_cache_buffer=256, hbm_utilization_vllm=0.72, swap_space_vllm_gb=2, enable_dp_attention=False, enable_expert_parallel=False, async_scheduling=False, max_num_batched_tokens=None, max_num_seqs=None, stop_strings=None, vllm_additional_config={}, vllm_hf_overrides={}, vllm_hf_config_path='', trainer_devices_fraction=0.5, sampler_devices_fraction=0.5, chips_per_vm=4, use_pathways=False, num_trainer_slices=-1, num_samplers_slices=-1, rollout_data_parallelism=2, rollout_tensor_parallelism=-1, rollout_expert_parallelism=1, student_overrides={}, teacher_overrides={}, offline_data_dir=None, distill_alpha=0.5, distill_temperature=1.0, distill_beta=0.0, distill_feature_loss_type='cosine', distill_layer_indices=None, distill_alpha_end=None, distill_alpha_schedule='constant', distill_temperature_end=None, distill_temperature_schedule='constant', distill_beta_end=None, distill_beta_schedule='constant', student_params_to_update=None, use_dpo=False, dpo_label_smoothing=0.0, dpo_beta=0.1, use_sft=False, sft_train_on_completion_only=False, formatting_func_path='', formatting_func_kwargs={}, use_grpo=True, muon_beta=0.95, muon_weight_decay=0.0, muon_consistent_rms=None, adam_b1=0.9, adam_b2=0.99, adam_eps=1e-08, adam_eps_root=0.0, adam_weight_decay=0.1, adamw_mask=[], mu_dtype=<DType.BFLOAT16: 'bfloat16'>, opt_type=<OptimizerType.ADAMW: 'adamw'>, skip_step_on_spikes=False, skip_step_interval=128, skip_step_scaling_factor=6.0, gradient_accumulation_steps=1, use_tunix_gradient_accumulation=False, gradient_clipping_threshold=0.1, learning_rate=3e-06, lr_schedule_type=<LearningRateScheduleType.COSINE: 'cosine'>, learning_rate_final_fraction=0.1, wsd_decay_steps_fraction=0.1, wsd_decay_style=<WsdDecayStyle.LINEAR: 'linear'>, warmup_steps_fraction=0.1, learning_rate_schedule_steps=150001, trainable_parameters_mask=[], enable_diloco=False, diloco_sync_period=36, diloco_outer_lr=0.3, diloco_outer_momentum=0.9, mhc_expansion_rate=1, sinkhorn_iterations=20, steps=150001, log_period=20, eval_interval=10, eval_steps=-1, target_eval_loss=0.0, abort_on_nan_loss=True, abort_on_inf_loss=True, enable_dropout=False, dropout_rate=0.0, enable_data_shuffling=True, data_shuffle_seed=42, init_weights_seed=0, remat_policy='custom', remat_policy_for_vit='minimal', decoder_layer_input=<RematLocation.OFFLOAD: 'offload'>, context=<RematLocation.REMAT: 'remat'>, mlpwi=<RematLocation.REMAT: 'remat'>, mlpwi_0=<RematLocation.REMAT: 'remat'>, mlpwi_1=<RematLocation.REMAT: 'remat'>, mlpwo=<RematLocation.REMAT: 'remat'>, moe_mlpwi_0=<RematLocation.REMAT: 'remat'>, moe_mlpwi_1=<RematLocation.REMAT: 'remat'>, moe_mlpwo=<RematLocation.REMAT: 'remat'>, query_proj=<RematLocation.OFFLOAD: 'offload'>, key_proj=<RematLocation.OFFLOAD: 'offload'>, value_proj=<RematLocation.OFFLOAD: 'offload'>, query_wa_proj=<RematLocation.REMAT: 'remat'>, kv_wa_proj=<RematLocation.REMAT: 'remat'>, qkv_proj=<RematLocation.REMAT: 'remat'>, out_proj=<RematLocation.REMAT: 'remat'>, mla_q=<RematLocation.REMAT: 'remat'>, mla_kv=<RematLocation.REMAT: 'remat'>, attention_out=<RematLocation.REMAT: 'remat'>, engram=<RematLocation.REMAT: 'remat'>, optimizer_memory_host_offload=False, parameter_memory_host_offload=False, pipeline_fsdp_ag_per_repeat=False, num_layers_per_pipeline_stage=1, num_pipeline_repeats=-1, pipeline_parallel_layers=32, num_pipeline_microbatches=-1, pipeline_delay_activation_forwarding=False, pipeline_fsdp_ag_once=False, scan_pipeline_iterations=True, scan_pipeline_repeats=False, scan_layers_per_stage=False, set_remat_policy_on_pipeline_iterations=True, set_remat_policy_on_layers_per_stage=False, ici_diloco_parallelism=1, ici_data_parallelism=1, ici_fsdp_parallelism=-1, ici_fsdp_transpose_parallelism=1, ici_sequence_parallelism=1, ici_context_parallelism=1, ici_context_autoregressive_parallelism=1, ici_tensor_parallelism=1, ici_tensor_transpose_parallelism=1, ici_tensor_sequence_parallelism=1, ici_autoregressive_parallelism=1, ici_pipeline_parallelism=1, ici_expert_parallelism=1, dcn_diloco_parallelism=1, dcn_data_parallelism=-1, dcn_fsdp_parallelism=1, dcn_fsdp_transpose_parallelism=1, dcn_sequence_parallelism=1, dcn_context_parallelism=1, dcn_context_autoregressive_parallelism=1, dcn_tensor_parallelism=1, dcn_tensor_transpose_parallelism=1, dcn_tensor_sequence_parallelism=1, dcn_pipeline_parallelism=1, dcn_expert_parallelism=1, dcn_autoregressive_parallelism=1, logical_axis_rules=[['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'context', 'expert']], ['activation_vocab', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_vocab', ['tensor', 'tensor_transpose']], ['activation_vocab', 'tensor_sequence'], ['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['embed_vocab', ['fsdp', 'fsdp_transpose', 'context', 'expert']], ['activation_batch_attn', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['activation_kv_heads', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_length_attn', ['context']], ['activation_q_length', ['context']], ['activation_kv_length', []], ['activation_embed_attn', ['tensor', 'tensor_transpose']], ['activation_kv', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_kv_head_dim', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['q_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['kv_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['qkv', []], ['kv', []], ['kv_head_dim', []], ['q_lora', ['fsdp', 'fsdp_transpose', 'context', 'tensor_transpose', 'expert']], ['q_lora', ['fsdp', 'context', 'tensor_transpose', 'expert']], ['q_lora', ['fsdp', 'fsdp_transpose', 'context', 'expert']], ['q_lora', ['fsdp', 'context', 'expert']], ['q_lora_up_proj', []], ['kv_lora', ['fsdp', 'fsdp_transpose', 'context', 'tensor_transpose', 'expert']], ['kv_lora', ['fsdp', 'context', 'tensor_transpose', 'expert']], ['kv_lora', ['fsdp', 'fsdp_transpose', 'context', 'expert']], ['kv_lora', ['fsdp', 'context', 'expert']], ['kv_lora_up_proj', []], ['activation_batch_moe', ['data', 'fsdp', 'fsdp_transpose']], ['activation_length_moe', ['context']], ['activation_norm_length_moe', ['tensor_sequence', 'context']], ['activation_embed_moe', ['tensor', 'tensor_transpose']], ['activation_mlp_moe', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_exp', ['expert']], ['exp', 'expert'], ['mlp_moe', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']], ['embed_moe', ['fsdp', 'fsdp_transpose', 'tensor_transpose', 'context']], ['embed_moe', ['fsdp', 'tensor_transpose', 'context']], ['embed_moe', ['fsdp', 'fsdp_transpose', 'context']], ['embed_moe', ['fsdp', 'context']], ['activation_mlp', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_length', ['context']], ['activation_norm_length', ['tensor_sequence', 'context']], ['activation_embed', ['tensor', 'tensor_transpose']], ['activation_stage', 'stage'], ['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']], ['embed', ['fsdp', 'fsdp_transpose', 'tensor_transpose', 'context', 'expert']], ['embed', ['fsdp', 'tensor_transpose', 'context', 'expert']], ['embed', ['fsdp', 'fsdp_transpose', 'context', 'expert']], ['embed', ['fsdp', 'context', 'expert']], ['norm', ['tensor', 'tensor_transpose']], ['layers', 'stage'], ['diloco', 'diloco'], ['engram_dim', ['tensor']], ['dense_layers', []], ['moe_layers', []], ['mhc', []], ['prefill_activation_length', ['context']], ['prefill_activation_norm_length', ['tensor_sequence', 'context']], ['activation_prefill_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['decode_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['decode_length', []], ['cache_heads', ['autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence']], ['cache_heads', ['autoregressive', 'tensor', 'tensor_sequence']], ['paged_kv_heads', ['tensor']], ['cache_batch_prefill', []], ['cache_batch', []], ['cache_heads_none', []], ['cache_kv', []], ['cache_sequence', []], ['num_pages', []], ['tokens_per_page', []], ['paged_kv_head_dim_size', []], ['mlp_no_fsdp', ['tensor', 'tensor_sequence', 'autoregressive']], ['embed_tensor_transpose', ['tensor_transpose']], ['exp_with_fsdp', 'fsdp']], data_sharding=[['data', 'stage', 'fsdp', 'fsdp_transpose', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']], context_sharding='context', input_data_sharding_logical_axes=['activation_embed_and_logits_batch', 'activation_norm_length'], sharding_tolerance=0.02, shard_optimizer_over_data=False, internal_compile=False, internal_compile_num_devices=-1, compile_xla_flags='', hardware='tpu', num_slices=1, mesh_axes=['diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive'], shard_mode=<ShardMode.AUTO: 'auto'>, inhomogeneous_layer_cycle_interval=1, scan_layers=True, param_scan_axis=1, context_parallel_load_balance=True, context_parallel_strategy='all_gather', context_parallel_reorder_strategy=<ReorderStrategy.AUTO: 'auto'>, custom_mesh='', custom_mesh_and_rule='', allow_split_physical_axes=False, enable_nnx=True, optimize_mesh_for_tpu_v6e=False, shardy=True, pure_nnx_decoder=True, pure_nnx=True, remove_size_one_mesh_axis_from_type=True, gdn_conv_kernel_dim=4, gdn_key_head_dim=128, gdn_value_head_dim=128, gdn_num_key_heads=16, gdn_num_value_heads=32, gdn_chunk_size=64, use_qk_norm_in_gdn=True, partial_rotary_factor=1.0, first_num_dense_layers=0, shared_experts=0, routed_scaling_factor=1.0, routed_score_func='', routed_bias=False, routed_bias_update_rate=0.0, mlp_bias=False, n_routing_groups=-1, topk_routing_group=-1, use_batch_split_schedule=False, batch_split_factor=1, megablox=True, sparse_matmul=True, wi_tile_fwd_batch_seq=512, wi_tile_fwd_embed_dim=1024, wi_tile_fwd_mlp_dim=1024, wi_tile_dlhs_batch_seq=512, wi_tile_dlhs_embed_dim=1024, wi_tile_dlhs_mlp_dim=1024, wi_tile_drhs_batch_seq=512, wi_tile_drhs_embed_dim=1024, wi_tile_drhs_mlp_dim=1024, wo_tile_fwd_batch_seq=512, wo_tile_fwd_embed_dim=1024, wo_tile_fwd_mlp_dim=1024, wo_tile_dlhs_batch_seq=512, wo_tile_dlhs_embed_dim=1024, wo_tile_dlhs_mlp_dim=1024, wo_tile_drhs_batch_seq=512, wo_tile_drhs_embed_dim=1024, wo_tile_drhs_mlp_dim=1024, merge_gating_gmm=False, num_experts=1, num_experts_per_tok=1, capacity_factor=-1.0, ragged_buffer_factor=-1.0, moe_expert_input_dim=-1, base_moe_mlp_dim=-1, load_balance_loss_weight=0.0, use_custom_sort_vjp=True, use_ring_of_experts=False, use_gather_mosaic_kernel=False, use_random_routing=False, interleave_moe_layer_step=1, moe_fsdp_use_two_stage_all_gather=False, shard_exp_on_fsdp=False, use_2d_fsdp_sharding=False, norm_topk_prob=False, float32_weight_sum=True, float32_gate_logits=False, prefuse_moe_weights=False, pagedattn_num_pages=64, pagedattn_tokens_per_page=32, pagedattn_pages_per_compute_block=4, pagedattn_max_pages_per_group=-1, pagedattn_head_dim_alignment=128, sa_block_q=512, sa_block_kv=512, sa_block_kv_compute=512, sa_block_q_dkv=512, sa_block_kv_dkv=512, sa_block_kv_dkv_compute=512, sa_block_q_dq=512, sa_block_kv_dq=512, sa_use_fused_bwd_kernel=False, sa_q_layout='HEAD_DIM_MINOR', sa_k_layout='HEAD_DIM_MINOR', sa_v_layout='HEAD_DIM_MINOR', use_max_logit_estimate=-1, cost_estimate_flops_fwd=-1, cost_estimate_flops_bwd=-1, dq_reduction_steps=0, use_splash_scheduler=False, use_qk_norm=False, temperature_tuning=False, use_indexer=False, indexer_head_dim=128, indexer_n_heads=64, indexer_topk=2048, indexer_sparse_training=False, indexer_loss_scaling_factor=0.0, moba=False, moba_chunk_size=1024, moba_topk=8, mla_naive_kvcache=True, q_lora_rank=0, kv_lora_rank=512, qk_nope_head_dim=128, qk_rope_head_dim=64, v_head_dim=128, attention='dot_product', attention_type='global', share_kv_projections=False, global_num_kv_heads=0, attention_sink=False, float32_qk_product=False, float32_logits=False, sliding_window_size=0, chunk_attn_window_size=0, attn_logits_soft_cap=None, use_post_attn_norm=False, use_post_ffw_norm=False, use_ragged_attention=False, use_tokamax_gmm=False, ragged_block_size=256, enable_padding_causal_mask=True, use_tokamax_splash=False, use_jax_splash=False, force_q_layout=False, use_qk_clip=False, qk_clip_threshold=100.0, logits_via_embedding=False, normalize_embedding_logits=True, logits_dot_in_fp32=False, cast_logits_to_fp32=True, final_logits_soft_cap=None, z_loss_multiplier=0.0, mtp_num_layers=0, mtp_loss_scaling_factor=0.1, mtp_eval_target_module=0, engram_layers=[], engram_num_heads=8, engram_head_dim=1280, engram_vocab_bases=[], engram_max_ngram_size=3, engram_kernel_size=4, engram_seed=0, decoder_block=<DecoderBlockType.LLAMA2: 'llama2'>, global_parameter_scale=1, base_emb_dim=4096, base_num_query_heads=32, base_num_kv_heads=8, base_mlp_dim=14336, dense_init_scale=1.0, base_num_decoder_layers=32, head_dim=128, attention_output_dim=-1, global_head_dim=0, mlp_activations=['silu', 'linear'], mlp_activations_limit=-1.0, normalization_layer_epsilon=1e-05, fused_qkv=False, attention_bias=False, fused_mlp=False, qk_norm_with_scale=True, v_norm_with_scale=True, quantization=<QuantizationType.NONE: ''>, replicate_quant_scale=False, quant_cfg_path='', quantize_kvcache=False, kv_quant_axis=<KvQuantAxis.HEADS_AND_DKV: 'heads_and_dkv'>, kv_quant_dtype='int8', quantization_local_shard_count=4, use_qwix_quantization=False, weight_quantization_calibration_method='absmax', act_quantization_calibration_method='absmax', bwd_quantization_calibration_method='absmax', dtype=<DType.BFLOAT16: 'bfloat16'>, grad_dtype=<DType.FLOAT32: 'float32'>, weight_dtype=<DType.BFLOAT16: 'bfloat16'>, matmul_precision=<MatmulPrecision.DEFAULT: 'default'>, activations_in_float32=False, dtype_mm='float32', elastic_enabled=False, elastic_timeout_seconds=300, elastic_max_retries=10, enable_multi_tier_checkpointing=False, local_checkpoint_directory='', local_checkpoint_period=0, multi_tier_checkpointing_backup_interval_minutes=0, mtc_data_parallelism=0, enable_emergency_checkpoint=False, use_replicator_service=False, replicator_backup_interval_minutes=0, checkpoint_storage_target_data_file_size_bytes=2147483648, checkpoint_storage_use_ocdbt=False, checkpoint_storage_use_zarr3=False, checkpoint_storage_concurrent_gb=96, load_parameters_path='gs://lance-maxtext/rl_ckpt_llama31_8b/rl_ckpt_llama31_8b/0/items', lora_input_adapters_path='', load_full_state_path='', enable_checkpointing=True, load_checkpoint_only_once=False, async_checkpointing=False, checkpoint_period=50, max_num_checkpoints_to_keep=10, enable_single_replica_ckpt_restoring=False, checkpoint_todelete_subdir=None, checkpoint_todelete_full_path=None, force_unroll=False, checkpoint_is_quantized=False, save_quantized_params_path='', enable_orbax_v1=False, checkpoint_conversion_fn=None, source_checkpoint_layout='orbax', save_checkpoint_on_completion=True, enable_continuous_checkpointing=False, colocated_python_checkpointing=False, enable_autocheckpoint=False, base_config='../base.yml', run_name='pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional', model_name='llama3.1-8b-Instruct', override_model_config=False, override_logical_axis_rules=False, log_config=False, debug_sharding=False, base_output_directory='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237', sharding_strategy=None, debug=Debug(rl=True), rl=RL(num_generations=2, num_iterations=1, grpo_beta=0.08, grpo_epsilon=0.2, loss_algo='grpo', use_agentic_rollout=False, max_concurrency=256, off_policy_steps=0, system_prompt='', degenerate_group_masking=True, epsilon_high=None)), dtype=<DType.BFLOAT16: 'bfloat16'>, embedding=Param( # 525,336,576 (1.1 GB) value=Array(shape=(128256, 4096), dtype=dtype(bfloat16)), eager_sharding=False, out_sharding=('vocab', 'embed_vocab') ), mesh=Mesh(axis_sizes=(1, 1, 1, 32, 1, 1, 1, 1, 1, 1, 1, 1), axis_names=('diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive'), axis_types=(Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto)), num_embeddings=128256, num_features=4096 ), vision_encoder=None ), use_no_op_mappings=False, config=None ) TunixMaxTextAdapter( # Param: 8,030,261,248 (16.1 GB), RngState: 588 (3.5 KB), Total: 8,030,261,836 (16.1 GB) base=Transformer( # Param: 8,030,261,248 (16.1 GB), RngState: 588 (3.5 KB), Total: 8,030,261,836 (16.1 GB) audio_encoder=None, config=MaxTextConfig(emb_dim=4096, mlp_dim=14336, moe_mlp_dim=-1, num_decoder_layers=32, num_kv_heads=8, num_query_heads=32, num_diloco_replicas=1, ici_parallelism=[1, 1, 1, -1, 1, 1, 1, 1, 1, 1, 1, 1], dcn_parallelism=[1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], using_pipeline_parallelism=False, context_parallel_size=1, num_target_devices=32, global_batch_size_to_train_on=384, global_batch_size_to_eval_on=384, global_batch_size_to_load=384, global_batch_size_to_load_eval=384, micro_batch_size_to_train_on=384, micro_batch_size_to_eval_on=384, checkpoint_dir='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237/pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional/checkpoints/', convert_checkpoint_if_possible=True, metrics_dir='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237/pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional/metrics/', tensorboard_dir='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237/pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional/tensorboard/', managed_mldiagnostics_dir='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237/pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional/managed-mldiagnostics', rampup_end_step=0, tensors_on_device=[], tensors_to_offload=['decoder_layer_input', 'query_proj', 'key_proj', 'value_proj'], global_batch_size_to_load_start=None, global_batch_size_to_load_increment=None, rampup_samples_per_increment_to_load=None, d_model_for_audio=256, encoder_attention_heads_for_audio=4, encoder_ffn_dim_for_audio=512, encoder_layers_for_audio=2, attention_dropout_for_audio=0.0, activation_dropout_for_audio=0.0, activation_function_for_audio='gelu', num_mel_bins_for_audio=128, max_source_positions_for_audio=1500, scale_embedding_for_audio=True, n_window_for_audio=50, n_window_infer_for_audio=800, conv_chunksize_for_audio=500, downsample_hidden_size_for_audio=256, output_dim_for_audio=512, num_conv_layers_for_audio=3, max_timescale_for_audio=10000.0, max_sample_len_for_audio=10000, projector_input_dim_for_vit=4096, projector_output_dim_for_vit=4096, pixel_shuffle_ratio_for_vit=0.5, projector_dropout_for_vit=0.0, hidden_size_for_vit=1408, intermediate_size_for_vit=5632, num_attention_heads_for_vit=16, num_channels_for_vit=3, tile_size_for_vit=336, patch_size_for_vit=14, conv_stride_for_vit=14, num_hidden_layers_for_vit=34, rope_theta_for_vit=10000, vision_output_dim_for_vit=4096, spatial_merge_size_for_vit=2, out_hidden_size_for_vit=512, temporal_patch_size_for_vit=2, num_position_embeddings_for_vit=1024, deepstack_visual_indexes_for_vit=[], vision_output_length=-1, use_multimodal=False, freeze_vision_encoder_params=True, freeze_audio_encoder_params=True, use_audio=False, image_size_for_vit=896, image_path='', image_placeholder='<|image|>', posemb_type_for_vit='learn', max_num_images_per_example=-1, video_path='', audio_path='', video_placeholder='<|video|>', audio_placeholder='<|audio|>', use_audio_in_video=False, use_mrope=False, mrope_section=[24, 20, 20], position_id_per_seconds=25, managed_mldiagnostics=False, managed_mldiagnostics_run_group='', enable_tensorboard=True, use_vertex_tensorboard=False, vertex_tensorboard_project='', vertex_tensorboard_region='', report_heartbeat_metric_for_gcp_monitoring=False, heartbeat_reporting_interval_in_seconds=5, report_performance_metric_for_gcp_monitoring=False, enable_goodput_recording=False, monitor_goodput=False, goodput_upload_interval_seconds=30, enable_pathways_goodput=False, monitor_step_time_deviation=True, step_deviation_interval_seconds=30, enable_gcp_goodput_metrics=True, enable_gcp_step_deviation_metrics=True, metrics_file='', gcs_metrics=False, save_config_to_gcs=False, record_internal_nn_metrics=0, prometheus_port=0, enable_checkpoint_cloud_logger=False, enable_tunix_perf_metrics=False, collect_stack_trace=False, stack_trace_to_cloud=False, stack_trace_interval_seconds=600, dump_hlo=False, dump_step=-1, dump_hlo_local_dir='/tmp/xla_dump/', dump_hlo_delete_local_after=True, dump_hlo_gcs_dir='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237/pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional/xla_dump', dump_hlo_module_name='jit_train_step', dump_hlo_local_module_name='jit_train_step', dump_hlo_xla_flags='--xla_dump_to=/tmp/xla_dump/ --xla_dump_large_constants --xla_dump_hlo_module_re=jit_train_step', dump_hlo_upload_all=False, dump_jaxpr=False, dump_jaxpr_local_dir='/tmp/jaxpr_dump/', dump_jaxpr_delete_local_after=True, dump_jaxpr_gcs_dir='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237/pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional/jaxpr_dump', profiler=<ProfilerType.XPLANE: 'xplane'>, upload_all_profiler_results=False, skip_first_n_steps_for_profiler=1, profiler_steps=5, profile_cleanly=True, profile_periodically_period=-1, hide_profiler_step_metric=False, enable_jax_profiler=False, jax_profiler_port=9999, xprof_tpu_power_trace_level=<XProfTPUPowerTraceMode.POWER_TRACE_NONE: 0>, xprof_e2e_enable_fw_throttle_event=False, xprof_e2e_enable_fw_power_level_event=False, xprof_e2e_enable_fw_thermal_event=False, profile_power_events=False, constant_bound_config=[], jax_cache_dir='~/jax_cache', jax_distributed_initialization_timeout=300, jax_debug_log_modules='', skip_jax_distributed_system=True, enable_single_controller=False, subslice_shape='', max_checkify=False, compiled_trainstep_file='', compile_topology='', compile_topology_num_slices=-1, enable_prefix_caching=False, prefix_caching_hbm_byte=10000000000, prefix_caching_dram_byte=100000000000, inference_microbenchmark_prefill_lengths='64,128,256,512,1024', inference_microbenchmark_stages='prefill,generate', inference_microbenchmark_loop_iters=10, inference_microbenchmark_log_file_path='', inference_microbenchmark_num_samples=[1, 2, 3, 4, 5], inference_metadata_file='', inference_benchmark_test=False, inference_server='MaxtextInterleavedServer', prefill_slice='v5e-16', generate_slice='v5e-16', stack_prefill_result_cache=False, prefill_cache_axis_order='1,2,0,3', ar_cache_axis_order='1,2,0,3', compute_axis_order='0,1,2,3', reshape_q=False, decode_sampling_strategy=<SamplingStrategy.GREEDY: 'greedy'>, decode_sampling_nucleus_p=1.0, decode_sampling_top_k=50, decode_sampling_temperature=0.9, max_target_length=1024, max_prefill_predict_length=256, prompt='I love to', load_from_prefill_dir=False, prefill_cache_dir='', autoregressive_decode_assert='', model_call_mode='', use_chunked_prefill=False, prefill_chunk_size=256, enable_model_warmup=False, enable_llm_inference_pool=False, multi_sampling=False, return_log_prob=False, vocab_size=128256, tokenizer_path='meta-llama/Llama-3.1-8B-Instruct', tokenizer_type=<TokenizerType.HUGGINGFACE: 'huggingface'>, use_chat_template=False, chat_template_path='maxtext/examples/chat_templates/gsm8k_rl.json', chat_template='', tokenize_train_data=True, tokenize_eval_data=True, add_bos=True, add_eos=True, use_truncation=True, num_vocab_tiling=1, grain_train_files='', grain_eval_files='', grain_train_mixture_config_path='', grain_file_type='arrayrecord', grain_worker_count=1, grain_per_worker_buffer_size=1, grain_worker_count_eval=1, grain_per_worker_buffer_size_eval=1, grain_ram_budget_mb=1024, grain_num_threads=16, grain_prefetch_buffer_size=500, grain_num_threads_eval=16, grain_prefetch_buffer_size_eval=500, grain_data_source_max_workers=16, grain_shuffle_buffer_size=100, hf_path='', hf_name='main', hf_data_dir=None, hf_train_files=None, hf_eval_split=None, hf_eval_files=None, hf_access_token='<redacted>', dataset_path='', dataset_name='openai/gsm8k', eval_dataset_name='openai/gsm8k', train_split='train', eval_split='test', dataset_type=<DatasetType.TFDS: 'tfds'>, per_device_batch_size=12.0, eval_per_device_batch_size=12.0, max_corpus_chars=10000000, train_data_columns=['text'], train_image_column='image', eval_data_columns=['text'], eval_image_column='image', packing=True, grain_packing_type='first_fit', max_segments_per_seq=-1, num_epoch=1, expansion_factor_real_data=-1.0, reuse_example_batch=0, generate_padding_batch_train=False, generate_padding_batch_eval=False, enable_rampup_batch_size=False, per_device_batch_size_start=4.0, per_device_batch_size_increment=2.0, global_rampup_samples=500, colocated_python_data_input=False, max_position_embeddings=163840, original_max_position_embeddings=4096, rope_factor=40, beta_fast=32, beta_slow=1, mscale=1.0, rope_interleave=True, rope_truncate=True, rope_attention_scaling=False, rope_type=<RopeType.DEFAULT: 'default'>, rope_use_scale=True, rope_min_timescale=1, rope_max_timescale=500000, rope_linear_scaling_factor=1.0, local_rope_max_timescale=-1, global_rope_max_timescale=-1, global_rope_proportion=0.25, local_rope_proportion=1.0, use_iota_embed=False, use_untrainable_positional_embedding=False, trainable_position_size=-1, nope_layer_interval=-1, reasoning_start_token='<reasoning>', reasoning_end_token='</reasoning>', solution_start_token='<answer>', solution_end_token='</answer>', reward_exact_answer=1.0, reward_exact_format_match=0.1, reward_white_space_format_match=1.0, reward_partial_format_match=0.0, reward_ratio_guess_to_answer_high=0.0, reward_ratio_guess_to_answer_low=0.0, penalty_incorrect_format=0.0, penalty_incorrect_answer=0.0, math_verify_timeout=120, math_verify_num_procs=None, eval_sampling_strategy='greedy', generation_configs={'greedy': {'eval_temperature': 0.01, 'eval_top_k': 1, 'eval_top_p': 1.0}, 'standard': {'eval_temperature': 0.7, 'eval_top_k': 50, 'eval_top_p': 0.95}, 'liberal': {'eval_temperature': 0.85, 'eval_top_k': 2000, 'eval_top_p': 1.0}}, num_eval_passes=1, eval_corr_lst=False, eval_make_lst=False, eval_mode='pass', batch_size=1, num_batches=2, num_test_batches=5, test_batch_start_index=0, train_fraction=1.0, train_micro_batch_size=-1, rollout_micro_batch_size=-1, num_generations=2, num_iterations=1, grpo_beta=0.08, grpo_epsilon=0.2, loss_algo='grpo', use_agentic_rollout=False, max_concurrency=256, off_policy_steps=0, system_prompt='', degenerate_group_masking=True, epsilon_high=None, kv_cache_buffer=256, hbm_utilization_vllm=0.72, swap_space_vllm_gb=2, enable_dp_attention=False, enable_expert_parallel=False, async_scheduling=False, max_num_batched_tokens=None, max_num_seqs=None, stop_strings=None, vllm_additional_config={}, vllm_hf_overrides={}, vllm_hf_config_path='', trainer_devices_fraction=0.5, sampler_devices_fraction=0.5, chips_per_vm=4, use_pathways=False, num_trainer_slices=-1, num_samplers_slices=-1, rollout_data_parallelism=2, rollout_tensor_parallelism=-1, rollout_expert_parallelism=1, student_overrides={}, teacher_overrides={}, offline_data_dir=None, distill_alpha=0.5, distill_temperature=1.0, distill_beta=0.0, distill_feature_loss_type='cosine', distill_layer_indices=None, distill_alpha_end=None, distill_alpha_schedule='constant', distill_temperature_end=None, distill_temperature_schedule='constant', distill_beta_end=None, distill_beta_schedule='constant', student_params_to_update=None, use_dpo=False, dpo_label_smoothing=0.0, dpo_beta=0.1, use_sft=False, sft_train_on_completion_only=False, formatting_func_path='', formatting_func_kwargs={}, use_grpo=True, muon_beta=0.95, muon_weight_decay=0.0, muon_consistent_rms=None, adam_b1=0.9, adam_b2=0.99, adam_eps=1e-08, adam_eps_root=0.0, adam_weight_decay=0.1, adamw_mask=[], mu_dtype=<DType.BFLOAT16: 'bfloat16'>, opt_type=<OptimizerType.ADAMW: 'adamw'>, skip_step_on_spikes=False, skip_step_interval=128, skip_step_scaling_factor=6.0, gradient_accumulation_steps=1, use_tunix_gradient_accumulation=False, gradient_clipping_threshold=0.1, learning_rate=3e-06, lr_schedule_type=<LearningRateScheduleType.COSINE: 'cosine'>, learning_rate_final_fraction=0.1, wsd_decay_steps_fraction=0.1, wsd_decay_style=<WsdDecayStyle.LINEAR: 'linear'>, warmup_steps_fraction=0.1, learning_rate_schedule_steps=150001, trainable_parameters_mask=[], enable_diloco=False, diloco_sync_period=36, diloco_outer_lr=0.3, diloco_outer_momentum=0.9, mhc_expansion_rate=1, sinkhorn_iterations=20, steps=150001, log_period=20, eval_interval=10, eval_steps=-1, target_eval_loss=0.0, abort_on_nan_loss=True, abort_on_inf_loss=True, enable_dropout=False, dropout_rate=0.0, enable_data_shuffling=True, data_shuffle_seed=42, init_weights_seed=0, remat_policy='custom', remat_policy_for_vit='minimal', decoder_layer_input=<RematLocation.OFFLOAD: 'offload'>, context=<RematLocation.REMAT: 'remat'>, mlpwi=<RematLocation.REMAT: 'remat'>, mlpwi_0=<RematLocation.REMAT: 'remat'>, mlpwi_1=<RematLocation.REMAT: 'remat'>, mlpwo=<RematLocation.REMAT: 'remat'>, moe_mlpwi_0=<RematLocation.REMAT: 'remat'>, moe_mlpwi_1=<RematLocation.REMAT: 'remat'>, moe_mlpwo=<RematLocation.REMAT: 'remat'>, query_proj=<RematLocation.OFFLOAD: 'offload'>, key_proj=<RematLocation.OFFLOAD: 'offload'>, value_proj=<RematLocation.OFFLOAD: 'offload'>, query_wa_proj=<RematLocation.REMAT: 'remat'>, kv_wa_proj=<RematLocation.REMAT: 'remat'>, qkv_proj=<RematLocation.REMAT: 'remat'>, out_proj=<RematLocation.REMAT: 'remat'>, mla_q=<RematLocation.REMAT: 'remat'>, mla_kv=<RematLocation.REMAT: 'remat'>, attention_out=<RematLocation.REMAT: 'remat'>, engram=<RematLocation.REMAT: 'remat'>, optimizer_memory_host_offload=False, parameter_memory_host_offload=False, pipeline_fsdp_ag_per_repeat=False, num_layers_per_pipeline_stage=1, num_pipeline_repeats=-1, pipeline_parallel_layers=32, num_pipeline_microbatches=-1, pipeline_delay_activation_forwarding=False, pipeline_fsdp_ag_once=False, scan_pipeline_iterations=True, scan_pipeline_repeats=False, scan_layers_per_stage=False, set_remat_policy_on_pipeline_iterations=True, set_remat_policy_on_layers_per_stage=False, ici_diloco_parallelism=1, ici_data_parallelism=1, ici_fsdp_parallelism=-1, ici_fsdp_transpose_parallelism=1, ici_sequence_parallelism=1, ici_context_parallelism=1, ici_context_autoregressive_parallelism=1, ici_tensor_parallelism=1, ici_tensor_transpose_parallelism=1, ici_tensor_sequence_parallelism=1, ici_autoregressive_parallelism=1, ici_pipeline_parallelism=1, ici_expert_parallelism=1, dcn_diloco_parallelism=1, dcn_data_parallelism=-1, dcn_fsdp_parallelism=1, dcn_fsdp_transpose_parallelism=1, dcn_sequence_parallelism=1, dcn_context_parallelism=1, dcn_context_autoregressive_parallelism=1, dcn_tensor_parallelism=1, dcn_tensor_transpose_parallelism=1, dcn_tensor_sequence_parallelism=1, dcn_pipeline_parallelism=1, dcn_expert_parallelism=1, dcn_autoregressive_parallelism=1, logical_axis_rules=[['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'context', 'expert']], ['activation_vocab', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_vocab', ['tensor', 'tensor_transpose']], ['activation_vocab', 'tensor_sequence'], ['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['embed_vocab', ['fsdp', 'fsdp_transpose', 'context', 'expert']], ['activation_batch_attn', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['activation_kv_heads', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_length_attn', ['context']], ['activation_q_length', ['context']], ['activation_kv_length', []], ['activation_embed_attn', ['tensor', 'tensor_transpose']], ['activation_kv', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_kv_head_dim', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['q_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['kv_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['qkv', []], ['kv', []], ['kv_head_dim', []], ['q_lora', ['fsdp', 'fsdp_transpose', 'context', 'tensor_transpose', 'expert']], ['q_lora', ['fsdp', 'context', 'tensor_transpose', 'expert']], ['q_lora', ['fsdp', 'fsdp_transpose', 'context', 'expert']], ['q_lora', ['fsdp', 'context', 'expert']], ['q_lora_up_proj', []], ['kv_lora', ['fsdp', 'fsdp_transpose', 'context', 'tensor_transpose', 'expert']], ['kv_lora', ['fsdp', 'context', 'tensor_transpose', 'expert']], ['kv_lora', ['fsdp', 'fsdp_transpose', 'context', 'expert']], ['kv_lora', ['fsdp', 'context', 'expert']], ['kv_lora_up_proj', []], ['activation_batch_moe', ['data', 'fsdp', 'fsdp_transpose']], ['activation_length_moe', ['context']], ['activation_norm_length_moe', ['tensor_sequence', 'context']], ['activation_embed_moe', ['tensor', 'tensor_transpose']], ['activation_mlp_moe', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_exp', ['expert']], ['exp', 'expert'], ['mlp_moe', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']], ['embed_moe', ['fsdp', 'fsdp_transpose', 'tensor_transpose', 'context']], ['embed_moe', ['fsdp', 'tensor_transpose', 'context']], ['embed_moe', ['fsdp', 'fsdp_transpose', 'context']], ['embed_moe', ['fsdp', 'context']], ['activation_mlp', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_length', ['context']], ['activation_norm_length', ['tensor_sequence', 'context']], ['activation_embed', ['tensor', 'tensor_transpose']], ['activation_stage', 'stage'], ['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']], ['embed', ['fsdp', 'fsdp_transpose', 'tensor_transpose', 'context', 'expert']], ['embed', ['fsdp', 'tensor_transpose', 'context', 'expert']], ['embed', ['fsdp', 'fsdp_transpose', 'context', 'expert']], ['embed', ['fsdp', 'context', 'expert']], ['norm', ['tensor', 'tensor_transpose']], ['layers', 'stage'], ['diloco', 'diloco'], ['engram_dim', ['tensor']], ['dense_layers', []], ['moe_layers', []], ['mhc', []], ['prefill_activation_length', ['context']], ['prefill_activation_norm_length', ['tensor_sequence', 'context']], ['activation_prefill_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['decode_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['decode_length', []], ['cache_heads', ['autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence']], ['cache_heads', ['autoregressive', 'tensor', 'tensor_sequence']], ['paged_kv_heads', ['tensor']], ['cache_batch_prefill', []], ['cache_batch', []], ['cache_heads_none', []], ['cache_kv', []], ['cache_sequence', []], ['num_pages', []], ['tokens_per_page', []], ['paged_kv_head_dim_size', []], ['mlp_no_fsdp', ['tensor', 'tensor_sequence', 'autoregressive']], ['embed_tensor_transpose', ['tensor_transpose']], ['exp_with_fsdp', 'fsdp']], data_sharding=[['data', 'stage', 'fsdp', 'fsdp_transpose', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']], context_sharding='context', input_data_sharding_logical_axes=['activation_embed_and_logits_batch', 'activation_norm_length'], sharding_tolerance=0.02, shard_optimizer_over_data=False, internal_compile=False, internal_compile_num_devices=-1, compile_xla_flags='', hardware='tpu', num_slices=1, mesh_axes=['diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive'], shard_mode=<ShardMode.AUTO: 'auto'>, inhomogeneous_layer_cycle_interval=1, scan_layers=True, param_scan_axis=1, context_parallel_load_balance=True, context_parallel_strategy='all_gather', context_parallel_reorder_strategy=<ReorderStrategy.AUTO: 'auto'>, custom_mesh='', custom_mesh_and_rule='', allow_split_physical_axes=False, enable_nnx=True, optimize_mesh_for_tpu_v6e=False, shardy=True, pure_nnx_decoder=True, pure_nnx=True, remove_size_one_mesh_axis_from_type=True, gdn_conv_kernel_dim=4, gdn_key_head_dim=128, gdn_value_head_dim=128, gdn_num_key_heads=16, gdn_num_value_heads=32, gdn_chunk_size=64, use_qk_norm_in_gdn=True, partial_rotary_factor=1.0, first_num_dense_layers=0, shared_experts=0, routed_scaling_factor=1.0, routed_score_func='', routed_bias=False, routed_bias_update_rate=0.0, mlp_bias=False, n_routing_groups=-1, topk_routing_group=-1, use_batch_split_schedule=False, batch_split_factor=1, megablox=True, sparse_matmul=True, wi_tile_fwd_batch_seq=512, wi_tile_fwd_embed_dim=1024, wi_tile_fwd_mlp_dim=1024, wi_tile_dlhs_batch_seq=512, wi_tile_dlhs_embed_dim=1024, wi_tile_dlhs_mlp_dim=1024, wi_tile_drhs_batch_seq=512, wi_tile_drhs_embed_dim=1024, wi_tile_drhs_mlp_dim=1024, wo_tile_fwd_batch_seq=512, wo_tile_fwd_embed_dim=1024, wo_tile_fwd_mlp_dim=1024, wo_tile_dlhs_batch_seq=512, wo_tile_dlhs_embed_dim=1024, wo_tile_dlhs_mlp_dim=1024, wo_tile_drhs_batch_seq=512, wo_tile_drhs_embed_dim=1024, wo_tile_drhs_mlp_dim=1024, merge_gating_gmm=False, num_experts=1, num_experts_per_tok=1, capacity_factor=-1.0, ragged_buffer_factor=-1.0, moe_expert_input_dim=-1, base_moe_mlp_dim=-1, load_balance_loss_weight=0.0, use_custom_sort_vjp=True, use_ring_of_experts=False, use_gather_mosaic_kernel=False, use_random_routing=False, interleave_moe_layer_step=1, moe_fsdp_use_two_stage_all_gather=False, shard_exp_on_fsdp=False, use_2d_fsdp_sharding=False, norm_topk_prob=False, float32_weight_sum=True, float32_gate_logits=False, prefuse_moe_weights=False, pagedattn_num_pages=64, pagedattn_tokens_per_page=32, pagedattn_pages_per_compute_block=4, pagedattn_max_pages_per_group=-1, pagedattn_head_dim_alignment=128, sa_block_q=512, sa_block_kv=512, sa_block_kv_compute=512, sa_block_q_dkv=512, sa_block_kv_dkv=512, sa_block_kv_dkv_compute=512, sa_block_q_dq=512, sa_block_kv_dq=512, sa_use_fused_bwd_kernel=False, sa_q_layout='HEAD_DIM_MINOR', sa_k_layout='HEAD_DIM_MINOR', sa_v_layout='HEAD_DIM_MINOR', use_max_logit_estimate=-1, cost_estimate_flops_fwd=-1, cost_estimate_flops_bwd=-1, dq_reduction_steps=0, use_splash_scheduler=False, use_qk_norm=False, temperature_tuning=False, use_indexer=False, indexer_head_dim=128, indexer_n_heads=64, indexer_topk=2048, indexer_sparse_training=False, indexer_loss_scaling_factor=0.0, moba=False, moba_chunk_size=1024, moba_topk=8, mla_naive_kvcache=True, q_lora_rank=0, kv_lora_rank=512, qk_nope_head_dim=128, qk_rope_head_dim=64, v_head_dim=128, attention='dot_product', attention_type='global', share_kv_projections=False, global_num_kv_heads=0, attention_sink=False, float32_qk_product=False, float32_logits=False, sliding_window_size=0, chunk_attn_window_size=0, attn_logits_soft_cap=None, use_post_attn_norm=False, use_post_ffw_norm=False, use_ragged_attention=False, use_tokamax_gmm=False, ragged_block_size=256, enable_padding_causal_mask=True, use_tokamax_splash=False, use_jax_splash=False, force_q_layout=False, use_qk_clip=False, qk_clip_threshold=100.0, logits_via_embedding=False, normalize_embedding_logits=True, logits_dot_in_fp32=False, cast_logits_to_fp32=True, final_logits_soft_cap=None, z_loss_multiplier=0.0, mtp_num_layers=0, mtp_loss_scaling_factor=0.1, mtp_eval_target_module=0, engram_layers=[], engram_num_heads=8, engram_head_dim=1280, engram_vocab_bases=[], engram_max_ngram_size=3, engram_kernel_size=4, engram_seed=0, decoder_block=<DecoderBlockType.LLAMA2: 'llama2'>, global_parameter_scale=1, base_emb_dim=4096, base_num_query_heads=32, base_num_kv_heads=8, base_mlp_dim=14336, dense_init_scale=1.0, base_num_decoder_layers=32, head_dim=128, attention_output_dim=-1, global_head_dim=0, mlp_activations=['silu', 'linear'], mlp_activations_limit=-1.0, normalization_layer_epsilon=1e-05, fused_qkv=False, attention_bias=False, fused_mlp=False, qk_norm_with_scale=True, v_norm_with_scale=True, quantization=<QuantizationType.NONE: ''>, replicate_quant_scale=False, quant_cfg_path='', quantize_kvcache=False, kv_quant_axis=<KvQuantAxis.HEADS_AND_DKV: 'heads_and_dkv'>, kv_quant_dtype='int8', quantization_local_shard_count=4, use_qwix_quantization=False, weight_quantization_calibration_method='absmax', act_quantization_calibration_method='absmax', bwd_quantization_calibration_method='absmax', dtype=<DType.BFLOAT16: 'bfloat16'>, grad_dtype=<DType.FLOAT32: 'float32'>, weight_dtype=<DType.BFLOAT16: 'bfloat16'>, matmul_precision=<MatmulPrecision.DEFAULT: 'default'>, activations_in_float32=False, dtype_mm='float32', elastic_enabled=False, elastic_timeout_seconds=300, elastic_max_retries=10, enable_multi_tier_checkpointing=False, local_checkpoint_directory='', local_checkpoint_period=0, multi_tier_checkpointing_backup_interval_minutes=0, mtc_data_parallelism=0, enable_emergency_checkpoint=False, use_replicator_service=False, replicator_backup_interval_minutes=0, checkpoint_storage_target_data_file_size_bytes=2147483648, checkpoint_storage_use_ocdbt=False, checkpoint_storage_use_zarr3=False, checkpoint_storage_concurrent_gb=96, load_parameters_path='gs://lance-maxtext/rl_ckpt_llama31_8b/rl_ckpt_llama31_8b/0/items', lora_input_adapters_path='', load_full_state_path='', enable_checkpointing=True, load_checkpoint_only_once=False, async_checkpointing=False, checkpoint_period=50, max_num_checkpoints_to_keep=10, enable_single_replica_ckpt_restoring=False, checkpoint_todelete_subdir=None, checkpoint_todelete_full_path=None, force_unroll=False, checkpoint_is_quantized=False, save_quantized_params_path='', enable_orbax_v1=False, checkpoint_conversion_fn=None, source_checkpoint_layout='orbax', save_checkpoint_on_completion=True, enable_continuous_checkpointing=False, colocated_python_checkpointing=False, enable_autocheckpoint=False, base_config='../base.yml', run_name='pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional', model_name='llama3.1-8b-Instruct', override_model_config=False, override_logical_axis_rules=False, log_config=False, debug_sharding=False, base_output_directory='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237', sharding_strategy=None, debug=Debug(rl=True), rl=RL(num_generations=2, num_iterations=1, grpo_beta=0.08, grpo_epsilon=0.2, loss_algo='grpo', use_agentic_rollout=False, max_concurrency=256, off_policy_steps=0, system_prompt='', degenerate_group_masking=True, epsilon_high=None)), decoder=NNXDecoder( # Param: 7,504,924,672 (15.0 GB), RngState: 588 (3.5 KB), Total: 7,504,925,260 (15.0 GB) config=MaxTextConfig(emb_dim=4096, mlp_dim=14336, moe_mlp_dim=-1, num_decoder_layers=32, num_kv_heads=8, num_query_heads=32, num_diloco_replicas=1, ici_parallelism=[1, 1, 1, -1, 1, 1, 1, 1, 1, 1, 1, 1], dcn_parallelism=[1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], using_pipeline_parallelism=False, context_parallel_size=1, num_target_devices=32, global_batch_size_to_train_on=384, global_batch_size_to_eval_on=384, global_batch_size_to_load=384, global_batch_size_to_load_eval=384, micro_batch_size_to_train_on=384, micro_batch_size_to_eval_on=384, checkpoint_dir='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237/pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional/checkpoints/', convert_checkpoint_if_possible=True, metrics_dir='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237/pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional/metrics/', tensorboard_dir='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237/pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional/tensorboard/', managed_mldiagnostics_dir='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237/pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional/managed-mldiagnostics', rampup_end_step=0, tensors_on_device=[], tensors_to_offload=['decoder_layer_input', 'query_proj', 'key_proj', 'value_proj'], global_batch_size_to_load_start=None, global_batch_size_to_load_increment=None, rampup_samples_per_increment_to_load=None, d_model_for_audio=256, encoder_attention_heads_for_audio=4, encoder_ffn_dim_for_audio=512, encoder_layers_for_audio=2, attention_dropout_for_audio=0.0, activation_dropout_for_audio=0.0, activation_function_for_audio='gelu', num_mel_bins_for_audio=128, max_source_positions_for_audio=1500, scale_embedding_for_audio=True, n_window_for_audio=50, n_window_infer_for_audio=800, conv_chunksize_for_audio=500, downsample_hidden_size_for_audio=256, output_dim_for_audio=512, num_conv_layers_for_audio=3, max_timescale_for_audio=10000.0, max_sample_len_for_audio=10000, projector_input_dim_for_vit=4096, projector_output_dim_for_vit=4096, pixel_shuffle_ratio_for_vit=0.5, projector_dropout_for_vit=0.0, hidden_size_for_vit=1408, intermediate_size_for_vit=5632, num_attention_heads_for_vit=16, num_channels_for_vit=3, tile_size_for_vit=336, patch_size_for_vit=14, conv_stride_for_vit=14, num_hidden_layers_for_vit=34, rope_theta_for_vit=10000, vision_output_dim_for_vit=4096, spatial_merge_size_for_vit=2, out_hidden_size_for_vit=512, temporal_patch_size_for_vit=2, num_position_embeddings_for_vit=1024, deepstack_visual_indexes_for_vit=[], vision_output_length=-1, use_multimodal=False, freeze_vision_encoder_params=True, freeze_audio_encoder_params=True, use_audio=False, image_size_for_vit=896, image_path='', image_placeholder='<|image|>', posemb_type_for_vit='learn', max_num_images_per_example=-1, video_path='', audio_path='', video_placeholder='<|video|>', audio_placeholder='<|audio|>', use_audio_in_video=False, use_mrope=False, mrope_section=[24, 20, 20], position_id_per_seconds=25, managed_mldiagnostics=False, managed_mldiagnostics_run_group='', enable_tensorboard=True, use_vertex_tensorboard=False, vertex_tensorboard_project='', vertex_tensorboard_region='', report_heartbeat_metric_for_gcp_monitoring=False, heartbeat_reporting_interval_in_seconds=5, report_performance_metric_for_gcp_monitoring=False, enable_goodput_recording=False, monitor_goodput=False, goodput_upload_interval_seconds=30, enable_pathways_goodput=False, monitor_step_time_deviation=True, step_deviation_interval_seconds=30, enable_gcp_goodput_metrics=True, enable_gcp_step_deviation_metrics=True, metrics_file='', gcs_metrics=False, save_config_to_gcs=False, record_internal_nn_metrics=0, prometheus_port=0, enable_checkpoint_cloud_logger=False, enable_tunix_perf_metrics=False, collect_stack_trace=False, stack_trace_to_cloud=False, stack_trace_interval_seconds=600, dump_hlo=False, dump_step=-1, dump_hlo_local_dir='/tmp/xla_dump/', dump_hlo_delete_local_after=True, dump_hlo_gcs_dir='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237/pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional/xla_dump', dump_hlo_module_name='jit_train_step', dump_hlo_local_module_name='jit_train_step', dump_hlo_xla_flags='--xla_dump_to=/tmp/xla_dump/ --xla_dump_large_constants --xla_dump_hlo_module_re=jit_train_step', dump_hlo_upload_all=False, dump_jaxpr=False, dump_jaxpr_local_dir='/tmp/jaxpr_dump/', dump_jaxpr_delete_local_after=True, dump_jaxpr_gcs_dir='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237/pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional/jaxpr_dump', profiler=<ProfilerType.XPLANE: 'xplane'>, upload_all_profiler_results=False, skip_first_n_steps_for_profiler=1, profiler_steps=5, profile_cleanly=True, profile_periodically_period=-1, hide_profiler_step_metric=False, enable_jax_profiler=False, jax_profiler_port=9999, xprof_tpu_power_trace_level=<XProfTPUPowerTraceMode.POWER_TRACE_NONE: 0>, xprof_e2e_enable_fw_throttle_event=False, xprof_e2e_enable_fw_power_level_event=False, xprof_e2e_enable_fw_thermal_event=False, profile_power_events=False, constant_bound_config=[], jax_cache_dir='~/jax_cache', jax_distributed_initialization_timeout=300, jax_debug_log_modules='', skip_jax_distributed_system=True, enable_single_controller=False, subslice_shape='', max_checkify=False, compiled_trainstep_file='', compile_topology='', compile_topology_num_slices=-1, enable_prefix_caching=False, prefix_caching_hbm_byte=10000000000, prefix_caching_dram_byte=100000000000, inference_microbenchmark_prefill_lengths='64,128,256,512,1024', inference_microbenchmark_stages='prefill,generate', inference_microbenchmark_loop_iters=10, inference_microbenchmark_log_file_path='', inference_microbenchmark_num_samples=[1, 2, 3, 4, 5], inference_metadata_file='', inference_benchmark_test=False, inference_server='MaxtextInterleavedServer', prefill_slice='v5e-16', generate_slice='v5e-16', stack_prefill_result_cache=False, prefill_cache_axis_order='1,2,0,3', ar_cache_axis_order='1,2,0,3', compute_axis_order='0,1,2,3', reshape_q=False, decode_sampling_strategy=<SamplingStrategy.GREEDY: 'greedy'>, decode_sampling_nucleus_p=1.0, decode_sampling_top_k=50, decode_sampling_temperature=0.9, max_target_length=1024, max_prefill_predict_length=256, prompt='I love to', load_from_prefill_dir=False, prefill_cache_dir='', autoregressive_decode_assert='', model_call_mode='', use_chunked_prefill=False, prefill_chunk_size=256, enable_model_warmup=False, enable_llm_inference_pool=False, multi_sampling=False, return_log_prob=False, vocab_size=128256, tokenizer_path='meta-llama/Llama-3.1-8B-Instruct', tokenizer_type=<TokenizerType.HUGGINGFACE: 'huggingface'>, use_chat_template=False, chat_template_path='maxtext/examples/chat_templates/gsm8k_rl.json', chat_template='', tokenize_train_data=True, tokenize_eval_data=True, add_bos=True, add_eos=True, use_truncation=True, num_vocab_tiling=1, grain_train_files='', grain_eval_files='', grain_train_mixture_config_path='', grain_file_type='arrayrecord', grain_worker_count=1, grain_per_worker_buffer_size=1, grain_worker_count_eval=1, grain_per_worker_buffer_size_eval=1, grain_ram_budget_mb=1024, grain_num_threads=16, grain_prefetch_buffer_size=500, grain_num_threads_eval=16, grain_prefetch_buffer_size_eval=500, grain_data_source_max_workers=16, grain_shuffle_buffer_size=100, hf_path='', hf_name='main', hf_data_dir=None, hf_train_files=None, hf_eval_split=None, hf_eval_files=None, hf_access_token='<redacted>', dataset_path='', dataset_name='openai/gsm8k', eval_dataset_name='openai/gsm8k', train_split='train', eval_split='test', dataset_type=<DatasetType.TFDS: 'tfds'>, per_device_batch_size=12.0, eval_per_device_batch_size=12.0, max_corpus_chars=10000000, train_data_columns=['text'], train_image_column='image', eval_data_columns=['text'], eval_image_column='image', packing=True, grain_packing_type='first_fit', max_segments_per_seq=-1, num_epoch=1, expansion_factor_real_data=-1.0, reuse_example_batch=0, generate_padding_batch_train=False, generate_padding_batch_eval=False, enable_rampup_batch_size=False, per_device_batch_size_start=4.0, per_device_batch_size_increment=2.0, global_rampup_samples=500, colocated_python_data_input=False, max_position_embeddings=163840, original_max_position_embeddings=4096, rope_factor=40, beta_fast=32, beta_slow=1, mscale=1.0, rope_interleave=True, rope_truncate=True, rope_attention_scaling=False, rope_type=<RopeType.DEFAULT: 'default'>, rope_use_scale=True, rope_min_timescale=1, rope_max_timescale=500000, rope_linear_scaling_factor=1.0, local_rope_max_timescale=-1, global_rope_max_timescale=-1, global_rope_proportion=0.25, local_rope_proportion=1.0, use_iota_embed=False, use_untrainable_positional_embedding=False, trainable_position_size=-1, nope_layer_interval=-1, reasoning_start_token='<reasoning>', reasoning_end_token='</reasoning>', solution_start_token='<answer>', solution_end_token='</answer>', reward_exact_answer=1.0, reward_exact_format_match=0.1, reward_white_space_format_match=1.0, reward_partial_format_match=0.0, reward_ratio_guess_to_answer_high=0.0, reward_ratio_guess_to_answer_low=0.0, penalty_incorrect_format=0.0, penalty_incorrect_answer=0.0, math_verify_timeout=120, math_verify_num_procs=None, eval_sampling_strategy='greedy', generation_configs={'greedy': {'eval_temperature': 0.01, 'eval_top_k': 1, 'eval_top_p': 1.0}, 'standard': {'eval_temperature': 0.7, 'eval_top_k': 50, 'eval_top_p': 0.95}, 'liberal': {'eval_temperature': 0.85, 'eval_top_k': 2000, 'eval_top_p': 1.0}}, num_eval_passes=1, eval_corr_lst=False, eval_make_lst=False, eval_mode='pass', batch_size=1, num_batches=2, num_test_batches=5, test_batch_start_index=0, train_fraction=1.0, train_micro_batch_size=-1, rollout_micro_batch_size=-1, num_generations=2, num_iterations=1, grpo_beta=0.08, grpo_epsilon=0.2, loss_algo='grpo', use_agentic_rollout=False, max_concurrency=256, off_policy_steps=0, system_prompt='', degenerate_group_masking=True, epsilon_high=None, kv_cache_buffer=256, hbm_utilization_vllm=0.72, swap_space_vllm_gb=2, enable_dp_attention=False, enable_expert_parallel=False, async_scheduling=False, max_num_batched_tokens=None, max_num_seqs=None, stop_strings=None, vllm_additional_config={}, vllm_hf_overrides={}, vllm_hf_config_path='', trainer_devices_fraction=0.5, sampler_devices_fraction=0.5, chips_per_vm=4, use_pathways=False, num_trainer_slices=-1, num_samplers_slices=-1, rollout_data_parallelism=2, rollout_tensor_parallelism=-1, rollout_expert_parallelism=1, student_overrides={}, teacher_overrides={}, offline_data_dir=None, distill_alpha=0.5, distill_temperature=1.0, distill_beta=0.0, distill_feature_loss_type='cosine', distill_layer_indices=None, distill_alpha_end=None, distill_alpha_schedule='constant', distill_temperature_end=None, distill_temperature_schedule='constant', distill_beta_end=None, distill_beta_schedule='constant', student_params_to_update=None, use_dpo=False, dpo_label_smoothing=0.0, dpo_beta=0.1, use_sft=False, sft_train_on_completion_only=False, formatting_func_path='', formatting_func_kwargs={}, use_grpo=True, muon_beta=0.95, muon_weight_decay=0.0, muon_consistent_rms=None, adam_b1=0.9, adam_b2=0.99, adam_eps=1e-08, adam_eps_root=0.0, adam_weight_decay=0.1, adamw_mask=[], mu_dtype=<DType.BFLOAT16: 'bfloat16'>, opt_type=<OptimizerType.ADAMW: 'adamw'>, skip_step_on_spikes=False, skip_step_interval=128, skip_step_scaling_factor=6.0, gradient_accumulation_steps=1, use_tunix_gradient_accumulation=False, gradient_clipping_threshold=0.1, learning_rate=3e-06, lr_schedule_type=<LearningRateScheduleType.COSINE: 'cosine'>, learning_rate_final_fraction=0.1, wsd_decay_steps_fraction=0.1, wsd_decay_style=<WsdDecayStyle.LINEAR: 'linear'>, warmup_steps_fraction=0.1, learning_rate_schedule_steps=150001, trainable_parameters_mask=[], enable_diloco=False, diloco_sync_period=36, diloco_outer_lr=0.3, diloco_outer_momentum=0.9, mhc_expansion_rate=1, sinkhorn_iterations=20, steps=150001, log_period=20, eval_interval=10, eval_steps=-1, target_eval_loss=0.0, abort_on_nan_loss=True, abort_on_inf_loss=True, enable_dropout=False, dropout_rate=0.0, enable_data_shuffling=True, data_shuffle_seed=42, init_weights_seed=0, remat_policy='custom', remat_policy_for_vit='minimal', decoder_layer_input=<RematLocation.OFFLOAD: 'offload'>, context=<RematLocation.REMAT: 'remat'>, mlpwi=<RematLocation.REMAT: 'remat'>, mlpwi_0=<RematLocation.REMAT: 'remat'>, mlpwi_1=<RematLocation.REMAT: 'remat'>, mlpwo=<RematLocation.REMAT: 'remat'>, moe_mlpwi_0=<RematLocation.REMAT: 'remat'>, moe_mlpwi_1=<RematLocation.REMAT: 'remat'>, moe_mlpwo=<RematLocation.REMAT: 'remat'>, query_proj=<RematLocation.OFFLOAD: 'offload'>, key_proj=<RematLocation.OFFLOAD: 'offload'>, value_proj=<RematLocation.OFFLOAD: 'offload'>, query_wa_proj=<RematLocation.REMAT: 'remat'>, kv_wa_proj=<RematLocation.REMAT: 'remat'>, qkv_proj=<RematLocation.REMAT: 'remat'>, out_proj=<RematLocation.REMAT: 'remat'>, mla_q=<RematLocation.REMAT: 'remat'>, mla_kv=<RematLocation.REMAT: 'remat'>, attention_out=<RematLocation.REMAT: 'remat'>, engram=<RematLocation.REMAT: 'remat'>, optimizer_memory_host_offload=False, parameter_memory_host_offload=False, pipeline_fsdp_ag_per_repeat=False, num_layers_per_pipeline_stage=1, num_pipeline_repeats=-1, pipeline_parallel_layers=32, num_pipeline_microbatches=-1, pipeline_delay_activation_forwarding=False, pipeline_fsdp_ag_once=False, scan_pipeline_iterations=True, scan_pipeline_repeats=False, scan_layers_per_stage=False, set_remat_policy_on_pipeline_iterations=True, set_remat_policy_on_layers_per_stage=False, ici_diloco_parallelism=1, ici_data_parallelism=1, ici_fsdp_parallelism=-1, ici_fsdp_transpose_parallelism=1, ici_sequence_parallelism=1, ici_context_parallelism=1, ici_context_autoregressive_parallelism=1, ici_tensor_parallelism=1, ici_tensor_transpose_parallelism=1, ici_tensor_sequence_parallelism=1, ici_autoregressive_parallelism=1, ici_pipeline_parallelism=1, ici_expert_parallelism=1, dcn_diloco_parallelism=1, dcn_data_parallelism=-1, dcn_fsdp_parallelism=1, dcn_fsdp_transpose_parallelism=1, dcn_sequence_parallelism=1, dcn_context_parallelism=1, dcn_context_autoregressive_parallelism=1, dcn_tensor_parallelism=1, dcn_tensor_transpose_parallelism=1, dcn_tensor_sequence_parallelism=1, dcn_pipeline_parallelism=1, dcn_expert_parallelism=1, dcn_autoregressive_parallelism=1, logical_axis_rules=[['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'context', 'expert']], ['activation_vocab', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_vocab', ['tensor', 'tensor_transpose']], ['activation_vocab', 'tensor_sequence'], ['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['embed_vocab', ['fsdp', 'fsdp_transpose', 'context', 'expert']], ['activation_batch_attn', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['activation_kv_heads', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_length_attn', ['context']], ['activation_q_length', ['context']], ['activation_kv_length', []], ['activation_embed_attn', ['tensor', 'tensor_transpose']], ['activation_kv', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_kv_head_dim', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['q_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['kv_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['qkv', []], ['kv', []], ['kv_head_dim', []], ['q_lora', ['fsdp', 'fsdp_transpose', 'context', 'tensor_transpose', 'expert']], ['q_lora', ['fsdp', 'context', 'tensor_transpose', 'expert']], ['q_lora', ['fsdp', 'fsdp_transpose', 'context', 'expert']], ['q_lora', ['fsdp', 'context', 'expert']], ['q_lora_up_proj', []], ['kv_lora', ['fsdp', 'fsdp_transpose', 'context', 'tensor_transpose', 'expert']], ['kv_lora', ['fsdp', 'context', 'tensor_transpose', 'expert']], ['kv_lora', ['fsdp', 'fsdp_transpose', 'context', 'expert']], ['kv_lora', ['fsdp', 'context', 'expert']], ['kv_lora_up_proj', []], ['activation_batch_moe', ['data', 'fsdp', 'fsdp_transpose']], ['activation_length_moe', ['context']], ['activation_norm_length_moe', ['tensor_sequence', 'context']], ['activation_embed_moe', ['tensor', 'tensor_transpose']], ['activation_mlp_moe', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_exp', ['expert']], ['exp', 'expert'], ['mlp_moe', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']], ['embed_moe', ['fsdp', 'fsdp_transpose', 'tensor_transpose', 'context']], ['embed_moe', ['fsdp', 'tensor_transpose', 'context']], ['embed_moe', ['fsdp', 'fsdp_transpose', 'context']], ['embed_moe', ['fsdp', 'context']], ['activation_mlp', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_length', ['context']], ['activation_norm_length', ['tensor_sequence', 'context']], ['activation_embed', ['tensor', 'tensor_transpose']], ['activation_stage', 'stage'], ['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']], ['embed', ['fsdp', 'fsdp_transpose', 'tensor_transpose', 'context', 'expert']], ['embed', ['fsdp', 'tensor_transpose', 'context', 'expert']], ['embed', ['fsdp', 'fsdp_transpose', 'context', 'expert']], ['embed', ['fsdp', 'context', 'expert']], ['norm', ['tensor', 'tensor_transpose']], ['layers', 'stage'], ['diloco', 'diloco'], ['engram_dim', ['tensor']], ['dense_layers', []], ['moe_layers', []], ['mhc', []], ['prefill_activation_length', ['context']], ['prefill_activation_norm_length', ['tensor_sequence', 'context']], ['activation_prefill_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['decode_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['decode_length', []], ['cache_heads', ['autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence']], ['cache_heads', ['autoregressive', 'tensor', 'tensor_sequence']], ['paged_kv_heads', ['tensor']], ['cache_batch_prefill', []], ['cache_batch', []], ['cache_heads_none', []], ['cache_kv', []], ['cache_sequence', []], ['num_pages', []], ['tokens_per_page', []], ['paged_kv_head_dim_size', []], ['mlp_no_fsdp', ['tensor', 'tensor_sequence', 'autoregressive']], ['embed_tensor_transpose', ['tensor_transpose']], ['exp_with_fsdp', 'fsdp']], data_sharding=[['data', 'stage', 'fsdp', 'fsdp_transpose', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']], context_sharding='context', input_data_sharding_logical_axes=['activation_embed_and_logits_batch', 'activation_norm_length'], sharding_tolerance=0.02, shard_optimizer_over_data=False, internal_compile=False, internal_compile_num_devices=-1, compile_xla_flags='', hardware='tpu', num_slices=1, mesh_axes=['diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive'], shard_mode=<ShardMode.AUTO: 'auto'>, inhomogeneous_layer_cycle_interval=1, scan_layers=True, param_scan_axis=1, context_parallel_load_balance=True, context_parallel_strategy='all_gather', context_parallel_reorder_strategy=<ReorderStrategy.AUTO: 'auto'>, custom_mesh='', custom_mesh_and_rule='', allow_split_physical_axes=False, enable_nnx=True, optimize_mesh_for_tpu_v6e=False, shardy=True, pure_nnx_decoder=True, pure_nnx=True, remove_size_one_mesh_axis_from_type=True, gdn_conv_kernel_dim=4, gdn_key_head_dim=128, gdn_value_head_dim=128, gdn_num_key_heads=16, gdn_num_value_heads=32, gdn_chunk_size=64, use_qk_norm_in_gdn=True, partial_rotary_factor=1.0, first_num_dense_layers=0, shared_experts=0, routed_scaling_factor=1.0, routed_score_func='', routed_bias=False, routed_bias_update_rate=0.0, mlp_bias=False, n_routing_groups=-1, topk_routing_group=-1, use_batch_split_schedule=False, batch_split_factor=1, megablox=True, sparse_matmul=True, wi_tile_fwd_batch_seq=512, wi_tile_fwd_embed_dim=1024, wi_tile_fwd_mlp_dim=1024, wi_tile_dlhs_batch_seq=512, wi_tile_dlhs_embed_dim=1024, wi_tile_dlhs_mlp_dim=1024, wi_tile_drhs_batch_seq=512, wi_tile_drhs_embed_dim=1024, wi_tile_drhs_mlp_dim=1024, wo_tile_fwd_batch_seq=512, wo_tile_fwd_embed_dim=1024, wo_tile_fwd_mlp_dim=1024, wo_tile_dlhs_batch_seq=512, wo_tile_dlhs_embed_dim=1024, wo_tile_dlhs_mlp_dim=1024, wo_tile_drhs_batch_seq=512, wo_tile_drhs_embed_dim=1024, wo_tile_drhs_mlp_dim=1024, merge_gating_gmm=False, num_experts=1, num_experts_per_tok=1, capacity_factor=-1.0, ragged_buffer_factor=-1.0, moe_expert_input_dim=-1, base_moe_mlp_dim=-1, load_balance_loss_weight=0.0, use_custom_sort_vjp=True, use_ring_of_experts=False, use_gather_mosaic_kernel=False, use_random_routing=False, interleave_moe_layer_step=1, moe_fsdp_use_two_stage_all_gather=False, shard_exp_on_fsdp=False, use_2d_fsdp_sharding=False, norm_topk_prob=False, float32_weight_sum=True, float32_gate_logits=False, prefuse_moe_weights=False, pagedattn_num_pages=64, pagedattn_tokens_per_page=32, pagedattn_pages_per_compute_block=4, pagedattn_max_pages_per_group=-1, pagedattn_head_dim_alignment=128, sa_block_q=512, sa_block_kv=512, sa_block_kv_compute=512, sa_block_q_dkv=512, sa_block_kv_dkv=512, sa_block_kv_dkv_compute=512, sa_block_q_dq=512, sa_block_kv_dq=512, sa_use_fused_bwd_kernel=False, sa_q_layout='HEAD_DIM_MINOR', sa_k_layout='HEAD_DIM_MINOR', sa_v_layout='HEAD_DIM_MINOR', use_max_logit_estimate=-1, cost_estimate_flops_fwd=-1, cost_estimate_flops_bwd=-1, dq_reduction_steps=0, use_splash_scheduler=False, use_qk_norm=False, temperature_tuning=False, use_indexer=False, indexer_head_dim=128, indexer_n_heads=64, indexer_topk=2048, indexer_sparse_training=False, indexer_loss_scaling_factor=0.0, moba=False, moba_chunk_size=1024, moba_topk=8, mla_naive_kvcache=True, q_lora_rank=0, kv_lora_rank=512, qk_nope_head_dim=128, qk_rope_head_dim=64, v_head_dim=128, attention='dot_product', attention_type='global', share_kv_projections=False, global_num_kv_heads=0, attention_sink=False, float32_qk_product=False, float32_logits=False, sliding_window_size=0, chunk_attn_window_size=0, attn_logits_soft_cap=None, use_post_attn_norm=False, use_post_ffw_norm=False, use_ragged_attention=False, use_tokamax_gmm=False, ragged_block_size=256, enable_padding_causal_mask=True, use_tokamax_splash=False, use_jax_splash=False, force_q_layout=False, use_qk_clip=False, qk_clip_threshold=100.0, logits_via_embedding=False, normalize_embedding_logits=True, logits_dot_in_fp32=False, cast_logits_to_fp32=True, final_logits_soft_cap=None, z_loss_multiplier=0.0, mtp_num_layers=0, mtp_loss_scaling_factor=0.1, mtp_eval_target_module=0, engram_layers=[], engram_num_heads=8, engram_head_dim=1280, engram_vocab_bases=[], engram_max_ngram_size=3, engram_kernel_size=4, engram_seed=0, decoder_block=<DecoderBlockType.LLAMA2: 'llama2'>, global_parameter_scale=1, base_emb_dim=4096, base_num_query_heads=32, base_num_kv_heads=8, base_mlp_dim=14336, dense_init_scale=1.0, base_num_decoder_layers=32, head_dim=128, attention_output_dim=-1, global_head_dim=0, mlp_activations=['silu', 'linear'], mlp_activations_limit=-1.0, normalization_layer_epsilon=1e-05, fused_qkv=False, attention_bias=False, fused_mlp=False, qk_norm_with_scale=True, v_norm_with_scale=True, quantization=<QuantizationType.NONE: ''>, replicate_quant_scale=False, quant_cfg_path='', quantize_kvcache=False, kv_quant_axis=<KvQuantAxis.HEADS_AND_DKV: 'heads_and_dkv'>, kv_quant_dtype='int8', quantization_local_shard_count=4, use_qwix_quantization=False, weight_quantization_calibration_method='absmax', act_quantization_calibration_method='absmax', bwd_quantization_calibration_method='absmax', dtype=<DType.BFLOAT16: 'bfloat16'>, grad_dtype=<DType.FLOAT32: 'float32'>, weight_dtype=<DType.BFLOAT16: 'bfloat16'>, matmul_precision=<MatmulPrecision.DEFAULT: 'default'>, activations_in_float32=False, dtype_mm='float32', elastic_enabled=False, elastic_timeout_seconds=300, elastic_max_retries=10, enable_multi_tier_checkpointing=False, local_checkpoint_directory='', local_checkpoint_period=0, multi_tier_checkpointing_backup_interval_minutes=0, mtc_data_parallelism=0, enable_emergency_checkpoint=False, use_replicator_service=False, replicator_backup_interval_minutes=0, checkpoint_storage_target_data_file_size_bytes=2147483648, checkpoint_storage_use_ocdbt=False, checkpoint_storage_use_zarr3=False, checkpoint_storage_concurrent_gb=96, load_parameters_path='gs://lance-maxtext/rl_ckpt_llama31_8b/rl_ckpt_llama31_8b/0/items', lora_input_adapters_path='', load_full_state_path='', enable_checkpointing=True, load_checkpoint_only_once=False, async_checkpointing=False, checkpoint_period=50, max_num_checkpoints_to_keep=10, enable_single_replica_ckpt_restoring=False, checkpoint_todelete_subdir=None, checkpoint_todelete_full_path=None, force_unroll=False, checkpoint_is_quantized=False, save_quantized_params_path='', enable_orbax_v1=False, checkpoint_conversion_fn=None, source_checkpoint_layout='orbax', save_checkpoint_on_completion=True, enable_continuous_checkpointing=False, colocated_python_checkpointing=False, enable_autocheckpoint=False, base_config='../base.yml', run_name='pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional', model_name='llama3.1-8b-Instruct', override_model_config=False, override_logical_axis_rules=False, log_config=False, debug_sharding=False, base_output_directory='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237', sharding_strategy=None, debug=Debug(rl=True), rl=RL(num_generations=2, num_iterations=1, grpo_beta=0.08, grpo_epsilon=0.2, loss_algo='grpo', use_agentic_rollout=False, max_concurrency=256, off_policy_steps=0, system_prompt='', degenerate_group_masking=True, epsilon_high=None)), decoder_norm=RMSNorm( # Param: 4,096 (8.2 KB) dtype=<DType.BFLOAT16: 'bfloat16'>, epsilon=1e-05, kernel_axes=('norm',), num_features=4096, parameter_memory_host_offload=False, scale=Param( # 4,096 (8.2 KB) value=Array(shape=(4096,), dtype=dtype(bfloat16)), eager_sharding=False, out_sharding=('norm',) ), scale_init=<function ones at 0x7dba2a2d6980>, scale_offset=0.0, shard_mode=<ShardMode.AUTO: 'auto'>, weight_dtype=<DType.BFLOAT16: 'bfloat16'>, with_scale=True ), dropout=Dropout( # RngState: 6 (36 B) I0424 07:37:54.664475 138242845357888 train_rl.py:580] Policy mesh shape: OrderedDict({'diloco': 1, 'data': 1, 'stage': 1, 'fsdp': 32, 'fsdp_transpose': 1, 'context': 1, 'context_autoregressive': 1, 'tensor': 1, 'tensor_transpose': 1, 'tensor_sequence': 1, 'expert': 1, 'autoregressive': 1}) broadcast_dims=(-2,), deterministic=False, rate=0.0, rng_collection='dropout', I0424 07:37:54.664538 138242845357888 train_rl.py:581] Rollout_mesh shape: OrderedDict({'diloco': 1, 'data': 1, 'stage': 1, 'fsdp': 32, 'fsdp_transpose': 1, 'context': 1, 'context_autoregressive': 1, 'tensor': 1, 'tensor_transpose': 1, 'tensor_sequence': 1, 'expert': 1, 'autoregressive': 1}) rngs=Rngs( # RngState: 6 (36 B) aqt=RngStream( # RngState: 2 (12 B) count=RngCount( # 1 (4 B) value=Array(0, dtype=uint32), eager_sharding=False, tag='aqt' ), key=RngKey( # 1 (8 B) value=Array((), dtype=key<fry>) overlaying: [2799984767 1105366846], eager_sharding=False, tag='aqt' ), tag='aqt' ), I0424 07:37:54.664583 138242845357888 _schedule.py:129] A polynomial schedule was set with a non-positive `transition_steps` value; this results in a constant schedule with value `init_value`. dropout=RngStream( # RngState: 2 (12 B) count=RngCount( # 1 (4 B) value=Array(0, dtype=uint32), eager_sharding=False, tag='dropout' ), key=RngKey( # 1 (8 B) value=Array((), dtype=key<fry>) overlaying: [346279018 360566543], eager_sharding=False, tag='dropout' ), tag='dropout' ), params=RngStream( # RngState: 2 (12 B) count=RngCount( # 1 (4 B) value=Array(0, dtype=uint32), eager_sharding=False, tag='params' ), key=RngKey( # 1 (8 B) value=Array((), dtype=key<fry>) overlaying: [2839387376 2467677468], eager_sharding=False, tag='params' ), tag='params' ) ) ), is_deepseek=False, is_gemma3=False, layers=LlamaDecoderLayer( # RngState: 576 (3.5 KB), Param: 6,979,584,000 (14.0 GB), Total: 6,979,584,576 (14.0 GB) activation_axis_names=('activation_batch', 'activation_norm_length', 'activation_embed'), config=MaxTextConfig(emb_dim=4096, mlp_dim=14336, moe_mlp_dim=-1, num_decoder_layers=32, num_kv_heads=8, num_query_heads=32, num_diloco_replicas=1, ici_parallelism=[1, 1, 1, -1, 1, 1, 1, 1, 1, 1, 1, 1], dcn_parallelism=[1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], using_pipeline_parallelism=False, context_parallel_size=1, num_target_devices=32, global_batch_size_to_train_on=384, global_batch_size_to_eval_on=384, global_batch_size_to_load=384, global_batch_size_to_load_eval=384, micro_batch_size_to_train_on=384, micro_batch_size_to_eval_on=384, checkpoint_dir='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237/pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional/checkpoints/', convert_checkpoint_if_possible=True, metrics_dir='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237/pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional/metrics/', tensorboard_dir='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237/pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional/tensorboard/', managed_mldiagnostics_dir='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237/pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional/managed-mldiagnostics', rampup_end_step=0, tensors_on_device=[], tensors_to_offload=['decoder_layer_input', 'query_proj', 'key_proj', 'value_proj'], global_batch_size_to_load_start=None, global_batch_size_to_load_increment=None, rampup_samples_per_increment_to_load=None, d_model_for_audio=256, encoder_attention_heads_for_audio=4, encoder_ffn_dim_for_audio=512, encoder_layers_for_audio=2, attention_dropout_for_audio=0.0, activation_dropout_for_audio=0.0, activation_function_for_audio='gelu', num_mel_bins_for_audio=128, max_source_positions_for_audio=1500, scale_embedding_for_audio=True, n_window_for_audio=50, n_window_infer_for_audio=800, conv_chunksize_for_audio=500, downsample_hidden_size_for_audio=256, output_dim_for_audio=512, num_conv_layers_for_audio=3, max_timescale_for_audio=10000.0, max_sample_len_for_audio=10000, projector_input_dim_for_vit=4096, projector_output_dim_for_vit=4096, pixel_shuffle_ratio_for_vit=0.5, projector_dropout_for_vit=0.0, hidden_size_for_vit=1408, intermediate_size_for_vit=5632, num_attention_heads_for_vit=16, num_channels_for_vit=3, tile_size_for_vit=336, patch_size_for_vit=14, conv_stride_for_vit=14, num_hidden_layers_for_vit=34, rope_theta_for_vit=10000, vision_output_dim_for_vit=4096, spatial_merge_size_for_vit=2, out_hidden_size_for_vit=512, temporal_patch_size_for_vit=2, num_position_embeddings_for_vit=1024, deepstack_visual_indexes_for_vit=[], vision_output_length=-1, use_multimodal=False, freeze_vision_encoder_params=True, freeze_audio_encoder_params=True, use_audio=False, image_size_for_vit=896, image_path='', image_placeholder='<|image|>', posemb_type_for_vit='learn', max_num_images_per_example=-1, video_path='', audio_path='', video_placeholder='<|video|>', audio_placeholder='<|audio|>', use_audio_in_video=False, use_mrope=False, mrope_section=[24, 20, 20], position_id_per_seconds=25, managed_mldiagnostics=False, managed_mldiagnostics_run_group='', enable_tensorboard=True, use_vertex_tensorboard=False, vertex_tensorboard_project='', vertex_tensorboard_region='', report_heartbeat_metric_for_gcp_monitoring=False, heartbeat_reporting_interval_in_seconds=5, report_performance_metric_for_gcp_monitoring=False, enable_goodput_recording=False, monitor_goodput=False, goodput_upload_interval_seconds=30, enable_pathways_goodput=False, monitor_step_time_deviation=True, step_deviation_interval_seconds=30, enable_gcp_goodput_metrics=True, enable_gcp_step_deviation_metrics=True, metrics_file='', gcs_metrics=False, save_config_to_gcs=False, record_internal_nn_metrics=0, prometheus_port=0, enable_checkpoint_cloud_logger=False, enable_tunix_perf_metrics=False, collect_stack_trace=False, stack_trace_to_cloud=False, stack_trace_interval_seconds=600, dump_hlo=False, dump_step=-1, dump_hlo_local_dir='/tmp/xla_dump/', dump_hlo_delete_local_after=True, dump_hlo_gcs_dir='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237/pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional/xla_dump', dump_hlo_module_name='jit_train_step', dump_hlo_local_module_name='jit_train_step', dump_hlo_xla_flags='--xla_dump_to=/tmp/xla_dump/ --xla_dump_large_constants --xla_dump_hlo_module_re=jit_train_step', dump_hlo_upload_all=False, dump_jaxpr=False, dump_jaxpr_local_dir='/tmp/jaxpr_dump/', dump_jaxpr_delete_local_after=True, dump_jaxpr_gcs_dir='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237/pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional/jaxpr_dump', profiler=<ProfilerType.XPLANE: 'xplane'>, upload_all_profiler_results=False, skip_first_n_steps_for_profiler=1, profiler_steps=5, profile_cleanly=True, profile_periodically_period=-1, hide_profiler_step_metric=False, enable_jax_profiler=False, jax_profiler_port=9999, xprof_tpu_power_trace_level=<XProfTPUPowerTraceMode.POWER_TRACE_NONE: 0>, xprof_e2e_enable_fw_throttle_event=False, xprof_e2e_enable_fw_power_level_event=False, xprof_e2e_enable_fw_thermal_event=False, profile_power_events=False, constant_bound_config=[], jax_cache_dir='~/jax_cache', jax_distributed_initialization_timeout=300, jax_debug_log_modules='', skip_jax_distributed_system=True, enable_single_controller=False, subslice_shape='', max_checkify=False, compiled_trainstep_file='', compile_topology='', compile_topology_num_slices=-1, enable_prefix_caching=False, prefix_caching_hbm_byte=10000000000, prefix_caching_dram_byte=100000000000, inference_microbenchmark_prefill_lengths='64,128,256,512,1024', inference_microbenchmark_stages='prefill,generate', inference_microbenchmark_loop_iters=10, inference_microbenchmark_log_file_path='', inference_microbenchmark_num_samples=[1, 2, 3, 4, 5], inference_metadata_file='', inference_benchmark_test=False, inference_server='MaxtextInterleavedServer', prefill_slice='v5e-16', generate_slice='v5e-16', stack_prefill_result_cache=False, prefill_cache_axis_order='1,2,0,3', ar_cache_axis_order='1,2,0,3', compute_axis_order='0,1,2,3', reshape_q=False, decode_sampling_strategy=<SamplingStrategy.GREEDY: 'greedy'>, decode_sampling_nucleus_p=1.0, decode_sampling_top_k=50, decode_sampling_temperature=0.9, max_target_length=1024, max_prefill_predict_length=256, prompt='I love to', load_from_prefill_dir=False, prefill_cache_dir='', autoregressive_decode_assert='', model_call_mode='', use_chunked_prefill=False, prefill_chunk_size=256, enable_model_warmup=False, enable_llm_inference_pool=False, multi_sampling=False, return_log_prob=False, vocab_size=128256, tokenizer_path='meta-llama/Llama-3.1-8B-Instruct', tokenizer_type=<TokenizerType.HUGGINGFACE: 'huggingface'>, use_chat_template=False, chat_template_path='maxtext/examples/chat_templates/gsm8k_rl.json', chat_template='', tokenize_train_data=True, tokenize_eval_data=True, add_bos=True, add_eos=True, use_truncation=True, num_vocab_tiling=1, grain_train_files='', grain_eval_files='', grain_train_mixture_config_path='', grain_file_type='arrayrecord', grain_worker_count=1, grain_per_worker_buffer_size=1, grain_worker_count_eval=1, grain_per_worker_buffer_size_eval=1, grain_ram_budget_mb=1024, grain_num_threads=16, grain_prefetch_buffer_size=500, grain_num_threads_eval=16, grain_prefetch_buffer_size_eval=500, grain_data_source_max_workers=16, grain_shuffle_buffer_size=100, hf_path='', hf_name='main', hf_data_dir=None, hf_train_files=None, hf_eval_split=None, hf_eval_files=None, hf_access_token='<redacted>', dataset_path='', dataset_name='openai/gsm8k', eval_dataset_name='openai/gsm8k', train_split='train', eval_split='test', dataset_type=<DatasetType.TFDS: 'tfds'>, per_device_batch_size=12.0, eval_per_device_batch_size=12.0, max_corpus_chars=10000000, train_data_columns=['text'], train_image_column='image', eval_data_columns=['text'], eval_image_column='image', packing=True, grain_packing_type='first_fit', max_segments_per_seq=-1, num_epoch=1, expansion_factor_real_data=-1.0, reuse_example_batch=0, generate_padding_batch_train=False, generate_padding_batch_eval=False, enable_rampup_batch_size=False, per_device_batch_size_start=4.0, per_device_batch_size_increment=2.0, global_rampup_samples=500, colocated_python_data_input=False, max_position_embeddings=163840, original_max_position_embeddings=4096, rope_factor=40, beta_fast=32, beta_slow=1, mscale=1.0, rope_interleave=True, rope_truncate=True, rope_attention_scaling=False, rope_type=<RopeType.DEFAULT: 'default'>, rope_use_scale=True, rope_min_timescale=1, rope_max_timescale=500000, rope_linear_scaling_factor=1.0, local_rope_max_timescale=-1, global_rope_max_timescale=-1, global_rope_proportion=0.25, local_rope_proportion=1.0, use_iota_embed=False, use_untrainable_positional_embedding=False, trainable_position_size=-1, nope_layer_interval=-1, reasoning_start_token='<reasoning>', reasoning_end_token='</reasoning>', solution_start_token='<answer>', solution_end_token='</answer>', reward_exact_answer=1.0, reward_exact_format_match=0.1, reward_white_space_format_match=1.0, reward_partial_format_match=0.0, reward_ratio_guess_to_answer_high=0.0, reward_ratio_guess_to_answer_low=0.0, penalty_incorrect_format=0.0, penalty_incorrect_answer=0.0, math_verify_timeout=120, math_verify_num_procs=None, eval_sampling_strategy='greedy', generation_configs={'greedy': {'eval_temperature': 0.01, 'eval_top_k': 1, 'eval_top_p': 1.0}, 'standard': {'eval_temperature': 0.7, 'eval_top_k': 50, 'eval_top_p': 0.95}, 'liberal': {'eval_temperature': 0.85, 'eval_top_k': 2000, 'eval_top_p': 1.0}}, num_eval_passes=1, eval_corr_lst=False, eval_make_lst=False, eval_mode='pass', batch_size=1, num_batches=2, num_test_batches=5, test_batch_start_index=0, train_fraction=1.0, train_micro_batch_size=-1, rollout_micro_batch_size=-1, num_generations=2, num_iterations=1, grpo_beta=0.08, grpo_epsilon=0.2, loss_algo='grpo', use_agentic_rollout=False, max_concurrency=256, off_policy_steps=0, system_prompt='', degenerate_group_masking=True, epsilon_high=None, kv_cache_buffer=256, hbm_utilization_vllm=0.72, swap_space_vllm_gb=2, enable_dp_attention=False, enable_expert_parallel=False, async_scheduling=False, max_num_batched_tokens=None, max_num_seqs=None, stop_strings=None, vllm_additional_config={}, vllm_hf_overrides={}, vllm_hf_config_path='', trainer_devices_fraction=0.5, sampler_devices_fraction=0.5, chips_per_vm=4, use_pathways=False, num_trainer_slices=-1, num_samplers_slices=-1, rollout_data_parallelism=2, rollout_tensor_parallelism=-1, rollout_expert_parallelism=1, student_overrides={}, teacher_overrides={}, offline_data_dir=None, distill_alpha=0.5, distill_temperature=1.0, distill_beta=0.0, distill_feature_loss_type='cosine', distill_layer_indices=None, distill_alpha_end=None, distill_alpha_schedule='constant', distill_temperature_end=None, distill_temperature_schedule='constant', distill_beta_end=None, distill_beta_schedule='constant', student_params_to_update=None, use_dpo=False, dpo_label_smoothing=0.0, dpo_beta=0.1, use_sft=False, sft_train_on_completion_only=False, formatting_func_path='', formatting_func_kwargs={}, use_grpo=True, muon_beta=0.95, muon_weight_decay=0.0, muon_consistent_rms=None, adam_b1=0.9, adam_b2=0.99, adam_eps=1e-08, adam_eps_root=0.0, adam_weight_decay=0.1, adamw_mask=[], mu_dtype=<DType.BFLOAT16: 'bfloat16'>, opt_type=<OptimizerType.ADAMW: 'adamw'>, skip_step_on_spikes=False, skip_step_interval=128, skip_step_scaling_factor=6.0, gradient_accumulation_steps=1, use_tunix_gradient_accumulation=False, gradient_clipping_threshold=0.1, learning_rate=3e-06, lr_schedule_type=<LearningRateScheduleType.COSINE: 'cosine'>, learning_rate_final_fraction=0.1, wsd_decay_steps_fraction=0.1, wsd_decay_style=<WsdDecayStyle.LINEAR: 'linear'>, warmup_steps_fraction=0.1, learning_rate_schedule_steps=150001, trainable_parameters_mask=[], enable_diloco=False, diloco_sync_period=36, diloco_outer_lr=0.3, diloco_outer_momentum=0.9, mhc_expansion_rate=1, sinkhorn_iterations=20, steps=150001, log_period=20, eval_interval=10, eval_steps=-1, target_eval_loss=0.0, abort_on_nan_loss=True, abort_on_inf_loss=True, enable_dropout=False, dropout_rate=0.0, enable_data_shuffling=True, data_shuffle_seed=42, init_weights_seed=0, remat_policy='custom', remat_policy_for_vit='minimal', decoder_layer_input=<RematLocation.OFFLOAD: 'offload'>, context=<RematLocation.REMAT: 'remat'>, mlpwi=<RematLocation.REMAT: 'remat'>, mlpwi_0=<RematLocation.REMAT: 'remat'>, mlpwi_1=<RematLocation.REMAT: 'remat'>, mlpwo=<RematLocation.REMAT: 'remat'>, moe_mlpwi_0=<RematLocation.REMAT: 'remat'>, moe_mlpwi_1=<RematLocation.REMAT: 'remat'>, moe_mlpwo=<RematLocation.REMAT: 'remat'>, query_proj=<RematLocation.OFFLOAD: 'offload'>, key_proj=<RematLocation.OFFLOAD: 'offload'>, value_proj=<RematLocation.OFFLOAD: 'offload'>, query_wa_proj=<RematLocation.REMAT: 'remat'>, kv_wa_proj=<RematLocation.REMAT: 'remat'>, qkv_proj=<RematLocation.REMAT: 'remat'>, out_proj=<RematLocation.REMAT: 'remat'>, mla_q=<RematLocation.REMAT: 'remat'>, mla_kv=<RematLocation.REMAT: 'remat'>, attention_out=<RematLocation.REMAT: 'remat'>, engram=<RematLocation.REMAT: 'remat'>, optimizer_memory_host_offload=False, parameter_memory_host_offload=False, pipeline_fsdp_ag_per_repeat=False, num_layers_per_pipeline_stage=1, num_pipeline_repeats=-1, pipeline_parallel_layers=32, num_pipeline_microbatches=-1, pipeline_delay_activation_forwarding=False, pipeline_fsdp_ag_once=False, scan_pipeline_iterations=True, scan_pipeline_repeats=False, scan_layers_per_stage=False, set_remat_policy_on_pipeline_iterations=True, set_remat_policy_on_layers_per_stage=False, ici_diloco_parallelism=1, ici_data_parallelism=1, ici_fsdp_parallelism=-1, ici_fsdp_transpose_parallelism=1, ici_sequence_parallelism=1, ici_context_parallelism=1, ici_context_autoregressive_parallelism=1, ici_tensor_parallelism=1, ici_tensor_transpose_parallelism=1, ici_tensor_sequence_parallelism=1, ici_autoregressive_parallelism=1, ici_pipeline_parallelism=1, ici_expert_parallelism=1, dcn_diloco_parallelism=1, dcn_data_parallelism=-1, dcn_fsdp_parallelism=1, dcn_fsdp_transpose_parallelism=1, dcn_sequence_parallelism=1, dcn_context_parallelism=1, dcn_context_autoregressive_parallelism=1, dcn_tensor_parallelism=1, dcn_tensor_transpose_parallelism=1, dcn_tensor_sequence_parallelism=1, dcn_pipeline_parallelism=1, dcn_expert_parallelism=1, dcn_autoregressive_parallelism=1, logical_axis_rules=[['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'context', 'expert']], ['activation_vocab', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_vocab', ['tensor', 'tensor_transpose']], ['activation_vocab', 'tensor_sequence'], ['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['embed_vocab', ['fsdp', 'fsdp_transpose', 'context', 'expert']], ['activation_batch_attn', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['activation_kv_heads', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_length_attn', ['context']], ['activation_q_length', ['context']], ['activation_kv_length', []], ['activation_embed_attn', ['tensor', 'tensor_transpose']], ['activation_kv', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_kv_head_dim', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['q_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['kv_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['qkv', []], ['kv', []], ['kv_head_dim', []], ['q_lora', ['fsdp', 'fsdp_transpose', 'context', 'tensor_transpose', 'expert']], ['q_lora', ['fsdp', 'context', 'tensor_transpose', 'expert']], ['q_lora', ['fsdp', 'fsdp_transpose', 'context', 'expert']], ['q_lora', ['fsdp', 'context', 'expert']], ['q_lora_up_proj', []], ['kv_lora', ['fsdp', 'fsdp_transpose', 'context', 'tensor_transpose', 'expert']], ['kv_lora', ['fsdp', 'context', 'tensor_transpose', 'expert']], ['kv_lora', ['fsdp', 'fsdp_transpose', 'context', 'expert']], ['kv_lora', ['fsdp', 'context', 'expert']], ['kv_lora_up_proj', []], ['activation_batch_moe', ['data', 'fsdp', 'fsdp_transpose']], ['activation_length_moe', ['context']], ['activation_norm_length_moe', ['tensor_sequence', 'context']], ['activation_embed_moe', ['tensor', 'tensor_transpose']], ['activation_mlp_moe', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_exp', ['expert']], ['exp', 'expert'], ['mlp_moe', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']], ['embed_moe', ['fsdp', 'fsdp_transpose', 'tensor_transpose', 'context']], ['embed_moe', ['fsdp', 'tensor_transpose', 'context']], ['embed_moe', ['fsdp', 'fsdp_transpose', 'context']], ['embed_moe', ['fsdp', 'context']], ['activation_mlp', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_length', ['context']], ['activation_norm_length', ['tensor_sequence', 'context']], ['activation_embed', ['tensor', 'tensor_transpose']], ['activation_stage', 'stage'], ['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']], ['embed', ['fsdp', 'fsdp_transpose', 'tensor_transpose', 'context', 'expert']], ['embed', ['fsdp', 'tensor_transpose', 'context', 'expert']], ['embed', ['fsdp', 'fsdp_transpose', 'context', 'expert']], ['embed', ['fsdp', 'context', 'expert']], ['norm', ['tensor', 'tensor_transpose']], ['layers', 'stage'], ['diloco', 'diloco'], ['engram_dim', ['tensor']], ['dense_layers', []], ['moe_layers', []], ['mhc', []], ['prefill_activation_length', ['context']], ['prefill_activation_norm_length', ['tensor_sequence', 'context']], ['activation_prefill_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['decode_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['decode_length', []], ['cache_heads', ['autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence']], ['cache_heads', ['autoregressive', 'tensor', 'tensor_sequence']], ['paged_kv_heads', ['tensor']], ['cache_batch_prefill', []], ['cache_batch', []], ['cache_heads_none', []], ['cache_kv', []], ['cache_sequence', []], ['num_pages', []], ['tokens_per_page', []], ['paged_kv_head_dim_size', []], ['mlp_no_fsdp', ['tensor', 'tensor_sequence', 'autoregressive']], ['embed_tensor_transpose', ['tensor_transpose']], ['exp_with_fsdp', 'fsdp']], data_sharding=[['data', 'stage', 'fsdp', 'fsdp_transpose', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']], context_sharding='context', input_data_sharding_logical_axes=['activation_embed_and_logits_batch', 'activation_norm_length'], sharding_tolerance=0.02, shard_optimizer_over_data=False, internal_compile=False, internal_compile_num_devices=-1, compile_xla_flags='', hardware='tpu', num_slices=1, mesh_axes=['diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive'], shard_mode=<ShardMode.AUTO: 'auto'>, inhomogeneous_layer_cycle_interval=1, scan_layers=True, param_scan_axis=1, context_parallel_load_balance=True, context_parallel_strategy='all_gather', context_parallel_reorder_strategy=<ReorderStrategy.AUTO: 'auto'>, custom_mesh='', custom_mesh_and_rule='', allow_split_physical_axes=False, enable_nnx=True, optimize_mesh_for_tpu_v6e=False, shardy=True, pure_nnx_decoder=True, pure_nnx=True, remove_size_one_mesh_axis_from_type=True, gdn_conv_kernel_dim=4, gdn_key_head_dim=128, gdn_value_head_dim=128, gdn_num_key_heads=16, gdn_num_value_heads=32, gdn_chunk_size=64, use_qk_norm_in_gdn=True, partial_rotary_factor=1.0, first_num_dense_layers=0, shared_experts=0, routed_scaling_factor=1.0, routed_score_func='', routed_bias=False, routed_bias_update_rate=0.0, mlp_bias=False, n_routing_groups=-1, topk_routing_group=-1, use_batch_split_schedule=False, batch_split_factor=1, megablox=True, sparse_matmul=True, wi_tile_fwd_batch_seq=512, wi_tile_fwd_embed_dim=1024, wi_tile_fwd_mlp_dim=1024, wi_tile_dlhs_batch_seq=512, wi_tile_dlhs_embed_dim=1024, wi_tile_dlhs_mlp_dim=1024, wi_tile_drhs_batch_seq=512, wi_tile_drhs_embed_dim=1024, wi_tile_drhs_mlp_dim=1024, wo_tile_fwd_batch_seq=512, wo_tile_fwd_embed_dim=1024, wo_tile_fwd_mlp_dim=1024, wo_tile_dlhs_batch_seq=512, wo_tile_dlhs_embed_dim=1024, wo_tile_dlhs_mlp_dim=1024, wo_tile_drhs_batch_seq=512, wo_tile_drhs_embed_dim=1024, wo_tile_drhs_mlp_dim=1024, merge_gating_gmm=False, num_experts=1, num_experts_per_tok=1, capacity_factor=-1.0, ragged_buffer_factor=-1.0, moe_expert_input_dim=-1, base_moe_mlp_dim=-1, load_balance_loss_weight=0.0, use_custom_sort_vjp=True, use_ring_of_experts=False, use_gather_mosaic_kernel=False, use_random_routing=False, interleave_moe_layer_step=1, moe_fsdp_use_two_stage_all_gather=False, shard_exp_on_fsdp=False, use_2d_fsdp_sharding=False, norm_topk_prob=False, float32_weight_sum=True, float32_gate_logits=False, prefuse_moe_weights=False, pagedattn_num_pages=64, pagedattn_tokens_per_page=32, pagedattn_pages_per_compute_block=4, pagedattn_max_pages_per_group=-1, pagedattn_head_dim_alignment=128, sa_block_q=512, sa_block_kv=512, sa_block_kv_compute=512, sa_block_q_dkv=512, sa_block_kv_dkv=512, sa_block_kv_dkv_compute=512, sa_block_q_dq=512, sa_block_kv_dq=512, sa_use_fused_bwd_kernel=False, sa_q_layout='HEAD_DIM_MINOR', sa_k_layout='HEAD_DIM_MINOR', sa_v_layout='HEAD_DIM_MINOR', use_max_logit_estimate=-1, cost_estimate_flops_fwd=-1, cost_estimate_flops_bwd=-1, dq_reduction_steps=0, use_splash_scheduler=False, use_qk_norm=False, temperature_tuning=False, use_indexer=False, indexer_head_dim=128, indexer_n_heads=64, indexer_topk=2048, indexer_sparse_training=False, indexer_loss_scaling_factor=0.0, moba=False, moba_chunk_size=1024, moba_topk=8, mla_naive_kvcache=True, q_lora_rank=0, kv_lora_rank=512, qk_nope_head_dim=128, qk_rope_head_dim=64, v_head_dim=128, attention='dot_product', attention_type='global', share_kv_projections=False, global_num_kv_heads=0, attention_sink=False, float32_qk_product=False, float32_logits=False, sliding_window_size=0, chunk_attn_window_size=0, attn_logits_soft_cap=None, use_post_attn_norm=False, use_post_ffw_norm=False, use_ragged_attention=False, use_tokamax_gmm=False, ragged_block_size=256, enable_padding_causal_mask=True, use_tokamax_splash=False, use_jax_splash=False, force_q_layout=False, use_qk_clip=False, qk_clip_threshold=100.0, logits_via_embedding=False, normalize_embedding_logits=True, logits_dot_in_fp32=False, cast_logits_to_fp32=True, final_logits_soft_cap=None, z_loss_multiplier=0.0, mtp_num_layers=0, mtp_loss_scaling_factor=0.1, mtp_eval_target_module=0, engram_layers=[], engram_num_heads=8, engram_head_dim=1280, engram_vocab_bases=[], engram_max_ngram_size=3, engram_kernel_size=4, engram_seed=0, decoder_block=<DecoderBlockType.LLAMA2: 'llama2'>, global_parameter_scale=1, base_emb_dim=4096, base_num_query_heads=32, base_num_kv_heads=8, base_mlp_dim=14336, dense_init_scale=1.0, base_num_decoder_layers=32, head_dim=128, attention_output_dim=-1, global_head_dim=0, mlp_activations=['silu', 'linear'], mlp_activations_limit=-1.0, normalization_layer_epsilon=1e-05, fused_qkv=False, attention_bias=False, fused_mlp=False, qk_norm_with_scale=True, v_norm_with_scale=True, quantization=<QuantizationType.NONE: ''>, replicate_quant_scale=False, quant_cfg_path='', quantize_kvcache=False, kv_quant_axis=<KvQuantAxis.HEADS_AND_DKV: 'heads_and_dkv'>, kv_quant_dtype='int8', quantization_local_shard_count=4, use_qwix_quantization=False, weight_quantization_calibration_method='absmax', act_quantization_calibration_method='absmax', bwd_quantization_calibration_method='absmax', dtype=<DType.BFLOAT16: 'bfloat16'>, grad_dtype=<DType.FLOAT32: 'float32'>, weight_dtype=<DType.BFLOAT16: 'bfloat16'>, matmul_precision=<MatmulPrecision.DEFAULT: 'default'>, activations_in_float32=False, dtype_mm='float32', elastic_enabled=False, elastic_timeout_seconds=300, elastic_max_retries=10, enable_multi_tier_checkpointing=False, local_checkpoint_directory='', local_checkpoint_period=0, multi_tier_checkpointing_backup_interval_minutes=0, mtc_data_parallelism=0, enable_emergency_checkpoint=False, use_replicator_service=False, replicator_backup_interval_minutes=0, checkpoint_storage_target_data_file_size_bytes=2147483648, checkpoint_storage_use_ocdbt=False, checkpoint_storage_use_zarr3=False, checkpoint_storage_concurrent_gb=96, load_parameters_path='gs://lance-maxtext/rl_ckpt_llama31_8b/rl_ckpt_llama31_8b/0/items', lora_input_adapters_path='', load_full_state_path='', enable_checkpointing=True, load_checkpoint_only_once=False, async_checkpointing=False, checkpoint_period=50, max_num_checkpoints_to_keep=10, enable_single_replica_ckpt_restoring=False, checkpoint_todelete_subdir=None, checkpoint_todelete_full_path=None, force_unroll=False, checkpoint_is_quantized=False, save_quantized_params_path='', enable_orbax_v1=False, checkpoint_conversion_fn=None, source_checkpoint_layout='orbax', save_checkpoint_on_completion=True, enable_continuous_checkpointing=False, colocated_python_checkpointing=False, enable_autocheckpoint=False, base_config='../base.yml', run_name='pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional', model_name='llama3.1-8b-Instruct', override_model_config=False, override_logical_axis_rules=False, log_config=False, debug_sharding=False, base_output_directory='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237', sharding_strategy=None, debug=Debug(rl=True), rl=RL(num_generations=2, num_iterations=1, grpo_beta=0.08, grpo_epsilon=0.2, loss_algo='grpo', use_agentic_rollout=False, max_concurrency=256, off_policy_steps=0, system_prompt='', degenerate_group_masking=True, epsilon_high=None)), dropout=Dropout( # RngState: 192 (1.2 KB) broadcast_dims=(-2,), deterministic=False, rate=0.0, rng_collection='dropout', rngs=Rngs( # RngState: 192 (1.2 KB) aqt=RngStream( # RngState: 64 (384 B) count=RngCount( # 32 (128 B) value=Array(shape=(32,), dtype=dtype('uint32')), eager_sharding=False, tag='aqt' ), key=RngKey( # 32 (256 B) value=Array(shape=(32,), dtype=key<fry>), eager_sharding=False, tag='aqt' ), tag='aqt' ), dropout=RngStream( # RngState: 64 (384 B) count=RngCount( # 32 (128 B) value=Array(shape=(32,), dtype=dtype('uint32')), eager_sharding=False, tag='dropout' ), key=RngKey( # 32 (256 B) value=Array(shape=(32,), dtype=key<fry>), eager_sharding=False, tag='dropout' ), tag='dropout' ), params=RngStream( # RngState: 64 (384 B) count=RngCount( # 32 (128 B) value=Array(shape=(32,), dtype=dtype('uint32')), eager_sharding=False, tag='params' ), key=RngKey( # 32 (256 B) value=Array(shape=(32,), dtype=key<fry>), eager_sharding=False, tag='params' ), tag='params' ) ) ), mesh=Mesh(axis_sizes=(1, 1, 1, 32, 1, 1, 1, 1, 1, 1, 1, 1), axis_names=('diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive'), axis_types=(Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto)), mlp=MlpBlock( # RngState: 192 (1.2 KB), Param: 5,637,144,576 (11.3 GB), Total: 5,637,144,768 (11.3 GB) activations=['silu', 'linear'], config=MaxTextConfig(emb_dim=4096, mlp_dim=14336, moe_mlp_dim=-1, num_decoder_layers=32, num_kv_heads=8, num_query_heads=32, num_diloco_replicas=1, ici_parallelism=[1, 1, 1, -1, 1, 1, 1, 1, 1, 1, 1, 1], dcn_parallelism=[1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], using_pipeline_parallelism=False, context_parallel_size=1, num_target_devices=32, global_batch_size_to_train_on=384, global_batch_size_to_eval_on=384, global_batch_size_to_load=384, global_batch_size_to_load_eval=384, micro_batch_size_to_train_on=384, micro_batch_size_to_eval_on=384, checkpoint_dir='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237/pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional/checkpoints/', convert_checkpoint_if_possible=True, metrics_dir='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237/pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional/metrics/', tensorboard_dir='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237/pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional/tensorboard/', managed_mldiagnostics_dir='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237/pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional/managed-mldiagnostics', rampup_end_step=0, tensors_on_device=[], tensors_to_offload=['decoder_layer_input', 'query_proj', 'key_proj', 'value_proj'], global_batch_size_to_load_start=None, global_batch_size_to_load_increment=None, rampup_samples_per_increment_to_load=None, d_model_for_audio=256, encoder_attention_heads_for_audio=4, encoder_ffn_dim_for_audio=512, encoder_layers_for_audio=2, attention_dropout_for_audio=0.0, activation_dropout_for_audio=0.0, activation_function_for_audio='gelu', num_mel_bins_for_audio=128, max_source_positions_for_audio=1500, scale_embedding_for_audio=True, n_window_for_audio=50, n_window_infer_for_audio=800, conv_chunksize_for_audio=500, downsample_hidden_size_for_audio=256, output_dim_for_audio=512, num_conv_layers_for_audio=3, max_timescale_for_audio=10000.0, max_sample_len_for_audio=10000, projector_input_dim_for_vit=4096, projector_output_dim_for_vit=4096, pixel_shuffle_ratio_for_vit=0.5, projector_dropout_for_vit=0.0, hidden_size_for_vit=1408, intermediate_size_for_vit=5632, num_attention_heads_for_vit=16, num_channels_for_vit=3, tile_size_for_vit=336, patch_size_for_vit=14, conv_stride_for_vit=14, num_hidden_layers_for_vit=34, rope_theta_for_vit=10000, vision_output_dim_for_vit=4096, spatial_merge_size_for_vit=2, out_hidden_size_for_vit=512, temporal_patch_size_for_vit=2, num_position_embeddings_for_vit=1024, deepstack_visual_indexes_for_vit=[], vision_output_length=-1, use_multimodal=False, freeze_vision_encoder_params=True, freeze_audio_encoder_params=True, use_audio=False, image_size_for_vit=896, image_path='', image_placeholder='<|image|>', posemb_type_for_vit='learn', max_num_images_per_example=-1, video_path='', audio_path='', video_placeholder='<|video|>', audio_placeholder='<|audio|>', use_audio_in_video=False, use_mrope=False, mrope_section=[24, 20, 20], position_id_per_seconds=25, managed_mldiagnostics=False, managed_mldiagnostics_run_group='', enable_tensorboard=True, use_vertex_tensorboard=False, vertex_tensorboard_project='', vertex_tensorboard_region='', report_heartbeat_metric_for_gcp_monitoring=False, heartbeat_reporting_interval_in_seconds=5, report_performance_metric_for_gcp_monitoring=False, enable_goodput_recording=False, monitor_goodput=False, goodput_upload_interval_seconds=30, enable_pathways_goodput=False, monitor_step_time_deviation=True, step_deviation_interval_seconds=30, enable_gcp_goodput_metrics=True, enable_gcp_step_deviation_metrics=True, metrics_file='', gcs_metrics=False, save_config_to_gcs=False, record_internal_nn_metrics=0, prometheus_port=0, enable_checkpoint_cloud_logger=False, enable_tunix_perf_metrics=False, collect_stack_trace=False, stack_trace_to_cloud=False, stack_trace_interval_seconds=600, dump_hlo=False, dump_step=-1, dump_hlo_local_dir='/tmp/xla_dump/', dump_hlo_delete_local_after=True, dump_hlo_gcs_dir='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237/pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional/xla_dump', dump_hlo_module_name='jit_train_step', dump_hlo_local_module_name='jit_train_step', dump_hlo_xla_flags='--xla_dump_to=/tmp/xla_dump/ --xla_dump_large_constants --xla_dump_hlo_module_re=jit_train_step', dump_hlo_upload_all=False, dump_jaxpr=False, dump_jaxpr_local_dir='/tmp/jaxpr_dump/', dump_jaxpr_delete_local_after=True, dump_jaxpr_gcs_dir='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237/pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional/jaxpr_dump', profiler=<ProfilerType.XPLANE: 'xplane'>, upload_all_profiler_results=False, skip_first_n_steps_for_profiler=1, profiler_steps=5, profile_cleanly=True, profile_periodically_period=-1, hide_profiler_step_metric=False, enable_jax_profiler=False, jax_profiler_port=9999, xprof_tpu_power_trace_level=<XProfTPUPowerTraceMode.POWER_TRACE_NONE: 0>, xprof_e2e_enable_fw_throttle_event=False, xprof_e2e_enable_fw_power_level_event=False, xprof_e2e_enable_fw_thermal_event=False, profile_power_events=False, constant_bound_config=[], jax_cache_dir='~/jax_cache', jax_distributed_initialization_timeout=300, jax_debug_log_modules='', skip_jax_distributed_system=True, enable_single_controller=False, subslice_shape='', max_checkify=False, compiled_trainstep_file='', compile_topology='', compile_topology_num_slices=-1, enable_prefix_caching=False, prefix_caching_hbm_byte=10000000000, prefix_caching_dram_byte=100000000000, inference_microbenchmark_prefill_lengths='64,128,256,512,1024', inference_microbenchmark_stages='prefill,generate', inference_microbenchmark_loop_iters=10, inference_microbenchmark_log_file_path='', inference_microbenchmark_num_samples=[1, 2, 3, 4, 5], inference_metadata_file='', inference_benchmark_test=False, inference_server='MaxtextInterleavedServer', prefill_slice='v5e-16', generate_slice='v5e-16', stack_prefill_result_cache=False, prefill_cache_axis_order='1,2,0,3', ar_cache_axis_order='1,2,0,3', compute_axis_order='0,1,2,3', reshape_q=False, decode_sampling_strategy=<SamplingStrategy.GREEDY: 'greedy'>, decode_sampling_nucleus_p=1.0, decode_sampling_top_k=50, decode_sampling_temperature=0.9, max_target_length=1024, max_prefill_predict_length=256, prompt='I love to', load_from_prefill_dir=False, prefill_cache_dir='', autoregressive_decode_assert='', model_call_mode='', use_chunked_prefill=False, prefill_chunk_size=256, enable_model_warmup=False, enable_llm_inference_pool=False, multi_sampling=False, return_log_prob=False, vocab_size=128256, tokenizer_path='meta-llama/Llama-3.1-8B-Instruct', tokenizer_type=<TokenizerType.HUGGINGFACE: 'huggingface'>, use_chat_template=False, chat_template_path='maxtext/examples/chat_templates/gsm8k_rl.json', chat_template='', tokenize_train_data=True, tokenize_eval_data=True, add_bos=True, add_eos=True, use_truncation=True, num_vocab_tiling=1, grain_train_files='', grain_eval_files='', grain_train_mixture_config_path='', grain_file_type='arrayrecord', grain_worker_count=1, grain_per_worker_buffer_size=1, grain_worker_count_eval=1, grain_per_worker_buffer_size_eval=1, grain_ram_budget_mb=1024, grain_num_threads=16, grain_prefetch_buffer_size=500, grain_num_threads_eval=16, grain_prefetch_buffer_size_eval=500, grain_data_source_max_workers=16, grain_shuffle_buffer_size=100, hf_path='', hf_name='main', hf_data_dir=None, hf_train_files=None, hf_eval_split=None, hf_eval_files=None, hf_access_token='<redacted>', dataset_path='', dataset_name='openai/gsm8k', eval_dataset_name='openai/gsm8k', train_split='train', eval_split='test', dataset_type=<DatasetType.TFDS: 'tfds'>, per_device_batch_size=12.0, eval_per_device_batch_size=12.0, max_corpus_chars=10000000, train_data_columns=['text'], train_image_column='image', eval_data_columns=['text'], eval_image_column='image', packing=True, grain_packing_type='first_fit', max_segments_per_seq=-1, num_epoch=1, expansion_factor_real_data=-1.0, reuse_example_batch=0, generate_padding_batch_train=False, generate_padding_batch_eval=False, enable_rampup_batch_size=False, per_device_batch_size_start=4.0, per_device_batch_size_increment=2.0, global_rampup_samples=500, colocated_python_data_input=False, max_position_embeddings=163840, original_max_position_embeddings=4096, rope_factor=40, beta_fast=32, beta_slow=1, mscale=1.0, rope_interleave=True, rope_truncate=True, rope_attention_scaling=False, rope_type=<RopeType.DEFAULT: 'default'>, rope_use_scale=True, rope_min_timescale=1, rope_max_timescale=500000, rope_linear_scaling_factor=1.0, local_rope_max_timescale=-1, global_rope_max_timescale=-1, global_rope_proportion=0.25, local_rope_proportion=1.0, use_iota_embed=False, use_untrainable_positional_embedding=False, trainable_position_size=-1, nope_layer_interval=-1, reasoning_start_token='<reasoning>', reasoning_end_token='</reasoning>', solution_start_token='<answer>', solution_end_token='</answer>', reward_exact_answer=1.0, reward_exact_format_match=0.1, reward_white_space_format_match=1.0, reward_partial_format_match=0.0, reward_ratio_guess_to_answer_high=0.0, reward_ratio_guess_to_answer_low=0.0, penalty_incorrect_format=0.0, penalty_incorrect_answer=0.0, math_verify_timeout=120, math_verify_num_procs=None, eval_sampling_strategy='greedy', generation_configs={'greedy': {'eval_temperature': 0.01, 'eval_top_k': 1, 'eval_top_p': 1.0}, 'standard': {'eval_temperature': 0.7, 'eval_top_k': 50, 'eval_top_p': 0.95}, 'liberal': {'eval_temperature': 0.85, 'eval_top_k': 2000, 'eval_top_p': 1.0}}, num_eval_passes=1, eval_corr_lst=False, eval_make_lst=False, eval_mode='pass', batch_size=1, num_batches=2, num_test_batches=5, test_batch_start_index=0, train_fraction=1.0, train_micro_batch_size=-1, rollout_micro_batch_size=-1, num_generations=2, num_iterations=1, grpo_beta=0.08, grpo_epsilon=0.2, loss_algo='grpo', use_agentic_rollout=False, max_concurrency=256, off_policy_steps=0, system_prompt='', degenerate_group_masking=True, epsilon_high=None, kv_cache_buffer=256, hbm_utilization_vllm=0.72, swap_space_vllm_gb=2, enable_dp_attention=False, enable_expert_parallel=False, async_scheduling=False, max_num_batched_tokens=None, max_num_seqs=None, stop_strings=None, vllm_additional_config={}, vllm_hf_overrides={}, vllm_hf_config_path='', trainer_devices_fraction=0.5, sampler_devices_fraction=0.5, chips_per_vm=4, use_pathways=False, num_trainer_slices=-1, num_samplers_slices=-1, rollout_data_parallelism=2, rollout_tensor_parallelism=-1, rollout_expert_parallelism=1, student_overrides={}, teacher_overrides={}, offline_data_dir=None, distill_alpha=0.5, distill_temperature=1.0, distill_beta=0.0, distill_feature_loss_type='cosine', distill_layer_indices=None, distill_alpha_end=None, distill_alpha_schedule='constant', distill_temperature_end=None, distill_temperature_schedule='constant', distill_beta_end=None, distill_beta_schedule='constant', student_params_to_update=None, use_dpo=False, dpo_label_smoothing=0.0, dpo_beta=0.1, use_sft=False, sft_train_on_completion_only=False, formatting_func_path='', formatting_func_kwargs={}, use_grpo=True, muon_beta=0.95, muon_weight_decay=0.0, muon_consistent_rms=None, adam_b1=0.9, adam_b2=0.99, adam_eps=1e-08, adam_eps_root=0.0, adam_weight_decay=0.1, adamw_mask=[], mu_dtype=<DType.BFLOAT16: 'bfloat16'>, opt_type=<OptimizerType.ADAMW: 'adamw'>, skip_step_on_spikes=False, skip_step_interval=128, skip_step_scaling_factor=6.0, gradient_accumulation_steps=1, use_tunix_gradient_accumulation=False, gradient_clipping_threshold=0.1, learning_rate=3e-06, lr_schedule_type=<LearningRateScheduleType.COSINE: 'cosine'>, learning_rate_final_fraction=0.1, wsd_decay_steps_fraction=0.1, wsd_decay_style=<WsdDecayStyle.LINEAR: 'linear'>, warmup_steps_fraction=0.1, learning_rate_schedule_steps=150001, trainable_parameters_mask=[], enable_diloco=False, diloco_sync_period=36, diloco_outer_lr=0.3, diloco_outer_momentum=0.9, mhc_expansion_rate=1, sinkhorn_iterations=20, steps=150001, log_period=20, eval_interval=10, eval_steps=-1, target_eval_loss=0.0, abort_on_nan_loss=True, abort_on_inf_loss=True, enable_dropout=False, dropout_rate=0.0, enable_data_shuffling=True, data_shuffle_seed=42, init_weights_seed=0, remat_policy='custom', remat_policy_for_vit='minimal', decoder_layer_input=<RematLocation.OFFLOAD: 'offload'>, context=<RematLocation.REMAT: 'remat'>, mlpwi=<RematLocation.REMAT: 'remat'>, mlpwi_0=<RematLocation.REMAT: 'remat'>, mlpwi_1=<RematLocation.REMAT: 'remat'>, mlpwo=<RematLocation.REMAT: 'remat'>, moe_mlpwi_0=<RematLocation.REMAT: 'remat'>, moe_mlpwi_1=<RematLocation.REMAT: 'remat'>, moe_mlpwo=<RematLocation.REMAT: 'remat'>, query_proj=<RematLocation.OFFLOAD: 'offload'>, key_proj=<RematLocation.OFFLOAD: 'offload'>, value_proj=<RematLocation.OFFLOAD: 'offload'>, query_wa_proj=<RematLocation.REMAT: 'remat'>, kv_wa_proj=<RematLocation.REMAT: 'remat'>, qkv_proj=<RematLocation.REMAT: 'remat'>, out_proj=<RematLocation.REMAT: 'remat'>, mla_q=<RematLocation.REMAT: 'remat'>, mla_kv=<RematLocation.REMAT: 'remat'>, attention_out=<RematLocation.REMAT: 'remat'>, engram=<RematLocation.REMAT: 'remat'>, optimizer_memory_host_offload=False, parameter_memory_host_offload=False, pipeline_fsdp_ag_per_repeat=False, num_layers_per_pipeline_stage=1, num_pipeline_repeats=-1, pipeline_parallel_layers=32, num_pipeline_microbatches=-1, pipeline_delay_activation_forwarding=False, pipeline_fsdp_ag_once=False, scan_pipeline_iterations=True, scan_pipeline_repeats=False, scan_layers_per_stage=False, set_remat_policy_on_pipeline_iterations=True, set_remat_policy_on_layers_per_stage=False, ici_diloco_parallelism=1, ici_data_parallelism=1, ici_fsdp_parallelism=-1, ici_fsdp_transpose_parallelism=1, ici_sequence_parallelism=1, ici_context_parallelism=1, ici_context_autoregressive_parallelism=1, ici_tensor_parallelism=1, ici_tensor_transpose_parallelism=1, ici_tensor_sequence_parallelism=1, ici_autoregressive_parallelism=1, ici_pipeline_parallelism=1, ici_expert_parallelism=1, dcn_diloco_parallelism=1, dcn_data_parallelism=-1, dcn_fsdp_parallelism=1, dcn_fsdp_transpose_parallelism=1, dcn_sequence_parallelism=1, dcn_context_parallelism=1, dcn_context_autoregressive_parallelism=1, dcn_tensor_parallelism=1, dcn_tensor_transpose_parallelism=1, dcn_tensor_sequence_parallelism=1, dcn_pipeline_parallelism=1, dcn_expert_parallelism=1, dcn_autoregressive_parallelism=1, logical_axis_rules=[['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'context', 'expert']], ['activation_vocab', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_vocab', ['tensor', 'tensor_transpose']], ['activation_vocab', 'tensor_sequence'], ['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['embed_vocab', ['fsdp', 'fsdp_transpose', 'context', 'expert']], ['activation_batch_attn', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['activation_kv_heads', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_length_attn', ['context']], ['activation_q_length', ['context']], ['activation_kv_length', []], ['activation_embed_attn', ['tensor', 'tensor_transpose']], ['activation_kv', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_kv_head_dim', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['q_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['kv_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['qkv', []], ['kv', []], ['kv_head_dim', []], ['q_lora', ['fsdp', 'fsdp_transpose', 'context', 'tensor_transpose', 'expert']], ['q_lora', ['fsdp', 'context', 'tensor_transpose', 'expert']], ['q_lora', ['fsdp', 'fsdp_transpose', 'context', 'expert']], ['q_lora', ['fsdp', 'context', 'expert']], ['q_lora_up_proj', []], ['kv_lora', ['fsdp', 'fsdp_transpose', 'context', 'tensor_transpose', 'expert']], ['kv_lora', ['fsdp', 'context', 'tensor_transpose', 'expert']], ['kv_lora', ['fsdp', 'fsdp_transpose', 'context', 'expert']], ['kv_lora', ['fsdp', 'context', 'expert']], ['kv_lora_up_proj', []], ['activation_batch_moe', ['data', 'fsdp', 'fsdp_transpose']], ['activation_length_moe', ['context']], ['activation_norm_length_moe', ['tensor_sequence', 'context']], ['activation_embed_moe', ['tensor', 'tensor_transpose']], ['activation_mlp_moe', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_exp', ['expert']], ['exp', 'expert'], ['mlp_moe', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']], ['embed_moe', ['fsdp', 'fsdp_transpose', 'tensor_transpose', 'context']], ['embed_moe', ['fsdp', 'tensor_transpose', 'context']], ['embed_moe', ['fsdp', 'fsdp_transpose', 'context']], ['embed_moe', ['fsdp', 'context']], ['activation_mlp', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_length', ['context']], ['activation_norm_length', ['tensor_sequence', 'context']], ['activation_embed', ['tensor', 'tensor_transpose']], ['activation_stage', 'stage'], ['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']], ['embed', ['fsdp', 'fsdp_transpose', 'tensor_transpose', 'context', 'expert']], ['embed', ['fsdp', 'tensor_transpose', 'context', 'expert']], ['embed', ['fsdp', 'fsdp_transpose', 'context', 'expert']], ['embed', ['fsdp', 'context', 'expert']], ['norm', ['tensor', 'tensor_transpose']], ['layers', 'stage'], ['diloco', 'diloco'], ['engram_dim', ['tensor']], ['dense_layers', []], ['moe_layers', []], ['mhc', []], ['prefill_activation_length', ['context']], ['prefill_activation_norm_length', ['tensor_sequence', 'context']], ['activation_prefill_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['decode_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['decode_length', []], ['cache_heads', ['autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence']], ['cache_heads', ['autoregressive', 'tensor', 'tensor_sequence']], ['paged_kv_heads', ['tensor']], ['cache_batch_prefill', []], ['cache_batch', []], ['cache_heads_none', []], ['cache_kv', []], ['cache_sequence', []], ['num_pages', []], ['tokens_per_page', []], ['paged_kv_head_dim_size', []], ['mlp_no_fsdp', ['tensor', 'tensor_sequence', 'autoregressive']], ['embed_tensor_transpose', ['tensor_transpose']], ['exp_with_fsdp', 'fsdp']], data_sharding=[['data', 'stage', 'fsdp', 'fsdp_transpose', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']], context_sharding='context', input_data_sharding_logical_axes=['activation_embed_and_logits_batch', 'activation_norm_length'], sharding_tolerance=0.02, shard_optimizer_over_data=False, internal_compile=False, internal_compile_num_devices=-1, compile_xla_flags='', hardware='tpu', num_slices=1, mesh_axes=['diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive'], shard_mode=<ShardMode.AUTO: 'auto'>, inhomogeneous_layer_cycle_interval=1, scan_layers=True, param_scan_axis=1, context_parallel_load_balance=True, context_parallel_strategy='all_gather', context_parallel_reorder_strategy=<ReorderStrategy.AUTO: 'auto'>, custom_mesh='', custom_mesh_and_rule='', allow_split_physical_axes=False, enable_nnx=True, optimize_mesh_for_tpu_v6e=False, shardy=True, pure_nnx_decoder=True, pure_nnx=True, remove_size_one_mesh_axis_from_type=True, gdn_conv_kernel_dim=4, gdn_key_head_dim=128, gdn_value_head_dim=128, gdn_num_key_heads=16, gdn_num_value_heads=32, gdn_chunk_size=64, use_qk_norm_in_gdn=True, partial_rotary_factor=1.0, first_num_dense_layers=0, shared_experts=0, routed_scaling_factor=1.0, routed_score_func='', routed_bias=False, routed_bias_update_rate=0.0, mlp_bias=False, n_routing_groups=-1, topk_routing_group=-1, use_batch_split_schedule=False, batch_split_factor=1, megablox=True, sparse_matmul=True, wi_tile_fwd_batch_seq=512, wi_tile_fwd_embed_dim=1024, wi_tile_fwd_mlp_dim=1024, wi_tile_dlhs_batch_seq=512, wi_tile_dlhs_embed_dim=1024, wi_tile_dlhs_mlp_dim=1024, wi_tile_drhs_batch_seq=512, wi_tile_drhs_embed_dim=1024, wi_tile_drhs_mlp_dim=1024, wo_tile_fwd_batch_seq=512, wo_tile_fwd_embed_dim=1024, wo_tile_fwd_mlp_dim=1024, wo_tile_dlhs_batch_seq=512, wo_tile_dlhs_embed_dim=1024, wo_tile_dlhs_mlp_dim=1024, wo_tile_drhs_batch_seq=512, wo_tile_drhs_embed_dim=1024, wo_tile_drhs_mlp_dim=1024, merge_gating_gmm=False, num_experts=1, num_experts_per_tok=1, capacity_factor=-1.0, ragged_buffer_factor=-1.0, moe_expert_input_dim=-1, base_moe_mlp_dim=-1, load_balance_loss_weight=0.0, use_custom_sort_vjp=True, use_ring_of_experts=False, use_gather_mosaic_kernel=False, use_random_routing=False, interleave_moe_layer_step=1, moe_fsdp_use_two_stage_all_gather=False, shard_exp_on_fsdp=False, use_2d_fsdp_sharding=False, norm_topk_prob=False, float32_weight_sum=True, float32_gate_logits=False, prefuse_moe_weights=False, pagedattn_num_pages=64, pagedattn_tokens_per_page=32, pagedattn_pages_per_compute_block=4, pagedattn_max_pages_per_group=-1, pagedattn_head_dim_alignment=128, sa_block_q=512, sa_block_kv=512, sa_block_kv_compute=512, sa_block_q_dkv=512, sa_block_kv_dkv=512, sa_block_kv_dkv_compute=512, sa_block_q_dq=512, sa_block_kv_dq=512, sa_use_fused_bwd_kernel=False, sa_q_layout='HEAD_DIM_MINOR', sa_k_layout='HEAD_DIM_MINOR', sa_v_layout='HEAD_DIM_MINOR', use_max_logit_estimate=-1, cost_estimate_flops_fwd=-1, cost_estimate_flops_bwd=-1, dq_reduction_steps=0, use_splash_scheduler=False, use_qk_norm=False, temperature_tuning=False, use_indexer=False, indexer_head_dim=128, indexer_n_heads=64, indexer_topk=2048, indexer_sparse_training=False, indexer_loss_scaling_factor=0.0, moba=False, moba_chunk_size=1024, moba_topk=8, mla_naive_kvcache=True, q_lora_rank=0, kv_lora_rank=512, qk_nope_head_dim=128, qk_rope_head_dim=64, v_head_dim=128, attention='dot_product', attention_type='global', share_kv_projections=False, global_num_kv_heads=0, attention_sink=False, float32_qk_product=False, float32_logits=False, sliding_window_size=0, chunk_attn_window_size=0, attn_logits_soft_cap=None, use_post_attn_norm=False, use_post_ffw_norm=False, use_ragged_attention=False, use_tokamax_gmm=False, ragged_block_size=256, enable_padding_causal_mask=True, use_tokamax_splash=False, use_jax_splash=False, force_q_layout=False, use_qk_clip=False, qk_clip_threshold=100.0, logits_via_embedding=False, normalize_embedding_logits=True, logits_dot_in_fp32=False, cast_logits_to_fp32=True, final_logits_soft_cap=None, z_loss_multiplier=0.0, mtp_num_layers=0, mtp_loss_scaling_factor=0.1, mtp_eval_target_module=0, engram_layers=[], engram_num_heads=8, engram_head_dim=1280, engram_vocab_bases=[], engram_max_ngram_size=3, engram_kernel_size=4, engram_seed=0, decoder_block=<DecoderBlockType.LLAMA2: 'llama2'>, global_parameter_scale=1, base_emb_dim=4096, base_num_query_heads=32, base_num_kv_heads=8, base_mlp_dim=14336, dense_init_scale=1.0, base_num_decoder_layers=32, head_dim=128, attention_output_dim=-1, global_head_dim=0, mlp_activations=['silu', 'linear'], mlp_activations_limit=-1.0, normalization_layer_epsilon=1e-05, fused_qkv=False, attention_bias=False, fused_mlp=False, qk_norm_with_scale=True, v_norm_with_scale=True, quantization=<QuantizationType.NONE: ''>, replicate_quant_scale=False, quant_cfg_path='', quantize_kvcache=False, kv_quant_axis=<KvQuantAxis.HEADS_AND_DKV: 'heads_and_dkv'>, kv_quant_dtype='int8', quantization_local_shard_count=4, use_qwix_quantization=False, weight_quantization_calibration_method='absmax', act_quantization_calibration_method='absmax', bwd_quantization_calibration_method='absmax', dtype=<DType.BFLOAT16: 'bfloat16'>, grad_dtype=<DType.FLOAT32: 'float32'>, weight_dtype=<DType.BFLOAT16: 'bfloat16'>, matmul_precision=<MatmulPrecision.DEFAULT: 'default'>, activations_in_float32=False, dtype_mm='float32', elastic_enabled=False, elastic_timeout_seconds=300, elastic_max_retries=10, enable_multi_tier_checkpointing=False, local_checkpoint_directory='', local_checkpoint_period=0, multi_tier_checkpointing_backup_interval_minutes=0, mtc_data_parallelism=0, enable_emergency_checkpoint=False, use_replicator_service=False, replicator_backup_interval_minutes=0, checkpoint_storage_target_data_file_size_bytes=2147483648, checkpoint_storage_use_ocdbt=False, checkpoint_storage_use_zarr3=False, checkpoint_storage_concurrent_gb=96, load_parameters_path='gs://lance-maxtext/rl_ckpt_llama31_8b/rl_ckpt_llama31_8b/0/items', lora_input_adapters_path='', load_full_state_path='', enable_checkpointing=True, load_checkpoint_only_once=False, async_checkpointing=False, checkpoint_period=50, max_num_checkpoints_to_keep=10, enable_single_replica_ckpt_restoring=False, checkpoint_todelete_subdir=None, checkpoint_todelete_full_path=None, force_unroll=False, checkpoint_is_quantized=False, save_quantized_params_path='', enable_orbax_v1=False, checkpoint_conversion_fn=None, source_checkpoint_layout='orbax', save_checkpoint_on_completion=True, enable_continuous_checkpointing=False, colocated_python_checkpointing=False, enable_autocheckpoint=False, base_config='../base.yml', run_name='pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional', model_name='llama3.1-8b-Instruct', override_model_config=False, override_logical_axis_rules=False, log_config=False, debug_sharding=False, base_output_directory='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237', sharding_strategy=None, debug=Debug(rl=True), rl=RL(num_generations=2, num_iterations=1, grpo_beta=0.08, grpo_epsilon=0.2, loss_algo='grpo', use_agentic_rollout=False, max_concurrency=256, off_policy_steps=0, system_prompt='', degenerate_group_masking=True, epsilon_high=None)), dropout=Dropout( # RngState: 192 (1.2 KB) broadcast_dims=(-2,), deterministic=False, rate=0.0, rng_collection='dropout', rngs=Rngs( # RngState: 192 (1.2 KB) aqt=RngStream( # RngState: 64 (384 B) count=RngCount( # 32 (128 B) value=Array(shape=(32,), dtype=dtype('uint32')), eager_sharding=False, tag='aqt' ), key=RngKey( # 32 (256 B) value=Array(shape=(32,), dtype=key<fry>), eager_sharding=False, tag='aqt' ), tag='aqt' ), dropout=RngStream( # RngState: 64 (384 B) count=RngCount( # 32 (128 B) value=Array(shape=(32,), dtype=dtype('uint32')), eager_sharding=False, tag='dropout' ), key=RngKey( # 32 (256 B) value=Array(shape=(32,), dtype=key<fry>), eager_sharding=False, tag='dropout' ), tag='dropout' ), params=RngStream( # RngState: 64 (384 B) count=RngCount( # 32 (128 B) value=Array(shape=(32,), dtype=dtype('uint32')), eager_sharding=False, tag='params' ), key=RngKey( # 32 (256 B) value=Array(shape=(32,), dtype=key<fry>), eager_sharding=False, tag='params' ), tag='params' ) ) ), dtype=<DType.BFLOAT16: 'bfloat16'>, in_features=4096, intermediate_dim=14336, intermediate_dropout_rate=0.0, intermediate_logical=('activation_batch', 'activation_length', 'activation_mlp'), kernel_init=<function nd_dense_init.<locals>.init_fn at 0x7db8f65d5e40>, mesh=Mesh(axis_sizes=(1, 1, 1, 32, 1, 1, 1, 1, 1, 1, 1, 1), axis_names=('diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive'), axis_types=(Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto)), mlp_layer_norm=None, model_mode='train', quant=None, use_bias=False, use_pre_norm=False, weight_dtype=<DType.BFLOAT16: 'bfloat16'>, wi_0=DenseGeneral( # Param: 1,879,048,192 (3.8 GB) axis=(-1,), bias=None, dtype=<DType.BFLOAT16: 'bfloat16'>, in_features_shape=(4096,), kernel=Param( # 1,879,048,192 (3.8 GB) value=Array(shape=(4096, 32, 14336), dtype=dtype(bfloat16)), eager_sharding=False, out_sharding=('embed', 'layers', 'mlp') ), kernel_axes=('embed', 'mlp'), kernel_init=<function nd_dense_init.<locals>.init_fn at 0x7db8f65d5e40>, matmul_precision=<MatmulPrecision.DEFAULT: 'default'>, out_features_shape=(14336,), parameter_memory_host_offload=False, quant=None, shard_mode=<ShardMode.AUTO: 'auto'>, use_bias=False, weight_dtype=<DType.BFLOAT16: 'bfloat16'> ), wi_1=DenseGeneral( # Param: 1,879,048,192 (3.8 GB) axis=(-1,), bias=None, dtype=<DType.BFLOAT16: 'bfloat16'>, in_features_shape=(4096,), kernel=Param( # 1,879,048,192 (3.8 GB) value=Array(shape=(4096, 32, 14336), dtype=dtype(bfloat16)), eager_sharding=False, out_sharding=('embed', 'layers', 'mlp') ), kernel_axes=('embed', 'mlp'), kernel_init=<function nd_dense_init.<locals>.init_fn at 0x7db8f65d5e40>, matmul_precision=<MatmulPrecision.DEFAULT: 'default'>, out_features_shape=(14336,), parameter_memory_host_offload=False, quant=None, shard_mode=<ShardMode.AUTO: 'auto'>, use_bias=False, weight_dtype=<DType.BFLOAT16: 'bfloat16'> ), wo=DenseGeneral( # Param: 1,879,048,192 (3.8 GB) axis=(-1,), bias=None, dtype=<DType.BFLOAT16: 'bfloat16'>, in_features_shape=(14336,), kernel=Param( # 1,879,048,192 (3.8 GB) value=Array(shape=(14336, 32, 4096), dtype=dtype(bfloat16)), eager_sharding=False, out_sharding=('mlp', 'layers', 'embed') ), kernel_axes=('mlp', 'embed'), kernel_init=<function nd_dense_init.<locals>.init_fn at 0x7db8f65d5e40>, matmul_precision=<MatmulPrecision.DEFAULT: 'default'>, out_features_shape=(4096,), parameter_memory_host_offload=False, quant=None, shard_mode=<ShardMode.AUTO: 'auto'>, use_bias=False, weight_dtype=<DType.BFLOAT16: 'bfloat16'> ) ), post_self_attention_layer_norm=RMSNorm( # Param: 131,072 (262.1 KB) dtype=<DType.BFLOAT16: 'bfloat16'>, epsilon=1e-05, kernel_axes=('norm',), num_features=4096, parameter_memory_host_offload=False, scale=Param( # 131,072 (262.1 KB) value=Array(shape=(4096, 32), dtype=dtype(bfloat16)), eager_sharding=False, out_sharding=('norm', 'layers') ), scale_init=<function ones at 0x7dba2a2d6980>, scale_offset=0.0, shard_mode=<ShardMode.AUTO: 'auto'>, weight_dtype=<DType.BFLOAT16: 'bfloat16'>, with_scale=True ), pre_self_attention_layer_norm=RMSNorm( # Param: 131,072 (262.1 KB) dtype=<DType.BFLOAT16: 'bfloat16'>, epsilon=1e-05, kernel_axes=('norm',), num_features=4096, parameter_memory_host_offload=False, scale=Param( # 131,072 (262.1 KB) value=Array(shape=(4096, 32), dtype=dtype(bfloat16)), eager_sharding=False, out_sharding=('norm', 'layers') ), scale_init=<function ones at 0x7dba2a2d6980>, scale_offset=0.0, shard_mode=<ShardMode.AUTO: 'auto'>, weight_dtype=<DType.BFLOAT16: 'bfloat16'>, with_scale=True ), quant=None, self_attention=Attention( # RngState: 192 (1.2 KB), Param: 1,342,177,280 (2.7 GB), Total: 1,342,177,472 (2.7 GB) KVCache_0=None, ar_cache_axis_order=(1, 2, 0, 3), attention_kernel='dot_product', attention_op=AttentionOp( # RngState: 192 (1.2 KB) AqtEinsum_0=<function einsum at 0x7dba2a842480>, AqtEinsum_1=<function einsum at 0x7dba2a842480>, AqtEinsum_2=<function einsum at 0x7dba2a842480>, AqtEinsum_3=<function einsum at 0x7dba2a842480>, attention_kernel='dot_product', attention_type=<AttentionType.GLOBAL: 'global'>, attn_logits_soft_cap=None, cache_logical_axis_names=('cache_batch', 'cache_sequence', 'cache_heads', 'cache_kv'), cache_scale_logical_axis_names=('cache_scale_batch', 'cache_scale_sequence', 'cache_scale_heads', 'cache_scale_kv'), chunk_attn_window_size=0, compute_axis_order=(0, 1, 2, 3), config=MaxTextConfig(emb_dim=4096, mlp_dim=14336, moe_mlp_dim=-1, num_decoder_layers=32, num_kv_heads=8, num_query_heads=32, num_diloco_replicas=1, ici_parallelism=[1, 1, 1, -1, 1, 1, 1, 1, 1, 1, 1, 1], dcn_parallelism=[1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], using_pipeline_parallelism=False, context_parallel_size=1, num_target_devices=32, global_batch_size_to_train_on=384, global_batch_size_to_eval_on=384, global_batch_size_to_load=384, global_batch_size_to_load_eval=384, micro_batch_size_to_train_on=384, micro_batch_size_to_eval_on=384, checkpoint_dir='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237/pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional/checkpoints/', convert_checkpoint_if_possible=True, metrics_dir='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237/pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional/metrics/', tensorboard_dir='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237/pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional/tensorboard/', managed_mldiagnostics_dir='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237/pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional/managed-mldiagnostics', rampup_end_step=0, tensors_on_device=[], tensors_to_offload=['decoder_layer_input', 'query_proj', 'key_proj', 'value_proj'], global_batch_size_to_load_start=None, global_batch_size_to_load_increment=None, rampup_samples_per_increment_to_load=None, d_model_for_audio=256, encoder_attention_heads_for_audio=4, encoder_ffn_dim_for_audio=512, encoder_layers_for_audio=2, attention_dropout_for_audio=0.0, activation_dropout_for_audio=0.0, activation_function_for_audio='gelu', num_mel_bins_for_audio=128, max_source_positions_for_audio=1500, scale_embedding_for_audio=True, n_window_for_audio=50, n_window_infer_for_audio=800, conv_chunksize_for_audio=500, downsample_hidden_size_for_audio=256, output_dim_for_audio=512, num_conv_layers_for_audio=3, max_timescale_for_audio=10000.0, max_sample_len_for_audio=10000, projector_input_dim_for_vit=4096, projector_output_dim_for_vit=4096, pixel_shuffle_ratio_for_vit=0.5, projector_dropout_for_vit=0.0, hidden_size_for_vit=1408, intermediate_size_for_vit=5632, num_attention_heads_for_vit=16, num_channels_for_vit=3, tile_size_for_vit=336, patch_size_for_vit=14, conv_stride_for_vit=14, num_hidden_layers_for_vit=34, rope_theta_for_vit=10000, vision_output_dim_for_vit=4096, spatial_merge_size_for_vit=2, out_hidden_size_for_vit=512, temporal_patch_size_for_vit=2, num_position_embeddings_for_vit=1024, deepstack_visual_indexes_for_vit=[], vision_output_length=-1, use_multimodal=False, freeze_vision_encoder_params=True, freeze_audio_encoder_params=True, use_audio=False, image_size_for_vit=896, image_path='', image_placeholder='<|image|>', posemb_type_for_vit='learn', max_num_images_per_example=-1, video_path='', audio_path='', video_placeholder='<|video|>', audio_placeholder='<|audio|>', use_audio_in_video=False, use_mrope=False, mrope_section=[24, 20, 20], position_id_per_seconds=25, managed_mldiagnostics=False, managed_mldiagnostics_run_group='', enable_tensorboard=True, use_vertex_tensorboard=False, vertex_tensorboard_project='', vertex_tensorboard_region='', report_heartbeat_metric_for_gcp_monitoring=False, heartbeat_reporting_interval_in_seconds=5, report_performance_metric_for_gcp_monitoring=False, enable_goodput_recording=False, monitor_goodput=False, goodput_upload_interval_seconds=30, enable_pathways_goodput=False, monitor_step_time_deviation=True, step_deviation_interval_seconds=30, enable_gcp_goodput_metrics=True, enable_gcp_step_deviation_metrics=True, metrics_file='', gcs_metrics=False, save_config_to_gcs=False, record_internal_nn_metrics=0, prometheus_port=0, enable_checkpoint_cloud_logger=False, enable_tunix_perf_metrics=False, collect_stack_trace=False, stack_trace_to_cloud=False, stack_trace_interval_seconds=600, dump_hlo=False, dump_step=-1, dump_hlo_local_dir='/tmp/xla_dump/', dump_hlo_delete_local_after=True, dump_hlo_gcs_dir='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237/pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional/xla_dump', dump_hlo_module_name='jit_train_step', dump_hlo_local_module_name='jit_train_step', dump_hlo_xla_flags='--xla_dump_to=/tmp/xla_dump/ --xla_dump_large_constants --xla_dump_hlo_module_re=jit_train_step', dump_hlo_upload_all=False, dump_jaxpr=False, dump_jaxpr_local_dir='/tmp/jaxpr_dump/', dump_jaxpr_delete_local_after=True, dump_jaxpr_gcs_dir='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237/pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional/jaxpr_dump', profiler=<ProfilerType.XPLANE: 'xplane'>, upload_all_profiler_results=False, skip_first_n_steps_for_profiler=1, profiler_steps=5, profile_cleanly=True, profile_periodically_period=-1, hide_profiler_step_metric=False, enable_jax_profiler=False, jax_profiler_port=9999, xprof_tpu_power_trace_level=<XProfTPUPowerTraceMode.POWER_TRACE_NONE: 0>, xprof_e2e_enable_fw_throttle_event=False, xprof_e2e_enable_fw_power_level_event=False, xprof_e2e_enable_fw_thermal_event=False, profile_power_events=False, constant_bound_config=[], jax_cache_dir='~/jax_cache', jax_distributed_initialization_timeout=300, jax_debug_log_modules='', skip_jax_distributed_system=True, enable_single_controller=False, subslice_shape='', max_checkify=False, compiled_trainstep_file='', compile_topology='', compile_topology_num_slices=-1, enable_prefix_caching=False, prefix_caching_hbm_byte=10000000000, prefix_caching_dram_byte=100000000000, inference_microbenchmark_prefill_lengths='64,128,256,512,1024', inference_microbenchmark_stages='prefill,generate', inference_microbenchmark_loop_iters=10, inference_microbenchmark_log_file_path='', inference_microbenchmark_num_samples=[1, 2, 3, 4, 5], inference_metadata_file='', inference_benchmark_test=False, inference_server='MaxtextInterleavedServer', prefill_slice='v5e-16', generate_slice='v5e-16', stack_prefill_result_cache=False, prefill_cache_axis_order='1,2,0,3', ar_cache_axis_order='1,2,0,3', compute_axis_order='0,1,2,3', reshape_q=False, decode_sampling_strategy=<SamplingStrategy.GREEDY: 'greedy'>, decode_sampling_nucleus_p=1.0, decode_sampling_top_k=50, decode_sampling_temperature=0.9, max_target_length=1024, max_prefill_predict_length=256, prompt='I love to', load_from_prefill_dir=False, prefill_cache_dir='', autoregressive_decode_assert='', model_call_mode='', use_chunked_prefill=False, prefill_chunk_size=256, enable_model_warmup=False, enable_llm_inference_pool=False, multi_sampling=False, return_log_prob=False, vocab_size=128256, tokenizer_path='meta-llama/Llama-3.1-8B-Instruct', tokenizer_type=<TokenizerType.HUGGINGFACE: 'huggingface'>, use_chat_template=False, chat_template_path='maxtext/examples/chat_templates/gsm8k_rl.json', chat_template='', tokenize_train_data=True, tokenize_eval_data=True, add_bos=True, add_eos=True, use_truncation=True, num_vocab_tiling=1, grain_train_files='', grain_eval_files='', grain_train_mixture_config_path='', grain_file_type='arrayrecord', grain_worker_count=1, grain_per_worker_buffer_size=1, grain_worker_count_eval=1, grain_per_worker_buffer_size_eval=1, grain_ram_budget_mb=1024, grain_num_threads=16, grain_prefetch_buffer_size=500, grain_num_threads_eval=16, grain_prefetch_buffer_size_eval=500, grain_data_source_max_workers=16, grain_shuffle_buffer_size=100, hf_path='', hf_name='main', hf_data_dir=None, hf_train_files=None, hf_eval_split=None, hf_eval_files=None, hf_access_token='<redacted>', dataset_path='', dataset_name='openai/gsm8k', eval_dataset_name='openai/gsm8k', train_split='train', eval_split='test', dataset_type=<DatasetType.TFDS: 'tfds'>, per_device_batch_size=12.0, eval_per_device_batch_size=12.0, max_corpus_chars=10000000, train_data_columns=['text'], train_image_column='image', eval_data_columns=['text'], eval_image_column='image', packing=True, grain_packing_type='first_fit', max_segments_per_seq=-1, num_epoch=1, expansion_factor_real_data=-1.0, reuse_example_batch=0, generate_padding_batch_train=False, generate_padding_batch_eval=False, enable_rampup_batch_size=False, per_device_batch_size_start=4.0, per_device_batch_size_increment=2.0, global_rampup_samples=500, colocated_python_data_input=False, max_position_embeddings=163840, original_max_position_embeddings=4096, rope_factor=40, beta_fast=32, beta_slow=1, mscale=1.0, rope_interleave=True, rope_truncate=True, rope_attention_scaling=False, rope_type=<RopeType.DEFAULT: 'default'>, rope_use_scale=True, rope_min_timescale=1, rope_max_timescale=500000, rope_linear_scaling_factor=1.0, local_rope_max_timescale=-1, global_rope_max_timescale=-1, global_rope_proportion=0.25, local_rope_proportion=1.0, use_iota_embed=False, use_untrainable_positional_embedding=False, trainable_position_size=-1, nope_layer_interval=-1, reasoning_start_token='<reasoning>', reasoning_end_token='</reasoning>', solution_start_token='<answer>', solution_end_token='</answer>', reward_exact_answer=1.0, reward_exact_format_match=0.1, reward_white_space_format_match=1.0, reward_partial_format_match=0.0, reward_ratio_guess_to_answer_high=0.0, reward_ratio_guess_to_answer_low=0.0, penalty_incorrect_format=0.0, penalty_incorrect_answer=0.0, math_verify_timeout=120, math_verify_num_procs=None, eval_sampling_strategy='greedy', generation_configs={'greedy': {'eval_temperature': 0.01, 'eval_top_k': 1, 'eval_top_p': 1.0}, 'standard': {'eval_temperature': 0.7, 'eval_top_k': 50, 'eval_top_p': 0.95}, 'liberal': {'eval_temperature': 0.85, 'eval_top_k': 2000, 'eval_top_p': 1.0}}, num_eval_passes=1, eval_corr_lst=False, eval_make_lst=False, eval_mode='pass', batch_size=1, num_batches=2, num_test_batches=5, test_batch_start_index=0, train_fraction=1.0, train_micro_batch_size=-1, rollout_micro_batch_size=-1, num_generations=2, num_iterations=1, grpo_beta=0.08, grpo_epsilon=0.2, loss_algo='grpo', use_agentic_rollout=False, max_concurrency=256, off_policy_steps=0, system_prompt='', degenerate_group_masking=True, epsilon_high=None, kv_cache_buffer=256, hbm_utilization_vllm=0.72, swap_space_vllm_gb=2, enable_dp_attention=False, enable_expert_parallel=False, async_scheduling=False, max_num_batched_tokens=None, max_num_seqs=None, stop_strings=None, vllm_additional_config={}, vllm_hf_overrides={}, vllm_hf_config_path='', trainer_devices_fraction=0.5, sampler_devices_fraction=0.5, chips_per_vm=4, use_pathways=False, num_trainer_slices=-1, num_samplers_slices=-1, rollout_data_parallelism=2, rollout_tensor_parallelism=-1, rollout_expert_parallelism=1, student_overrides={}, teacher_overrides={}, offline_data_dir=None, distill_alpha=0.5, distill_temperature=1.0, distill_beta=0.0, distill_feature_loss_type='cosine', distill_layer_indices=None, distill_alpha_end=None, distill_alpha_schedule='constant', distill_temperature_end=None, distill_temperature_schedule='constant', distill_beta_end=None, distill_beta_schedule='constant', student_params_to_update=None, use_dpo=False, dpo_label_smoothing=0.0, dpo_beta=0.1, use_sft=False, sft_train_on_completion_only=False, formatting_func_path='', formatting_func_kwargs={}, use_grpo=True, muon_beta=0.95, muon_weight_decay=0.0, muon_consistent_rms=None, adam_b1=0.9, adam_b2=0.99, adam_eps=1e-08, adam_eps_root=0.0, adam_weight_decay=0.1, adamw_mask=[], mu_dtype=<DType.BFLOAT16: 'bfloat16'>, opt_type=<OptimizerType.ADAMW: 'adamw'>, skip_step_on_spikes=False, skip_step_interval=128, skip_step_scaling_factor=6.0, gradient_accumulation_steps=1, use_tunix_gradient_accumulation=False, gradient_clipping_threshold=0.1, learning_rate=3e-06, lr_schedule_type=<LearningRateScheduleType.COSINE: 'cosine'>, learning_rate_final_fraction=0.1, wsd_decay_steps_fraction=0.1, wsd_decay_style=<WsdDecayStyle.LINEAR: 'linear'>, warmup_steps_fraction=0.1, learning_rate_schedule_steps=150001, trainable_parameters_mask=[], enable_diloco=False, diloco_sync_period=36, diloco_outer_lr=0.3, diloco_outer_momentum=0.9, mhc_expansion_rate=1, sinkhorn_iterations=20, steps=150001, log_period=20, eval_interval=10, eval_steps=-1, target_eval_loss=0.0, abort_on_nan_loss=True, abort_on_inf_loss=True, enable_dropout=False, dropout_rate=0.0, enable_data_shuffling=True, data_shuffle_seed=42, init_weights_seed=0, remat_policy='custom', remat_policy_for_vit='minimal', decoder_layer_input=<RematLocation.OFFLOAD: 'offload'>, context=<RematLocation.REMAT: 'remat'>, mlpwi=<RematLocation.REMAT: 'remat'>, mlpwi_0=<RematLocation.REMAT: 'remat'>, mlpwi_1=<RematLocation.REMAT: 'remat'>, mlpwo=<RematLocation.REMAT: 'remat'>, moe_mlpwi_0=<RematLocation.REMAT: 'remat'>, moe_mlpwi_1=<RematLocation.REMAT: 'remat'>, moe_mlpwo=<RematLocation.REMAT: 'remat'>, query_proj=<RematLocation.OFFLOAD: 'offload'>, key_proj=<RematLocation.OFFLOAD: 'offload'>, value_proj=<RematLocation.OFFLOAD: 'offload'>, query_wa_proj=<RematLocation.REMAT: 'remat'>, kv_wa_proj=<RematLocation.REMAT: 'remat'>, qkv_proj=<RematLocation.REMAT: 'remat'>, out_proj=<RematLocation.REMAT: 'remat'>, mla_q=<RematLocation.REMAT: 'remat'>, mla_kv=<RematLocation.REMAT: 'remat'>, attention_out=<RematLocation.REMAT: 'remat'>, engram=<RematLocation.REMAT: 'remat'>, optimizer_memory_host_offload=False, parameter_memory_host_offload=False, pipeline_fsdp_ag_per_repeat=False, num_layers_per_pipeline_stage=1, num_pipeline_repeats=-1, pipeline_parallel_layers=32, num_pipeline_microbatches=-1, pipeline_delay_activation_forwarding=False, pipeline_fsdp_ag_once=False, scan_pipeline_iterations=True, scan_pipeline_repeats=False, scan_layers_per_stage=False, set_remat_policy_on_pipeline_iterations=True, set_remat_policy_on_layers_per_stage=False, ici_diloco_parallelism=1, ici_data_parallelism=1, ici_fsdp_parallelism=-1, ici_fsdp_transpose_parallelism=1, ici_sequence_parallelism=1, ici_context_parallelism=1, ici_context_autoregressive_parallelism=1, ici_tensor_parallelism=1, ici_tensor_transpose_parallelism=1, ici_tensor_sequence_parallelism=1, ici_autoregressive_parallelism=1, ici_pipeline_parallelism=1, ici_expert_parallelism=1, dcn_diloco_parallelism=1, dcn_data_parallelism=-1, dcn_fsdp_parallelism=1, dcn_fsdp_transpose_parallelism=1, dcn_sequence_parallelism=1, dcn_context_parallelism=1, dcn_context_autoregressive_parallelism=1, dcn_tensor_parallelism=1, dcn_tensor_transpose_parallelism=1, dcn_tensor_sequence_parallelism=1, dcn_pipeline_parallelism=1, dcn_expert_parallelism=1, dcn_autoregressive_parallelism=1, logical_axis_rules=[['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'context', 'expert']], ['activation_vocab', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_vocab', ['tensor', 'tensor_transpose']], ['activation_vocab', 'tensor_sequence'], ['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['embed_vocab', ['fsdp', 'fsdp_transpose', 'context', 'expert']], ['activation_batch_attn', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['activation_kv_heads', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_length_attn', ['context']], ['activation_q_length', ['context']], ['activation_kv_length', []], ['activation_embed_attn', ['tensor', 'tensor_transpose']], ['activation_kv', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_kv_head_dim', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['q_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['kv_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['qkv', []], ['kv', []], ['kv_head_dim', []], ['q_lora', ['fsdp', 'fsdp_transpose', 'context', 'tensor_transpose', 'expert']], ['q_lora', ['fsdp', 'context', 'tensor_transpose', 'expert']], ['q_lora', ['fsdp', 'fsdp_transpose', 'context', 'expert']], ['q_lora', ['fsdp', 'context', 'expert']], ['q_lora_up_proj', []], ['kv_lora', ['fsdp', 'fsdp_transpose', 'context', 'tensor_transpose', 'expert']], ['kv_lora', ['fsdp', 'context', 'tensor_transpose', 'expert']], ['kv_lora', ['fsdp', 'fsdp_transpose', 'context', 'expert']], ['kv_lora', ['fsdp', 'context', 'expert']], ['kv_lora_up_proj', []], ['activation_batch_moe', ['data', 'fsdp', 'fsdp_transpose']], ['activation_length_moe', ['context']], ['activation_norm_length_moe', ['tensor_sequence', 'context']], ['activation_embed_moe', ['tensor', 'tensor_transpose']], ['activation_mlp_moe', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_exp', ['expert']], ['exp', 'expert'], ['mlp_moe', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']], ['embed_moe', ['fsdp', 'fsdp_transpose', 'tensor_transpose', 'context']], ['embed_moe', ['fsdp', 'tensor_transpose', 'context']], ['embed_moe', ['fsdp', 'fsdp_transpose', 'context']], ['embed_moe', ['fsdp', 'context']], ['activation_mlp', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_length', ['context']], ['activation_norm_length', ['tensor_sequence', 'context']], ['activation_embed', ['tensor', 'tensor_transpose']], ['activation_stage', 'stage'], ['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']], ['embed', ['fsdp', 'fsdp_transpose', 'tensor_transpose', 'context', 'expert']], ['embed', ['fsdp', 'tensor_transpose', 'context', 'expert']], ['embed', ['fsdp', 'fsdp_transpose', 'context', 'expert']], ['embed', ['fsdp', 'context', 'expert']], ['norm', ['tensor', 'tensor_transpose']], ['layers', 'stage'], ['diloco', 'diloco'], ['engram_dim', ['tensor']], ['dense_layers', []], ['moe_layers', []], ['mhc', []], ['prefill_activation_length', ['context']], ['prefill_activation_norm_length', ['tensor_sequence', 'context']], ['activation_prefill_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['decode_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['decode_length', []], ['cache_heads', ['autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence']], ['cache_heads', ['autoregressive', 'tensor', 'tensor_sequence']], ['paged_kv_heads', ['tensor']], ['cache_batch_prefill', []], ['cache_batch', []], ['cache_heads_none', []], ['cache_kv', []], ['cache_sequence', []], ['num_pages', []], ['tokens_per_page', []], ['paged_kv_head_dim_size', []], ['mlp_no_fsdp', ['tensor', 'tensor_sequence', 'autoregressive']], ['embed_tensor_transpose', ['tensor_transpose']], ['exp_with_fsdp', 'fsdp']], data_sharding=[['data', 'stage', 'fsdp', 'fsdp_transpose', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']], context_sharding='context', input_data_sharding_logical_axes=['activation_embed_and_logits_batch', 'activation_norm_length'], sharding_tolerance=0.02, shard_optimizer_over_data=False, internal_compile=False, internal_compile_num_devices=-1, compile_xla_flags='', hardware='tpu', num_slices=1, mesh_axes=['diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive'], shard_mode=<ShardMode.AUTO: 'auto'>, inhomogeneous_layer_cycle_interval=1, scan_layers=True, param_scan_axis=1, context_parallel_load_balance=True, context_parallel_strategy='all_gather', context_parallel_reorder_strategy=<ReorderStrategy.AUTO: 'auto'>, custom_mesh='', custom_mesh_and_rule='', allow_split_physical_axes=False, enable_nnx=True, optimize_mesh_for_tpu_v6e=False, shardy=True, pure_nnx_decoder=True, pure_nnx=True, remove_size_one_mesh_axis_from_type=True, gdn_conv_kernel_dim=4, gdn_key_head_dim=128, gdn_value_head_dim=128, gdn_num_key_heads=16, gdn_num_value_heads=32, gdn_chunk_size=64, use_qk_norm_in_gdn=True, partial_rotary_factor=1.0, first_num_dense_layers=0, shared_experts=0, routed_scaling_factor=1.0, routed_score_func='', routed_bias=False, routed_bias_update_rate=0.0, mlp_bias=False, n_routing_groups=-1, topk_routing_group=-1, use_batch_split_schedule=False, batch_split_factor=1, megablox=True, sparse_matmul=True, wi_tile_fwd_batch_seq=512, wi_tile_fwd_embed_dim=1024, wi_tile_fwd_mlp_dim=1024, wi_tile_dlhs_batch_seq=512, wi_tile_dlhs_embed_dim=1024, wi_tile_dlhs_mlp_dim=1024, wi_tile_drhs_batch_seq=512, wi_tile_drhs_embed_dim=1024, wi_tile_drhs_mlp_dim=1024, wo_tile_fwd_batch_seq=512, wo_tile_fwd_embed_dim=1024, wo_tile_fwd_mlp_dim=1024, wo_tile_dlhs_batch_seq=512, wo_tile_dlhs_embed_dim=1024, wo_tile_dlhs_mlp_dim=1024, wo_tile_drhs_batch_seq=512, wo_tile_drhs_embed_dim=1024, wo_tile_drhs_mlp_dim=1024, merge_gating_gmm=False, num_experts=1, num_experts_per_tok=1, capacity_factor=-1.0, ragged_buffer_factor=-1.0, moe_expert_input_dim=-1, base_moe_mlp_dim=-1, load_balance_loss_weight=0.0, use_custom_sort_vjp=True, use_ring_of_experts=False, use_gather_mosaic_kernel=False, use_random_routing=False, interleave_moe_layer_step=1, moe_fsdp_use_two_stage_all_gather=False, shard_exp_on_fsdp=False, use_2d_fsdp_sharding=False, norm_topk_prob=False, float32_weight_sum=True, float32_gate_logits=False, prefuse_moe_weights=False, pagedattn_num_pages=64, pagedattn_tokens_per_page=32, pagedattn_pages_per_compute_block=4, pagedattn_max_pages_per_group=-1, pagedattn_head_dim_alignment=128, sa_block_q=512, sa_block_kv=512, sa_block_kv_compute=512, sa_block_q_dkv=512, sa_block_kv_dkv=512, sa_block_kv_dkv_compute=512, sa_block_q_dq=512, sa_block_kv_dq=512, sa_use_fused_bwd_kernel=False, sa_q_layout='HEAD_DIM_MINOR', sa_k_layout='HEAD_DIM_MINOR', sa_v_layout='HEAD_DIM_MINOR', use_max_logit_estimate=-1, cost_estimate_flops_fwd=-1, cost_estimate_flops_bwd=-1, dq_reduction_steps=0, use_splash_scheduler=False, use_qk_norm=False, temperature_tuning=False, use_indexer=False, indexer_head_dim=128, indexer_n_heads=64, indexer_topk=2048, indexer_sparse_training=False, indexer_loss_scaling_factor=0.0, moba=False, moba_chunk_size=1024, moba_topk=8, mla_naive_kvcache=True, q_lora_rank=0, kv_lora_rank=512, qk_nope_head_dim=128, qk_rope_head_dim=64, v_head_dim=128, attention='dot_product', attention_type='global', share_kv_projections=False, global_num_kv_heads=0, attention_sink=False, float32_qk_product=False, float32_logits=False, sliding_window_size=0, chunk_attn_window_size=0, attn_logits_soft_cap=None, use_post_attn_norm=False, use_post_ffw_norm=False, use_ragged_attention=False, use_tokamax_gmm=False, ragged_block_size=256, enable_padding_causal_mask=True, use_tokamax_splash=False, use_jax_splash=False, force_q_layout=False, use_qk_clip=False, qk_clip_threshold=100.0, logits_via_embedding=False, normalize_embedding_logits=True, logits_dot_in_fp32=False, cast_logits_to_fp32=True, final_logits_soft_cap=None, z_loss_multiplier=0.0, mtp_num_layers=0, mtp_loss_scaling_factor=0.1, mtp_eval_target_module=0, engram_layers=[], engram_num_heads=8, engram_head_dim=1280, engram_vocab_bases=[], engram_max_ngram_size=3, engram_kernel_size=4, engram_seed=0, decoder_block=<DecoderBlockType.LLAMA2: 'llama2'>, global_parameter_scale=1, base_emb_dim=4096, base_num_query_heads=32, base_num_kv_heads=8, base_mlp_dim=14336, dense_init_scale=1.0, base_num_decoder_layers=32, head_dim=128, attention_output_dim=-1, global_head_dim=0, mlp_activations=['silu', 'linear'], mlp_activations_limit=-1.0, normalization_layer_epsilon=1e-05, fused_qkv=False, attention_bias=False, fused_mlp=False, qk_norm_with_scale=True, v_norm_with_scale=True, quantization=<QuantizationType.NONE: ''>, replicate_quant_scale=False, quant_cfg_path='', quantize_kvcache=False, kv_quant_axis=<KvQuantAxis.HEADS_AND_DKV: 'heads_and_dkv'>, kv_quant_dtype='int8', quantization_local_shard_count=4, use_qwix_quantization=False, weight_quantization_calibration_method='absmax', act_quantization_calibration_method='absmax', bwd_quantization_calibration_method='absmax', dtype=<DType.BFLOAT16: 'bfloat16'>, grad_dtype=<DType.FLOAT32: 'float32'>, weight_dtype=<DType.BFLOAT16: 'bfloat16'>, matmul_precision=<MatmulPrecision.DEFAULT: 'default'>, activations_in_float32=False, dtype_mm='float32', elastic_enabled=False, elastic_timeout_seconds=300, elastic_max_retries=10, enable_multi_tier_checkpointing=False, local_checkpoint_directory='', local_checkpoint_period=0, multi_tier_checkpointing_backup_interval_minutes=0, mtc_data_parallelism=0, enable_emergency_checkpoint=False, use_replicator_service=False, replicator_backup_interval_minutes=0, checkpoint_storage_target_data_file_size_bytes=2147483648, checkpoint_storage_use_ocdbt=False, checkpoint_storage_use_zarr3=False, checkpoint_storage_concurrent_gb=96, load_parameters_path='gs://lance-maxtext/rl_ckpt_llama31_8b/rl_ckpt_llama31_8b/0/items', lora_input_adapters_path='', load_full_state_path='', enable_checkpointing=True, load_checkpoint_only_once=False, async_checkpointing=False, checkpoint_period=50, max_num_checkpoints_to_keep=10, enable_single_replica_ckpt_restoring=False, checkpoint_todelete_subdir=None, checkpoint_todelete_full_path=None, force_unroll=False, checkpoint_is_quantized=False, save_quantized_params_path='', enable_orbax_v1=False, checkpoint_conversion_fn=None, source_checkpoint_layout='orbax', save_checkpoint_on_completion=True, enable_continuous_checkpointing=False, colocated_python_checkpointing=False, enable_autocheckpoint=False, base_config='../base.yml', run_name='pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional', model_name='llama3.1-8b-Instruct', override_model_config=False, override_logical_axis_rules=False, log_config=False, debug_sharding=False, base_output_directory='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237', sharding_strategy=None, debug=Debug(rl=True), rl=RL(num_generations=2, num_iterations=1, grpo_beta=0.08, grpo_epsilon=0.2, loss_algo='grpo', use_agentic_rollout=False, max_concurrency=256, off_policy_steps=0, system_prompt='', degenerate_group_masking=True, epsilon_high=None)), dropout_rate=0.0, dtype=<DType.BFLOAT16: 'bfloat16'>, flash_axis_names_kv=('activation_batch_attn', 'activation_heads', 'activation_kv_length', 'activation_kv'), flash_axis_names_q=('activation_batch_attn', 'activation_heads', 'activation_length', 'activation_kv'), flash_axis_names_splash_kernel=('activation_heads', 'activation_length'), float32_logits=False, float32_qk_product=False, key_axis_order=(2, 0, 1, 3), kv_quant=None, max_prefill_predict_length=256, max_target_length=1024, mesh=Mesh(axis_sizes=(1, 1, 1, 32, 1, 1, 1, 1, 1, 1, 1, 1), axis_names=('diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive'), axis_types=(Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto)), num_kv_heads=8, num_query_heads=32, prefill_cache_logical_axis_names=('cache_batch_prefill', 'cache_sequence', 'cache_heads', 'cache_kv'), quant=None, ragged_block_size=256, ragged_lengths_names=('cache_batch',), ragged_qkv_axis_names=('cache_batch', 'cache_heads', 'cache_sequence', 'cache_kv'), reshape_q=False, rngs=Rngs( # RngState: 192 (1.2 KB) aqt=RngStream( # RngState: 64 (384 B) count=RngCount( # 32 (128 B) value=Array(shape=(32,), dtype=dtype('uint32')), eager_sharding=False, tag='aqt' ), key=RngKey( # 32 (256 B) value=Array(shape=(32,), dtype=key<fry>), eager_sharding=False, tag='aqt' ), tag='aqt' ), dropout=RngStream( # RngState: 64 (384 B) count=RngCount( # 32 (128 B) value=Array(shape=(32,), dtype=dtype('uint32')), eager_sharding=False, tag='dropout' ), key=RngKey( # 32 (256 B) value=Array(shape=(32,), dtype=key<fry>), eager_sharding=False, tag='dropout' ), tag='dropout' ), params=RngStream( # RngState: 64 (384 B) count=RngCount( # 32 (128 B) value=Array(shape=(32,), dtype=dtype('uint32')), eager_sharding=False, tag='params' ), key=RngKey( # 32 (256 B) value=Array(shape=(32,), dtype=key<fry>), eager_sharding=False, tag='params' ), tag='params' ) ), sliding_window_size=None, use_ragged_attention=False ), attention_type=<AttentionType.GLOBAL: 'global'>, attn_logits_soft_cap=None, compute_axis_order=(0, 1, 2, 3), config=MaxTextConfig(emb_dim=4096, mlp_dim=14336, moe_mlp_dim=-1, num_decoder_layers=32, num_kv_heads=8, num_query_heads=32, num_diloco_replicas=1, ici_parallelism=[1, 1, 1, -1, 1, 1, 1, 1, 1, 1, 1, 1], dcn_parallelism=[1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], using_pipeline_parallelism=False, context_parallel_size=1, num_target_devices=32, global_batch_size_to_train_on=384, global_batch_size_to_eval_on=384, global_batch_size_to_load=384, global_batch_size_to_load_eval=384, micro_batch_size_to_train_on=384, micro_batch_size_to_eval_on=384, checkpoint_dir='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237/pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional/checkpoints/', convert_checkpoint_if_possible=True, metrics_dir='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237/pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional/metrics/', tensorboard_dir='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237/pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional/tensorboard/', managed_mldiagnostics_dir='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237/pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional/managed-mldiagnostics', rampup_end_step=0, tensors_on_device=[], tensors_to_offload=['decoder_layer_input', 'query_proj', 'key_proj', 'value_proj'], global_batch_size_to_load_start=None, global_batch_size_to_load_increment=None, rampup_samples_per_increment_to_load=None, d_model_for_audio=256, encoder_attention_heads_for_audio=4, encoder_ffn_dim_for_audio=512, encoder_layers_for_audio=2, attention_dropout_for_audio=0.0, activation_dropout_for_audio=0.0, activation_function_for_audio='gelu', num_mel_bins_for_audio=128, max_source_positions_for_audio=1500, scale_embedding_for_audio=True, n_window_for_audio=50, n_window_infer_for_audio=800, conv_chunksize_for_audio=500, downsample_hidden_size_for_audio=256, output_dim_for_audio=512, num_conv_layers_for_audio=3, max_timescale_for_audio=10000.0, max_sample_len_for_audio=10000, projector_input_dim_for_vit=4096, projector_output_dim_for_vit=4096, pixel_shuffle_ratio_for_vit=0.5, projector_dropout_for_vit=0.0, hidden_size_for_vit=1408, intermediate_size_for_vit=5632, num_attention_heads_for_vit=16, num_channels_for_vit=3, tile_size_for_vit=336, patch_size_for_vit=14, conv_stride_for_vit=14, num_hidden_layers_for_vit=34, rope_theta_for_vit=10000, vision_output_dim_for_vit=4096, spatial_merge_size_for_vit=2, out_hidden_size_for_vit=512, temporal_patch_size_for_vit=2, num_position_embeddings_for_vit=1024, deepstack_visual_indexes_for_vit=[], vision_output_length=-1, use_multimodal=False, freeze_vision_encoder_params=True, freeze_audio_encoder_params=True, use_audio=False, image_size_for_vit=896, image_path='', image_placeholder='<|image|>', posemb_type_for_vit='learn', max_num_images_per_example=-1, video_path='', audio_path='', video_placeholder='<|video|>', audio_placeholder='<|audio|>', use_audio_in_video=False, use_mrope=False, mrope_section=[24, 20, 20], position_id_per_seconds=25, managed_mldiagnostics=False, managed_mldiagnostics_run_group='', enable_tensorboard=True, use_vertex_tensorboard=False, vertex_tensorboard_project='', vertex_tensorboard_region='', report_heartbeat_metric_for_gcp_monitoring=False, heartbeat_reporting_interval_in_seconds=5, report_performance_metric_for_gcp_monitoring=False, enable_goodput_recording=False, monitor_goodput=False, goodput_upload_interval_seconds=30, enable_pathways_goodput=False, monitor_step_time_deviation=True, step_deviation_interval_seconds=30, enable_gcp_goodput_metrics=True, enable_gcp_step_deviation_metrics=True, metrics_file='', gcs_metrics=False, save_config_to_gcs=False, record_internal_nn_metrics=0, prometheus_port=0, enable_checkpoint_cloud_logger=False, enable_tunix_perf_metrics=False, collect_stack_trace=False, stack_trace_to_cloud=False, stack_trace_interval_seconds=600, dump_hlo=False, dump_step=-1, dump_hlo_local_dir='/tmp/xla_dump/', dump_hlo_delete_local_after=True, dump_hlo_gcs_dir='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237/pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional/xla_dump', dump_hlo_module_name='jit_train_step', dump_hlo_local_module_name='jit_train_step', dump_hlo_xla_flags='--xla_dump_to=/tmp/xla_dump/ --xla_dump_large_constants --xla_dump_hlo_module_re=jit_train_step', dump_hlo_upload_all=False, dump_jaxpr=False, dump_jaxpr_local_dir='/tmp/jaxpr_dump/', dump_jaxpr_delete_local_after=True, dump_jaxpr_gcs_dir='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237/pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional/jaxpr_dump', profiler=<ProfilerType.XPLANE: 'xplane'>, upload_all_profiler_results=False, skip_first_n_steps_for_profiler=1, profiler_steps=5, profile_cleanly=True, profile_periodically_period=-1, hide_profiler_step_metric=False, enable_jax_profiler=False, jax_profiler_port=9999, xprof_tpu_power_trace_level=<XProfTPUPowerTraceMode.POWER_TRACE_NONE: 0>, xprof_e2e_enable_fw_throttle_event=False, xprof_e2e_enable_fw_power_level_event=False, xprof_e2e_enable_fw_thermal_event=False, profile_power_events=False, constant_bound_config=[], jax_cache_dir='~/jax_cache', jax_distributed_initialization_timeout=300, jax_debug_log_modules='', skip_jax_distributed_system=True, enable_single_controller=False, subslice_shape='', max_checkify=False, compiled_trainstep_file='', compile_topology='', compile_topology_num_slices=-1, enable_prefix_caching=False, prefix_caching_hbm_byte=10000000000, prefix_caching_dram_byte=100000000000, inference_microbenchmark_prefill_lengths='64,128,256,512,1024', inference_microbenchmark_stages='prefill,generate', inference_microbenchmark_loop_iters=10, inference_microbenchmark_log_file_path='', inference_microbenchmark_num_samples=[1, 2, 3, 4, 5], inference_metadata_file='', inference_benchmark_test=False, inference_server='MaxtextInterleavedServer', prefill_slice='v5e-16', generate_slice='v5e-16', stack_prefill_result_cache=False, prefill_cache_axis_order='1,2,0,3', ar_cache_axis_order='1,2,0,3', compute_axis_order='0,1,2,3', reshape_q=False, decode_sampling_strategy=<SamplingStrategy.GREEDY: 'greedy'>, decode_sampling_nucleus_p=1.0, decode_sampling_top_k=50, decode_sampling_temperature=0.9, max_target_length=1024, max_prefill_predict_length=256, prompt='I love to', load_from_prefill_dir=False, prefill_cache_dir='', autoregressive_decode_assert='', model_call_mode='', use_chunked_prefill=False, prefill_chunk_size=256, enable_model_warmup=False, enable_llm_inference_pool=False, multi_sampling=False, return_log_prob=False, vocab_size=128256, tokenizer_path='meta-llama/Llama-3.1-8B-Instruct', tokenizer_type=<TokenizerType.HUGGINGFACE: 'huggingface'>, use_chat_template=False, chat_template_path='maxtext/examples/chat_templates/gsm8k_rl.json', chat_template='', tokenize_train_data=True, tokenize_eval_data=True, add_bos=True, add_eos=True, use_truncation=True, num_vocab_tiling=1, grain_train_files='', grain_eval_files='', grain_train_mixture_config_path='', grain_file_type='arrayrecord', grain_worker_count=1, grain_per_worker_buffer_size=1, grain_worker_count_eval=1, grain_per_worker_buffer_size_eval=1, grain_ram_budget_mb=1024, grain_num_threads=16, grain_prefetch_buffer_size=500, grain_num_threads_eval=16, grain_prefetch_buffer_size_eval=500, grain_data_source_max_workers=16, grain_shuffle_buffer_size=100, hf_path='', hf_name='main', hf_data_dir=None, hf_train_files=None, hf_eval_split=None, hf_eval_files=None, hf_access_token='<redacted>', dataset_path='', dataset_name='openai/gsm8k', eval_dataset_name='openai/gsm8k', train_split='train', eval_split='test', dataset_type=<DatasetType.TFDS: 'tfds'>, per_device_batch_size=12.0, eval_per_device_batch_size=12.0, max_corpus_chars=10000000, train_data_columns=['text'], train_image_column='image', eval_data_columns=['text'], eval_image_column='image', packing=True, grain_packing_type='first_fit', max_segments_per_seq=-1, num_epoch=1, expansion_factor_real_data=-1.0, reuse_example_batch=0, generate_padding_batch_train=False, generate_padding_batch_eval=False, enable_rampup_batch_size=False, per_device_batch_size_start=4.0, per_device_batch_size_increment=2.0, global_rampup_samples=500, colocated_python_data_input=False, max_position_embeddings=163840, original_max_position_embeddings=4096, rope_factor=40, beta_fast=32, beta_slow=1, mscale=1.0, rope_interleave=True, rope_truncate=True, rope_attention_scaling=False, rope_type=<RopeType.DEFAULT: 'default'>, rope_use_scale=True, rope_min_timescale=1, rope_max_timescale=500000, rope_linear_scaling_factor=1.0, local_rope_max_timescale=-1, global_rope_max_timescale=-1, global_rope_proportion=0.25, local_rope_proportion=1.0, use_iota_embed=False, use_untrainable_positional_embedding=False, trainable_position_size=-1, nope_layer_interval=-1, reasoning_start_token='<reasoning>', reasoning_end_token='</reasoning>', solution_start_token='<answer>', solution_end_token='</answer>', reward_exact_answer=1.0, reward_exact_format_match=0.1, reward_white_space_format_match=1.0, reward_partial_format_match=0.0, reward_ratio_guess_to_answer_high=0.0, reward_ratio_guess_to_answer_low=0.0, penalty_incorrect_format=0.0, penalty_incorrect_answer=0.0, math_verify_timeout=120, math_verify_num_procs=None, eval_sampling_strategy='greedy', generation_configs={'greedy': {'eval_temperature': 0.01, 'eval_top_k': 1, 'eval_top_p': 1.0}, 'standard': {'eval_temperature': 0.7, 'eval_top_k': 50, 'eval_top_p': 0.95}, 'liberal': {'eval_temperature': 0.85, 'eval_top_k': 2000, 'eval_top_p': 1.0}}, num_eval_passes=1, eval_corr_lst=False, eval_make_lst=False, eval_mode='pass', batch_size=1, num_batches=2, num_test_batches=5, test_batch_start_index=0, train_fraction=1.0, train_micro_batch_size=-1, rollout_micro_batch_size=-1, num_generations=2, num_iterations=1, grpo_beta=0.08, grpo_epsilon=0.2, loss_algo='grpo', use_agentic_rollout=False, max_concurrency=256, off_policy_steps=0, system_prompt='', degenerate_group_masking=True, epsilon_high=None, kv_cache_buffer=256, hbm_utilization_vllm=0.72, swap_space_vllm_gb=2, enable_dp_attention=False, enable_expert_parallel=False, async_scheduling=False, max_num_batched_tokens=None, max_num_seqs=None, stop_strings=None, vllm_additional_config={}, vllm_hf_overrides={}, vllm_hf_config_path='', trainer_devices_fraction=0.5, sampler_devices_fraction=0.5, chips_per_vm=4, use_pathways=False, num_trainer_slices=-1, num_samplers_slices=-1, rollout_data_parallelism=2, rollout_tensor_parallelism=-1, rollout_expert_parallelism=1, student_overrides={}, teacher_overrides={}, offline_data_dir=None, distill_alpha=0.5, distill_temperature=1.0, distill_beta=0.0, distill_feature_loss_type='cosine', distill_layer_indices=None, distill_alpha_end=None, distill_alpha_schedule='constant', distill_temperature_end=None, distill_temperature_schedule='constant', distill_beta_end=None, distill_beta_schedule='constant', student_params_to_update=None, use_dpo=False, dpo_label_smoothing=0.0, dpo_beta=0.1, use_sft=False, sft_train_on_completion_only=False, formatting_func_path='', formatting_func_kwargs={}, use_grpo=True, muon_beta=0.95, muon_weight_decay=0.0, muon_consistent_rms=None, adam_b1=0.9, adam_b2=0.99, adam_eps=1e-08, adam_eps_root=0.0, adam_weight_decay=0.1, adamw_mask=[], mu_dtype=<DType.BFLOAT16: 'bfloat16'>, opt_type=<OptimizerType.ADAMW: 'adamw'>, skip_step_on_spikes=False, skip_step_interval=128, skip_step_scaling_factor=6.0, gradient_accumulation_steps=1, use_tunix_gradient_accumulation=False, gradient_clipping_threshold=0.1, learning_rate=3e-06, lr_schedule_type=<LearningRateScheduleType.COSINE: 'cosine'>, learning_rate_final_fraction=0.1, wsd_decay_steps_fraction=0.1, wsd_decay_style=<WsdDecayStyle.LINEAR: 'linear'>, warmup_steps_fraction=0.1, learning_rate_schedule_steps=150001, trainable_parameters_mask=[], enable_diloco=False, diloco_sync_period=36, diloco_outer_lr=0.3, diloco_outer_momentum=0.9, mhc_expansion_rate=1, sinkhorn_iterations=20, steps=150001, log_period=20, eval_interval=10, eval_steps=-1, target_eval_loss=0.0, abort_on_nan_loss=True, abort_on_inf_loss=True, enable_dropout=False, dropout_rate=0.0, enable_data_shuffling=True, data_shuffle_seed=42, init_weights_seed=0, remat_policy='custom', remat_policy_for_vit='minimal', decoder_layer_input=<RematLocation.OFFLOAD: 'offload'>, context=<RematLocation.REMAT: 'remat'>, mlpwi=<RematLocation.REMAT: 'remat'>, mlpwi_0=<RematLocation.REMAT: 'remat'>, mlpwi_1=<RematLocation.REMAT: 'remat'>, mlpwo=<RematLocation.REMAT: 'remat'>, moe_mlpwi_0=<RematLocation.REMAT: 'remat'>, moe_mlpwi_1=<RematLocation.REMAT: 'remat'>, moe_mlpwo=<RematLocation.REMAT: 'remat'>, query_proj=<RematLocation.OFFLOAD: 'offload'>, key_proj=<RematLocation.OFFLOAD: 'offload'>, value_proj=<RematLocation.OFFLOAD: 'offload'>, query_wa_proj=<RematLocation.REMAT: 'remat'>, kv_wa_proj=<RematLocation.REMAT: 'remat'>, qkv_proj=<RematLocation.REMAT: 'remat'>, out_proj=<RematLocation.REMAT: 'remat'>, mla_q=<RematLocation.REMAT: 'remat'>, mla_kv=<RematLocation.REMAT: 'remat'>, attention_out=<RematLocation.REMAT: 'remat'>, engram=<RematLocation.REMAT: 'remat'>, optimizer_memory_host_offload=False, parameter_memory_host_offload=False, pipeline_fsdp_ag_per_repeat=False, num_layers_per_pipeline_stage=1, num_pipeline_repeats=-1, pipeline_parallel_layers=32, num_pipeline_microbatches=-1, pipeline_delay_activation_forwarding=False, pipeline_fsdp_ag_once=False, scan_pipeline_iterations=True, scan_pipeline_repeats=False, scan_layers_per_stage=False, set_remat_policy_on_pipeline_iterations=True, set_remat_policy_on_layers_per_stage=False, ici_diloco_parallelism=1, ici_data_parallelism=1, ici_fsdp_parallelism=-1, ici_fsdp_transpose_parallelism=1, ici_sequence_parallelism=1, ici_context_parallelism=1, ici_context_autoregressive_parallelism=1, ici_tensor_parallelism=1, ici_tensor_transpose_parallelism=1, ici_tensor_sequence_parallelism=1, ici_autoregressive_parallelism=1, ici_pipeline_parallelism=1, ici_expert_parallelism=1, dcn_diloco_parallelism=1, dcn_data_parallelism=-1, dcn_fsdp_parallelism=1, dcn_fsdp_transpose_parallelism=1, dcn_sequence_parallelism=1, dcn_context_parallelism=1, dcn_context_autoregressive_parallelism=1, dcn_tensor_parallelism=1, dcn_tensor_transpose_parallelism=1, dcn_tensor_sequence_parallelism=1, dcn_pipeline_parallelism=1, dcn_expert_parallelism=1, dcn_autoregressive_parallelism=1, logical_axis_rules=[['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'context', 'expert']], ['activation_vocab', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_vocab', ['tensor', 'tensor_transpose']], ['activation_vocab', 'tensor_sequence'], ['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['embed_vocab', ['fsdp', 'fsdp_transpose', 'context', 'expert']], ['activation_batch_attn', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['activation_kv_heads', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_length_attn', ['context']], ['activation_q_length', ['context']], ['activation_kv_length', []], ['activation_embed_attn', ['tensor', 'tensor_transpose']], ['activation_kv', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_kv_head_dim', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['q_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['kv_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['qkv', []], ['kv', []], ['kv_head_dim', []], ['q_lora', ['fsdp', 'fsdp_transpose', 'context', 'tensor_transpose', 'expert']], ['q_lora', ['fsdp', 'context', 'tensor_transpose', 'expert']], ['q_lora', ['fsdp', 'fsdp_transpose', 'context', 'expert']], ['q_lora', ['fsdp', 'context', 'expert']], ['q_lora_up_proj', []], ['kv_lora', ['fsdp', 'fsdp_transpose', 'context', 'tensor_transpose', 'expert']], ['kv_lora', ['fsdp', 'context', 'tensor_transpose', 'expert']], ['kv_lora', ['fsdp', 'fsdp_transpose', 'context', 'expert']], ['kv_lora', ['fsdp', 'context', 'expert']], ['kv_lora_up_proj', []], ['activation_batch_moe', ['data', 'fsdp', 'fsdp_transpose']], ['activation_length_moe', ['context']], ['activation_norm_length_moe', ['tensor_sequence', 'context']], ['activation_embed_moe', ['tensor', 'tensor_transpose']], ['activation_mlp_moe', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_exp', ['expert']], ['exp', 'expert'], ['mlp_moe', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']], ['embed_moe', ['fsdp', 'fsdp_transpose', 'tensor_transpose', 'context']], ['embed_moe', ['fsdp', 'tensor_transpose', 'context']], ['embed_moe', ['fsdp', 'fsdp_transpose', 'context']], ['embed_moe', ['fsdp', 'context']], ['activation_mlp', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_length', ['context']], ['activation_norm_length', ['tensor_sequence', 'context']], ['activation_embed', ['tensor', 'tensor_transpose']], ['activation_stage', 'stage'], ['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']], ['embed', ['fsdp', 'fsdp_transpose', 'tensor_transpose', 'context', 'expert']], ['embed', ['fsdp', 'tensor_transpose', 'context', 'expert']], ['embed', ['fsdp', 'fsdp_transpose', 'context', 'expert']], ['embed', ['fsdp', 'context', 'expert']], ['norm', ['tensor', 'tensor_transpose']], ['layers', 'stage'], ['diloco', 'diloco'], ['engram_dim', ['tensor']], ['dense_layers', []], ['moe_layers', []], ['mhc', []], ['prefill_activation_length', ['context']], ['prefill_activation_norm_length', ['tensor_sequence', 'context']], ['activation_prefill_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['decode_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['decode_length', []], ['cache_heads', ['autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence']], ['cache_heads', ['autoregressive', 'tensor', 'tensor_sequence']], ['paged_kv_heads', ['tensor']], ['cache_batch_prefill', []], ['cache_batch', []], ['cache_heads_none', []], ['cache_kv', []], ['cache_sequence', []], ['num_pages', []], ['tokens_per_page', []], ['paged_kv_head_dim_size', []], ['mlp_no_fsdp', ['tensor', 'tensor_sequence', 'autoregressive']], ['embed_tensor_transpose', ['tensor_transpose']], ['exp_with_fsdp', 'fsdp']], data_sharding=[['data', 'stage', 'fsdp', 'fsdp_transpose', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']], context_sharding='context', input_data_sharding_logical_axes=['activation_embed_and_logits_batch', 'activation_norm_length'], sharding_tolerance=0.02, shard_optimizer_over_data=False, internal_compile=False, internal_compile_num_devices=-1, compile_xla_flags='', hardware='tpu', num_slices=1, mesh_axes=['diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive'], shard_mode=<ShardMode.AUTO: 'auto'>, inhomogeneous_layer_cycle_interval=1, scan_layers=True, param_scan_axis=1, context_parallel_load_balance=True, context_parallel_strategy='all_gather', context_parallel_reorder_strategy=<ReorderStrategy.AUTO: 'auto'>, custom_mesh='', custom_mesh_and_rule='', allow_split_physical_axes=False, enable_nnx=True, optimize_mesh_for_tpu_v6e=False, shardy=True, pure_nnx_decoder=True, pure_nnx=True, remove_size_one_mesh_axis_from_type=True, gdn_conv_kernel_dim=4, gdn_key_head_dim=128, gdn_value_head_dim=128, gdn_num_key_heads=16, gdn_num_value_heads=32, gdn_chunk_size=64, use_qk_norm_in_gdn=True, partial_rotary_factor=1.0, first_num_dense_layers=0, shared_experts=0, routed_scaling_factor=1.0, routed_score_func='', routed_bias=False, routed_bias_update_rate=0.0, mlp_bias=False, n_routing_groups=-1, topk_routing_group=-1, use_batch_split_schedule=False, batch_split_factor=1, megablox=True, sparse_matmul=True, wi_tile_fwd_batch_seq=512, wi_tile_fwd_embed_dim=1024, wi_tile_fwd_mlp_dim=1024, wi_tile_dlhs_batch_seq=512, wi_tile_dlhs_embed_dim=1024, wi_tile_dlhs_mlp_dim=1024, wi_tile_drhs_batch_seq=512, wi_tile_drhs_embed_dim=1024, wi_tile_drhs_mlp_dim=1024, wo_tile_fwd_batch_seq=512, wo_tile_fwd_embed_dim=1024, wo_tile_fwd_mlp_dim=1024, wo_tile_dlhs_batch_seq=512, wo_tile_dlhs_embed_dim=1024, wo_tile_dlhs_mlp_dim=1024, wo_tile_drhs_batch_seq=512, wo_tile_drhs_embed_dim=1024, wo_tile_drhs_mlp_dim=1024, merge_gating_gmm=False, num_experts=1, num_experts_per_tok=1, capacity_factor=-1.0, ragged_buffer_factor=-1.0, moe_expert_input_dim=-1, base_moe_mlp_dim=-1, load_balance_loss_weight=0.0, use_custom_sort_vjp=True, use_ring_of_experts=False, use_gather_mosaic_kernel=False, use_random_routing=False, interleave_moe_layer_step=1, moe_fsdp_use_two_stage_all_gather=False, shard_exp_on_fsdp=False, use_2d_fsdp_sharding=False, norm_topk_prob=False, float32_weight_sum=True, float32_gate_logits=False, prefuse_moe_weights=False, pagedattn_num_pages=64, pagedattn_tokens_per_page=32, pagedattn_pages_per_compute_block=4, pagedattn_max_pages_per_group=-1, pagedattn_head_dim_alignment=128, sa_block_q=512, sa_block_kv=512, sa_block_kv_compute=512, sa_block_q_dkv=512, sa_block_kv_dkv=512, sa_block_kv_dkv_compute=512, sa_block_q_dq=512, sa_block_kv_dq=512, sa_use_fused_bwd_kernel=False, sa_q_layout='HEAD_DIM_MINOR', sa_k_layout='HEAD_DIM_MINOR', sa_v_layout='HEAD_DIM_MINOR', use_max_logit_estimate=-1, cost_estimate_flops_fwd=-1, cost_estimate_flops_bwd=-1, dq_reduction_steps=0, use_splash_scheduler=False, use_qk_norm=False, temperature_tuning=False, use_indexer=False, indexer_head_dim=128, indexer_n_heads=64, indexer_topk=2048, indexer_sparse_training=False, indexer_loss_scaling_factor=0.0, moba=False, moba_chunk_size=1024, moba_topk=8, mla_naive_kvcache=True, q_lora_rank=0, kv_lora_rank=512, qk_nope_head_dim=128, qk_rope_head_dim=64, v_head_dim=128, attention='dot_product', attention_type='global', share_kv_projections=False, global_num_kv_heads=0, attention_sink=False, float32_qk_product=False, float32_logits=False, sliding_window_size=0, chunk_attn_window_size=0, attn_logits_soft_cap=None, use_post_attn_norm=False, use_post_ffw_norm=False, use_ragged_attention=False, use_tokamax_gmm=False, ragged_block_size=256, enable_padding_causal_mask=True, use_tokamax_splash=False, use_jax_splash=False, force_q_layout=False, use_qk_clip=False, qk_clip_threshold=100.0, logits_via_embedding=False, normalize_embedding_logits=True, logits_dot_in_fp32=False, cast_logits_to_fp32=True, final_logits_soft_cap=None, z_loss_multiplier=0.0, mtp_num_layers=0, mtp_loss_scaling_factor=0.1, mtp_eval_target_module=0, engram_layers=[], engram_num_heads=8, engram_head_dim=1280, engram_vocab_bases=[], engram_max_ngram_size=3, engram_kernel_size=4, engram_seed=0, decoder_block=<DecoderBlockType.LLAMA2: 'llama2'>, global_parameter_scale=1, base_emb_dim=4096, base_num_query_heads=32, base_num_kv_heads=8, base_mlp_dim=14336, dense_init_scale=1.0, base_num_decoder_layers=32, head_dim=128, attention_output_dim=-1, global_head_dim=0, mlp_activations=['silu', 'linear'], mlp_activations_limit=-1.0, normalization_layer_epsilon=1e-05, fused_qkv=False, attention_bias=False, fused_mlp=False, qk_norm_with_scale=True, v_norm_with_scale=True, quantization=<QuantizationType.NONE: ''>, replicate_quant_scale=False, quant_cfg_path='', quantize_kvcache=False, kv_quant_axis=<KvQuantAxis.HEADS_AND_DKV: 'heads_and_dkv'>, kv_quant_dtype='int8', quantization_local_shard_count=4, use_qwix_quantization=False, weight_quantization_calibration_method='absmax', act_quantization_calibration_method='absmax', bwd_quantization_calibration_method='absmax', dtype=<DType.BFLOAT16: 'bfloat16'>, grad_dtype=<DType.FLOAT32: 'float32'>, weight_dtype=<DType.BFLOAT16: 'bfloat16'>, matmul_precision=<MatmulPrecision.DEFAULT: 'default'>, activations_in_float32=False, dtype_mm='float32', elastic_enabled=False, elastic_timeout_seconds=300, elastic_max_retries=10, enable_multi_tier_checkpointing=False, local_checkpoint_directory='', local_checkpoint_period=0, multi_tier_checkpointing_backup_interval_minutes=0, mtc_data_parallelism=0, enable_emergency_checkpoint=False, use_replicator_service=False, replicator_backup_interval_minutes=0, checkpoint_storage_target_data_file_size_bytes=2147483648, checkpoint_storage_use_ocdbt=False, checkpoint_storage_use_zarr3=False, checkpoint_storage_concurrent_gb=96, load_parameters_path='gs://lance-maxtext/rl_ckpt_llama31_8b/rl_ckpt_llama31_8b/0/items', lora_input_adapters_path='', load_full_state_path='', enable_checkpointing=True, load_checkpoint_only_once=False, async_checkpointing=False, checkpoint_period=50, max_num_checkpoints_to_keep=10, enable_single_replica_ckpt_restoring=False, checkpoint_todelete_subdir=None, checkpoint_todelete_full_path=None, force_unroll=False, checkpoint_is_quantized=False, save_quantized_params_path='', enable_orbax_v1=False, checkpoint_conversion_fn=None, source_checkpoint_layout='orbax', save_checkpoint_on_completion=True, enable_continuous_checkpointing=False, colocated_python_checkpointing=False, enable_autocheckpoint=False, base_config='../base.yml', run_name='pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional', model_name='llama3.1-8b-Instruct', override_model_config=False, override_logical_axis_rules=False, log_config=False, debug_sharding=False, base_output_directory='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237', sharding_strategy=None, debug=Debug(rl=True), rl=RL(num_generations=2, num_iterations=1, grpo_beta=0.08, grpo_epsilon=0.2, loss_algo='grpo', use_agentic_rollout=False, max_concurrency=256, off_policy_steps=0, system_prompt='', degenerate_group_masking=True, epsilon_high=None)), decode_input_axis_names=('decode_batch', 'decode_length', 'activation_embed_attn'), decode_out_axis_names=('decode_batch', 'decode_length', 'activation_heads', 'activation_kv'), dropout_rate=0.0, dtype=<DType.BFLOAT16: 'bfloat16'>, float32_logits=False, float32_qk_product=False, head_dim=128, input_axis_names=('activation_batch_attn', 'activation_length_attn', 'activation_embed_attn'), is_nope_layer=False, is_qwen2=False, is_qwen3_next=False, is_vision=False, kernel_init=<function nd_dense_init.<locals>.init_fn at 0x7db8f6039940>, key=DenseGeneral( # Param: 134,217,728 (268.4 MB) axis=(-1,), bias=None, dtype=<DType.BFLOAT16: 'bfloat16'>, in_features_shape=(4096,), kernel=Param( # 134,217,728 (268.4 MB) value=Array(shape=(4096, 32, 8, 128), dtype=dtype(bfloat16)), eager_sharding=False, out_sharding=('embed', 'layers', 'kv_heads', 'kv_head_dim') ), kernel_axes=('embed', 'kv_heads', 'kv_head_dim'), kernel_init=<function nd_dense_init.<locals>.init_fn at 0x7db8f6039940>, matmul_precision=<MatmulPrecision.DEFAULT: 'default'>, out_features_shape=(8, 128), parameter_memory_host_offload=False, quant=None, shard_mode=<ShardMode.AUTO: 'auto'>, use_bias=False, weight_dtype=<DType.BFLOAT16: 'bfloat16'> ), key_axis_names=('activation_kv_batch', 'activation_length_attn', 'activation_kv_heads', 'activation_kv_head_dim'), key_norm=None, kv_quant=None, max_prefill_predict_length=256, max_target_length=1024, mesh=Mesh(axis_sizes=(1, 1, 1, 32, 1, 1, 1, 1, 1, 1, 1, 1), axis_names=('diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive'), axis_types=(Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto)), model_mode='train', mrope_section=None, num_kv_heads=8, num_query_heads=32, out=DenseGeneral( # Param: 536,870,912 (1.1 GB) axis=(-2, -1), bias=None, dtype=<DType.BFLOAT16: 'bfloat16'>, in_features_shape=(32, 128), kernel=Param( # 536,870,912 (1.1 GB) value=Array(shape=(32, 32, 128, 4096), dtype=dtype(bfloat16)), eager_sharding=False, out_sharding=('heads', 'layers', 'kv', 'embed') ), kernel_axes=('heads', 'kv', 'embed'), kernel_init=<function nd_dense_init.<locals>.init_fn at 0x7db8f6039940>, matmul_precision=<MatmulPrecision.DEFAULT: 'default'>, out_features_shape=(4096,), parameter_memory_host_offload=False, quant=None, shard_mode=<ShardMode.AUTO: 'auto'>, use_bias=False, weight_dtype=<DType.BFLOAT16: 'bfloat16'> ), out_axis_names=('activation_batch_attn', 'activation_length_attn', 'activation_heads', 'activation_kv'), partial_rotary_factor=None, prefill_cache_axis_order=(1, 2, 0, 3), prefill_input_axis_names=('activation_prefill_kv_batch', 'prefill_activation_length', 'activation_embed_attn'), prefill_key_axis_names=('activation_prefill_kv_batch', 'prefill_activation_length', 'activation_kv_heads', 'activation_kv_head_dim'), prefill_out_axis_names=('activation_prefill_kv_batch', 'prefill_activation_length', 'activation_heads', 'activation_kv'), prefill_query_axis_names=('activation_prefill_kv_batch', 'prefill_activation_length', 'activation_kv_heads', 'activation_kv_head_dim'), prefill_value_axis_names=('activation_prefill_kv_batch', 'prefill_activation_length', 'activation_kv_heads', 'activation_kv_head_dim'), quant=None, query=DenseGeneral( # Param: 536,870,912 (1.1 GB) axis=(-1,), bias=None, dtype=<DType.BFLOAT16: 'bfloat16'>, in_features_shape=(4096,), kernel=Param( # 536,870,912 (1.1 GB) value=Array(shape=(4096, 32, 32, 128), dtype=dtype(bfloat16)), eager_sharding=False, out_sharding=('embed', 'layers', 'q_heads', 'kv') ), kernel_axes=('embed', 'q_heads', 'kv'), kernel_init=<function Attention.init_query_w.<locals>.query_init at 0x7da1d0016020>, matmul_precision=<MatmulPrecision.DEFAULT: 'default'>, out_features_shape=(32, 128), parameter_memory_host_offload=False, quant=None, shard_mode=<ShardMode.AUTO: 'auto'>, use_bias=False, weight_dtype=<DType.BFLOAT16: 'bfloat16'> ), query_axis_names=('activation_kv_batch', 'activation_length_attn', 'activation_kv_heads', 'activation_kv_head_dim'), query_norm=None, query_pre_attn_scalar=None, ragged_block_size=256, reshape_q=False, rngs=Rngs(...), rope_max_timescale=500000, rope_type='default', rotary_embedding=LLaMARotaryEmbedding( cast_as_fprop_dtype=True, embedding_dims=128, fprop_dtype=<DType.BFLOAT16: 'bfloat16'>, max_timescale=500000, mesh=Mesh(axis_sizes=(1, 1, 1, 32, 1, 1, 1, 1, 1, 1, 1, 1), axis_names=('diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive'), axis_types=(Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto)), min_timescale=1, rope_linear_scaling_factor=1.0, shard_mode=<ShardMode.AUTO: 'auto'>, use_scale=True ), share_kv_projections=False, sinks=None, sliding_window_size=None, temperature_tuning=False, temperature_tuning_floor_scale=8192.0, temperature_tuning_scale=0.1, use_bias_in_projections=False, use_mrope=False, use_qk_norm=False, use_ragged_attention=False, use_v_norm=False, value=DenseGeneral( # Param: 134,217,728 (268.4 MB) axis=(-1,), bias=None, dtype=<DType.BFLOAT16: 'bfloat16'>, in_features_shape=(4096,), kernel=Param( # 134,217,728 (268.4 MB) value=Array(shape=(4096, 32, 8, 128), dtype=dtype(bfloat16)), eager_sharding=False, out_sharding=('embed', 'layers', 'kv_heads', 'kv_head_dim') ), kernel_axes=('embed', 'kv_heads', 'kv_head_dim'), kernel_init=<function nd_dense_init.<locals>.init_fn at 0x7db8f6039940>, matmul_precision=<MatmulPrecision.DEFAULT: 'default'>, out_features_shape=(8, 128), parameter_memory_host_offload=False, quant=None, shard_mode=<ShardMode.AUTO: 'auto'>, use_bias=False, weight_dtype=<DType.BFLOAT16: 'bfloat16'> ), value_axis_names=('activation_kv_batch', 'activation_length_attn', 'activation_kv_heads', 'activation_kv_head_dim'), value_norm=None, weight_dtype=<DType.BFLOAT16: 'bfloat16'> ) ), logits_dense=DenseGeneral( # Param: 525,336,576 (1.1 GB) axis=(-1,), bias=None, dtype=<DType.BFLOAT16: 'bfloat16'>, in_features_shape=(4096,), kernel=Param( # 525,336,576 (1.1 GB) value=Array(shape=(4096, 128256), dtype=dtype(bfloat16)), eager_sharding=False, out_sharding=('embed_vocab', 'vocab') ), kernel_axes=('embed_vocab', 'vocab'), kernel_init=<function nd_dense_init.<locals>.init_fn at 0x7db8f65d58a0>, matmul_precision=<MatmulPrecision.DEFAULT: 'default'>, out_features_shape=(128256,), parameter_memory_host_offload=False, quant=None, shard_mode=<ShardMode.AUTO: 'auto'>, use_bias=False, weight_dtype=<DType.BFLOAT16: 'bfloat16'> ), mesh=Mesh(axis_sizes=(1, 1, 1, 32, 1, 1, 1, 1, 1, 1, 1, 1), axis_names=('diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive'), axis_types=(Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto)), model_mode='train', positional_embedding=PositionalEmbedding( cast_as_fprop_dtype=False, embedding_dims=4096, fprop_dtype=bfloat16, max_wavelength=10000, rngs=None ), quant=None, rngs=Rngs( # RngState: 6 (36 B) aqt=RngStream( # RngState: 2 (12 B) count=RngCount( # 1 (4 B) value=Array(2, dtype=uint32), eager_sharding=False, tag='aqt' ), key=RngKey( # 1 (8 B) value=Array((), dtype=key<fry>) overlaying: [4146024105 2718843009], eager_sharding=False, tag='aqt' ), tag='aqt' ), dropout=RngStream( # RngState: 2 (12 B) count=RngCount( # 1 (4 B) value=Array(2, dtype=uint32), eager_sharding=False, tag='dropout' ), key=RngKey( # 1 (8 B) value=Array((), dtype=key<fry>) overlaying: [ 928981903 3453687069], eager_sharding=False, tag='dropout' ), tag='dropout' ), params=RngStream( # RngState: 2 (12 B) count=RngCount( # 1 (4 B) value=Array(5, dtype=uint32), eager_sharding=False, tag='params' ), key=RngKey( # 1 (8 B) value=Array((), dtype=key<fry>) overlaying: [1797259609 2579123966], eager_sharding=False, tag='params' ), tag='params' ) ), scanned_layers=None ), hidden_states=None, mesh=Mesh(axis_sizes=(1, 1, 1, 32, 1, 1, 1, 1, 1, 1, 1, 1), axis_names=('diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive'), axis_types=(Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto)), model_mode='train', quant=None, token_embedder=Embed( # Param: 525,336,576 (1.1 GB) attend_dtype=<DType.BFLOAT16: 'bfloat16'>, cast_input_dtype=None, config=MaxTextConfig(emb_dim=4096, mlp_dim=14336, moe_mlp_dim=-1, num_decoder_layers=32, num_kv_heads=8, num_query_heads=32, num_diloco_replicas=1, ici_parallelism=[1, 1, 1, -1, 1, 1, 1, 1, 1, 1, 1, 1], dcn_parallelism=[1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], using_pipeline_parallelism=False, context_parallel_size=1, num_target_devices=32, global_batch_size_to_train_on=384, global_batch_size_to_eval_on=384, global_batch_size_to_load=384, global_batch_size_to_load_eval=384, micro_batch_size_to_train_on=384, micro_batch_size_to_eval_on=384, checkpoint_dir='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237/pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional/checkpoints/', convert_checkpoint_if_possible=True, metrics_dir='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237/pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional/metrics/', tensorboard_dir='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237/pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional/tensorboard/', managed_mldiagnostics_dir='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237/pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional/managed-mldiagnostics', rampup_end_step=0, tensors_on_device=[], tensors_to_offload=['decoder_layer_input', 'query_proj', 'key_proj', 'value_proj'], global_batch_size_to_load_start=None, global_batch_size_to_load_increment=None, rampup_samples_per_increment_to_load=None, d_model_for_audio=256, encoder_attention_heads_for_audio=4, encoder_ffn_dim_for_audio=512, encoder_layers_for_audio=2, attention_dropout_for_audio=0.0, activation_dropout_for_audio=0.0, activation_function_for_audio='gelu', num_mel_bins_for_audio=128, max_source_positions_for_audio=1500, scale_embedding_for_audio=True, n_window_for_audio=50, n_window_infer_for_audio=800, conv_chunksize_for_audio=500, downsample_hidden_size_for_audio=256, output_dim_for_audio=512, num_conv_layers_for_audio=3, max_timescale_for_audio=10000.0, max_sample_len_for_audio=10000, projector_input_dim_for_vit=4096, projector_output_dim_for_vit=4096, pixel_shuffle_ratio_for_vit=0.5, projector_dropout_for_vit=0.0, hidden_size_for_vit=1408, intermediate_size_for_vit=5632, num_attention_heads_for_vit=16, num_channels_for_vit=3, tile_size_for_vit=336, patch_size_for_vit=14, conv_stride_for_vit=14, num_hidden_layers_for_vit=34, rope_theta_for_vit=10000, vision_output_dim_for_vit=4096, spatial_merge_size_for_vit=2, out_hidden_size_for_vit=512, temporal_patch_size_for_vit=2, num_position_embeddings_for_vit=1024, deepstack_visual_indexes_for_vit=[], vision_output_length=-1, use_multimodal=False, freeze_vision_encoder_params=True, freeze_audio_encoder_params=True, use_audio=False, image_size_for_vit=896, image_path='', image_placeholder='<|image|>', posemb_type_for_vit='learn', max_num_images_per_example=-1, video_path='', audio_path='', video_placeholder='<|video|>', audio_placeholder='<|audio|>', use_audio_in_video=False, use_mrope=False, mrope_section=[24, 20, 20], position_id_per_seconds=25, managed_mldiagnostics=False, managed_mldiagnostics_run_group='', enable_tensorboard=True, use_vertex_tensorboard=False, vertex_tensorboard_project='', vertex_tensorboard_region='', report_heartbeat_metric_for_gcp_monitoring=False, heartbeat_reporting_interval_in_seconds=5, report_performance_metric_for_gcp_monitoring=False, enable_goodput_recording=False, monitor_goodput=False, goodput_upload_interval_seconds=30, enable_pathways_goodput=False, monitor_step_time_deviation=True, step_deviation_interval_seconds=30, enable_gcp_goodput_metrics=True, enable_gcp_step_deviation_metrics=True, metrics_file='', gcs_metrics=False, save_config_to_gcs=False, record_internal_nn_metrics=0, prometheus_port=0, enable_checkpoint_cloud_logger=False, enable_tunix_perf_metrics=False, collect_stack_trace=False, stack_trace_to_cloud=False, stack_trace_interval_seconds=600, dump_hlo=False, dump_step=-1, dump_hlo_local_dir='/tmp/xla_dump/', dump_hlo_delete_local_after=True, dump_hlo_gcs_dir='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237/pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional/xla_dump', dump_hlo_module_name='jit_train_step', dump_hlo_local_module_name='jit_train_step', dump_hlo_xla_flags='--xla_dump_to=/tmp/xla_dump/ --xla_dump_large_constants --xla_dump_hlo_module_re=jit_train_step', dump_hlo_upload_all=False, dump_jaxpr=False, dump_jaxpr_local_dir='/tmp/jaxpr_dump/', dump_jaxpr_delete_local_after=True, dump_jaxpr_gcs_dir='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237/pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional/jaxpr_dump', profiler=<ProfilerType.XPLANE: 'xplane'>, upload_all_profiler_results=False, skip_first_n_steps_for_profiler=1, profiler_steps=5, profile_cleanly=True, profile_periodically_period=-1, hide_profiler_step_metric=False, enable_jax_profiler=False, jax_profiler_port=9999, xprof_tpu_power_trace_level=<XProfTPUPowerTraceMode.POWER_TRACE_NONE: 0>, xprof_e2e_enable_fw_throttle_event=False, xprof_e2e_enable_fw_power_level_event=False, xprof_e2e_enable_fw_thermal_event=False, profile_power_events=False, constant_bound_config=[], jax_cache_dir='~/jax_cache', jax_distributed_initialization_timeout=300, jax_debug_log_modules='', skip_jax_distributed_system=True, enable_single_controller=False, subslice_shape='', max_checkify=False, compiled_trainstep_file='', compile_topology='', compile_topology_num_slices=-1, enable_prefix_caching=False, prefix_caching_hbm_byte=10000000000, prefix_caching_dram_byte=100000000000, inference_microbenchmark_prefill_lengths='64,128,256,512,1024', inference_microbenchmark_stages='prefill,generate', inference_microbenchmark_loop_iters=10, inference_microbenchmark_log_file_path='', inference_microbenchmark_num_samples=[1, 2, 3, 4, 5], inference_metadata_file='', inference_benchmark_test=False, inference_server='MaxtextInterleavedServer', prefill_slice='v5e-16', generate_slice='v5e-16', stack_prefill_result_cache=False, prefill_cache_axis_order='1,2,0,3', ar_cache_axis_order='1,2,0,3', compute_axis_order='0,1,2,3', reshape_q=False, decode_sampling_strategy=<SamplingStrategy.GREEDY: 'greedy'>, decode_sampling_nucleus_p=1.0, decode_sampling_top_k=50, decode_sampling_temperature=0.9, max_target_length=1024, max_prefill_predict_length=256, prompt='I love to', load_from_prefill_dir=False, prefill_cache_dir='', autoregressive_decode_assert='', model_call_mode='', use_chunked_prefill=False, prefill_chunk_size=256, enable_model_warmup=False, enable_llm_inference_pool=False, multi_sampling=False, return_log_prob=False, vocab_size=128256, tokenizer_path='meta-llama/Llama-3.1-8B-Instruct', tokenizer_type=<TokenizerType.HUGGINGFACE: 'huggingface'>, use_chat_template=False, chat_template_path='maxtext/examples/chat_templates/gsm8k_rl.json', chat_template='', tokenize_train_data=True, tokenize_eval_data=True, add_bos=True, add_eos=True, use_truncation=True, num_vocab_tiling=1, grain_train_files='', grain_eval_files='', grain_train_mixture_config_path='', grain_file_type='arrayrecord', grain_worker_count=1, grain_per_worker_buffer_size=1, grain_worker_count_eval=1, grain_per_worker_buffer_size_eval=1, grain_ram_budget_mb=1024, grain_num_threads=16, grain_prefetch_buffer_size=500, grain_num_threads_eval=16, grain_prefetch_buffer_size_eval=500, grain_data_source_max_workers=16, grain_shuffle_buffer_size=100, hf_path='', hf_name='main', hf_data_dir=None, hf_train_files=None, hf_eval_split=None, hf_eval_files=None, hf_access_token='<redacted>', dataset_path='', dataset_name='openai/gsm8k', eval_dataset_name='openai/gsm8k', train_split='train', eval_split='test', dataset_type=<DatasetType.TFDS: 'tfds'>, per_device_batch_size=12.0, eval_per_device_batch_size=12.0, max_corpus_chars=10000000, train_data_columns=['text'], train_image_column='image', eval_data_columns=['text'], eval_image_column='image', packing=True, grain_packing_type='first_fit', max_segments_per_seq=-1, num_epoch=1, expansion_factor_real_data=-1.0, reuse_example_batch=0, generate_padding_batch_train=False, generate_padding_batch_eval=False, enable_rampup_batch_size=False, per_device_batch_size_start=4.0, per_device_batch_size_increment=2.0, global_rampup_samples=500, colocated_python_data_input=False, max_position_embeddings=163840, original_max_position_embeddings=4096, rope_factor=40, beta_fast=32, beta_slow=1, mscale=1.0, rope_interleave=True, rope_truncate=True, rope_attention_scaling=False, rope_type=<RopeType.DEFAULT: 'default'>, rope_use_scale=True, rope_min_timescale=1, rope_max_timescale=500000, rope_linear_scaling_factor=1.0, local_rope_max_timescale=-1, global_rope_max_timescale=-1, global_rope_proportion=0.25, local_rope_proportion=1.0, use_iota_embed=False, use_untrainable_positional_embedding=False, trainable_position_size=-1, nope_layer_interval=-1, reasoning_start_token='<reasoning>', reasoning_end_token='</reasoning>', solution_start_token='<answer>', solution_end_token='</answer>', reward_exact_answer=1.0, reward_exact_format_match=0.1, reward_white_space_format_match=1.0, reward_partial_format_match=0.0, reward_ratio_guess_to_answer_high=0.0, reward_ratio_guess_to_answer_low=0.0, penalty_incorrect_format=0.0, penalty_incorrect_answer=0.0, math_verify_timeout=120, math_verify_num_procs=None, eval_sampling_strategy='greedy', generation_configs={'greedy': {'eval_temperature': 0.01, 'eval_top_k': 1, 'eval_top_p': 1.0}, 'standard': {'eval_temperature': 0.7, 'eval_top_k': 50, 'eval_top_p': 0.95}, 'liberal': {'eval_temperature': 0.85, 'eval_top_k': 2000, 'eval_top_p': 1.0}}, num_eval_passes=1, eval_corr_lst=False, eval_make_lst=False, eval_mode='pass', batch_size=1, num_batches=2, num_test_batches=5, test_batch_start_index=0, train_fraction=1.0, train_micro_batch_size=-1, rollout_micro_batch_size=-1, num_generations=2, num_iterations=1, grpo_beta=0.08, grpo_epsilon=0.2, loss_algo='grpo', use_agentic_rollout=False, max_concurrency=256, off_policy_steps=0, system_prompt='', degenerate_group_masking=True, epsilon_high=None, kv_cache_buffer=256, hbm_utilization_vllm=0.72, swap_space_vllm_gb=2, enable_dp_attention=False, enable_expert_parallel=False, async_scheduling=False, max_num_batched_tokens=None, max_num_seqs=None, stop_strings=None, vllm_additional_config={}, vllm_hf_overrides={}, vllm_hf_config_path='', trainer_devices_fraction=0.5, sampler_devices_fraction=0.5, chips_per_vm=4, use_pathways=False, num_trainer_slices=-1, num_samplers_slices=-1, rollout_data_parallelism=2, rollout_tensor_parallelism=-1, rollout_expert_parallelism=1, student_overrides={}, teacher_overrides={}, offline_data_dir=None, distill_alpha=0.5, distill_temperature=1.0, distill_beta=0.0, distill_feature_loss_type='cosine', distill_layer_indices=None, distill_alpha_end=None, distill_alpha_schedule='constant', distill_temperature_end=None, distill_temperature_schedule='constant', distill_beta_end=None, distill_beta_schedule='constant', student_params_to_update=None, use_dpo=False, dpo_label_smoothing=0.0, dpo_beta=0.1, use_sft=False, sft_train_on_completion_only=False, formatting_func_path='', formatting_func_kwargs={}, use_grpo=True, muon_beta=0.95, muon_weight_decay=0.0, muon_consistent_rms=None, adam_b1=0.9, adam_b2=0.99, adam_eps=1e-08, adam_eps_root=0.0, adam_weight_decay=0.1, adamw_mask=[], mu_dtype=<DType.BFLOAT16: 'bfloat16'>, opt_type=<OptimizerType.ADAMW: 'adamw'>, skip_step_on_spikes=False, skip_step_interval=128, skip_step_scaling_factor=6.0, gradient_accumulation_steps=1, use_tunix_gradient_accumulation=False, gradient_clipping_threshold=0.1, learning_rate=3e-06, lr_schedule_type=<LearningRateScheduleType.COSINE: 'cosine'>, learning_rate_final_fraction=0.1, wsd_decay_steps_fraction=0.1, wsd_decay_style=<WsdDecayStyle.LINEAR: 'linear'>, warmup_steps_fraction=0.1, learning_rate_schedule_steps=150001, trainable_parameters_mask=[], enable_diloco=False, diloco_sync_period=36, diloco_outer_lr=0.3, diloco_outer_momentum=0.9, mhc_expansion_rate=1, sinkhorn_iterations=20, steps=150001, log_period=20, eval_interval=10, eval_steps=-1, target_eval_loss=0.0, abort_on_nan_loss=True, abort_on_inf_loss=True, enable_dropout=False, dropout_rate=0.0, enable_data_shuffling=True, data_shuffle_seed=42, init_weights_seed=0, remat_policy='custom', remat_policy_for_vit='minimal', decoder_layer_input=<RematLocation.OFFLOAD: 'offload'>, context=<RematLocation.REMAT: 'remat'>, mlpwi=<RematLocation.REMAT: 'remat'>, mlpwi_0=<RematLocation.REMAT: 'remat'>, mlpwi_1=<RematLocation.REMAT: 'remat'>, mlpwo=<RematLocation.REMAT: 'remat'>, moe_mlpwi_0=<RematLocation.REMAT: 'remat'>, moe_mlpwi_1=<RematLocation.REMAT: 'remat'>, moe_mlpwo=<RematLocation.REMAT: 'remat'>, query_proj=<RematLocation.OFFLOAD: 'offload'>, key_proj=<RematLocation.OFFLOAD: 'offload'>, value_proj=<RematLocation.OFFLOAD: 'offload'>, query_wa_proj=<RematLocation.REMAT: 'remat'>, kv_wa_proj=<RematLocation.REMAT: 'remat'>, qkv_proj=<RematLocation.REMAT: 'remat'>, out_proj=<RematLocation.REMAT: 'remat'>, mla_q=<RematLocation.REMAT: 'remat'>, mla_kv=<RematLocation.REMAT: 'remat'>, attention_out=<RematLocation.REMAT: 'remat'>, engram=<RematLocation.REMAT: 'remat'>, optimizer_memory_host_offload=False, parameter_memory_host_offload=False, pipeline_fsdp_ag_per_repeat=False, num_layers_per_pipeline_stage=1, num_pipeline_repeats=-1, pipeline_parallel_layers=32, num_pipeline_microbatches=-1, pipeline_delay_activation_forwarding=False, pipeline_fsdp_ag_once=False, scan_pipeline_iterations=True, scan_pipeline_repeats=False, scan_layers_per_stage=False, set_remat_policy_on_pipeline_iterations=True, set_remat_policy_on_layers_per_stage=False, ici_diloco_parallelism=1, ici_data_parallelism=1, ici_fsdp_parallelism=-1, ici_fsdp_transpose_parallelism=1, ici_sequence_parallelism=1, ici_context_parallelism=1, ici_context_autoregressive_parallelism=1, ici_tensor_parallelism=1, ici_tensor_transpose_parallelism=1, ici_tensor_sequence_parallelism=1, ici_autoregressive_parallelism=1, ici_pipeline_parallelism=1, ici_expert_parallelism=1, dcn_diloco_parallelism=1, dcn_data_parallelism=-1, dcn_fsdp_parallelism=1, dcn_fsdp_transpose_parallelism=1, dcn_sequence_parallelism=1, dcn_context_parallelism=1, dcn_context_autoregressive_parallelism=1, dcn_tensor_parallelism=1, dcn_tensor_transpose_parallelism=1, dcn_tensor_sequence_parallelism=1, dcn_pipeline_parallelism=1, dcn_expert_parallelism=1, dcn_autoregressive_parallelism=1, logical_axis_rules=[['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'context', 'expert']], ['activation_vocab', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_vocab', ['tensor', 'tensor_transpose']], ['activation_vocab', 'tensor_sequence'], ['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['embed_vocab', ['fsdp', 'fsdp_transpose', 'context', 'expert']], ['activation_batch_attn', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['activation_kv_heads', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_length_attn', ['context']], ['activation_q_length', ['context']], ['activation_kv_length', []], ['activation_embed_attn', ['tensor', 'tensor_transpose']], ['activation_kv', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_kv_head_dim', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['q_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['kv_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['qkv', []], ['kv', []], ['kv_head_dim', []], ['q_lora', ['fsdp', 'fsdp_transpose', 'context', 'tensor_transpose', 'expert']], ['q_lora', ['fsdp', 'context', 'tensor_transpose', 'expert']], ['q_lora', ['fsdp', 'fsdp_transpose', 'context', 'expert']], ['q_lora', ['fsdp', 'context', 'expert']], ['q_lora_up_proj', []], ['kv_lora', ['fsdp', 'fsdp_transpose', 'context', 'tensor_transpose', 'expert']], ['kv_lora', ['fsdp', 'context', 'tensor_transpose', 'expert']], ['kv_lora', ['fsdp', 'fsdp_transpose', 'context', 'expert']], ['kv_lora', ['fsdp', 'context', 'expert']], ['kv_lora_up_proj', []], ['activation_batch_moe', ['data', 'fsdp', 'fsdp_transpose']], ['activation_length_moe', ['context']], ['activation_norm_length_moe', ['tensor_sequence', 'context']], ['activation_embed_moe', ['tensor', 'tensor_transpose']], ['activation_mlp_moe', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_exp', ['expert']], ['exp', 'expert'], ['mlp_moe', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']], ['embed_moe', ['fsdp', 'fsdp_transpose', 'tensor_transpose', 'context']], ['embed_moe', ['fsdp', 'tensor_transpose', 'context']], ['embed_moe', ['fsdp', 'fsdp_transpose', 'context']], ['embed_moe', ['fsdp', 'context']], ['activation_mlp', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_length', ['context']], ['activation_norm_length', ['tensor_sequence', 'context']], ['activation_embed', ['tensor', 'tensor_transpose']], ['activation_stage', 'stage'], ['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']], ['embed', ['fsdp', 'fsdp_transpose', 'tensor_transpose', 'context', 'expert']], ['embed', ['fsdp', 'tensor_transpose', 'context', 'expert']], ['embed', ['fsdp', 'fsdp_transpose', 'context', 'expert']], ['embed', ['fsdp', 'context', 'expert']], ['norm', ['tensor', 'tensor_transpose']], ['layers', 'stage'], ['diloco', 'diloco'], ['engram_dim', ['tensor']], ['dense_layers', []], ['moe_layers', []], ['mhc', []], ['prefill_activation_length', ['context']], ['prefill_activation_norm_length', ['tensor_sequence', 'context']], ['activation_prefill_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['decode_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['decode_length', []], ['cache_heads', ['autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence']], ['cache_heads', ['autoregressive', 'tensor', 'tensor_sequence']], ['paged_kv_heads', ['tensor']], ['cache_batch_prefill', []], ['cache_batch', []], ['cache_heads_none', []], ['cache_kv', []], ['cache_sequence', []], ['num_pages', []], ['tokens_per_page', []], ['paged_kv_head_dim_size', []], ['mlp_no_fsdp', ['tensor', 'tensor_sequence', 'autoregressive']], ['embed_tensor_transpose', ['tensor_transpose']], ['exp_with_fsdp', 'fsdp']], data_sharding=[['data', 'stage', 'fsdp', 'fsdp_transpose', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']], context_sharding='context', input_data_sharding_logical_axes=['activation_embed_and_logits_batch', 'activation_norm_length'], sharding_tolerance=0.02, shard_optimizer_over_data=False, internal_compile=False, internal_compile_num_devices=-1, compile_xla_flags='', hardware='tpu', num_slices=1, mesh_axes=['diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive'], shard_mode=<ShardMode.AUTO: 'auto'>, inhomogeneous_layer_cycle_interval=1, scan_layers=True, param_scan_axis=1, context_parallel_load_balance=True, context_parallel_strategy='all_gather', context_parallel_reorder_strategy=<ReorderStrategy.AUTO: 'auto'>, custom_mesh='', custom_mesh_and_rule='', allow_split_physical_axes=False, enable_nnx=True, optimize_mesh_for_tpu_v6e=False, shardy=True, pure_nnx_decoder=True, pure_nnx=True, remove_size_one_mesh_axis_from_type=True, gdn_conv_kernel_dim=4, gdn_key_head_dim=128, gdn_value_head_dim=128, gdn_num_key_heads=16, gdn_num_value_heads=32, gdn_chunk_size=64, use_qk_norm_in_gdn=True, partial_rotary_factor=1.0, first_num_dense_layers=0, shared_experts=0, routed_scaling_factor=1.0, routed_score_func='', routed_bias=False, routed_bias_update_rate=0.0, mlp_bias=False, n_routing_groups=-1, topk_routing_group=-1, use_batch_split_schedule=False, batch_split_factor=1, megablox=True, sparse_matmul=True, wi_tile_fwd_batch_seq=512, wi_tile_fwd_embed_dim=1024, wi_tile_fwd_mlp_dim=1024, wi_tile_dlhs_batch_seq=512, wi_tile_dlhs_embed_dim=1024, wi_tile_dlhs_mlp_dim=1024, wi_tile_drhs_batch_seq=512, wi_tile_drhs_embed_dim=1024, wi_tile_drhs_mlp_dim=1024, wo_tile_fwd_batch_seq=512, wo_tile_fwd_embed_dim=1024, wo_tile_fwd_mlp_dim=1024, wo_tile_dlhs_batch_seq=512, wo_tile_dlhs_embed_dim=1024, wo_tile_dlhs_mlp_dim=1024, wo_tile_drhs_batch_seq=512, wo_tile_drhs_embed_dim=1024, wo_tile_drhs_mlp_dim=1024, merge_gating_gmm=False, num_experts=1, num_experts_per_tok=1, capacity_factor=-1.0, ragged_buffer_factor=-1.0, moe_expert_input_dim=-1, base_moe_mlp_dim=-1, load_balance_loss_weight=0.0, use_custom_sort_vjp=True, use_ring_of_experts=False, use_gather_mosaic_kernel=False, use_random_routing=False, interleave_moe_layer_step=1, moe_fsdp_use_two_stage_all_gather=False, shard_exp_on_fsdp=False, use_2d_fsdp_sharding=False, norm_topk_prob=False, float32_weight_sum=True, float32_gate_logits=False, prefuse_moe_weights=False, pagedattn_num_pages=64, pagedattn_tokens_per_page=32, pagedattn_pages_per_compute_block=4, pagedattn_max_pages_per_group=-1, pagedattn_head_dim_alignment=128, sa_block_q=512, sa_block_kv=512, sa_block_kv_compute=512, sa_block_q_dkv=512, sa_block_kv_dkv=512, sa_block_kv_dkv_compute=512, sa_block_q_dq=512, sa_block_kv_dq=512, sa_use_fused_bwd_kernel=False, sa_q_layout='HEAD_DIM_MINOR', sa_k_layout='HEAD_DIM_MINOR', sa_v_layout='HEAD_DIM_MINOR', use_max_logit_estimate=-1, cost_estimate_flops_fwd=-1, cost_estimate_flops_bwd=-1, dq_reduction_steps=0, use_splash_scheduler=False, use_qk_norm=False, temperature_tuning=False, use_indexer=False, indexer_head_dim=128, indexer_n_heads=64, indexer_topk=2048, indexer_sparse_training=False, indexer_loss_scaling_factor=0.0, moba=False, moba_chunk_size=1024, moba_topk=8, mla_naive_kvcache=True, q_lora_rank=0, kv_lora_rank=512, qk_nope_head_dim=128, qk_rope_head_dim=64, v_head_dim=128, attention='dot_product', attention_type='global', share_kv_projections=False, global_num_kv_heads=0, attention_sink=False, float32_qk_product=False, float32_logits=False, sliding_window_size=0, chunk_attn_window_size=0, attn_logits_soft_cap=None, use_post_attn_norm=False, use_post_ffw_norm=False, use_ragged_attention=False, use_tokamax_gmm=False, ragged_block_size=256, enable_padding_causal_mask=True, use_tokamax_splash=False, use_jax_splash=False, force_q_layout=False, use_qk_clip=False, qk_clip_threshold=100.0, logits_via_embedding=False, normalize_embedding_logits=True, logits_dot_in_fp32=False, cast_logits_to_fp32=True, final_logits_soft_cap=None, z_loss_multiplier=0.0, mtp_num_layers=0, mtp_loss_scaling_factor=0.1, mtp_eval_target_module=0, engram_layers=[], engram_num_heads=8, engram_head_dim=1280, engram_vocab_bases=[], engram_max_ngram_size=3, engram_kernel_size=4, engram_seed=0, decoder_block=<DecoderBlockType.LLAMA2: 'llama2'>, global_parameter_scale=1, base_emb_dim=4096, base_num_query_heads=32, base_num_kv_heads=8, base_mlp_dim=14336, dense_init_scale=1.0, base_num_decoder_layers=32, head_dim=128, attention_output_dim=-1, global_head_dim=0, mlp_activations=['silu', 'linear'], mlp_activations_limit=-1.0, normalization_layer_epsilon=1e-05, fused_qkv=False, attention_bias=False, fused_mlp=False, qk_norm_with_scale=True, v_norm_with_scale=True, quantization=<QuantizationType.NONE: ''>, replicate_quant_scale=False, quant_cfg_path='', quantize_kvcache=False, kv_quant_axis=<KvQuantAxis.HEADS_AND_DKV: 'heads_and_dkv'>, kv_quant_dtype='int8', quantization_local_shard_count=4, use_qwix_quantization=False, weight_quantization_calibration_method='absmax', act_quantization_calibration_method='absmax', bwd_quantization_calibration_method='absmax', dtype=<DType.BFLOAT16: 'bfloat16'>, grad_dtype=<DType.FLOAT32: 'float32'>, weight_dtype=<DType.BFLOAT16: 'bfloat16'>, matmul_precision=<MatmulPrecision.DEFAULT: 'default'>, activations_in_float32=False, dtype_mm='float32', elastic_enabled=False, elastic_timeout_seconds=300, elastic_max_retries=10, enable_multi_tier_checkpointing=False, local_checkpoint_directory='', local_checkpoint_period=0, multi_tier_checkpointing_backup_interval_minutes=0, mtc_data_parallelism=0, enable_emergency_checkpoint=False, use_replicator_service=False, replicator_backup_interval_minutes=0, checkpoint_storage_target_data_file_size_bytes=2147483648, checkpoint_storage_use_ocdbt=False, checkpoint_storage_use_zarr3=False, checkpoint_storage_concurrent_gb=96, load_parameters_path='gs://lance-maxtext/rl_ckpt_llama31_8b/rl_ckpt_llama31_8b/0/items', lora_input_adapters_path='', load_full_state_path='', enable_checkpointing=True, load_checkpoint_only_once=False, async_checkpointing=False, checkpoint_period=50, max_num_checkpoints_to_keep=10, enable_single_replica_ckpt_restoring=False, checkpoint_todelete_subdir=None, checkpoint_todelete_full_path=None, force_unroll=False, checkpoint_is_quantized=False, save_quantized_params_path='', enable_orbax_v1=False, checkpoint_conversion_fn=None, source_checkpoint_layout='orbax', save_checkpoint_on_completion=True, enable_continuous_checkpointing=False, colocated_python_checkpointing=False, enable_autocheckpoint=False, base_config='../base.yml', run_name='pt_rl_nnx_xpk_main_20260424_070237_06_rl_grpo_functional', model_name='llama3.1-8b-Instruct', override_model_config=False, override_logical_axis_rules=False, log_config=False, debug_sharding=False, base_output_directory='gs://lance-maxtext/pt_ckpt_xpk_main_20260424_070237', sharding_strategy=None, debug=Debug(rl=True), rl=RL(num_generations=2, num_iterations=1, grpo_beta=0.08, grpo_epsilon=0.2, loss_algo='grpo', use_agentic_rollout=False, max_concurrency=256, off_policy_steps=0, system_prompt='', degenerate_group_masking=True, epsilon_high=None)), dtype=<DType.BFLOAT16: 'bfloat16'>, embedding=Param( # 525,336,576 (1.1 GB) value=Array(shape=(128256, 4096), dtype=dtype(bfloat16)), eager_sharding=False, out_sharding=('vocab', 'embed_vocab') ), mesh=Mesh(axis_sizes=(1, 1, 1, 32, 1, 1, 1, 1, 1, 1, 1, 1), axis_names=('diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive'), axis_types=(Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto)), num_embeddings=128256, num_features=4096 ), vision_encoder=None ), use_no_op_mappings=False, config=None W0424 07:37:54.802840 138242845357888 pyconfig.py:111] base_output_directory is not provided; Using local directory called maxtext_output I0424 07:37:54.841005 138242845357888 max_utils.py:238] Skipping jax distributed system due to skip_jax_distributed_system=True flag. I0424 07:37:54.842449 138242845357888 train_rl.py:426] Creating RL cluster... ) ERROR 04-24 07:37:55 [tpu_info.py:40] Unable to poll TPU GCE Metadata. Got status code: 404 and content: <!DOCTYPE html> ERROR 04-24 07:37:55 [tpu_info.py:40] <html lang=en> ERROR 04-24 07:37:55 [tpu_info.py:40] <meta charset=utf-8> ERROR 04-24 07:37:55 [tpu_info.py:40] <meta name=viewport content="initial-scale=1, minimum-scale=1, width=device-width"> ERROR 04-24 07:37:55 [tpu_info.py:40] <title>Error 404 (Not Found)!!1</title> ERROR 04-24 07:37:55 [tpu_info.py:40] <style> ERROR 04-24 07:37:55 [tpu_info.py:40] *{margin:0;padding:0}html,code{font:15px/22px arial,sans-serif}html{background:#fff;color:#222;padding:15px}body{margin:7% auto 0;max-width:390px;min-height:180px;padding:30px 0 15px}* > body{background:url(//www.google.com/images/errors/robot.png) 100% 5px no-repeat;padding-right:205px}p{margin:11px 0 22px;overflow:hidden}ins{color:#777;text-decoration:none}a img{border:0}@media screen and (max-width:772px){body{background:none;margin-top:0;max-width:none;padding-right:0}}#logo{background:url(//www.google.com/images/branding/googlelogo/1x/googlelogo_color_150x54dp.png) no-repeat;margin-left:-5px}@media only screen and (min-resolution:192dpi){#logo{background:url(//www.google.com/images/branding/googlelogo/2x/googlelogo_color_150x54dp.png) no-repeat 0% 0%/100% 100%;-moz-border-image:url(//www.google.com/images/branding/googlelogo/2x/googlelogo_color_150x54dp.png) 0}}@media only screen and (-webkit-min-device-pixel-ratio:2){#logo{background:url(//www.google.com/images/branding/googlelogo/2x/googlelogo_color_150x54dp.png) no-repeat;-webkit-background-size:100% 100%}}#logo{display:inline-block;height:54px;width:150px} ERROR 04-24 07:37:55 [tpu_info.py:40] </style> ERROR 04-24 07:37:55 [tpu_info.py:40] <a href=//www.google.com/><span id=logo aria-label=Google></span></a> ERROR 04-24 07:37:55 [tpu_info.py:40] <p><b>404.</b> <ins>That’s an error.</ins> ERROR 04-24 07:37:55 [tpu_info.py:40] <p>The requested URL <code>/computeMetadata/v1/instance/attributes/instance-id</code> was not found on this server. <ins>That’s all we know.</ins> ERROR 04-24 07:37:55 [tpu_info.py:40] INFO 04-24 07:37:55 [__init__.py:59] TPU info: node_name=None | tpu_type=v6e-32 | worker_id=0 | num_chips=4 | num_cores_per_chip=1 /usr/local/lib/python3.12/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock. self.pid = os.fork() WARNING: All log messages before absl::InitializeLog() is called are written to STDERR E0000 00:00:1777016278.504277 1987 descriptor_database.cc:633] File already exists in database: google/protobuf/timestamp.proto F0000 00:00:1777016278.504355 1987 descriptor.cc:2236] Check failed: GeneratedDatabase()->Add(encoded_file_descriptor, size) *** Check failure stack trace: *** @ 0x7db8f52a2fe4 absl::lts_20250127::log_internal::LogMessage::SendToLog() @ 0x7db8f52a2976 absl::lts_20250127::log_internal::LogMessage::Flush() @ 0x7db8f52a3539 absl::lts_20250127::log_internal::LogMessageFatal::~LogMessageFatal() @ 0x7db8f51955cb google::protobuf::DescriptorPool::InternalAddGeneratedFile() @ 0x7db8f520f308 google::protobuf::internal::AddDescriptors() @ 0x7db8f520f2fa google::protobuf::internal::AddDescriptors() @ 0x7dbaa7895b9f __static_initialization_and_destruction_0() @ 0x7dbaa7895bd2 _GLOBAL__sub_I.00102_tpu_metric_service.pb.cc @ 0x7dbb2c4d3fe2 (unknown) Fatal Python error: Aborted Current thread 0x00007dbb2bb0b740 (most recent call first): File "<frozen importlib._bootstrap>", line 488 in _call_with_frames_removed File "<frozen importlib._bootstrap_external>", line 1293 in create_module File "<frozen importlib._bootstrap>", line 813 in module_from_spec File "<frozen importlib._bootstrap>", line 921 in _load_unlocked File "<frozen importlib._bootstrap>", line 1331 in _find_and_load_unlocked File "<frozen importlib._bootstrap>", line 1360 in _find_and_load File "<frozen importlib._bootstrap>", line 488 in _call_with_frames_removed File "<frozen importlib._bootstrap>", line 1415 in _handle_fromlist File "/usr/local/lib/python3.12/site-packages/tpu_info/cli_helper.py", line 56 in _check_library_safety File "/usr/local/lib/python3.12/multiprocessing/process.py", line 108 in run File "/usr/local/lib/python3.12/multiprocessing/process.py", line 314 in _bootstrap File "/usr/local/lib/python3.12/multiprocessing/popen_fork.py", line 71 in _launch File "/usr/local/lib/python3.12/multiprocessing/popen_fork.py", line 19 in __init__ File "/usr/local/lib/python3.12/multiprocessing/context.py", line 282 in _Popen File "/usr/local/lib/python3.12/multiprocessing/context.py", line 224 in _Popen File "/usr/local/lib/python3.12/multiprocessing/process.py", line 121 in start File "/usr/local/lib/python3.12/site-packages/tpu_info/cli_helper.py", line 96 in _initialize_libtpu_safely File "/usr/local/lib/python3.12/site-packages/tpu_info/cli_helper.py", line 132 in <module> File "<frozen importlib._bootstrap>", line 488 in _call_with_frames_removed File "<frozen importlib._bootstrap_external>", line 999 in exec_module File "<frozen importlib._bootstrap>", line 935 in _load_unlocked File "<frozen importlib._bootstrap>", line 1331 in _find_and_load_unlocked File "<frozen importlib._bootstrap>", line 1360 in _find_and_load File "<frozen importlib._bootstrap>", line 488 in _call_with_frames_removed File "<frozen importlib._bootstrap>", line 1415 in _handle_fromlist File "/usr/local/lib/python3.12/site-packages/tpu_info/cli.py", line 27 in <module> File "<frozen importlib._bootstrap>", line 488 in _call_with_frames_removed File "<frozen importlib._bootstrap_external>", line 999 in exec_module File "<frozen importlib._bootstrap>", line 935 in _load_unlocked File "<frozen importlib._bootstrap>", line 1331 in _find_and_load_unlocked File "<frozen importlib._bootstrap>", line 1360 in _find_and_load File "<frozen importlib._bootstrap>", line 488 in _call_with_frames_removed File "<frozen importlib._bootstrap>", line 1415 in _handle_fromlist File "/usr/local/lib/python3.12/site-packages/tpu_info/__init__.py", line 16 in <module> File "<frozen importlib._bootstrap>", line 488 in _call_with_frames_removed File "<frozen importlib._bootstrap_external>", line 999 in exec_module File "<frozen importlib._bootstrap>", line 935 in _load_unlocked File "<frozen importlib._bootstrap>", line 1331 in _find_and_load_unlocked File "<frozen importlib._bootstrap>", line 1360 in _find_and_load File "/usr/local/lib/python3.12/site-packages/tpu_inference/platforms/tpu_platform.py", line 142 in get_device_name File "/usr/local/lib/python3.12/site-packages/tpu_inference/platforms/tpu_platform.py", line 151 in fp8_dtype File "/usr/local/lib/python3.12/site-packages/vllm/model_executor/layers/quantization/utils/quant_utils.py", line 20 in <module> File "<frozen importlib._bootstrap>", line 488 in _call_with_frames_removed File "<frozen importlib._bootstrap_external>", line 999 in exec_module File "<frozen importlib._bootstrap>", line 935 in _load_unlocked File "<frozen importlib._bootstrap>", line 1331 in _find_and_load_unlocked File "<frozen importlib._bootstrap>", line 1360 in _find_and_load File "/usr/local/lib/python3.12/site-packages/vllm/v1/attention/backend.py", line 13 in <module> File "<frozen importlib._bootstrap>", line 488 in _call_with_frames_removed File "<frozen importlib._bootstrap_external>", line 999 in exec_module File "<frozen importlib._bootstrap>", line 935 in _load_unlocked File "<frozen importlib._bootstrap>", line 1331 in _find_and_load_unlocked File "<frozen importlib._bootstrap>", line 1360 in _find_and_load File "/usr/local/lib/python3.12/site-packages/vllm/forward_context.py", line 17 in <module> File "<frozen importlib._bootstrap>", line 488 in _call_with_frames_removed File "<frozen importlib._bootstrap_external>", line 999 in exec_module File "<frozen importlib._bootstrap>", line 935 in _load_unlocked File "<frozen importlib._bootstrap>", line 1331 in _find_and_load_unlocked File "<frozen importlib._bootstrap>", line 1360 in _find_and_load File "/usr/local/lib/python3.12/site-packages/vllm/compilation/cuda_graph.py", line 19 in <module> File "<frozen importlib._bootstrap>", line 488 in _call_with_frames_removed File "<frozen importlib._bootstrap_external>", line 999 in exec_module File "<frozen importlib._bootstrap>", line 935 in _load_unlocked File "<frozen importlib._bootstrap>", line 1331 in _find_and_load_unlocked File "<frozen importlib._bootstrap>", line 1360 in _find_and_load File "/usr/local/lib/python3.12/site-packages/vllm/v1/metrics/stats.py", line 10 in <module> File "<frozen importlib._bootstrap>", line 488 in _call_with_frames_removed File "<frozen importlib._bootstrap_external>", line 999 in exec_module File "<frozen importlib._bootstrap>", line 935 in _load_unlocked File "<frozen importlib._bootstrap>", line 1331 in _find_and_load_unlocked File "<frozen importlib._bootstrap>", line 1360 in _find_and_load File "/usr/local/lib/python3.12/site-packages/vllm/outputs.py", line 16 in <module> File "<frozen importlib._bootstrap>", line 488 in _call_with_frames_removed File "<frozen importlib._bootstrap_external>", line 999 in exec_module File "<frozen importlib._bootstrap>", line 935 in _load_unlocked File "<frozen importlib._bootstrap>", line 1331 in _find_and_load_unlocked File "<frozen importlib._bootstrap>", line 1360 in _find_and_load File "/usr/local/lib/python3.12/site-packages/tunix/generate/vllm_async_driver.py", line 35 in <module> File "<frozen importlib._bootstrap>", line 488 in _call_with_frames_removed File "<frozen importlib._bootstrap_external>", line 999 in exec_module File "<frozen importlib._bootstrap>", line 935 in _load_unlocked File "<frozen importlib._bootstrap>", line 1331 in _find_and_load_unlocked File "<frozen importlib._bootstrap>", line 1360 in _find_and_load File "/usr/local/lib/python3.12/site-packages/tunix/generate/vllm_sampler.py", line 32 in <module> File "<frozen importlib._bootstrap>", line 488 in _call_with_frames_removed File "<frozen importlib._bootstrap_external>", line 999 in exec_module File "<frozen importlib._bootstrap>", line 935 in _load_unlocked File "<frozen importlib._bootstrap>", line 1331 in _find_and_load_unlocked File "<frozen importlib._bootstrap>", line 1360 in _find_and_load File "<frozen importlib._bootstrap>", line 488 in _call_with_frames_removed File "<frozen importlib._bootstrap>", line 1415 in _handle_fromlist File "/usr/local/lib/python3.12/site-packages/tunix/rl/rollout/vllm_rollout.py", line 23 in <module> File "<frozen importlib._bootstrap>", line 488 in _call_with_frames_removed File "<frozen importlib._bootstrap_external>", line 999 in exec_module File "<frozen importlib._bootstrap>", line 935 in _load_unlocked File "<frozen importlib._bootstrap>", line 1331 in _find_and_load_unlocked File "<frozen importlib._bootstrap>", line 1360 in _find_and_load File "<frozen importlib._bootstrap>", line 488 in _call_with_frames_removed File "<frozen importlib._bootstrap>", line 1415 in _handle_fromlist File "/usr/local/lib/python3.12/site-packages/tunix/rl/rl_cluster.py", line 392 in _init_cluster ... Extension modules: numpy._core._multiarray_umath, numpy.linalg._umath_linalg, pyarrow.lib, numpy.random._common, numpy.random.bit_generator, numpy.random._bounded_integers, numpy.random._mt19937, numpy.random.mtrand, numpy.random._philox, numpy.random._pcg64, numpy.random._sfc64, numpy.random._generator, pandas._libs.tslibs.ccalendar, pandas._libs.tslibs.np_datetime, pandas._libs.tslibs.dtypes, pandas._libs.tslibs.base, pandas._libs.tslibs.nattype, pandas._libs.tslibs.timezones, pandas._libs.tslibs.fields, pandas._libs.tslibs.timedeltas, pandas._libs.tslibs.tzconversion, pandas._libs.tslibs.timestamps, pandas._libs.properties, pandas._libs.tslibs.offsets, pandas._libs.tslibs.strptime, pandas._libs.tslibs.parsing, pandas._libs.tslibs.conversion, pandas._libs.tslibs.period, pandas._libs.tslibs.vectorized, pandas._libs.ops_dispatch, pandas._libs.missing, pandas._libs.hashtable, pandas._libs.algos, pandas._libs.interval, pandas._libs.lib, pyarrow._compute, pandas._libs.ops, pandas._libs.hashing, pandas._libs.arrays, pandas._libs.tslib, pandas._libs.sparse, pandas._libs.internals, pandas._libs.indexing, pandas._libs.index, pandas._libs.writers, pandas._libs.join, pandas._libs.window.aggregations, pandas._libs.window.indexers, pandas._libs.reshape, pandas._libs.groupby, pandas._libs.json, pandas._libs.parsers, pandas._libs.testing, pyarrow._acero, pyarrow._fs, pyarrow._csv, pyarrow._json, pyarrow._substrait, pyarrow._dataset, pyarrow._dataset_orc, pyarrow._parquet, pyarrow._parquet_encryption, pyarrow._dataset_parquet_encryption, pyarrow._dataset_parquet, zstandard.backend_c, yaml._yaml, pyarrow._azurefs, pyarrow._hdfs, pyarrow._gcsfs, pyarrow._s3fs, charset_normalizer.md, simplejson._speedups, requests.packages.charset_normalizer.md, requests.packages.chardet.md, multidict._multidict, yarl._quoting_c, propcache._helpers_c, aiohttp._http_writer, aiohttp._http_parser, aiohttp._websocket.mask, aiohttp._websocket.reader_c, frozenlist._frozenlist, xxhash._xxhash, jaxlib.cpu_feature_guard, google._upb._message, msgpack._cmsgpack, grpc._cython.cygrpc, _cffi_backend, regex._regex, markupsafe._speedups, PIL._imaging, torch._C, torch._C._dynamo.autograd_compiler, torch._C._dynamo.eval_frame, torch._C._dynamo.guards, torch._C._dynamo.utils, torch._C._fft, torch._C._linalg, torch._C._nested, torch._C._nn, torch._C._sparse, torch._C._special, psutil._psutil_linux, sentencepiece._sentencepiece, h5py._errors, h5py.defs, h5py._objects, h5py.h5, h5py.utils, h5py.h5t, h5py.h5s, h5py.h5ac, h5py.h5p, h5py.h5r, h5py._npystrings, h5py._proxy, h5py._conv, h5py.h5z, h5py.h5a, h5py.h5d, h5py.h5ds, h5py.h5g, h5py.h5i, h5py.h5o, h5py.h5f, h5py.h5fd, h5py.h5pl, h5py.h5l, h5py._selector, scipy._lib._ccallback_c, scipy.sparse._sparsetools, _csparsetools, _cyutility, scipy._cyutility, scipy.sparse._csparsetools, kiwisolver._cext, PIL._imagingft, scipy.io.matlab._mio_utils, scipy.io.matlab._streams, scipy.io.matlab._mio5_utils, msgspec._core, _cbor2, scipy.linalg._fblas, scipy.linalg._flapack, scipy.linalg.cython_lapack, scipy.linalg._cythonized_array_utils, scipy.linalg._solve_toeplitz, scipy.linalg._decomp_lu_cython, scipy.linalg._matfuncs_schur_sqrtm, scipy.linalg._matfuncs_expm, scipy.linalg._linalg_pythran, scipy.linalg.cython_blas, scipy.linalg._decomp_update, scipy.sparse.linalg._dsolve._superlu, scipy.sparse.linalg._eigen.arpack._arpack, scipy.sparse.linalg._propack._spropack, scipy.sparse.linalg._propack._dpropack, scipy.sparse.linalg._propack._cpropack, scipy.sparse.linalg._propack._zpropack, scipy.optimize._group_columns, scipy._lib.messagestream, scipy.optimize._trlib._trlib, scipy.optimize._lbfgsb, _moduleTNC, scipy.optimize._moduleTNC, scipy.optimize._slsqplib, scipy.optimize._minpack, scipy.optimize._lsq.givens_elimination, scipy.optimize._zeros, scipy._lib._uarray._uarray, scipy.special._ufuncs_cxx, scipy.special._ellip_harm_2, scipy.special._special_ufuncs, scipy.special._gufuncs, scipy.special._ufuncs, scipy.special._specfun, scipy.special._comb, scipy.linalg._decomp_interpolative, scipy.optimize._bglu_dense, scipy.optimize._lsap, scipy.spatial._ckdtree, scipy.spatial._qhull, scipy.spatial._voronoi, scipy.spatial._hausdorff, scipy.spatial._distance_wrap, scipy.spatial.transform._rotation, scipy.spatial.transform._rigid_transform, scipy.optimize._direct, zmq.backend.cython._zmq, pybase64._pybase64, scipy.signal._sigtools, scipy.signal._max_len_seq_inner, scipy.signal._upfirdn_apply, scipy.signal._spline, scipy.interpolate._fitpack, scipy.interpolate._dfitpack, scipy.interpolate._dierckx, scipy.interpolate._ppoly, scipy.interpolate._interpnd, scipy.interpolate._rbfinterp_pythran, scipy.interpolate._rgi_cython, scipy.ndimage._nd_image, scipy.ndimage._rank_filter_1d, _ni_label, scipy.ndimage._ni_label, scipy.signal._sosfilt, scipy.integrate._odepack, scipy.integrate._quadpack, scipy.integrate._vode, scipy.integrate._dop, scipy.integrate._lsoda, scipy.special.cython_special, scipy.stats._stats, scipy.stats._biasedurn, scipy.stats._stats_pythran, scipy.stats._levy_stable.levyst, scipy.stats._ansari_swilk_statistics, scipy.sparse.csgraph._tools, scipy.sparse.csgraph._shortest_path, scipy.sparse.csgraph._traversal, scipy.sparse.csgraph._min_spanning_tree, scipy.sparse.csgraph._flow, scipy.sparse.csgraph._matching, scipy.sparse.csgraph._reordering, scipy.stats._sobol, scipy.stats._qmc_cy, scipy.stats._rcont.rcont, scipy.stats._qmvnt_cy, scipy.signal._peak_finding_utils, uvloop.loop (total: 230) Check failed with unknown exit code: -6. INFO 04-24 07:38:33 [tpu_platform.py:152] Automatically using fp8_e5m2 for FP8 KV cache on TPU v6e. I0424 07:38:34.187158 138242845357888 vllm_sampler.py:102] Engine kwargs setting key 'model' with value 'meta-llama/Llama-3.1-8B-Instruct'. I0424 07:38:34.187288 138242845357888 vllm_sampler.py:102] Engine kwargs setting key 'max_model_len' with value '1280'. I0424 07:38:34.187316 138242845357888 vllm_sampler.py:102] Engine kwargs setting key 'async_scheduling' with value 'False'. I0424 07:38:34.187337 138242845357888 vllm_sampler.py:102] Engine kwargs setting key 'max_num_batched_tokens' with value 'None'. I0424 07:38:34.187357 138242845357888 vllm_sampler.py:102] Engine kwargs setting key 'max_num_seqs' with value 'None'. I0424 07:38:34.187375 138242845357888 vllm_sampler.py:102] Engine kwargs setting key 'hf_config_path' with value ''. I0424 07:38:34.187391 138242845357888 vllm_sampler.py:102] Engine kwargs setting key 'max_logprobs' with value '1'. I0424 07:38:34.187408 138242845357888 vllm_sampler.py:102] Engine kwargs setting key 'hf_overrides' with value '{}'. I0424 07:38:34.187424 138242845357888 vllm_sampler.py:102] Engine kwargs setting key 'enable_expert_parallel' with value 'False'. I0424 07:38:34.187441 138242845357888 vllm_sampler.py:102] Engine kwargs setting key 'enable_prefix_caching' with value 'True'. INFO 04-24 07:38:34 [attention_interface.py:53] Using default RPA kernel INFO 04-24 07:38:34 [importing.py:44] Triton is installed but 0 active driver(s) found (expected 1). Disabling Triton to prevent runtime errors. INFO 04-24 07:38:34 [importing.py:68] Triton not installed or not compatible; certain GPU-related functions will not be available. WARNING 04-24 07:38:34 [interface.py:240] Failed to import from vllm._C: ModuleNotFoundError("No module named 'vllm._C'") INFO 04-24 07:38:34 [tpu_platform.py:152] Automatically using fp8_e5m2 for FP8 KV cache on TPU v6e. INFO 04-24 07:38:34 [tpu_platform.py:152] Automatically using fp8_e5m2 for FP8 KV cache on TPU v6e. INFO 04-24 07:38:34 [tpu_platform.py:152] Automatically using fp8_e5m2 for FP8 KV cache on TPU v6e. INFO 04-24 07:38:34 [tpu_platform.py:152] Automatically using fp8_e5m2 for FP8 KV cache on TPU v6e. INFO 04-24 07:38:35 [nixl_utils.py:20] Setting UCX_RCACHE_MAX_UNRELEASED to '1024' to avoid a rare memory leak in UCX when using NIXL. WARNING 04-24 07:38:35 [nixl_utils.py:34] NIXL is not available WARNING 04-24 07:38:35 [nixl_utils.py:44] NIXL agent config is not available INFO 04-24 07:38:35 [__init__.py:110] Registered model loader `<class 'tpu_inference.models.jax.utils.weight_utils.JaxDummyModelLoader'>` with load format `jax_dummy` INFO 04-24 07:38:35 [__init__.py:110] Registered model loader `<class 'tpu_inference.models.common.pathways_dummy_loader.PathwaysDummyModelLoader'>` with load format `pathways_dummy` WARNING 04-24 07:38:35 [__init__.py:85] The quantization method 'awq' already exists and will be overwritten by the quantization config <class 'tpu_inference.layers.vllm.quantization.awq.VllmAWQConfig'>. WARNING 04-24 07:38:36 [__init__.py:85] The quantization method 'compressed-tensors' already exists and will be overwritten by the quantization config <class 'tpu_inference.layers.vllm.quantization.compressed_tensors.compressed_tensors.VllmCompressedTensorsConfig'>. WARNING 04-24 07:38:36 [__init__.py:85] The quantization method 'fp8' already exists and will be overwritten by the quantization config <class 'tpu_inference.layers.vllm.quantization.fp8.VllmFp8Config'>. WARNING 04-24 07:38:36 [__init__.py:85] The quantization method 'gpt_oss_mxfp4' already exists and will be overwritten by the quantization config <class 'tpu_inference.layers.vllm.quantization.mxfp4.VllmMxfp4Config'>. INFO 04-24 07:38:36 [__init__.py:110] Registered model loader `<class 'tpu_inference.models.vllm.vllm_model_loader.IncrementalModelLoader'>` with load format `tpu_streaming_loader` WARNING 04-24 07:38:36 [__init__.py:99] Load format `runai_streamer` is already registered, and will be overwritten by the new loader class `<class 'tpu_inference.models.vllm.vllm_model_loader.RunaiIncrementalModelLoader'>`. INFO 04-24 07:38:36 [__init__.py:110] Registered model loader `<class 'tpu_inference.models.vllm.vllm_model_loader.RunaiIncrementalModelLoader'>` with load format `runai_streamer` WARNING 04-24 07:38:36 [interface.py:240] Failed to import from vllm._C: ModuleNotFoundError("No module named 'vllm._C'") WARNING 04-24 07:38:36 [interface.py:240] Failed to import from vllm._C: ModuleNotFoundError("No module named 'vllm._C'") WARNING 04-24 07:38:36 [interface.py:240] Failed to import from vllm._C: ModuleNotFoundError("No module named 'vllm._C'") W0424 07:38:36.665107 138242845357888 ops_registry.py:52] Duplicate op registration for aten.__and__ WARNING 04-24 07:38:36 [tpu_platform.py:317] Pin memory is not supported on TPU. INFO 04-24 07:38:36 [__init__.py:31] Registering MaxTextForCausalLM model with tpu_inference and vllm. INFO 04-24 07:38:36 [model_loader.py:682] Registered JAX model MaxTextForCausalLM with tpu_inference and vLLM registries. INFO 04-24 07:38:36 [__init__.py:33] Successfully registered MaxTextForCausalLM model. INFO 04-24 07:38:36 [utils.py:233] non-default args: {'hf_config_path': '', 'load_format': 'dummy', 'max_model_len': 1280, 'tensor_parallel_size': 16, 'data_parallel_size': 2, 'enable_prefix_caching': True, 'gpu_memory_utilization': 0.72, 'max_logprobs': 1, 'disable_log_stats': True, 'additional_config': {'sharding': {'sharding_strategy': {'expert_parallelism': 1, 'device_indexes': [0, 4, 8, 12, 16, 20, 24, 28, 1, 5, 9, 13, 17, 21, 25, 29, 2, 6, 10, 14, 18, 22, 26, 30, 3, 7, 11, 15, 19, 23, 27, 31], 'enable_dp_attention': False}}}, 'async_scheduling': False, 'model': 'meta-llama/Llama-3.1-8B-Instruct'} WARNING 04-24 07:38:36 [arg_utils.py:1440] The global random seed is set to 0. Since VLLM_ENABLE_V1_MULTIPROCESSING is set to False, this may affect the random state of the Python process that launched vLLM. INFO 04-24 07:38:55 [model.py:554] Resolved architecture: LlamaForCausalLM INFO 04-24 07:38:55 [model.py:1685] Using max model len 1280 INFO 04-24 07:38:56 [scheduler.py:239] Chunked prefill is enabled with max_num_batched_tokens=8192. INFO 04-24 07:38:56 [vllm.py:845] Asynchronous scheduling is disabled. INFO 04-24 07:38:56 [kernel.py:199] Final IR op priority after setting platform defaults: IrOpPriorityConfig(rms_norm=['native']) INFO 04-24 07:38:56 [tpu_platform.py:190] Initialized sharding configuration: ShardingConfigManager(total_devices=32, sharding_strategy=ShardingStrategy(tensor_parallelism=16, expert_parallelism=1, sequence_parallelism=1, data_parallelism=2, attention_data_parallelism=1, attention_data_expert_parallelism=1), device_indexes=[0, 4, 8, 12, 16, 20, 24, 28, 1, 5, 9, 13, 17, 21, 25, 29, 2, 6, 10, 14, 18, 22, 26, 30, 3, 7, 11, 15, 19, 23, 27, 31]) INFO 04-24 07:38:56 [tpu_platform.py:245] Using KV cache block size: 128 INFO 04-24 07:38:56 [tpu_platform.py:256] Force using UniProcExecutor for JAX on single host without pipeline parallelism. INFO 04-24 07:38:56 [compilation.py:303] Enabled custom fusions: norm_quant, act_quant INFO 04-24 07:38:58 [core.py:107] Initializing a V1 LLM engine (v0.19.2rc1.dev43+g595562651) with config: model='meta-llama/Llama-3.1-8B-Instruct', speculative_config=None, tokenizer='meta-llama/Llama-3.1-8B-Instruct', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=1280, download_dir=None, load_format=dummy, tensor_parallel_size=16, pipeline_parallel_size=1, data_parallel_size=1, decode_context_parallel_size=1, dcp_comm_backend=ag_rs, disable_custom_all_reduce=True, quantization=None, quantization_config=None, enforce_eager=False, enable_return_routed_experts=False, kv_cache_dtype=auto, device_config=None, structured_outputs_config=StructuredOutputsConfig(backend='auto', disable_any_whitespace=False, disable_additional_properties=False, reasoning_parser='', reasoning_parser_plugin='', enable_in_reasoning=False), observability_config=ObservabilityConfig(show_hidden_metrics_for_version=None, otlp_traces_endpoint=None, collect_detailed_traces=None, kv_cache_metrics=False, kv_cache_metrics_sample=0.01, cudagraph_metrics=False, enable_layerwise_nvtx_tracing=False, enable_mfu_metrics=False, enable_mm_processor_stats=False, enable_logging_iteration_details=False), seed=0, served_model_name=meta-llama/Llama-3.1-8B-Instruct, enable_prefix_caching=True, enable_chunked_prefill=True, pooler_config=None, compilation_config={'mode': <CompilationMode.DYNAMO_TRACE_ONCE: 2>, 'debug_dump_path': None, 'cache_dir': '', 'compile_cache_save_format': 'binary', 'backend': 'openxla', 'custom_ops': ['all'], 'ir_enable_torch_wrap': False, 'splitting_ops': [], 'compile_mm_encoder': False, 'cudagraph_mm_encoder': False, 'encoder_cudagraph_token_budgets': [], 'encoder_cudagraph_max_vision_items_per_batch': 0, 'encoder_cudagraph_max_frames_per_batch': 0, 'compile_sizes': None, 'compile_ranges_endpoints': [8192], 'inductor_compile_config': {'enable_auto_functionalized_v2': False, 'size_asserts': False, 'alignment_asserts': False, 'scalar_asserts': False, 'combo_kernels': True, 'benchmark_combo_kernel': True}, 'inductor_passes': {}, 'cudagraph_mode': <CUDAGraphMode.NONE: 0>, 'cudagraph_num_of_warmups': 0, 'cudagraph_capture_sizes': None, 'cudagraph_copy_inputs': False, 'cudagraph_specialize_lora': True, 'use_inductor_graph_partition': False, 'pass_config': {'fuse_norm_quant': True, 'fuse_act_quant': True, 'fuse_attn_quant': False, 'enable_sp': False, 'fuse_gemm_comms': False, 'fuse_allreduce_rms': False}, 'max_cudagraph_capture_size': None, 'dynamic_shapes_config': {'type': <DynamicShapesType.BACKED: 'backed'>, 'evaluate_guards': False, 'assume_32_bit_indexing': False}, 'local_cache_dir': None, 'fast_moe_cold_start': True, 'static_all_moe_layers': []}, kernel_config=KernelConfig(ir_op_priority=IrOpPriorityConfig(rms_norm=['native']), enable_flashinfer_autotune=True, moe_backend='auto') Traceback (most recent call last): File "<frozen runpy>", line 198, in _run_module_as_main File "<frozen runpy>", line 88, in _run_code File "/deps/src/maxtext/trainers/post_train/rl/train_rl.py", line 661, in <module> app.run(main) File "/usr/local/lib/python3.12/site-packages/absl/app.py", line 316, in run _run_main(main, args) File "/usr/local/lib/python3.12/site-packages/absl/app.py", line 261, in _run_main sys.exit(main(argv)) ^^^^^^^^^^ File "/deps/src/maxtext/trainers/post_train/rl/train_rl.py", line 657, in main rl_train(argv, kwargs) File "/deps/src/maxtext/trainers/post_train/rl/train_rl.py", line 583, in rl_train rl_cluster, rl_trainer, _ = create_rl_components( ^^^^^^^^^^^^^^^^^^^^^ File "/deps/src/maxtext/trainers/post_train/rl/train_rl.py", line 441, in create_rl_components rl_cluster = rl_cluster_lib.RLCluster( ^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/site-packages/tunix/rl/rl_cluster.py", line 250, in __init__ self._init_cluster() File "/usr/local/lib/python3.12/site-packages/tunix/rl/rl_cluster.py", line 415, in _init_cluster self._rollout = vllm_rollout.VllmRollout( ^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/site-packages/tunix/rl/rollout/vllm_rollout.py", line 43, in __init__ self._sampler = vllm_sampler.VllmSampler( ^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/site-packages/tunix/generate/vllm_sampler.py", line 156, in __init__ self.llm = LLM(**self.args) ^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/site-packages/vllm/entrypoints/llm.py", line 381, in __init__ self.llm_engine = LLMEngine.from_engine_args( ^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/site-packages/vllm/v1/engine/llm_engine.py", line 171, in from_engine_args return cls( ^^^^ File "/usr/local/lib/python3.12/site-packages/vllm/v1/engine/llm_engine.py", line 105, in __init__ self.engine_core = EngineCoreClient.make_client( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/site-packages/vllm/v1/engine/core_client.py", line 103, in make_client return InprocClient(vllm_config, executor_class, log_stats) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/site-packages/vllm/v1/engine/core_client.py", line 285, in __init__ self.engine_core = EngineCore(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/site-packages/vllm/v1/engine/core.py", line 116, in __init__ self.model_executor = executor_class(vllm_config) ^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/site-packages/vllm/tracing/otel.py", line 178, in sync_wrapper return func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/site-packages/vllm/v1/executor/abstract.py", line 109, in __init__ self._init_executor() File "/usr/local/lib/python3.12/site-packages/vllm/v1/executor/uniproc_executor.py", line 47, in _init_executor self.driver_worker.init_device() File "/usr/local/lib/python3.12/site-packages/vllm/v1/worker/worker_base.py", line 317, in init_device self.worker.init_device() # type: ignore ^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/site-packages/tpu_inference/worker/tpu_worker.py", line 241, in init_device device = device_dict[device_index] ~~~~~~~~~~~^^^^^^^^^^^^^^ KeyError: 0 XPK End: Fri Apr 24 07:39:11 UTC 2026 EXIT_CODE=1