diff --git a/tests/cli/grpo_main_test.py b/tests/cli/grpo_main_test.py index b68d98117..dfaf58a83 100644 --- a/tests/cli/grpo_main_test.py +++ b/tests/cli/grpo_main_test.py @@ -20,6 +20,8 @@ import os import pathlib import tempfile +from typing import Any +from typing import cast from unittest import mock from absl.testing import absltest @@ -578,6 +580,26 @@ def test_standard_grpo_kv_cache(self): self.assertEqual(cfg.kv_cache_size, 256 + 512 + 256) +class ComputeParamsTest(absltest.TestCase): + + def test_compute_params_persists_dynamic_num_batches(self): + pipeline = _make_pipeline("") + pipeline.config["batch_size"] = 8 + pipeline.config["num_batches"] = 0 + pipeline.config["num_train_epochs"] = 1 + pipeline.config["train_fraction"] = 0.8 + rl_training_config = cast(dict[str, Any], pipeline.config["rl_training_config"]) + rl_training_config["max_steps"] = 0 + + raw_dataset = mock.Mock() + raw_dataset.__len__ = mock.Mock(return_value=7473) + + pipeline.compute_params(raw_dataset) + + self.assertEqual(pipeline.config["num_batches"], 934) + self.assertEqual(rl_training_config["max_steps"], 747) + + # --------------------------------------------------------------------------- # GRPOConfig construction # --------------------------------------------------------------------------- diff --git a/tests/cli/utils/data_test.py b/tests/cli/utils/data_test.py index 5b5382451..a1fedfac0 100644 --- a/tests/cli/utils/data_test.py +++ b/tests/cli/utils/data_test.py @@ -205,6 +205,40 @@ def test_filters_by_prompt_length(self): self.assertLen(batches, 1) self.assertEqual(batches[0], [{"prompts": "short", "answer": 1}]) + def test_raises_when_prompt_length_filter_removes_all_examples(self): + tokenizer = _FakeTokenizer() + dataset = _BaseDataset([ + {"prompts": "this is too long", "answer": 1}, + {"prompts": "also too long", "answer": 2}, + ]) + + with self.assertRaisesRegex( + ValueError, "empty after post_init_dataset filtering" + ): + data_lib.post_init_dataset( + dataset, + tokenizer=tokenizer, # pytype: disable=wrong-arg-types + batch_size=2, + num_batches=None, + max_prompt_length=2, + ) + + def test_raises_when_fraction_makes_training_split_empty(self): + tokenizer = _FakeTokenizer() + dataset = _BaseDataset([ + {"prompts": "short", "answer": 1}, + ]) + + with self.assertRaisesRegex(ValueError, "empty after post_init_dataset split"): + data_lib.post_init_dataset( + dataset, + tokenizer=tokenizer, # pytype: disable=wrong-arg-types + batch_size=1, + num_batches=None, + max_prompt_length=None, + fraction=0.5, + ) + def test_limits_num_batches(self): tokenizer = _FakeTokenizer() dataset = _BaseDataset( diff --git a/tunix/cli/grpo_main.py b/tunix/cli/grpo_main.py index 3491b885b..96597b9b1 100644 --- a/tunix/cli/grpo_main.py +++ b/tunix/cli/grpo_main.py @@ -30,7 +30,8 @@ python -m tunix.cli.grpo_main examples/deepswe/configs/qwen3_32b.yaml """ -import collections +from collections.abc import Mapping +from collections.abc import MutableMapping import dataclasses import importlib import os @@ -89,6 +90,61 @@ def __init__(self, argv: list[str], **kwargs): self.data_module: ModuleType | None = None super().__init__(argv, **kwargs) + def _config_mapping(self, key: str) -> dict[str, Any]: + """Returns a config section as a plain dictionary. + + Pytype can otherwise infer broad unions for nested config sections because + the top-level YAML values may be scalars or mappings. This helper narrows + the type once and keeps call sites simple. + """ + value = self.config.get(key) + if value is None: + return {} + if not isinstance(value, Mapping): + raise TypeError( + f"Expected config section {key!r} to be a mapping, got" + f" {type(value).__name__}." + ) + return dict(value) + + def _mutable_config_mapping(self, key: str) -> MutableMapping[str, Any]: + """Returns a mutable config section for in-place updates.""" + value = self.config.get(key) + if value is None: + section: dict[str, Any] = {} + self.config[key] = section + return section + if not isinstance(value, MutableMapping): + raise TypeError( + f"Expected config section {key!r} to be a mutable mapping, got" + f" {type(value).__name__}." + ) + return value + + def _config_string(self, key: str, default: str = "") -> str: + """Returns a string config value with validation.""" + value = self.config.get(key, default) + if value is None: + return default + if not isinstance(value, str): + raise TypeError( + f"Expected config value {key!r} to be a string, got" + f" {type(value).__name__}." + ) + return value + + def _config_bool(self, key: str, default: bool = False) -> bool: + """Returns a bool config value with validation.""" + value = self.config.get(key, default) + if value is None: + return default + if not isinstance(value, bool): + raise TypeError( + f"Expected config value {key!r} to be a bool, got" + f" {type(value).__name__}." + ) + return value + # ------------------------------------------------------------------ # Mesh # ------------------------------------------------------------------ @@ -263,9 +319,9 @@ def create_rollout_config( max_prompt + total_generation_steps * context_ratio * max_turns. Engine-specific extras (sglang_jax_config, vllm_config) are also applied. """ - rollout_cfg = self.config["rollout_config"] - mode = self.config.get("training_mode", "grpo") - engine = self.config.get("rollout_engine", "vanilla") + rollout_cfg = self._config_mapping("rollout_config") + mode = self._config_string("training_mode", "grpo") + engine = self._config_string("rollout_engine", "vanilla") valid_fields = { f.name for f in dataclasses.fields(base_rollout.RolloutConfig) @@ -280,7 +336,7 @@ def create_rollout_config( max_response = rollout_cfg.get("total_generation_steps", 0) if mode == "agentic_grpo": - agentic_cfg = self.config.get("agentic_grpo_config", {}) + agentic_cfg = self._config_mapping("agentic_grpo_config") max_turns = agentic_cfg.get("max_turns", 1) context_ratio = agentic_cfg.get("context_ratio", 1) if max_turns > 1: @@ -313,10 +369,10 @@ def _agentic_engine_extra( role_to_mesh: dict[rl_cluster_lib.Role, jax.sharding.Mesh] | None = None, ) -> dict: """Return engine-specific RolloutConfig fields for agentic mode.""" - model_id = self.config.get("actor_model_config", {}).get("model_id", "") + model_id = self._config_mapping("actor_model_config").get("model_id", "") if engine == "sglang_jax": - sg = self.config.get("sglang_jax_config", {}) + sg = self._config_mapping("sglang_jax_config") return dict( rollout_sglang_jax_model_version=sg.get("model_version", model_id), rollout_sglang_jax_mem_fraction_static=sg.get( @@ -345,17 +401,18 @@ def _agentic_engine_extra( ) if engine == "vllm": - vllm = self.config.get("vllm_config", {}) + vllm = self._config_mapping("vllm_config") if role_to_mesh is None: raise ValueError( "role_to_mesh must be provided for vllm rollout config." ) rollout_shape = role_to_mesh[rl_cluster_lib.Role.ROLLOUT].devices.shape - max_num_seqs = self.config["rollout_config"].get( + rollout_cfg = self._config_mapping("rollout_config") + max_num_seqs = rollout_cfg.get( "rollout_vllm_max_num_seqs", vllm.get("max_num_seqs", 768), ) - max_batched_tokens = self.config["rollout_config"].get( + max_batched_tokens = rollout_cfg.get( "rollout_vllm_max_num_batched_tokens", vllm.get( "max_num_batched_tokens", @@ -401,8 +458,8 @@ def create_cluster_config( rollout_config = self.create_rollout_config(role_to_mesh=role_to_mesh) return rl_cluster_lib.ClusterConfig( role_to_mesh=role_to_mesh, - rollout_engine=self.config["rollout_engine"], - offload_to_cpu=self.config["offload_to_cpu"], + rollout_engine=self._config_string("rollout_engine"), + offload_to_cpu=self._config_bool("offload_to_cpu"), training_config=self.create_rl_training_config(), rollout_config=rollout_config, ) @@ -411,7 +468,7 @@ def create_rl_training_config(self): base_key = "rl_training_config" constructed_rl_training_config = self.obtain_training_config_dict(base_key) - base_config = self.config[base_key] + base_config = self._config_mapping(base_key) if base_config.get("actor_optimizer_config"): constructed_rl_training_config["actor_optimizer"] = self.create_optimizer( base_key, "actor_optimizer_config" @@ -470,23 +527,26 @@ def create_perf_config(self, cluster_config: rl_cluster_lib.ClusterConfig): def create_rl_cluster(self, tokenizer): role_to_mesh = self.create_role_to_mesh() rollout_config = self.create_rollout_config(role_to_mesh=role_to_mesh) + reference_model_config = self._mutable_config_mapping("reference_model_config") + actor_model_config = self._mutable_config_mapping("actor_model_config") + tokenizer_config = self._config_mapping("tokenizer_config") # Should not use LoRA for reference model. - if self.config["reference_model_config"].get("lora_config"): + if reference_model_config.get("lora_config"): logging.warning( "LoRA config is not supported for the reference model. Disabling" " LoRA." ) - del self.config["reference_model_config"]["lora_config"] + del reference_model_config["lora_config"] reference_model, tokenizer_path = model_lib.create_model( - self.config["reference_model_config"], - self.config["tokenizer_config"], + dict(reference_model_config), + tokenizer_config, role_to_mesh[rl_cluster_lib.Role.REFERENCE], ) - if self.config["actor_model_config"].get("lora_config", None): + if actor_model_config.get("lora_config", None): actor_model = model_lib.apply_lora_to_model( reference_model, role_to_mesh[rl_cluster_lib.Role.ACTOR], - self.config["actor_model_config"]["lora_config"], + actor_model_config["lora_config"], ) else: graph_def, params = nnx.split(reference_model) @@ -509,9 +569,7 @@ def create_rl_cluster(self, tokenizer): ) def compute_params(self, dataset): - rl_training_config: dict[str, Any] = self.config.get( - "rl_training_config", {} - ) + rl_training_config = self._mutable_config_mapping("rl_training_config") # Return early if max_steps is already specified. max_steps = None @@ -534,6 +592,7 @@ def compute_params(self, dataset): num_batches, batch_size, ) + self.config["num_batches"] = num_batches num_train_epochs = self.config.get("num_train_epochs") if not num_train_epochs: num_train_epochs = 1 @@ -558,9 +617,12 @@ def compute_params(self, dataset): ) rl_training_config["max_steps"] = max_steps - actor_opt: dict[str, Any] = rl_training_config.get( - "actor_optimizer_config", {} - ) + actor_opt: dict[str, Any] = {} + actor_opt_value = rl_training_config.get("actor_optimizer_config") + if isinstance(actor_opt_value, dict): + actor_opt = actor_opt_value + elif actor_opt_value is not None: + raise ValueError("rl_training_config.actor_optimizer_config must be a dict.") if actor_opt and not actor_opt.get("decay_steps"): actor_opt["decay_steps"] = max_steps if actor_opt and not actor_opt.get("warmup_steps"): @@ -598,7 +660,7 @@ def _get_dataset(self, tokenizer): ) if self.config.get("data_module", None): - data_module = self.config.get("data_module", None) + data_module = self._config_string("data_module") dataset = data_lib.get_dataset_from_module( data_module, tokenizer, @@ -641,7 +703,7 @@ def _create_agentic_grpo_config(self): """Build GRPOConfig (agentic) from the agentic_grpo_config YAML section.""" from tunix.rl.agentic.agentic_grpo_learner import GRPOConfig # pylint: disable=g-import-not-at-top - cfg = dict(self.config.get("agentic_grpo_config", {})) + cfg = dict(self._config_mapping("agentic_grpo_config")) # episode_timeout = per_turn_timeout_secs * max_turns when not explicit if "episode_timeout" not in cfg: @@ -652,7 +714,7 @@ def _create_agentic_grpo_config(self): # max_response_length mirrors rollout_config.total_generation_steps if "max_response_length" not in cfg: - cfg["max_response_length"] = self.config["rollout_config"].get( + cfg["max_response_length"] = self._config_mapping("rollout_config").get( "total_generation_steps", 8192 ) @@ -666,7 +728,7 @@ def _create_chat_parser(self, tokenizer: Any) -> Any: """Instantiate a chat parser based on chat_parser_config.type.""" from tunix.rl.agentic.parser.chat_template_parser import parser as chat_parser_lib # pylint: disable=g-import-not-at-top - parser_type = (self.config.get("chat_parser_config") or {}).get( + parser_type = self._config_mapping("chat_parser_config").get( "type", "default" ) if parser_type == "qwen": @@ -692,7 +754,7 @@ def _load_raw_dataset(self, tokenizer): return dataset, batch_fn def _setup_kubernetes(self) -> None: - k8s_cfg = self.config.get("kubernetes_config") or {} + k8s_cfg = self._config_mapping("kubernetes_config") if not k8s_cfg: return os.environ["KUBECONFIG"] = k8s_cfg.get("kubeconfig", "~/.kube/config") @@ -703,8 +765,8 @@ def _setup_kubernetes(self) -> None: "node_selector_val", "deepswe-cpu-pool" ) try: - from kubernetes import client as k8s_client_lib # pylint: disable=g-import-not-at-top - from kubernetes import config as k8s_config_lib # pylint: disable=g-import-not-at-top + from kubernetes import client as k8s_client_lib # type: ignore[import-untyped] # pylint: disable=g-import-not-at-top + from kubernetes import config as k8s_config_lib # type: ignore[import-untyped] # pylint: disable=g-import-not-at-top k8s_config_lib.load_kube_config() k8s_client_lib.CoreV1Api() @@ -732,7 +794,7 @@ def _run(self, mode: str = "grpo"): tokenizer, batch_size=self.config.get("batch_size", 1), num_batches=self.config.get("num_batches"), - max_prompt_length=self.config["rollout_config"].get( + max_prompt_length=self._config_mapping("rollout_config").get( "max_prompt_length" ), fraction=self.config.get("train_fraction", 1.0), @@ -749,7 +811,9 @@ def _run(self, mode: str = "grpo"): grpo_trainer = grpo_learner.GrpoLearner( rl_cluster=rl_cluster, reward_fns=self.obtain_reward_fn(), - algo_config=grpo_learner.GrpoConfig(**self.config["grpo_config"]), + algo_config=grpo_learner.GrpoConfig( + **self._config_mapping("grpo_config") + ), ) grpo_trainer.train(dataset) return @@ -772,7 +836,7 @@ def _run(self, mode: str = "grpo"): chat_parser=chat_parser, ) - agent_class_path = self.config.get("agent_class_path") + agent_class_path = self._config_string("agent_class_path") if agent_class_path: learner_kwargs["agent_class"] = self._load_class_from_path( agent_class_path @@ -781,7 +845,7 @@ def _run(self, mode: str = "grpo"): self.config.get("agent_kwargs") or {} ) - env_class_path = self.config.get("env_class_path") + env_class_path = self._config_string("env_class_path") if env_class_path: learner_kwargs["env_class"] = self._load_class_from_path(env_class_path) learner_kwargs["env_kwargs"] = dict(self.config.get("env_kwargs") or {}) @@ -807,7 +871,7 @@ def _setup_jax_pathways(pathways_bns: str): def _setup_pathways_on_cloud(): - import pathwaysutils # pylint: disable=g-import-not-at-top + import pathwaysutils # type: ignore[import-not-found] # pytype: disable=import-error # pyright: ignore[reportMissingImports] # pylint: disable=g-import-not-at-top pathwaysutils.initialize() diff --git a/tunix/cli/utils/data.py b/tunix/cli/utils/data.py index 22d222566..3856fda7b 100644 --- a/tunix/cli/utils/data.py +++ b/tunix/cli/utils/data.py @@ -193,6 +193,8 @@ def post_init_dataset( Returns: The processed dataset. """ + original_size = len(dataset) + if prompt_key != "prompts": source_prompt_key = prompt_key @@ -210,6 +212,15 @@ def prompt_length_filter(x): dataset = dataset.filter(prompt_length_filter) + filtered_size = len(dataset) + if filtered_size == 0: + raise ValueError( + "Training dataset is empty after post_init_dataset filtering. " + f"original_size={original_size}, max_prompt_length={max_prompt_length}, " + f"prompt_key={prompt_key!r}. Consider increasing max_prompt_length " + "or adjusting the prompt template/filter settings." + ) + if num_batches is not None: target_size = min(num_batches * batch_size, len(dataset)) dataset = dataset[:target_size] @@ -222,6 +233,13 @@ def prompt_length_filter(x): first_segment_dataset = dataset second_segment_dataset = None + if len(first_segment_dataset) == 0: + raise ValueError( + "Training dataset is empty after post_init_dataset split. " + f"filtered_size={filtered_size}, fraction={fraction}, " + f"num_batches={num_batches}, batch_size={batch_size}." + ) + first_segment_dataset = ( first_segment_dataset.repeat(num_epochs) .to_iter_dataset()