From 9bbfbb360b9a1d3256498e7c26762d8617f2bb50 Mon Sep 17 00:00:00 2001 From: wang2yn84 Date: Thu, 30 Apr 2026 02:39:39 +0000 Subject: [PATCH] Fix non agentic colocate mode in cli. --- .../rl/grpo/gsm8k/run_qwen3_simplereward.sh | 2 +- ...run_qwen3.sh => run_qwen3_vanilla_lora.sh} | 11 ++- .../rl/grpo/gsm8k/run_qwen3_vllm_colocate.sh | 80 +++++++++++++++++++ tunix/cli/base_agentic_config.yaml | 8 +- tunix/cli/base_config.yaml | 10 ++- tunix/cli/grpo_main.py | 33 +++++--- 6 files changed, 116 insertions(+), 28 deletions(-) rename examples/rl/grpo/gsm8k/{run_qwen3.sh => run_qwen3_vanilla_lora.sh} (90%) create mode 100644 examples/rl/grpo/gsm8k/run_qwen3_vllm_colocate.sh diff --git a/examples/rl/grpo/gsm8k/run_qwen3_simplereward.sh b/examples/rl/grpo/gsm8k/run_qwen3_simplereward.sh index 4ddca5486..be92ce667 100644 --- a/examples/rl/grpo/gsm8k/run_qwen3_simplereward.sh +++ b/examples/rl/grpo/gsm8k/run_qwen3_simplereward.sh @@ -50,7 +50,7 @@ python3 -m tunix.cli.grpo_main \ model_config.rng_seed=42 \ actor_model_config.lora_config.rank=64 \ actor_model_config.lora_config.alpha=64.0 \ - actor_model_config.lora_config.module_path=".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|.*attn_vec_einsum" \ + actor_model_config.lora_config.module_path=".*q_proj|.*k_proj|.*v_proj|.*o_proj|.*gate_proj|.*down_proj|.*up_proj" \ actor_model_config.mesh.shape="(2,4)" \ actor_model_config.mesh.axis_names="('fsdp','tp')" \ rollout_model_config.mesh.shape="(2,4)" \ diff --git a/examples/rl/grpo/gsm8k/run_qwen3.sh b/examples/rl/grpo/gsm8k/run_qwen3_vanilla_lora.sh similarity index 90% rename from examples/rl/grpo/gsm8k/run_qwen3.sh rename to examples/rl/grpo/gsm8k/run_qwen3_vanilla_lora.sh index 7f0afafa5..b8f52fb69 100755 --- a/examples/rl/grpo/gsm8k/run_qwen3.sh +++ b/examples/rl/grpo/gsm8k/run_qwen3_vanilla_lora.sh @@ -21,12 +21,14 @@ batch_size=${batch_size:-8} num_train_epochs=${num_train_epochs:-1} warmup_ratio=${warmup_ratio:-0.1} train_fraction=${train_fraction:-0.8} +actor_mesh_shape=${actor_mesh_shape:-"(2,4)"} echo "Using parameters:" echo " Batch Size: $batch_size" echo " Num Epochs: $num_train_epochs" echo " Warmup Ratio: $warmup_ratio" echo " Train Fraction: $train_fraction" +echo " Actor Mesh Shape: $actor_mesh_shape" python3 -m tunix.cli.grpo_main \ base_config.yaml \ @@ -36,16 +38,13 @@ python3 -m tunix.cli.grpo_main \ model_config.use_flash_attention=true \ model_config.flash_attention_block_size=256 \ model_config.intermediate_ckpt_dir="/tmp/intermediate_ckpt/${model_name}" \ - model_config.mesh.shape="(2,4)" \ - model_config.mesh.axis_names="('fsdp','tp')" \ model_config.rng_seed=42 \ actor_model_config.lora_config.rank=64 \ actor_model_config.lora_config.alpha=64.0 \ - actor_model_config.lora_config.module_path=".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|.*attn_vec_einsum" \ - actor_model_config.mesh.shape="(2,4)" \ + actor_model_config.lora_config.module_path=".*q_proj|.*k_proj|.*v_proj|.*o_proj|.*gate_proj|.*down_proj|.*up_proj" \ + actor_model_config.mesh.shape="$actor_mesh_shape" \ actor_model_config.mesh.axis_names="('fsdp','tp')" \ - rollout_model_config.mesh.shape="(2,4)" \ - rollout_model_config.mesh.axis_names="('fsdp','tp')" \ + rollout_model_config.same_mesh_as="actor" \ tokenizer_config.tokenizer_path=Qwen/${model_name} \ tokenizer_config.tokenizer_type=huggingface \ tokenizer_config.add_bos=false \ diff --git a/examples/rl/grpo/gsm8k/run_qwen3_vllm_colocate.sh b/examples/rl/grpo/gsm8k/run_qwen3_vllm_colocate.sh new file mode 100644 index 000000000..cbfaa54d5 --- /dev/null +++ b/examples/rl/grpo/gsm8k/run_qwen3_vllm_colocate.sh @@ -0,0 +1,80 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +set -x # Enable xtrace + +# specify at cmd line to override defaults, e.g. +model_name=${model_name:-"Qwen3-1.7B-base"} +batch_size=${batch_size:-8} +num_train_epochs=${num_train_epochs:-1} +warmup_ratio=${warmup_ratio:-0.1} +train_fraction=${train_fraction:-0.8} +actor_mesh_shape=${actor_mesh_shape:-"(2,4)"} + +echo "Using parameters:" +echo " Batch Size: $batch_size" +echo " Num Epochs: $num_train_epochs" +echo " Warmup Ratio: $warmup_ratio" +echo " Train Fraction: $train_fraction" +echo " Actor Mesh Shape: $actor_mesh_shape" + +python3 -m tunix.cli.grpo_main \ + base_config.yaml \ + model_config.model_name=${model_name} \ + model_config.model_id=Qwen/${model_name} \ + model_config.model_source=huggingface \ + model_config.use_flash_attention=true \ + model_config.flash_attention_block_size=256 \ + model_config.intermediate_ckpt_dir="/tmp/intermediate_ckpt/${model_name}" \ + model_config.rng_seed=42 \ + actor_model_config.mesh.shape=${actor_mesh_shape} \ + actor_model_config.mesh.axis_names="('fsdp','tp')" \ + rollout_model_config.same_mesh_as="actor" \ + tokenizer_config.tokenizer_path=Qwen/${model_name} \ + tokenizer_config.tokenizer_type=huggingface \ + tokenizer_config.add_bos=false \ + dataset_name="gsm8k" \ + batch_size=$batch_size \ + num_test_batches=100 \ + num_train_epochs=$num_train_epochs \ + rl_training_config.actor_optimizer_config.opt_type="adamw" \ + rl_training_config.actor_optimizer_config.peak_value=3e-6 \ + rl_training_config.actor_optimizer_config.schedule_type="warmup_cosine_decay_schedule" \ + rl_training_config.actor_optimizer_config.init_value=0.0 \ + rl_training_config.actor_optimizer_config.end_value=0.0 \ + rl_training_config.actor_optimizer_config.warmup_ratio=$warmup_ratio \ + rl_training_config.actor_optimizer_config.b1=0.9 \ + rl_training_config.actor_optimizer_config.b2=0.99 \ + rl_training_config.actor_optimizer_config.weight_decay=0.1 \ + rl_training_config.actor_optimizer_config.max_grad_norm=0.1 \ + rl_training_config.eval_every_n_steps=10 \ + rl_training_config.metrics_logging_options.log_dir="/tmp/tensorboard/${model_name}" \ + rl_training_config.metrics_logging_options.flush_every_n_steps=20 \ + rl_training_config.checkpointing_options.save_interval_steps=500 \ + rl_training_config.checkpointing_options.max_to_keep=4 \ + rl_training_config.profiler_options={} \ + rollout_config.total_generation_steps=768 \ + rollout_config.max_prompt_length=256 \ + rollout_config.temperature=0.9 \ + rollout_config.top_p=1.0 \ + rollout_config.top_k=50 \ + rollout_engine="vllm" \ + vllm_config.async_scheduling=false \ + offload_to_cpu=false \ + grpo_config.num_generations=4 \ + grpo_config.num_iterations=1 \ + grpo_config.beta=0.08 \ + grpo_config.epsilon=0.2 \ + reward_functions="['tunix/cli/reward_fn/gsm8k.py']" diff --git a/tunix/cli/base_agentic_config.yaml b/tunix/cli/base_agentic_config.yaml index dca76e0cd..b95072f43 100644 --- a/tunix/cli/base_agentic_config.yaml +++ b/tunix/cli/base_agentic_config.yaml @@ -62,10 +62,10 @@ model_config: &base_model_config lora_config: {} ################################## MESH ################################## - mesh: - shape: "(2,2)" - # "('fsdp',)" - axis_names: "('fsdp','tp')" + # Base config should not set mesh, as it can be different for each role. + # mesh: + # shape: "(2,2)" + # axis_names: "('fsdp','tp')" actor_model_config: <<: *base_model_config diff --git a/tunix/cli/base_config.yaml b/tunix/cli/base_config.yaml index 728c43a95..a0272c902 100644 --- a/tunix/cli/base_config.yaml +++ b/tunix/cli/base_config.yaml @@ -62,10 +62,10 @@ model_config: &base_model_config lora_config: {} ################################## MESH ################################## - mesh: - shape: "(2,2)" - # "('fsdp',)" - axis_names: "('fsdp','tp')" + # Base config should not set mesh, as it can be different for each role. + # mesh: + # shape: "(2,2)" + # axis_names: "('fsdp','tp')" actor_model_config: <<: *base_model_config @@ -213,6 +213,8 @@ rollout_config: ############################### Other RL Config ############################### rollout_engine: "vanilla" +vllm_config: {} +sglang_jax_config: {} offload_to_cpu: false diff --git a/tunix/cli/grpo_main.py b/tunix/cli/grpo_main.py index 5de335bf7..114cc6eff 100644 --- a/tunix/cli/grpo_main.py +++ b/tunix/cli/grpo_main.py @@ -279,6 +279,8 @@ def create_rollout_config( max_prompt = rollout_cfg.get("max_prompt_length", 0) max_response = rollout_cfg.get("total_generation_steps", 0) + kv_cache_size = 0 + max_concurrency = 0 if mode == "agentic_grpo": agentic_cfg = self.config.get("agentic_grpo_config", {}) max_turns = agentic_cfg.get("max_turns", 1) @@ -290,26 +292,30 @@ def create_rollout_config( filtered["kv_cache_size"] = kv_cache_size logging.info("kv_cache_size: %d", kv_cache_size) - # Engine-specific extras - extra = self._agentic_engine_extra( - engine, - kv_cache_size, - agentic_cfg, - role_to_mesh=role_to_mesh, - ) - filtered.update({k: v for k, v in extra.items() if k in valid_fields}) + max_running_requests = agentic_cfg.get("max_concurrency", 16) else: # Standard: kv_cache_size = max_prompt + max_response + 256 if max_prompt and max_response: - filtered["kv_cache_size"] = max_prompt + max_response + 256 - + kv_cache_size = max_prompt + max_response + 256 + filtered["kv_cache_size"] = kv_cache_size + # Defaults to global batch size * num_generations to allow full concurrency + max_running_requests = self.config.get("batch_size", 1) * self.config.get("grpo_config", {}).get("num_generations", 1) + + # Engine-specific extras + extra = self._rollout_engine_extra( + engine, + kv_cache_size, + max_running_requests, + role_to_mesh=role_to_mesh, + ) + filtered.update({k: v for k, v in extra.items() if k in valid_fields}) return base_rollout.RolloutConfig(**filtered) - def _agentic_engine_extra( + def _rollout_engine_extra( self, engine: str, kv_cache_size: int, - agentic_cfg: dict, + max_running_requests: int, role_to_mesh: dict[rl_cluster_lib.Role, jax.sharding.Mesh] | None = None, ) -> dict: """Return engine-specific RolloutConfig fields for agentic mode.""" @@ -336,7 +342,7 @@ def _agentic_engine_extra( ), rollout_sglang_jax_max_running_requests=sg.get( "max_running_requests", - agentic_cfg.get("max_concurrency", 768), + max_running_requests, ), rollout_sglang_jax_page_size=sg.get("page_size", 128), rollout_sglang_jax_use_sort_for_toppk_minp=sg.get( @@ -529,6 +535,7 @@ def compute_params(self, dataset): num_batches = self.config.get("num_batches") if not num_batches: num_batches = dataset_length // batch_size + self.config["num_batches"] = num_batches logging.info( "Dynamically computed num_batches=%d with batch_size=%d", num_batches,