Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions .github/workflows/docker/docker-compose.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
services:
trinity-node-1:
image: trinity-rft-unittest:20260205
image: trinity-rft-unittest:20260211
cap_add:
- SYS_PTRACE
pull_policy: never
Expand All @@ -15,6 +15,7 @@ services:
- TRINITY_MODEL_PATH=/mnt/models/Qwen3-0.6B
- TRINITY_API_MODEL_PATH=/mnt/models/Qwen3-1.7B
- TRINITY_VLM_MODEL_PATH=/mnt/models/Qwen2.5-VL-3B
- TRINITY_ALTERNATIVE_VLM_MODEL_PATH=/mnt/models/Qwen3-VL-2B-Instruct
- VIRTUAL_ENV=/opt/venv
working_dir: /workspace
networks:
Expand All @@ -32,7 +33,7 @@ services:
capabilities: [gpu]

trinity-node-2:
image: trinity-rft-unittest:20260205
image: trinity-rft-unittest:20260211
cap_add:
- SYS_PTRACE
pull_policy: never
Expand Down
12 changes: 10 additions & 2 deletions examples/grpo_vlm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ This example shows the usage of GRPO with Qwen2.5-VL-3B-Instruct on the [geometr
The specific requirements are:

```yaml
vllm>=0.9.1,<0.10.0
transformers<4.53.0
vllm>=0.10.2 # Qwen3 VL requires vllm>=0.11.0; it is recommended to use version >= 0.13.0
transformers>=4.54.0
qwen_vl_utils
```

Expand All @@ -18,3 +18,11 @@ For other detailed information, please refer to the [documentation](../../docs/s
The config file is located in [`vlm.yaml`](vlm.yaml), and the curve is shown below.

![vlm](../../docs/sphinx_doc/assets/geometry3k_qwen25_vl_3b_reward.png)

## Supported Model Architectures

The following vision-language model series are currently supported:

1. Qwen2.5-VL series
2. Qwen3-VL series
3. Kimi-VL-A3B-Thinking series
12 changes: 10 additions & 2 deletions examples/mix_vlm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ This is an example of using the [MIX](../../docs/sphinx_doc/source/tutorial/exam
The specific requirements are:

```yaml
vllm>=0.9.1,<0.10.0
transformers<4.53.0
vllm>=0.10.2 # Qwen3 VL requires vllm>=0.11.0; it is recommended to use version >= 0.13.0
transformers>=4.54.0
qwen_vl_utils
```

Expand All @@ -34,3 +34,11 @@ trinity run --config examples/mix_vlm/mix_vlm.yaml

The reward curve is shown below:
![](../../docs/sphinx_doc/assets/mix_vlm_reward.png)

## Supported Model Architectures

The following vision-language model series are currently supported:

1. Qwen2.5-VL series
2. Qwen3-VL series
3. Kimi-VL-A3B-Thinking series
7 changes: 6 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ megatron = [
# "mbridge @ git+https://github.com/ISEEKYAN/mbridge.git@20e9ffbbe72ae7b1df83bfe1bc3c11f7382f2612",
]
tinker = [
"tinker; python_version >= '3.11'",
"tinker>=0.10.0; python_version >= '3.11'",
]

doc = [
Expand All @@ -103,6 +103,8 @@ doc = [

mm = [
"qwen-vl-utils",
"transformers>=4.54.0",
"blobfile",
]

flash_attn = [
Expand Down Expand Up @@ -143,6 +145,9 @@ known_third_party = ["wandb"]
[tool.uv.extra-build-dependencies]
flash-attn = ["torch", "numpy"]

[project.entry-points."vllm.general_plugins"]
vllm_patch = "trinity.common.models.vllm_patch:vllm_patch"

[project.urls]
"Homepage" = "https://github.com/agentscope-ai/Trinity-RFT"
"Documentation" = "https://agentscope-ai.github.io/Trinity-RFT/"
2 changes: 2 additions & 0 deletions tests/template/data/gsm8k/test.jsonl
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
{"question": "Janet\u2019s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?", "answer": "Janet sells 16 - 3 - 4 = <<16-3-4=9>>9 duck eggs a day.\nShe makes 9 * 2 = $<<9*2=18>>18 every day at the farmer\u2019s market.\n#### 18"}
{"question": "A robe takes 2 bolts of blue fiber and half that much white fiber. How many bolts in total does it take?", "answer": "It takes 2/2=<<2/2=1>>1 bolt of white fiber\nSo the total amount of fabric is 2+1=<<2+1=3>>3 bolts of fabric\n#### 3"}
24 changes: 24 additions & 0 deletions tests/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

API_MODEL_PATH_ENV_VAR = "TRINITY_API_MODEL_PATH"
VLM_MODEL_PATH_ENV_VAR = "TRINITY_VLM_MODEL_PATH"
ALTERNATIVE_VLM_MODEL_PATH_ENV_VAR = "TRINITY_ALTERNATIVE_VLM_MODEL_PATH"
SFT_DATASET_PATH_ENV_VAR = "TRINITY_SFT_DATASET_PATH"


Expand Down Expand Up @@ -134,6 +135,15 @@ def get_vision_language_model_path() -> str:
return path


def get_alternative_vision_language_model_path() -> str:
path = os.environ.get(ALTERNATIVE_VLM_MODEL_PATH_ENV_VAR)
if not path:
raise EnvironmentError(
f"Please set `export {ALTERNATIVE_VLM_MODEL_PATH_ENV_VAR}=<your_model_dir>` before running this test."
)
return path


def get_lora_config() -> LoRAConfig:
return LoRAConfig(name="lora", lora_rank=16, lora_alpha=16)

Expand Down Expand Up @@ -248,6 +258,20 @@ def get_unittest_dataset_config(dataset_name: str = "countdown", split: str = "t
default_workflow_type="simple_mm_workflow",
default_reward_fn_type="math_boxed_reward",
)
elif dataset_name == "geometry_sft":
# Multi-modal geometry dataset for sft with 8 samples
return ExperienceBufferConfig(
name=dataset_name,
path=os.path.join(os.path.dirname(__file__), "template", "data", "geometry"),
split="train",
storage_type=StorageType.FILE.value,
format=FormatConfig(
prompt_type=PromptType.PLAINTEXT,
prompt_key="problem",
response_key="answer",
image_key="images",
),
)
else:
raise ValueError(f"Unknown dataset name: {dataset_name}")

Expand Down
14 changes: 6 additions & 8 deletions tests/trainer/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
RayUnittestBase,
RayUnittestBaseAsync,
TensorBoardParser,
get_alternative_vision_language_model_path,
get_checkpoint_path,
get_lora_config,
get_model_path,
Expand Down Expand Up @@ -350,7 +351,7 @@ def test_trainer(self, mock_load):
mock_load.return_value = deepcopy(self.config)

with self.assertRaises(Exception):
run(config_path="dummy.yaml")
run(config="dummy.yaml")
ray.shutdown(_exiting_interpreter=True)

stage_configs = [cfg.check_and_update() for cfg in deepcopy(self.config)]
Expand All @@ -375,7 +376,7 @@ def test_trainer(self, mock_load):
self.config.stages[1].buffer.explorer_input.taskset.path = old_taskset_path
mock_load.return_value = deepcopy(self.config)
ray.init(ignore_reinit_error=True, namespace=self.config.ray_namespace)
run(config_path="dummy.yaml")
run(config="dummy.yaml")

# grpo stage
grpo_config = stage_configs[1]
Expand Down Expand Up @@ -1205,13 +1206,12 @@ def tearDown(self):


class TestMultiModalGRPO(BaseTrainerCase):
@unittest.skip("Require specific vllm/transformers version")
def test_trainer(self):
"""Test both mode with multi-modal data."""
self.config.buffer.explorer_input.taskset = get_unittest_dataset_config(
"geometry"
) # Total 8 tasks
self.config.model.model_path = get_vision_language_model_path()
self.config.model.model_path = get_alternative_vision_language_model_path()
self.config.algorithm.algorithm_type = "grpo"
self.config.algorithm.advantage_fn = "grpo"
self.config.algorithm.kl_loss_fn = "none"
Expand Down Expand Up @@ -1246,12 +1246,11 @@ def tearDown(self):


class TestMultiModalSFT(BaseTrainerCase):
@unittest.skip("Require specific vllm/transformers version")
def test_trainer(self):
"""Test SFT mode with multi-modal data."""
self.config.mode = "train"
self.config.buffer.trainer_input.experience_buffer = get_unittest_dataset_config(
"geometry"
"geometry_sft"
) # Total 8 tasks
self.config.model.model_path = get_vision_language_model_path()
self.config.algorithm.algorithm_type = "sft"
Expand Down Expand Up @@ -1522,7 +1521,6 @@ def tearDown(self):
shutil.rmtree(self.config.checkpoint_job_dir, ignore_errors=True)


@unittest.skip("Require agentscope >= 1.0.12")
class AgentScopeTunerTest(unittest.IsolatedAsyncioTestCase):
def setUp(self) -> None:
ray.init(ignore_reinit_error=True)
Expand Down Expand Up @@ -1622,7 +1620,7 @@ async def judge_func(
model_path=get_model_path(),
max_model_len=8192,
max_tokens=2048,
inference_engine_num=2,
inference_engine_num=1,
)
}

Expand Down
46 changes: 19 additions & 27 deletions trinity/algorithm/policy_loss_fn/chord_policy_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,35 +199,27 @@ def __call__( # type: ignore
per_micro_batch_weight_usual = self.gradient_accumulation / self.train_batch_size_usual # type: ignore
per_micro_batch_weight_expert = self.gradient_accumulation / self.train_batch_size_expert # type: ignore

if n_usual_exp > 0:
grpo_loss, grpo_metrics = self.grpo_loss_fn(
logprob[~expert_mask],
old_logprob[~expert_mask],
action_mask[~expert_mask],
advantages[~expert_mask],
**kwargs,
)
grpo_loss = grpo_loss * n_usual_exp * per_micro_batch_weight_usual
grpo_metrics = {
k: v * n_usual_exp * per_micro_batch_weight_usual for k, v in grpo_metrics.items()
}
else:
grpo_loss = torch.tensor(0.0, device=logprob.device)
grpo_metrics = {}
grpo_loss, grpo_metrics = self.grpo_loss_fn(
logprob[~expert_mask],
old_logprob[~expert_mask],
action_mask[~expert_mask],
advantages[~expert_mask],
**kwargs,
)
grpo_loss = grpo_loss * n_usual_exp * per_micro_batch_weight_usual
grpo_metrics = {
k: v * n_usual_exp * per_micro_batch_weight_usual for k, v in grpo_metrics.items()
}

# SFT Loss (expert)
if n_expert_exp > 0:
sft_loss, sft_metrics = self.sft_loss_fn(
logprob[expert_mask],
action_mask[expert_mask],
)
sft_loss = sft_loss * n_expert_exp * per_micro_batch_weight_expert
sft_metrics = {
k: v * n_expert_exp * per_micro_batch_weight_expert for k, v in sft_metrics.items()
}
else:
sft_loss = torch.tensor(0.0, device=logprob.device)
sft_metrics = {}
sft_loss, sft_metrics = self.sft_loss_fn(
logprob[expert_mask],
action_mask[expert_mask],
)
sft_loss = sft_loss * n_expert_exp * per_micro_batch_weight_expert
sft_metrics = {
k: v * n_expert_exp * per_micro_batch_weight_expert for k, v in sft_metrics.items()
}

mu = mu_schedule_function(
current_step, self.mu_warmup_steps, self.mu_decay_steps, self.mu_peak, self.mu_valley
Expand Down
Loading