Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/rl/grpo/gsm8k/run_qwen3_simplereward.sh
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ python3 -m tunix.cli.grpo_main \
model_config.rng_seed=42 \
actor_model_config.lora_config.rank=64 \
actor_model_config.lora_config.alpha=64.0 \
actor_model_config.lora_config.module_path=".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|.*attn_vec_einsum" \
actor_model_config.lora_config.module_path=".*q_proj|.*k_proj|.*v_proj|.*o_proj|.*gate_proj|.*down_proj|.*up_proj" \
actor_model_config.mesh.shape="(2,4)" \
actor_model_config.mesh.axis_names="('fsdp','tp')" \
rollout_model_config.mesh.shape="(2,4)" \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@ batch_size=${batch_size:-8}
num_train_epochs=${num_train_epochs:-1}
warmup_ratio=${warmup_ratio:-0.1}
train_fraction=${train_fraction:-0.8}
actor_mesh_shape=${actor_mesh_shape:-"(2,4)"}

echo "Using parameters:"
echo " Batch Size: $batch_size"
echo " Num Epochs: $num_train_epochs"
echo " Warmup Ratio: $warmup_ratio"
echo " Train Fraction: $train_fraction"
echo " Actor Mesh Shape: $actor_mesh_shape"

python3 -m tunix.cli.grpo_main \
base_config.yaml \
Expand All @@ -36,16 +38,13 @@ python3 -m tunix.cli.grpo_main \
model_config.use_flash_attention=true \
model_config.flash_attention_block_size=256 \
model_config.intermediate_ckpt_dir="/tmp/intermediate_ckpt/${model_name}" \
model_config.mesh.shape="(2,4)" \
model_config.mesh.axis_names="('fsdp','tp')" \
model_config.rng_seed=42 \
actor_model_config.lora_config.rank=64 \
actor_model_config.lora_config.alpha=64.0 \
actor_model_config.lora_config.module_path=".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|.*attn_vec_einsum" \
actor_model_config.mesh.shape="(2,4)" \
actor_model_config.lora_config.module_path=".*q_proj|.*k_proj|.*v_proj|.*o_proj|.*gate_proj|.*down_proj|.*up_proj" \
actor_model_config.mesh.shape="$actor_mesh_shape" \
actor_model_config.mesh.axis_names="('fsdp','tp')" \
rollout_model_config.mesh.shape="(2,4)" \
rollout_model_config.mesh.axis_names="('fsdp','tp')" \
rollout_model_config.same_mesh_as="actor" \
tokenizer_config.tokenizer_path=Qwen/${model_name} \
tokenizer_config.tokenizer_type=huggingface \
tokenizer_config.add_bos=false \
Expand Down
80 changes: 80 additions & 0 deletions examples/rl/grpo/gsm8k/run_qwen3_vllm_colocate.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


set -x # Enable xtrace

# specify at cmd line to override defaults, e.g.
model_name=${model_name:-"Qwen3-1.7B-base"}
batch_size=${batch_size:-8}
num_train_epochs=${num_train_epochs:-1}
warmup_ratio=${warmup_ratio:-0.1}
train_fraction=${train_fraction:-0.8}
actor_mesh_shape=${actor_mesh_shape:-"(2,4)"}

echo "Using parameters:"
echo " Batch Size: $batch_size"
echo " Num Epochs: $num_train_epochs"
echo " Warmup Ratio: $warmup_ratio"
echo " Train Fraction: $train_fraction"
echo " Actor Mesh Shape: $actor_mesh_shape"

python3 -m tunix.cli.grpo_main \
base_config.yaml \
model_config.model_name=${model_name} \
model_config.model_id=Qwen/${model_name} \
model_config.model_source=huggingface \
model_config.use_flash_attention=true \
model_config.flash_attention_block_size=256 \
model_config.intermediate_ckpt_dir="/tmp/intermediate_ckpt/${model_name}" \
model_config.rng_seed=42 \
actor_model_config.mesh.shape=${actor_mesh_shape} \
actor_model_config.mesh.axis_names="('fsdp','tp')" \
rollout_model_config.same_mesh_as="actor" \
tokenizer_config.tokenizer_path=Qwen/${model_name} \
tokenizer_config.tokenizer_type=huggingface \
tokenizer_config.add_bos=false \
dataset_name="gsm8k" \
batch_size=$batch_size \
num_test_batches=100 \
num_train_epochs=$num_train_epochs \
rl_training_config.actor_optimizer_config.opt_type="adamw" \
rl_training_config.actor_optimizer_config.peak_value=3e-6 \
rl_training_config.actor_optimizer_config.schedule_type="warmup_cosine_decay_schedule" \
rl_training_config.actor_optimizer_config.init_value=0.0 \
rl_training_config.actor_optimizer_config.end_value=0.0 \
rl_training_config.actor_optimizer_config.warmup_ratio=$warmup_ratio \
rl_training_config.actor_optimizer_config.b1=0.9 \
rl_training_config.actor_optimizer_config.b2=0.99 \
rl_training_config.actor_optimizer_config.weight_decay=0.1 \
rl_training_config.actor_optimizer_config.max_grad_norm=0.1 \
rl_training_config.eval_every_n_steps=10 \
rl_training_config.metrics_logging_options.log_dir="/tmp/tensorboard/${model_name}" \
rl_training_config.metrics_logging_options.flush_every_n_steps=20 \
rl_training_config.checkpointing_options.save_interval_steps=500 \
rl_training_config.checkpointing_options.max_to_keep=4 \
rl_training_config.profiler_options={} \
rollout_config.total_generation_steps=768 \
rollout_config.max_prompt_length=256 \
rollout_config.temperature=0.9 \
rollout_config.top_p=1.0 \
rollout_config.top_k=50 \
rollout_engine="vllm" \
vllm_config.async_scheduling=false \
offload_to_cpu=false \
grpo_config.num_generations=4 \
grpo_config.num_iterations=1 \
grpo_config.beta=0.08 \
grpo_config.epsilon=0.2 \
reward_functions="['tunix/cli/reward_fn/gsm8k.py']"
8 changes: 4 additions & 4 deletions tunix/cli/base_agentic_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,10 @@ model_config: &base_model_config
lora_config: {}

