Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions tests/cli/grpo_main_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import os
import pathlib
import tempfile
from typing import Any
from typing import cast
from unittest import mock

from absl.testing import absltest
Expand Down Expand Up @@ -578,6 +580,26 @@ def test_standard_grpo_kv_cache(self):
self.assertEqual(cfg.kv_cache_size, 256 + 512 + 256)


class ComputeParamsTest(absltest.TestCase):

def test_compute_params_persists_dynamic_num_batches(self):
pipeline = _make_pipeline("")
pipeline.config["batch_size"] = 8
pipeline.config["num_batches"] = 0
pipeline.config["num_train_epochs"] = 1
pipeline.config["train_fraction"] = 0.8
rl_training_config = cast(dict[str, Any], pipeline.config["rl_training_config"])
rl_training_config["max_steps"] = 0

raw_dataset = mock.Mock()
raw_dataset.__len__ = mock.Mock(return_value=7473)

pipeline.compute_params(raw_dataset)

self.assertEqual(pipeline.config["num_batches"], 934)
self.assertEqual(rl_training_config["max_steps"], 747)


# ---------------------------------------------------------------------------
# GRPOConfig construction
# ---------------------------------------------------------------------------
Expand Down
34 changes: 34 additions & 0 deletions tests/cli/utils/data_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,40 @@ def test_filters_by_prompt_length(self):
self.assertLen(batches, 1)
self.assertEqual(batches[0], [{"prompts": "short", "answer": 1}])

def test_raises_when_prompt_length_filter_removes_all_examples(self):
tokenizer = _FakeTokenizer()
dataset = _BaseDataset([
{"prompts": "this is too long", "answer": 1},
{"prompts": "also too long", "answer": 2},
])

with self.assertRaisesRegex(
ValueError, "empty after post_init_dataset filtering"
):
data_lib.post_init_dataset(
dataset,
tokenizer=tokenizer, # pytype: disable=wrong-arg-types
batch_size=2,
num_batches=None,
max_prompt_length=2,
)

def test_raises_when_fraction_makes_training_split_empty(self):
tokenizer = _FakeTokenizer()
dataset = _BaseDataset([
{"prompts": "short", "answer": 1},
])

with self.assertRaisesRegex(ValueError, "empty after post_init_dataset split"):
data_lib.post_init_dataset(
dataset,
tokenizer=tokenizer, # pytype: disable=wrong-arg-types
batch_size=1,
num_batches=None,
max_prompt_length=None,
fraction=0.5,
)

