From f33e5c3066b5780a65df373a4c464c8a2bd0f647 Mon Sep 17 00:00:00 2001 From: Shadi Noghabi Date: Tue, 5 May 2026 12:34:56 -0700 Subject: [PATCH] fix mesh in previously collocated cli scripts PiperOrigin-RevId: 910848514 --- examples/rl/grpo/gsm8k/run_gemma3_12b.sh | 6 ++++-- examples/rl/grpo/gsm8k/run_gemma3_1b.sh | 6 ++++-- examples/rl/grpo/gsm8k/run_gemma3_4b.sh | 6 ++++-- examples/rl/grpo/gsm8k/run_gemma_7b.sh | 6 ++++-- examples/rl/grpo/gsm8k/run_llama3.1_8b.sh | 6 ++++-- examples/rl/grpo/gsm8k/run_llama3.2_1b.sh | 6 ++++-- examples/rl/grpo/gsm8k/run_qwen3.sh | 6 ++++-- examples/rl/grpo/gsm8k/run_qwen3_simplereward.sh | 6 ++++-- examples/rl/grpo/gsm8k/verl_compatible/run_llama3.2_1b.sh | 6 ++++-- 9 files changed, 36 insertions(+), 18 deletions(-) diff --git a/examples/rl/grpo/gsm8k/run_gemma3_12b.sh b/examples/rl/grpo/gsm8k/run_gemma3_12b.sh index 2b3f5cb67..4db932abf 100755 --- a/examples/rl/grpo/gsm8k/run_gemma3_12b.sh +++ b/examples/rl/grpo/gsm8k/run_gemma3_12b.sh @@ -53,8 +53,10 @@ python3 -m tunix.cli.grpo_main \ actor_model_config.lora_config.module_path=".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|.*attn_vec_einsum" \ actor_model_config.mesh.shape="(2,4)" \ actor_model_config.mesh.axis_names="('fsdp','tp')" \ - rollout_model_config.mesh.shape="(2,4)" \ - rollout_model_config.mesh.axis_names="('fsdp','tp')" \ + reference_model_config.mesh=null \ + reference_model_config.same_mesh_as="actor" \ + rollout_model_config.mesh=null \ + rollout_model_config.same_mesh_as="actor" \ tokenizer_config.tokenizer_path="gs://gemma-data/tokenizers/tokenizer_gemma3.model" \ tokenizer_config.tokenizer_type="sentencepiece" \ tokenizer_config.add_bos=false \ diff --git a/examples/rl/grpo/gsm8k/run_gemma3_1b.sh b/examples/rl/grpo/gsm8k/run_gemma3_1b.sh index 7066c2323..f5c9cee5c 100755 --- a/examples/rl/grpo/gsm8k/run_gemma3_1b.sh +++ b/examples/rl/grpo/gsm8k/run_gemma3_1b.sh @@ -53,8 +53,10 @@ python3 -m tunix.cli.grpo_main \ actor_model_config.lora_config.module_path=".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|.*attn_vec_einsum" \ actor_model_config.mesh.shape="(2,4)" \ actor_model_config.mesh.axis_names="('fsdp','tp')" \ - rollout_model_config.mesh.shape="(2,4)" \ - rollout_model_config.mesh.axis_names="('fsdp','tp')" \ + reference_model_config.mesh=null \ + reference_model_config.same_mesh_as="actor" \ + rollout_model_config.mesh=null \ + rollout_model_config.same_mesh_as="actor" \ tokenizer_config.tokenizer_path="gs://gemma-data/tokenizers/tokenizer_gemma3.model" \ tokenizer_config.tokenizer_type="sentencepiece" \ tokenizer_config.add_bos=false \ diff --git a/examples/rl/grpo/gsm8k/run_gemma3_4b.sh b/examples/rl/grpo/gsm8k/run_gemma3_4b.sh index 4648329b1..a94ab6701 100755 --- a/examples/rl/grpo/gsm8k/run_gemma3_4b.sh +++ b/examples/rl/grpo/gsm8k/run_gemma3_4b.sh @@ -53,8 +53,10 @@ python3 -m tunix.cli.grpo_main \ actor_model_config.lora_config.module_path=".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|.*attn_vec_einsum" \ actor_model_config.mesh.shape="(2,4)" \ actor_model_config.mesh.axis_names="('fsdp','tp')" \ - rollout_model_config.mesh.shape="(2,4)" \ - rollout_model_config.mesh.axis_names="('fsdp','tp')" \ + reference_model_config.mesh=null \ + reference_model_config.same_mesh_as="actor" \ + rollout_model_config.mesh=null \ + rollout_model_config.same_mesh_as="actor" \ tokenizer_config.tokenizer_path="gs://gemma-data/tokenizers/tokenizer_gemma3.model" \ tokenizer_config.tokenizer_type="sentencepiece" \ tokenizer_config.add_bos=false \ diff --git a/examples/rl/grpo/gsm8k/run_gemma_7b.sh b/examples/rl/grpo/gsm8k/run_gemma_7b.sh index 75f9f3f7a..8f6072ca0 100755 --- a/examples/rl/grpo/gsm8k/run_gemma_7b.sh +++ b/examples/rl/grpo/gsm8k/run_gemma_7b.sh @@ -55,8 +55,10 @@ python3 -m tunix.cli.grpo_main \ actor_model_config.lora_config.module_path=".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|.*attn_vec_einsum" \ actor_model_config.mesh.shape="(2,4)" \ actor_model_config.mesh.axis_names="('fsdp','tp')" \ - rollout_model_config.mesh.shape="(2,4)" \ - rollout_model_config.mesh.axis_names="('fsdp','tp')" \ + reference_model_config.mesh=null \ + reference_model_config.same_mesh_as="actor" \ + rollout_model_config.mesh=null \ + rollout_model_config.same_mesh_as="actor" \ tokenizer_config.tokenizer_path="/tmp/models/gemma-7b/models/google/gemma/flax/7b-it/2/tokenizer.model" \ tokenizer_config.tokenizer_type="sentencepiece" \ tokenizer_config.add_bos=false \ diff --git a/examples/rl/grpo/gsm8k/run_llama3.1_8b.sh b/examples/rl/grpo/gsm8k/run_llama3.1_8b.sh index 8d8bbed6d..33f728bb3 100755 --- a/examples/rl/grpo/gsm8k/run_llama3.1_8b.sh +++ b/examples/rl/grpo/gsm8k/run_llama3.1_8b.sh @@ -52,8 +52,10 @@ python3 -m tunix.cli.grpo_main \ actor_model_config.lora_config.module_path=".*q_proj|.*k_proj|.*v_proj|.*o_proj|.*gate_proj|.*down_proj|.*up_proj" \ actor_model_config.mesh.shape="(2,4)" \ actor_model_config.mesh.axis_names="('fsdp','tp')" \ - rollout_model_config.mesh.shape="(2,4)" \ - rollout_model_config.mesh.axis_names="('fsdp','tp')" \ + reference_model_config.mesh=null \ + reference_model_config.same_mesh_as="actor" \ + rollout_model_config.mesh=null \ + rollout_model_config.same_mesh_as="actor" \ tokenizer_config.tokenizer_path="meta-llama/Llama-3.1-8B" \ tokenizer_config.tokenizer_type="huggingface" \ tokenizer_config.add_bos=false \ diff --git a/examples/rl/grpo/gsm8k/run_llama3.2_1b.sh b/examples/rl/grpo/gsm8k/run_llama3.2_1b.sh index 17b883e03..070e29c09 100755 --- a/examples/rl/grpo/gsm8k/run_llama3.2_1b.sh +++ b/examples/rl/grpo/gsm8k/run_llama3.2_1b.sh @@ -52,8 +52,10 @@ python3 -m tunix.cli.grpo_main \ actor_model_config.lora_config.module_path=".*q_proj|.*k_proj|.*v_proj|.*o_proj|.*gate_proj|.*down_proj|.*up_proj" \ actor_model_config.mesh.shape="(2,4)" \ actor_model_config.mesh.axis_names="('fsdp','tp')" \ - rollout_model_config.mesh.shape="(2,4)" \ - rollout_model_config.mesh.axis_names="('fsdp','tp')" \ + reference_model_config.mesh=null \ + reference_model_config.same_mesh_as="actor" \ + rollout_model_config.mesh=null \ + rollout_model_config.same_mesh_as="actor" \ tokenizer_config.tokenizer_path="meta-llama/Llama-3.2-1B" \ tokenizer_config.tokenizer_type="huggingface" \ tokenizer_config.add_bos=false \ diff --git a/examples/rl/grpo/gsm8k/run_qwen3.sh b/examples/rl/grpo/gsm8k/run_qwen3.sh index e644dafa0..76985952b 100755 --- a/examples/rl/grpo/gsm8k/run_qwen3.sh +++ b/examples/rl/grpo/gsm8k/run_qwen3.sh @@ -44,8 +44,10 @@ python3 -m tunix.cli.grpo_main \ actor_model_config.lora_config.module_path=".*q_proj|.*k_proj|.*v_proj|.*o_proj|.*gate_proj|.*down_proj|.*up_proj" \ actor_model_config.mesh.shape="(2,4)" \ actor_model_config.mesh.axis_names="('fsdp','tp')" \ - rollout_model_config.mesh.shape="(2,4)" \ - rollout_model_config.mesh.axis_names="('fsdp','tp')" \ + reference_model_config.mesh=null \ + reference_model_config.same_mesh_as="actor" \ + rollout_model_config.mesh=null \ + rollout_model_config.same_mesh_as="actor" \ tokenizer_config.tokenizer_path=Qwen/${model_name} \ tokenizer_config.tokenizer_type=huggingface \ tokenizer_config.add_bos=false \ diff --git a/examples/rl/grpo/gsm8k/run_qwen3_simplereward.sh b/examples/rl/grpo/gsm8k/run_qwen3_simplereward.sh index 714ab5f6b..8269e9080 100644 --- a/examples/rl/grpo/gsm8k/run_qwen3_simplereward.sh +++ b/examples/rl/grpo/gsm8k/run_qwen3_simplereward.sh @@ -53,8 +53,10 @@ python3 -m tunix.cli.grpo_main \ actor_model_config.lora_config.module_path=".*q_proj|.*k_proj|.*v_proj|.*o_proj|.*gate_proj|.*down_proj|.*up_proj" \ actor_model_config.mesh.shape="(2,4)" \ actor_model_config.mesh.axis_names="('fsdp','tp')" \ - rollout_model_config.mesh.shape="(2,4)" \ - rollout_model_config.mesh.axis_names="('fsdp','tp')" \ + reference_model_config.mesh=null \ + reference_model_config.same_mesh_as="actor" \ + rollout_model_config.mesh=null \ + rollout_model_config.same_mesh_as="actor" \ tokenizer_config.tokenizer_path=Qwen/${model_name} \ tokenizer_config.tokenizer_type=huggingface \ tokenizer_config.add_bos=false \ diff --git a/examples/rl/grpo/gsm8k/verl_compatible/run_llama3.2_1b.sh b/examples/rl/grpo/gsm8k/verl_compatible/run_llama3.2_1b.sh index b48072a98..0b49fa790 100755 --- a/examples/rl/grpo/gsm8k/verl_compatible/run_llama3.2_1b.sh +++ b/examples/rl/grpo/gsm8k/verl_compatible/run_llama3.2_1b.sh @@ -50,8 +50,10 @@ python3 -m tunix.cli.grpo_main \ actor_model_config.mesh.shape="(4,1)" \ actor_model_config.mesh.axis_names="('fsdp','tp')" \ actor_model_config.lora_config={} \ - rollout_model_config.mesh.shape="(4,1)" \ - rollout_model_config.mesh.axis_names="('fsdp','tp')" \ + reference_model_config.mesh=null \ + reference_model_config.same_mesh_as="actor" \ + rollout_model_config.mesh=null \ + rollout_model_config.same_mesh_as="actor" \ tokenizer_config.tokenizer_path="meta-llama/Llama-3.2-1B-Instruct" \ tokenizer_config.tokenizer_type="huggingface" \ tokenizer_config.add_bos=false \