################################## MESH ##################################
mesh:
shape: "(2,2)"
# "('fsdp',)"
axis_names: "('fsdp','tp')"
# Base config should not set mesh, as it can be different for each role.
# mesh:
# shape: "(2,2)"
# axis_names: "('fsdp','tp')"

actor_model_config:
<<: *base_model_config
Expand Down
10 changes: 6 additions & 4 deletions tunix/cli/base_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,10 @@ model_config: &base_model_config
lora_config: {}

################################## MESH ##################################
mesh:
shape: "(2,2)"
# "('fsdp',)"
axis_names: "('fsdp','tp')"
# Base config should not set mesh, as it can be different for each role.
# mesh:
# shape: "(2,2)"
# axis_names: "('fsdp','tp')"

actor_model_config:
<<: *base_model_config
Expand Down Expand Up @@ -213,6 +213,8 @@ rollout_config:
############################### Other RL Config ###############################

rollout_engine: "vanilla"
vllm_config: {}
sglang_jax_config: {}

offload_to_cpu: false

Expand Down
33 changes: 20 additions & 13 deletions tunix/cli/grpo_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,8 @@ def create_rollout_config(
max_prompt = rollout_cfg.get("max_prompt_length", 0)
max_response = rollout_cfg.get("total_generation_steps", 0)

kv_cache_size = 0
max_concurrency = 0
if mode == "agentic_grpo":
agentic_cfg = self.config.get("agentic_grpo_config", {})
max_turns = agentic_cfg.get("max_turns", 1)
Expand All @@ -290,26 +292,30 @@ def create_rollout_config(
filtered["kv_cache_size"] = kv_cache_size
logging.info("kv_cache_size: %d", kv_cache_size)

# Engine-specific extras
extra = self._agentic_engine_extra(
engine,
kv_cache_size,
agentic_cfg,
role_to_mesh=role_to_mesh,
)
filtered.update({k: v for k, v in extra.items() if k in valid_fields})
max_running_requests = agentic_cfg.get("max_concurrency", 16)
else:
# Standard: kv_cache_size = max_prompt + max_response + 256
if max_prompt and max_response:
filtered["kv_cache_size"] = max_prompt + max_response + 256

kv_cache_size = max_prompt + max_response + 256
filtered["kv_cache_size"] = kv_cache_size
# Defaults to global batch size * num_generations to allow full concurrency
max_running_requests = self.config.get("batch_size", 1) * self.config.get("grpo_config", {}).get("num_generations", 1)

# Engine-specific extras
extra = self._rollout_engine_extra(
engine,
kv_cache_size,
max_running_requests,
role_to_mesh=role_to_mesh,
)
filtered.update({k: v for k, v in extra.items() if k in valid_fields})
return base_rollout.RolloutConfig(**filtered)

def _agentic_engine_extra(
def _rollout_engine_extra(
self,
engine: str,
kv_cache_size: int,
agentic_cfg: dict,
max_running_requests: int,
role_to_mesh: dict[rl_cluster_lib.Role, jax.sharding.Mesh] | None = None,
) -> dict:
"""Return engine-specific RolloutConfig fields for agentic mode."""
Expand All @@ -336,7 +342,7 @@ def _agentic_engine_extra(
),
rollout_sglang_jax_max_running_requests=sg.get(
"max_running_requests",
agentic_cfg.get("max_concurrency", 768),
max_running_requests,
),
rollout_sglang_jax_page_size=sg.get("page_size", 128),
rollout_sglang_jax_use_sort_for_toppk_minp=sg.get(
Expand Down Expand Up @@ -529,6 +535,7 @@ def compute_params(self, dataset):
num_batches = self.config.get("num_batches")
if not num_batches:
num_batches = dataset_length // batch_size
self.config["num_batches"] = num_batches
logging.info(
"Dynamically computed num_batches=%d with batch_size=%d",
num_batches,
Expand Down
Loading