def test_limits_num_batches(self):
tokenizer = _FakeTokenizer()
dataset = _BaseDataset(
Expand Down
138 changes: 101 additions & 37 deletions tunix/cli/grpo_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@
python -m tunix.cli.grpo_main examples/deepswe/configs/qwen3_32b.yaml
"""

import collections
from collections.abc import Mapping
from collections.abc import MutableMapping
import dataclasses
import importlib
import os
Expand Down Expand Up @@ -89,6 +90,61 @@ def __init__(self, argv: list[str], **kwargs):
self.data_module: ModuleType | None = None
super().__init__(argv, **kwargs)

def _config_mapping(self, key: str) -> dict[str, Any]:
"""Returns a config section as a plain dictionary.

Pytype can otherwise infer broad unions for nested config sections because
the top-level YAML values may be scalars or mappings. This helper narrows
the type once and keeps call sites simple.
"""
value = self.config.get(key)
if value is None:
return {}
if not isinstance(value, Mapping):
raise TypeError(
f"Expected config section {key!r} to be a mapping, got"
f" {type(value).__name__}."
)
return dict(value)

def _mutable_config_mapping(self, key: str) -> MutableMapping[str, Any]:
"""Returns a mutable config section for in-place updates."""
value = self.config.get(key)
if value is None:
section: dict[str, Any] = {}
self.config[key] = section
return section
if not isinstance(value, MutableMapping):
raise TypeError(
f"Expected config section {key!r} to be a mutable mapping, got"
f" {type(value).__name__}."
)
return value

def _config_string(self, key: str, default: str = "") -> str:
"""Returns a string config value with validation."""
value = self.config.get(key, default)
if value is None:
return default
if not isinstance(value, str):
raise TypeError(
f"Expected config value {key!r} to be a string, got"
f" {type(value).__name__}."
)
return value

def _config_bool(self, key: str, default: bool = False) -> bool:
"""Returns a bool config value with validation."""
value = self.config.get(key, default)
if value is None:
return default
if not isinstance(value, bool):
raise TypeError(
f"Expected config value {key!r} to be a bool, got"
f" {type(value).__name__}."
)
return value

# ------------------------------------------------------------------
# Mesh
# ------------------------------------------------------------------
Expand Down Expand Up @@ -263,9 +319,9 @@ def create_rollout_config(
max_prompt + total_generation_steps * context_ratio * max_turns.
Engine-specific extras (sglang_jax_config, vllm_config) are also applied.
"""
rollout_cfg = self.config["rollout_config"]
mode = self.config.get("training_mode", "grpo")
engine = self.config.get("rollout_engine", "vanilla")
rollout_cfg = self._config_mapping("rollout_config")
mode = self._config_string("training_mode", "grpo")
engine = self._config_string("rollout_engine", "vanilla")

valid_fields = {
f.name for f in dataclasses.fields(base_rollout.RolloutConfig)
Expand All @@ -280,7 +336,7 @@ def create_rollout_config(
max_response = rollout_cfg.get("total_generation_steps", 0)

if mode == "agentic_grpo":
agentic_cfg = self.config.get("agentic_grpo_config", {})
agentic_cfg = self._config_mapping("agentic_grpo_config")
max_turns = agentic_cfg.get("max_turns", 1)
context_ratio = agentic_cfg.get("context_ratio", 1)
if max_turns > 1:
Expand Down Expand Up @@ -313,10 +369,10 @@ def _agentic_engine_extra(
role_to_mesh: dict[rl_cluster_lib.Role, jax.sharding.Mesh] | None = None,
) -> dict:
"""Return engine-specific RolloutConfig fields for agentic mode."""
model_id = self.config.get("actor_model_config", {}).get("model_id", "")
model_id = self._config_mapping("actor_model_config").get("model_id", "")

if engine == "sglang_jax":
sg = self.config.get("sglang_jax_config", {})
sg = self._config_mapping("sglang_jax_config")
return dict(
rollout_sglang_jax_model_version=sg.get("model_version", model_id),
rollout_sglang_jax_mem_fraction_static=sg.get(
Expand Down Expand Up @@ -345,17 +401,18 @@ def _agentic_engine_extra(
)

if engine == "vllm":
vllm = self.config.get("vllm_config", {})
vllm = self._config_mapping("vllm_config")
if role_to_mesh is None:
raise ValueError(
"role_to_mesh must be provided for vllm rollout config."
)
rollout_shape = role_to_mesh[rl_cluster_lib.Role.ROLLOUT].devices.shape
max_num_seqs = self.config["rollout_config"].get(
rollout_cfg = self._config_mapping("rollout_config")
max_num_seqs = rollout_cfg.get(
"rollout_vllm_max_num_seqs",
vllm.get("max_num_seqs", 768),
)
max_batched_tokens = self.config["rollout_config"].get(
max_batched_tokens = rollout_cfg.get(
"rollout_vllm_max_num_batched_tokens",
vllm.get(
"max_num_batched_tokens",
Expand Down Expand Up @@ -401,8 +458,8 @@ def create_cluster_config(
rollout_config = self.create_rollout_config(role_to_mesh=role_to_mesh)
return rl_cluster_lib.ClusterConfig(
role_to_mesh=role_to_mesh,
rollout_engine=self.config["rollout_engine"],
offload_to_cpu=self.config["offload_to_cpu"],
rollout_engine=self._config_string("rollout_engine"),
offload_to_cpu=self._config_bool("offload_to_cpu"),
training_config=self.create_rl_training_config(),
rollout_config=rollout_config,
)
Expand All @@ -411,7 +468,7 @@ def create_rl_training_config(self):
base_key = "rl_training_config"
constructed_rl_training_config = self.obtain_training_config_dict(base_key)

base_config = self.config[base_key]
base_config = self._config_mapping(base_key)
if base_config.get("actor_optimizer_config"):
constructed_rl_training_config["actor_optimizer"] = self.create_optimizer(
base_key, "actor_optimizer_config"
Expand Down Expand Up @@ -470,23 +527,26 @@ def create_perf_config(self, cluster_config: rl_cluster_lib.ClusterConfig):
def create_rl_cluster(self, tokenizer):
role_to_mesh = self.create_role_to_mesh()
rollout_config = self.create_rollout_config(role_to_mesh=role_to_mesh)
reference_model_config = self._mutable_config_mapping("reference_model_config")
actor_model_config = self._mutable_config_mapping("actor_model_config")
tokenizer_config = self._config_mapping("tokenizer_config")
# Should not use LoRA for reference model.
if self.config["reference_model_config"].get("lora_config"):
if reference_model_config.get("lora_config"):
logging.warning(
"LoRA config is not supported for the reference model. Disabling"
" LoRA."
)
del self.config["reference_model_config"]["lora_config"]
del reference_model_config["lora_config"]
reference_model, tokenizer_path = model_lib.create_model(
self.config["reference_model_config"],
self.config["tokenizer_config"],
dict(reference_model_config),
tokenizer_config,
role_to_mesh[rl_cluster_lib.Role.REFERENCE],
)
if self.config["actor_model_config"].get("lora_config", None):
if actor_model_config.get("lora_config", None):
actor_model = model_lib.apply_lora_to_model(
reference_model,
role_to_mesh[rl_cluster_lib.Role.ACTOR],
self.config["actor_model_config"]["lora_config"],
actor_model_config["lora_config"],
)
else:
graph_def, params = nnx.split(reference_model)
Expand All @@ -509,9 +569,7 @@ def create_rl_cluster(self, tokenizer):
)

def compute_params(self, dataset):
rl_training_config: dict[str, Any] = self.config.get(
"rl_training_config", {}
)
rl_training_config = self._mutable_config_mapping("rl_training_config")

# Return early if max_steps is already specified.
max_steps = None
Expand All @@ -534,6 +592,7 @@ def compute_params(self, dataset):
num_batches,
batch_size,
)
self.config["num_batches"] = num_batches
num_train_epochs = self.config.get("num_train_epochs")
if not num_train_epochs:
num_train_epochs = 1
Expand All @@ -558,9 +617,12 @@ def compute_params(self, dataset):
)

rl_training_config["max_steps"] = max_steps
actor_opt: dict[str, Any] = rl_training_config.get(
"actor_optimizer_config", {}
)
actor_opt: dict[str, Any] = {}
actor_opt_value = rl_training_config.get("actor_optimizer_config")
if isinstance(actor_opt_value, dict):
actor_opt = actor_opt_value
elif actor_opt_value is not None:
raise ValueError("rl_training_config.actor_optimizer_config must be a dict.")
if actor_opt and not actor_opt.get("decay_steps"):
actor_opt["decay_steps"] = max_steps
if actor_opt and not actor_opt.get("warmup_steps"):
Expand Down Expand Up @@ -598,7 +660,7 @@ def _get_dataset(self, tokenizer):
)

if self.config.get("data_module", None):
data_module = self.config.get("data_module", None)
data_module = self._config_string("data_module")
dataset = data_lib.get_dataset_from_module(
data_module,
tokenizer,
Expand Down Expand Up @@ -641,7 +703,7 @@ def _create_agentic_grpo_config(self):
"""Build GRPOConfig (agentic) from the agentic_grpo_config YAML section."""
from tunix.rl.agentic.agentic_grpo_learner import GRPOConfig # pylint: disable=g-import-not-at-top

cfg = dict(self.config.get("agentic_grpo_config", {}))
cfg = dict(self._config_mapping("agentic_grpo_config"))

# episode_timeout = per_turn_timeout_secs * max_turns when not explicit
if "episode_timeout" not in cfg:
Expand All @@ -652,7 +714,7 @@ def _create_agentic_grpo_config(self):

# max_response_length mirrors rollout_config.total_generation_steps
if "max_response_length" not in cfg:
cfg["max_response_length"] = self.config["rollout_config"].get(
cfg["max_response_length"] = self._config_mapping("rollout_config").get(
"total_generation_steps", 8192
)

Expand All @@ -666,7 +728,7 @@ def _create_chat_parser(self, tokenizer: Any) -> Any:
"""Instantiate a chat parser based on chat_parser_config.type."""
from tunix.rl.agentic.parser.chat_template_parser import parser as chat_parser_lib # pylint: disable=g-import-not-at-top

parser_type = (self.config.get("chat_parser_config") or {}).get(
parser_type = self._config_mapping("chat_parser_config").get(
"type", "default"
)
if parser_type == "qwen":
Expand All @@ -692,7 +754,7 @@ def _load_raw_dataset(self, tokenizer):
return dataset, batch_fn

def _setup_kubernetes(self) -> None:
k8s_cfg = self.config.get("kubernetes_config") or {}
k8s_cfg = self._config_mapping("kubernetes_config")
if not k8s_cfg:
return
os.environ["KUBECONFIG"] = k8s_cfg.get("kubeconfig", "~/.kube/config")
Expand All @@ -703,8 +765,8 @@ def _setup_kubernetes(self) -> None:
"node_selector_val", "deepswe-cpu-pool"
)
try:
from kubernetes import client as k8s_client_lib # pylint: disable=g-import-not-at-top
from kubernetes import config as k8s_config_lib # pylint: disable=g-import-not-at-top
from kubernetes import client as k8s_client_lib # type: ignore[import-untyped] # pylint: disable=g-import-not-at-top
from kubernetes import config as k8s_config_lib # type: ignore[import-untyped] # pylint: disable=g-import-not-at-top

k8s_config_lib.load_kube_config()
k8s_client_lib.CoreV1Api()
Expand Down Expand Up @@ -732,7 +794,7 @@ def _run(self, mode: str = "grpo"):
tokenizer,
batch_size=self.config.get("batch_size", 1),
num_batches=self.config.get("num_batches"),
max_prompt_length=self.config["rollout_config"].get(
max_prompt_length=self._config_mapping("rollout_config").get(
"max_prompt_length"
),
fraction=self.config.get("train_fraction", 1.0),
Expand All @@ -749,7 +811,9 @@ def _run(self, mode: str = "grpo"):
grpo_trainer = grpo_learner.GrpoLearner(
rl_cluster=rl_cluster,
reward_fns=self.obtain_reward_fn(),
algo_config=grpo_learner.GrpoConfig(**self.config["grpo_config"]),
algo_config=grpo_learner.GrpoConfig(
**self._config_mapping("grpo_config")
),
)
grpo_trainer.train(dataset)
return
Expand All @@ -772,7 +836,7 @@ def _run(self, mode: str = "grpo"):
chat_parser=chat_parser,
)

agent_class_path = self.config.get("agent_class_path")
agent_class_path = self._config_string("agent_class_path")
if agent_class_path:
learner_kwargs["agent_class"] = self._load_class_from_path(
agent_class_path
Expand All @@ -781,7 +845,7 @@ def _run(self, mode: str = "grpo"):
self.config.get("agent_kwargs") or {}
)

env_class_path = self.config.get("env_class_path")
env_class_path = self._config_string("env_class_path")
if env_class_path:
learner_kwargs["env_class"] = self._load_class_from_path(env_class_path)
learner_kwargs["env_kwargs"] = dict(self.config.get("env_kwargs") or {})
Expand All @@ -807,7 +871,7 @@ def _setup_jax_pathways(pathways_bns: str):


def _setup_pathways_on_cloud():
import pathwaysutils # pylint: disable=g-import-not-at-top
import pathwaysutils # type: ignore[import-not-found] # pytype: disable=import-error # pyright: ignore[reportMissingImports] # pylint: disable=g-import-not-at-top

pathwaysutils.initialize()

Expand Down
Loading
Loading