From 71ca0d65131e5a6f64a729e63426f090dfb53cd8 Mon Sep 17 00:00:00 2001 From: Shadi Noghabi Date: Wed, 29 Apr 2026 10:47:10 -0700 Subject: [PATCH] yaml based configs for llama cli scripts PiperOrigin-RevId: 907669931 --- .../deepscaler/run_deepscaler_disagg_v5p16.sh | 7 +- examples/deepswe/run_deepswe_disagg_v5p_32.sh | 7 +- .../rl/grpo/gsm8k/configs/llama3.1_8b.yaml | 81 +++++++++++++++++++ .../rl/grpo/gsm8k/configs/llama3.2_1b.yaml | 81 +++++++++++++++++++ examples/rl/grpo/gsm8k/run_gemma3_12b.sh | 6 +- examples/rl/grpo/gsm8k/run_gemma3_1b.sh | 6 +- examples/rl/grpo/gsm8k/run_gemma3_4b.sh | 6 +- examples/rl/grpo/gsm8k/run_gemma_7b.sh | 6 +- examples/rl/grpo/gsm8k/run_llama3.1_8b.sh | 48 +---------- examples/rl/grpo/gsm8k/run_llama3.2_1b.sh | 48 +---------- examples/rl/grpo/gsm8k/run_qwen3.sh | 6 +- ...run_qwen3_8b.sh => run_qwen3_8b_disagg.sh} | 6 +- ...text.sh => run_qwen3_8b_disagg_maxtext.sh} | 6 +- .../rl/grpo/gsm8k/run_qwen3_simplereward.sh | 6 +- .../gsm8k/verl_compatible/run_llama3.2_1b.sh | 6 +- 15 files changed, 214 insertions(+), 112 deletions(-) create mode 100644 examples/rl/grpo/gsm8k/configs/llama3.1_8b.yaml create mode 100644 examples/rl/grpo/gsm8k/configs/llama3.2_1b.yaml rename examples/rl/grpo/gsm8k/{run_qwen3_8b.sh => run_qwen3_8b_disagg.sh} (97%) rename examples/rl/grpo/gsm8k/{run_qwen3_8b_maxtext.sh => run_qwen3_8b_disagg_maxtext.sh} (97%) diff --git a/examples/deepscaler/run_deepscaler_disagg_v5p16.sh b/examples/deepscaler/run_deepscaler_disagg_v5p16.sh index 7feb5b3d9..e518d6ca3 100755 --- a/examples/deepscaler/run_deepscaler_disagg_v5p16.sh +++ b/examples/deepscaler/run_deepscaler_disagg_v5p16.sh @@ -32,8 +32,11 @@ warmup_ratio="${warmup_ratio:-0.1}" batch_size="${batch_size:-128}" mini_batch_size="${mini_batch_size:-128}" max_response_length="${max_response_length:-8192}" -trainer_mesh="${trainer_mesh:-(4,1)}" -rollout_mesh="${rollout_mesh:-(4,1)}" +total_tpus="${total_tpus:-8}" + +axis_size=$((total_tpus / 2)) +trainer_mesh="${trainer_mesh:-($axis_size,1)}" +rollout_mesh="${rollout_mesh:-($axis_size,1)}" checkpoint_dir="${checkpoint_dir:-gs://tunix/rl/checkpoints/01}" checkpoint_suffix="${checkpoint_suffix:-$(printf '%04d' "$((RANDOM % 10000))")}" diff --git a/examples/deepswe/run_deepswe_disagg_v5p_32.sh b/examples/deepswe/run_deepswe_disagg_v5p_32.sh index 4eceb7ae2..ca6c3da68 100755 --- a/examples/deepswe/run_deepswe_disagg_v5p_32.sh +++ b/examples/deepswe/run_deepswe_disagg_v5p_32.sh @@ -42,9 +42,12 @@ rollout_micro_batch_size="${rollout_micro_batch_size:-1}" num_generations="${num_generations:-2}" max_response_length="${max_response_length:-8192}" +total_tpus="${total_tpus:-32}" -trainer_mesh="${trainer_mesh:-(8,2)}" -rollout_mesh="${rollout_mesh:-(2,8)}" +# mesh shapes are (2, X) or (X, 2) for each of trainer or rollout. +axis_size=$((total_tpus / 4)) +trainer_mesh="${trainer_mesh:-($axis_size,2)}" +rollout_mesh="${rollout_mesh:-(2,$axis_size)}" checkpoint_dir="${checkpoint_dir:-gs://tunix/rl/checkpoints/01}" checkpoint_suffix="${checkpoint_suffix:-$(printf '%04d' "$((RANDOM % 10000))")}" diff --git a/examples/rl/grpo/gsm8k/configs/llama3.1_8b.yaml b/examples/rl/grpo/gsm8k/configs/llama3.1_8b.yaml new file mode 100644 index 000000000..c64744a97 --- /dev/null +++ b/examples/rl/grpo/gsm8k/configs/llama3.1_8b.yaml @@ -0,0 +1,81 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +model_config: + model_name: "llama-3.1-8b" + model_id: "meta-llama/Llama-3.1-8B" + model_source: "huggingface" + mesh: + shape: "(2,4)" + axis_names: "('fsdp','tp')" + rng_seed: 42 +actor_model_config: + lora_config: + rank: 64 + alpha: 64.0 + module_path: ".*q_proj|.*k_proj|.*v_proj|.*o_proj|.*gate_proj|.*down_proj|.*up_proj" + mesh: + shape: "(2,4)" + axis_names: "('fsdp','tp')" +reference_model_config: + mesh: null + same_mesh_as: "actor" +rollout_model_config: + mesh: null + same_mesh_as: "actor" +tokenizer_config: + tokenizer_type: "huggingface" + add_bos: False +dataset_name: "gsm8k" +batch_size: 8 +num_batches: 3738 +num_test_batches: 100 +num_train_epochs: 1 +rl_training_config: + actor_optimizer_config: + opt_type: "adamw" + peak_value: 3e-6 + schedule_type: "warmup_cosine_decay_schedule" + init_value: 0.0 + end_value: 0.0 + warmup_ratio: 0.1 + warmup_steps: 374 + decay_steps: 3738 + b1: 0.9 + b2: 0.99 + weight_decay: 0.1 + max_grad_norm: 0.1 + eval_every_n_steps: 10 + max_steps: 3738 + metrics_logging_options: + flush_every_n_steps: 20 + checkpointing_options: + save_interval_steps: 500 + max_to_keep: 4 + profiler_options: {} +rollout_config: + total_generation_steps: 768 + max_prompt_length: 256 + temperature: 0.9 + top_p: 1.0 + top_k: 50 +rollout_engine: "vanilla" +offload_to_cpu: False +grpo_config: + num_generations: 4 + num_iterations: 1 + beta: 0.08 + epsilon: 0.2 +reward_functions: + - "tunix/cli/reward_fn/gsm8k.py" diff --git a/examples/rl/grpo/gsm8k/configs/llama3.2_1b.yaml b/examples/rl/grpo/gsm8k/configs/llama3.2_1b.yaml new file mode 100644 index 000000000..ed4beee43 --- /dev/null +++ b/examples/rl/grpo/gsm8k/configs/llama3.2_1b.yaml @@ -0,0 +1,81 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +model_config: + model_name: "llama3.2-1b" + model_id: "meta-llama/Llama-3.2-1B" + model_source: "huggingface" + mesh: + shape: "(2,4)" + axis_names: "('fsdp','tp')" + rng_seed: 42 +actor_model_config: + lora_config: + rank: 64 + alpha: 64.0 + module_path: ".*q_proj|.*k_proj|.*v_proj|.*o_proj|.*gate_proj|.*down_proj|.*up_proj" + mesh: + shape: "(2,4)" + axis_names: "('fsdp','tp')" +reference_model_config: + mesh: null + same_mesh_as: "actor" +rollout_model_config: + mesh: null + same_mesh_as: "actor" +tokenizer_config: + tokenizer_type: "huggingface" + add_bos: False +dataset_name: "gsm8k" +batch_size: 1 +num_batches: 3738 +num_test_batches: 100 +num_train_epochs: 1 +rl_training_config: + actor_optimizer_config: + opt_type: "adamw" + peak_value: 3e-6 + schedule_type: "warmup_cosine_decay_schedule" + init_value: 0.0 + end_value: 0.0 + warmup_ratio: 0.1 + warmup_steps: 374 + decay_steps: 3738 + b1: 0.9 + b2: 0.99 + weight_decay: 0.1 + max_grad_norm: 0.1 + eval_every_n_steps: 10 + max_steps: 3738 + metrics_logging_options: + flush_every_n_steps: 20 + checkpointing_options: + save_interval_steps: 500 + max_to_keep: 4 + profiler_options: {} +rollout_config: + total_generation_steps: 768 + max_prompt_length: 256 + temperature: 0.9 + top_p: 1.0 + top_k: 50 +rollout_engine: "vanilla" +offload_to_cpu: False +grpo_config: + num_generations: 4 + num_iterations: 1 + beta: 0.08 + epsilon: 0.2 +reward_functions: + - "tunix/cli/reward_fn/gsm8k.py" diff --git a/examples/rl/grpo/gsm8k/run_gemma3_12b.sh b/examples/rl/grpo/gsm8k/run_gemma3_12b.sh index 2b3f5cb67..4db932abf 100755 --- a/examples/rl/grpo/gsm8k/run_gemma3_12b.sh +++ b/examples/rl/grpo/gsm8k/run_gemma3_12b.sh @@ -53,8 +53,10 @@ python3 -m tunix.cli.grpo_main \ actor_model_config.lora_config.module_path=".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|.*attn_vec_einsum" \ actor_model_config.mesh.shape="(2,4)" \ actor_model_config.mesh.axis_names="('fsdp','tp')" \ - rollout_model_config.mesh.shape="(2,4)" \ - rollout_model_config.mesh.axis_names="('fsdp','tp')" \ + reference_model_config.mesh=null \ + reference_model_config.same_mesh_as="actor" \ + rollout_model_config.mesh=null \ + rollout_model_config.same_mesh_as="actor" \ tokenizer_config.tokenizer_path="gs://gemma-data/tokenizers/tokenizer_gemma3.model" \ tokenizer_config.tokenizer_type="sentencepiece" \ tokenizer_config.add_bos=false \ diff --git a/examples/rl/grpo/gsm8k/run_gemma3_1b.sh b/examples/rl/grpo/gsm8k/run_gemma3_1b.sh index 7066c2323..f5c9cee5c 100755 --- a/examples/rl/grpo/gsm8k/run_gemma3_1b.sh +++ b/examples/rl/grpo/gsm8k/run_gemma3_1b.sh @@ -53,8 +53,10 @@ python3 -m tunix.cli.grpo_main \ actor_model_config.lora_config.module_path=".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|.*attn_vec_einsum" \ actor_model_config.mesh.shape="(2,4)" \ actor_model_config.mesh.axis_names="('fsdp','tp')" \ - rollout_model_config.mesh.shape="(2,4)" \ - rollout_model_config.mesh.axis_names="('fsdp','tp')" \ + reference_model_config.mesh=null \ + reference_model_config.same_mesh_as="actor" \ + rollout_model_config.mesh=null \ + rollout_model_config.same_mesh_as="actor" \ tokenizer_config.tokenizer_path="gs://gemma-data/tokenizers/tokenizer_gemma3.model" \ tokenizer_config.tokenizer_type="sentencepiece" \ tokenizer_config.add_bos=false \ diff --git a/examples/rl/grpo/gsm8k/run_gemma3_4b.sh b/examples/rl/grpo/gsm8k/run_gemma3_4b.sh index 4648329b1..a94ab6701 100755 --- a/examples/rl/grpo/gsm8k/run_gemma3_4b.sh +++ b/examples/rl/grpo/gsm8k/run_gemma3_4b.sh @@ -53,8 +53,10 @@ python3 -m tunix.cli.grpo_main \ actor_model_config.lora_config.module_path=".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|.*attn_vec_einsum" \ actor_model_config.mesh.shape="(2,4)" \ actor_model_config.mesh.axis_names="('fsdp','tp')" \ - rollout_model_config.mesh.shape="(2,4)" \ - rollout_model_config.mesh.axis_names="('fsdp','tp')" \ + reference_model_config.mesh=null \ + reference_model_config.same_mesh_as="actor" \ + rollout_model_config.mesh=null \ + rollout_model_config.same_mesh_as="actor" \ tokenizer_config.tokenizer_path="gs://gemma-data/tokenizers/tokenizer_gemma3.model" \ tokenizer_config.tokenizer_type="sentencepiece" \ tokenizer_config.add_bos=false \ diff --git a/examples/rl/grpo/gsm8k/run_gemma_7b.sh b/examples/rl/grpo/gsm8k/run_gemma_7b.sh index 75f9f3f7a..8f6072ca0 100755 --- a/examples/rl/grpo/gsm8k/run_gemma_7b.sh +++ b/examples/rl/grpo/gsm8k/run_gemma_7b.sh @@ -55,8 +55,10 @@ python3 -m tunix.cli.grpo_main \ actor_model_config.lora_config.module_path=".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|.*attn_vec_einsum" \ actor_model_config.mesh.shape="(2,4)" \ actor_model_config.mesh.axis_names="('fsdp','tp')" \ - rollout_model_config.mesh.shape="(2,4)" \ - rollout_model_config.mesh.axis_names="('fsdp','tp')" \ + reference_model_config.mesh=null \ + reference_model_config.same_mesh_as="actor" \ + rollout_model_config.mesh=null \ + rollout_model_config.same_mesh_as="actor" \ tokenizer_config.tokenizer_path="/tmp/models/gemma-7b/models/google/gemma/flax/7b-it/2/tokenizer.model" \ tokenizer_config.tokenizer_type="sentencepiece" \ tokenizer_config.add_bos=false \ diff --git a/examples/rl/grpo/gsm8k/run_llama3.1_8b.sh b/examples/rl/grpo/gsm8k/run_llama3.1_8b.sh index 8d8bbed6d..851c49b13 100755 --- a/examples/rl/grpo/gsm8k/run_llama3.1_8b.sh +++ b/examples/rl/grpo/gsm8k/run_llama3.1_8b.sh @@ -39,58 +39,16 @@ echo "Max steps: $max_steps" echo "Rounded warmup steps: $warmup_steps" python3 -m tunix.cli.grpo_main \ - base_config.yaml \ - model_config.model_name="llama-3.1-8b" \ - model_config.model_id="meta-llama/Llama-3.1-8B" \ - model_config.model_source="huggingface" \ + tunix/cli/base_config.yaml \ + override_config_file=examples/rl/grpo/gsm8k/configs/llama3.1_8b.yaml \ model_config.intermediate_ckpt_dir="/tmp/intermediate_ckpt/1" \ - model_config.mesh.shape="(2,4)" \ - model_config.mesh.axis_names="('fsdp','tp')" \ - model_config.rng_seed=42 \ - actor_model_config.lora_config.rank=64 \ - actor_model_config.lora_config.alpha=64.0 \ - actor_model_config.lora_config.module_path=".*q_proj|.*k_proj|.*v_proj|.*o_proj|.*gate_proj|.*down_proj|.*up_proj" \ - actor_model_config.mesh.shape="(2,4)" \ - actor_model_config.mesh.axis_names="('fsdp','tp')" \ - rollout_model_config.mesh.shape="(2,4)" \ - rollout_model_config.mesh.axis_names="('fsdp','tp')" \ tokenizer_config.tokenizer_path="meta-llama/Llama-3.1-8B" \ - tokenizer_config.tokenizer_type="huggingface" \ - tokenizer_config.add_bos=false \ - dataset_name="gsm8k" \ batch_size=$batch_size \ num_batches=$num_batches \ - num_test_batches=100 \ num_train_epochs=$num_train_epochs \ train_fraction=$train_fraction \ - rl_training_config.actor_optimizer_config.opt_type="adamw" \ - rl_training_config.actor_optimizer_config.peak_value=3e-6 \ - rl_training_config.actor_optimizer_config.schedule_type="warmup_cosine_decay_schedule" \ - rl_training_config.actor_optimizer_config.init_value=0.0 \ - rl_training_config.actor_optimizer_config.end_value=0.0 \ rl_training_config.actor_optimizer_config.warmup_ratio=$warmup_ratio \ rl_training_config.actor_optimizer_config.warmup_steps=$warmup_steps \ rl_training_config.actor_optimizer_config.decay_steps=$max_steps \ - rl_training_config.actor_optimizer_config.b1=0.9 \ - rl_training_config.actor_optimizer_config.b2=0.99 \ - rl_training_config.actor_optimizer_config.weight_decay=0.1 \ - rl_training_config.actor_optimizer_config.max_grad_norm=0.1 \ - rl_training_config.eval_every_n_steps=10 \ rl_training_config.max_steps=$max_steps \ - rl_training_config.metrics_logging_options.log_dir="/tmp/tensorboard/grpo" \ - rl_training_config.metrics_logging_options.flush_every_n_steps=20 \ - rl_training_config.checkpointing_options.save_interval_steps=500 \ - rl_training_config.checkpointing_options.max_to_keep=4 \ - rl_training_config.profiler_options={} \ - rollout_config.total_generation_steps=768 \ - rollout_config.max_prompt_length=256 \ - rollout_config.temperature=0.9 \ - rollout_config.top_p=1.0 \ - rollout_config.top_k=50 \ - rollout_engine="vanilla" \ - offload_to_cpu=false \ - grpo_config.num_generations=4 \ - grpo_config.num_iterations=1 \ - grpo_config.beta=0.08 \ - grpo_config.epsilon=0.2 \ - reward_functions="['tunix/cli/reward_fn/gsm8k.py']" + rl_training_config.metrics_logging_options.log_dir="/tmp/tensorboard/grpo" diff --git a/examples/rl/grpo/gsm8k/run_llama3.2_1b.sh b/examples/rl/grpo/gsm8k/run_llama3.2_1b.sh index 17b883e03..182f71c34 100755 --- a/examples/rl/grpo/gsm8k/run_llama3.2_1b.sh +++ b/examples/rl/grpo/gsm8k/run_llama3.2_1b.sh @@ -39,58 +39,16 @@ echo "Max steps: $max_steps" echo "Rounded warmup steps: $warmup_steps" python3 -m tunix.cli.grpo_main \ - base_config.yaml \ - model_config.model_name="llama3.2-1b" \ - model_config.model_id="meta-llama/Llama-3.2-1B" \ - model_config.model_source="huggingface" \ + tunix/cli/base_config.yaml \ + override_config_file=examples/rl/grpo/gsm8k/configs/llama3.2_1b.yaml \ model_config.intermediate_ckpt_dir="/tmp/intermediate_ckpt/1" \ - model_config.mesh.shape="(2,4)" \ - model_config.mesh.axis_names="('fsdp','tp')" \ - model_config.rng_seed=42 \ - actor_model_config.lora_config.rank=64 \ - actor_model_config.lora_config.alpha=64.0 \ - actor_model_config.lora_config.module_path=".*q_proj|.*k_proj|.*v_proj|.*o_proj|.*gate_proj|.*down_proj|.*up_proj" \ - actor_model_config.mesh.shape="(2,4)" \ - actor_model_config.mesh.axis_names="('fsdp','tp')" \ - rollout_model_config.mesh.shape="(2,4)" \ - rollout_model_config.mesh.axis_names="('fsdp','tp')" \ tokenizer_config.tokenizer_path="meta-llama/Llama-3.2-1B" \ - tokenizer_config.tokenizer_type="huggingface" \ - tokenizer_config.add_bos=false \ - dataset_name="gsm8k" \ batch_size=$batch_size \ num_batches=$num_batches \ - num_test_batches=100 \ num_train_epochs=$num_train_epochs \ train_fraction=$train_fraction \ - rl_training_config.actor_optimizer_config.opt_type="adamw" \ - rl_training_config.actor_optimizer_config.peak_value=3e-6 \ - rl_training_config.actor_optimizer_config.schedule_type="warmup_cosine_decay_schedule" \ - rl_training_config.actor_optimizer_config.init_value=0.0 \ - rl_training_config.actor_optimizer_config.end_value=0.0 \ rl_training_config.actor_optimizer_config.warmup_ratio=$warmup_ratio \ rl_training_config.actor_optimizer_config.warmup_steps=$warmup_steps \ rl_training_config.actor_optimizer_config.decay_steps=$max_steps \ - rl_training_config.actor_optimizer_config.b1=0.9 \ - rl_training_config.actor_optimizer_config.b2=0.99 \ - rl_training_config.actor_optimizer_config.weight_decay=0.1 \ - rl_training_config.actor_optimizer_config.max_grad_norm=0.1 \ - rl_training_config.eval_every_n_steps=10 \ rl_training_config.max_steps=$max_steps \ - rl_training_config.metrics_logging_options.log_dir="/tmp/tensorboard/grpo" \ - rl_training_config.metrics_logging_options.flush_every_n_steps=20 \ - rl_training_config.checkpointing_options.save_interval_steps=500 \ - rl_training_config.checkpointing_options.max_to_keep=4 \ - rl_training_config.profiler_options={} \ - rollout_config.total_generation_steps=768 \ - rollout_config.max_prompt_length=256 \ - rollout_config.temperature=0.9 \ - rollout_config.top_p=1.0 \ - rollout_config.top_k=50 \ - rollout_engine="vanilla" \ - offload_to_cpu=false \ - grpo_config.num_generations=4 \ - grpo_config.num_iterations=1 \ - grpo_config.beta=0.08 \ - grpo_config.epsilon=0.2 \ - reward_functions="['tunix/cli/reward_fn/gsm8k.py']" + rl_training_config.metrics_logging_options.log_dir="/tmp/tensorboard/grpo" diff --git a/examples/rl/grpo/gsm8k/run_qwen3.sh b/examples/rl/grpo/gsm8k/run_qwen3.sh index e644dafa0..76985952b 100755 --- a/examples/rl/grpo/gsm8k/run_qwen3.sh +++ b/examples/rl/grpo/gsm8k/run_qwen3.sh @@ -44,8 +44,10 @@ python3 -m tunix.cli.grpo_main \ actor_model_config.lora_config.module_path=".*q_proj|.*k_proj|.*v_proj|.*o_proj|.*gate_proj|.*down_proj|.*up_proj" \ actor_model_config.mesh.shape="(2,4)" \ actor_model_config.mesh.axis_names="('fsdp','tp')" \ - rollout_model_config.mesh.shape="(2,4)" \ - rollout_model_config.mesh.axis_names="('fsdp','tp')" \ + reference_model_config.mesh=null \ + reference_model_config.same_mesh_as="actor" \ + rollout_model_config.mesh=null \ + rollout_model_config.same_mesh_as="actor" \ tokenizer_config.tokenizer_path=Qwen/${model_name} \ tokenizer_config.tokenizer_type=huggingface \ tokenizer_config.add_bos=false \ diff --git a/examples/rl/grpo/gsm8k/run_qwen3_8b.sh b/examples/rl/grpo/gsm8k/run_qwen3_8b_disagg.sh similarity index 97% rename from examples/rl/grpo/gsm8k/run_qwen3_8b.sh rename to examples/rl/grpo/gsm8k/run_qwen3_8b_disagg.sh index dfc5e7c11..867c5028c 100755 --- a/examples/rl/grpo/gsm8k/run_qwen3_8b.sh +++ b/examples/rl/grpo/gsm8k/run_qwen3_8b_disagg.sh @@ -41,9 +41,11 @@ rollout_micro_batch_size="${rollout_micro_batch_size:-8}" compute_logps_micro_batch_size="${compute_logps_micro_batch_size:-1}" num_generations="${num_generations:-4}" +total_tpus="${total_tpus:-8}" -train_mesh="${train_mesh:-(8,1)}" -rollout_mesh="${rollout_mesh:-(1,8)}" +axis_size=$((total_tpus / 2)) +train_mesh="${train_mesh:-($axis_size,1)}" +rollout_mesh="${rollout_mesh:-(1,$axis_size)}" checkpoint_dir="${checkpoint_dir:-gs://tunix/rl/checkpoints/gsm8k/qwen3/01}" checkpoint_suffix="${checkpoint_suffix:-$(printf '%04d' "$((RANDOM % 10000))")}" diff --git a/examples/rl/grpo/gsm8k/run_qwen3_8b_maxtext.sh b/examples/rl/grpo/gsm8k/run_qwen3_8b_disagg_maxtext.sh similarity index 97% rename from examples/rl/grpo/gsm8k/run_qwen3_8b_maxtext.sh rename to examples/rl/grpo/gsm8k/run_qwen3_8b_disagg_maxtext.sh index 3fc3f1679..c456aeb3c 100755 --- a/examples/rl/grpo/gsm8k/run_qwen3_8b_maxtext.sh +++ b/examples/rl/grpo/gsm8k/run_qwen3_8b_disagg_maxtext.sh @@ -39,9 +39,11 @@ rollout_micro_batch_size="${rollout_micro_batch_size:-8}" compute_logps_micro_batch_size="${compute_logps_micro_batch_size:-1}" num_generations="${num_generations:-4}" +total_tpus="${total_tpus:-8}" -train_mesh="${train_mesh:-(8,1)}" -rollout_mesh="${rollout_mesh:-(1,8)}" +axis_size=$((total_tpus / 2)) +train_mesh="${train_mesh:-($axis_size,1)}" +rollout_mesh="${rollout_mesh:-(1,$axis_size)}" checkpoint_dir="${checkpoint_dir:-gs://tunix/rl/checkpoints/gsm8k/qwen3/01}" checkpoint_suffix="${checkpoint_suffix:-$(printf '%04d' "$((RANDOM % 10000))")}" diff --git a/examples/rl/grpo/gsm8k/run_qwen3_simplereward.sh b/examples/rl/grpo/gsm8k/run_qwen3_simplereward.sh index 714ab5f6b..8269e9080 100644 --- a/examples/rl/grpo/gsm8k/run_qwen3_simplereward.sh +++ b/examples/rl/grpo/gsm8k/run_qwen3_simplereward.sh @@ -53,8 +53,10 @@ python3 -m tunix.cli.grpo_main \ actor_model_config.lora_config.module_path=".*q_proj|.*k_proj|.*v_proj|.*o_proj|.*gate_proj|.*down_proj|.*up_proj" \ actor_model_config.mesh.shape="(2,4)" \ actor_model_config.mesh.axis_names="('fsdp','tp')" \ - rollout_model_config.mesh.shape="(2,4)" \ - rollout_model_config.mesh.axis_names="('fsdp','tp')" \ + reference_model_config.mesh=null \ + reference_model_config.same_mesh_as="actor" \ + rollout_model_config.mesh=null \ + rollout_model_config.same_mesh_as="actor" \ tokenizer_config.tokenizer_path=Qwen/${model_name} \ tokenizer_config.tokenizer_type=huggingface \ tokenizer_config.add_bos=false \ diff --git a/examples/rl/grpo/gsm8k/verl_compatible/run_llama3.2_1b.sh b/examples/rl/grpo/gsm8k/verl_compatible/run_llama3.2_1b.sh index b48072a98..0b49fa790 100755 --- a/examples/rl/grpo/gsm8k/verl_compatible/run_llama3.2_1b.sh +++ b/examples/rl/grpo/gsm8k/verl_compatible/run_llama3.2_1b.sh @@ -50,8 +50,10 @@ python3 -m tunix.cli.grpo_main \ actor_model_config.mesh.shape="(4,1)" \ actor_model_config.mesh.axis_names="('fsdp','tp')" \ actor_model_config.lora_config={} \ - rollout_model_config.mesh.shape="(4,1)" \ - rollout_model_config.mesh.axis_names="('fsdp','tp')" \ + reference_model_config.mesh=null \ + reference_model_config.same_mesh_as="actor" \ + rollout_model_config.mesh=null \ + rollout_model_config.same_mesh_as="actor" \ tokenizer_config.tokenizer_path="meta-llama/Llama-3.2-1B-Instruct" \ tokenizer_config.tokenizer_type="huggingface" \ tokenizer_config.add_bos=false \