Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
110 commits
Select commit Hold shift + click to select a range
8c1516c
runs but not limited master_port
bowenyang008 Jun 25, 2025
b0b57d9
hack local rank
bowenyang008 Jun 25, 2025
9e70dbb
composer launch works; torchrun somehow only allows use the same init…
bowenyang008 Jun 25, 2025
bc72f3e
clean up
bowenyang008 Jun 25, 2025
461c8c4
timeout to 30s
bowenyang008 Jun 25, 2025
e7803bb
update script
bowenyang008 Jun 25, 2025
feb62d1
None evn
bowenyang008 Jun 25, 2025
237373c
change internval
bowenyang008 Jun 25, 2025
4193ee3
try break and nodes instead
bowenyang008 Jun 25, 2025
fd63dcb
condition
bowenyang008 Jun 25, 2025
5c3f61d
rm resources
bowenyang008 Jun 25, 2025
ff227c4
sleep it
bowenyang008 Jun 25, 2025
172c676
use dist barrier to block
bowenyang008 Jun 26, 2025
f417e5b
context manager
bowenyang008 Jun 26, 2025
b29ba44
try to not release port
bowenyang008 Jun 26, 2025
9a30028
use ray actor
bowenyang008 Jun 26, 2025
12fbcef
another way to get ip address
bowenyang008 Jun 26, 2025
4ad2e65
ray init
bowenyang008 Jun 26, 2025
8df20ec
ray remote
bowenyang008 Jun 26, 2025
76c73a3
barrier
bowenyang008 Jun 26, 2025
8f51968
half subprocess
bowenyang008 Jun 26, 2025
a00369c
try w/o port
bowenyang008 Jun 26, 2025
a777c54
claude fix; questionable
bowenyang008 Jun 26, 2025
dd97cf2
do not raise
bowenyang008 Jun 26, 2025
09c92ed
revert back to subprocess
bowenyang008 Jun 27, 2025
cef2b9e
mix init
bowenyang008 Jun 27, 2025
9d93001
manual stop
bowenyang008 Jun 27, 2025
9dc7a74
two gpus runs
bowenyang008 Jun 27, 2025
c3db295
tensor parallel size 8
bowenyang008 Jun 30, 2025
824f086
update world size
bowenyang008 Jun 30, 2025
443d639
try all gpus for train again
bowenyang008 Jun 30, 2025
fde2960
use old port assignment
bowenyang008 Jun 30, 2025
8b625aa
get freeport again
bowenyang008 Jul 1, 2025
52bb8aa
change order again
bowenyang008 Jul 1, 2025
934ef7a
rank 0 again
bowenyang008 Jul 1, 2025
04217ab
half trainers
bowenyang008 Jul 1, 2025
f68c703
import; new method
bowenyang008 Jul 1, 2025
1fda4e3
no import
bowenyang008 Jul 1, 2025
6daa40c
import seems the killer
bowenyang008 Jul 1, 2025
42ab521
import vllm
bowenyang008 Jul 1, 2025
c86b2df
change rend time to 30s
bowenyang008 Jul 1, 2025
160d861
def method
bowenyang008 Jul 1, 2025
8eafdc6
add back vllm init
bowenyang008 Jul 1, 2025
0dc9f95
rm dup code
bowenyang008 Jul 1, 2025
d3c1a7c
comment out worker node
bowenyang008 Jul 1, 2025
8eb0116
running generations
bowenyang008 Jul 1, 2025
231674d
weight update donw
bowenyang008 Jul 2, 2025
d296b4e
not sync weight
bowenyang008 Jul 2, 2025
fe582a6
revert back
bowenyang008 Jul 2, 2025
f44f469
call master only
bowenyang008 Jul 2, 2025
01862b9
update pyproject for pyright filtering
bowenyang008 Jul 4, 2025
b7ebe7d
temp single controller
bowenyang008 Jul 4, 2025
1916fcd
separate out roundtrip code
bowenyang008 Jul 11, 2025
d8d4816
relocate to use test
bowenyang008 Jul 11, 2025
9d9850b
ref built
bowenyang008 Jul 11, 2025
8352a08
build ppo trainer
bowenyang008 Jul 11, 2025
4f093a4
weight update name mismatch; inference and prompts/gen exchange and t…
bowenyang008 Jul 11, 2025
88758c9
update ref model build
bowenyang008 Jul 12, 2025
1871b12
trains e2e!
bowenyang008 Jul 12, 2025
d2eb8f4
clean up imports
bowenyang008 Jul 15, 2025
567e37b
recover assert
bowenyang008 Jul 17, 2025
eb5b5c6
Merge remote-tracking branch 'origin/main' into boweny/playground
bowenyang008 Jul 17, 2025
f06dcfa
add test files
bowenyang008 Jul 18, 2025
133dac2
Merge branch 'main' into boweny/single-controller-composer
bowenyang008 Jul 22, 2025
da377bd
clean up dir
bowenyang008 Jul 22, 2025
3240715
refactor base actor
bowenyang008 Jul 22, 2025
861768b
rel method
bowenyang008 Jul 22, 2025
d0b5a44
callback fails
bowenyang008 Jul 23, 2025
ec7e550
revive ppo training again
bowenyang008 Jul 23, 2025
f7952b4
clean up
bowenyang008 Jul 23, 2025
95b96ef
first round code reduction done
bowenyang008 Jul 23, 2025
dee910b
another rm
bowenyang008 Jul 23, 2025
1e0b07a
train again; trim down
bowenyang008 Jul 23, 2025
b23b80e
recover weight update
bowenyang008 Jul 23, 2025
2f8c946
share method
bowenyang008 Jul 23, 2025
2966ccf
Revert "share method"
bowenyang008 Jul 23, 2025
fb98225
rm files
bowenyang008 Jul 23, 2025
be08a75
docs
bowenyang008 Jul 24, 2025
7dac060
doc
bowenyang008 Jul 24, 2025
d4a76e1
actor group
bowenyang008 Jul 25, 2025
24d2508
no more explicit master addr
bowenyang008 Jul 25, 2025
e37d2ec
all class ready
bowenyang008 Jul 25, 2025
7feb36c
run cmd
bowenyang008 Jul 25, 2025
c25bebc
use device direclty
bowenyang008 Jul 25, 2025
7195e4e
clean up to last method
bowenyang008 Jul 25, 2025
f28b8c8
clean up
bowenyang008 Jul 25, 2025
d180533
mv vllm engines out
bowenyang008 Jul 25, 2025
5c18c99
yeah works
bowenyang008 Jul 25, 2025
369e58e
relocate file and pytest works
bowenyang008 Jul 25, 2025
5a69a15
rm file
bowenyang008 Jul 25, 2025
5e2ebc2
format
bowenyang008 Jul 26, 2025
9e21f3f
Merge remote-tracking branch 'origin/main' into boweny/single-control…
bowenyang008 Jul 26, 2025
470bef9
format
bowenyang008 Jul 26, 2025
44b6609
type ignore
bowenyang008 Jul 28, 2025
0662810
todo
bowenyang008 Jul 28, 2025
ab4faad
different type fix
bowenyang008 Jul 28, 2025
7ff5867
todos
bowenyang008 Jul 28, 2025
092da43
doc fix
bowenyang008 Jul 28, 2025
6c5f216
format
bowenyang008 Jul 28, 2025
00f32f3
change gpu test to 2
bowenyang008 Jul 28, 2025
7094033
format
bowenyang008 Jul 28, 2025
59305a6
revert 4 gpu for now, regression does not like this test
bowenyang008 Jul 28, 2025
47d9b5f
revert 2 gpu for now
bowenyang008 Jul 28, 2025
9cb6e98
todo
bowenyang008 Jul 28, 2025
0b5d911
use 3.11 and update doc formatter
bowenyang008 Jul 28, 2025
e268626
revert change
bowenyang008 Jul 28, 2025
e9e9b3c
try diff
bowenyang008 Jul 29, 2025
6a513e4
yapf again
bowenyang008 Jul 29, 2025
81fce8c
todo
bowenyang008 Jul 29, 2025
eb889f1
revert
bowenyang008 Jul 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions compose_rl/algorithms/online/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
HFPolicyConfig,
MPTPolicyConfig,
)
from compose_rl.algorithms.online.single_controller_callback import \
SingleControllerOnPolicyCallback
from compose_rl.registry import kl_controllers

kl_controllers.register('adaptive', func=AdaptiveKLController)
Expand All @@ -28,6 +30,7 @@

__all__ = [
'OnPolicyCallback',
'SingleControllerOnPolicyCallback',
'ComposerMPTPolicyLM',
'ComposerHFPolicyLM',
'ComposerHFCriticFreePolicyLM',
Expand Down
20 changes: 16 additions & 4 deletions compose_rl/algorithms/online/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,7 +764,6 @@ def _interact_with_env(self, batch: dict[str, torch.Tensor]):
# When we hit this function, we should already have all the prompts we need per iteration.
num_gen_calls = bs // self.device_generate_batch_size

gen_batch_partial_outputs = []
all_sequences = []
for i in range(num_gen_calls):
gen_batch = self._extract_minibatch(
Expand Down Expand Up @@ -796,6 +795,15 @@ def _interact_with_env(self, batch: dict[str, torch.Tensor]):
# Add the prepared sequences to the batch again
batch['sequences'] = sequences

# Compute rewards and populate buffer
self._get_reward(batch)

def _get_reward(self, batch: dict[str, torch.Tensor]):
"""Compute rewards for a batch of generated sequences.

Args:
batch (dict): The batch containing generated sequences to compute rewards for.
"""
env_outputs, prompts_and_gens, ref_outputs, all_rewards_dict = env_reward(
actor_critic=self.actor_critic, # pyright: ignore
reward_manager=self.reward_manager,
Expand Down Expand Up @@ -825,7 +833,9 @@ def _interact_with_env(self, batch: dict[str, torch.Tensor]):
del resolved_outputs[key]

# We need to split the resolved outputs into minibatches
for idx in range(bs // self.device_train_batch_size):
for idx in range(
batch['prompt_id'].shape[0] // self.device_train_batch_size,
):
minibatch = self._extract_minibatch(
resolved_outputs,
idx,
Expand All @@ -834,7 +844,9 @@ def _interact_with_env(self, batch: dict[str, torch.Tensor]):
self.buffer.add(minibatch)

# Making sure we correctly parsed the minibatches
assert len(self.buffer) == self.num_batches_per_update
assert len(
self.buffer,
) == self.num_batches_per_update, f'{len(self.buffer)} != {self.num_batches_per_update}'

self.actor_critic.train()

Expand Down Expand Up @@ -1149,7 +1161,7 @@ def _update_inference_model(self, batch: dict[str, torch.Tensor]):
model=self.actor_critic,
vllm_engines=self.vllm_engines,
model_update_group=self.model_update_group,
batch=batch,
device=batch['prompt'].device,
loss_type=self.actor_critic.loss_type, # type: ignore
enable_prefix_caching=self.vllm_enable_prefix_caching,
)
Expand Down
19 changes: 8 additions & 11 deletions compose_rl/algorithms/online/generation_utils/vllm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ def broadcast_to_vllm(
model: nn.Module,
vllm_engines: list,
model_update_group: Optional[torch.distributed.ProcessGroup],
batch: dict[str, torch.Tensor],
device: torch.device,
loss_type: OnPolicyEnum = OnPolicyEnum.PPO,
enable_prefix_caching: bool = False,
):
Expand All @@ -391,7 +391,7 @@ def broadcast_to_vllm(
model (nn.Module): The model to broadcast
vllm_engines (list): List of vllm engines
model_update_group (torch.distributed.ProcessGroup): The process group for model updates
batch (dict[str, torch.Tensor]): The batch to use for the forward pass
device (torch.device): The device to use for the forward pass
loss_type (str): The loss type which decides whether to use critic-free or not. Defaults to `ppo`.
enable_prefix_caching (bool): Whether to enable prefix caching. Defaults to `False`.
"""
Expand Down Expand Up @@ -419,9 +419,6 @@ def broadcast_to_vllm(
engine.reset_prefix_cache.remote() for engine in vllm_engines
]

# This is needed to get the correct model device
cur_device = batch['prompt'].device

# These apply to llama modules, it might change for other modules
valid_non_leaf_module_names = [
'model.embed_tokens.weight',
Expand All @@ -438,17 +435,17 @@ def broadcast_to_vllm(
# We need this otherwise FSDP throws an error during a standard forward pass.
dummy_batch = {
'obs':
torch.tensor([[0]], dtype=torch.long, device=cur_device),
torch.tensor([[0]], dtype=torch.long, device=device),
'right_padded_attn_mask':
torch.tensor([[1]], dtype=torch.bool, device=cur_device),
torch.tensor([[1]], dtype=torch.bool, device=device),
'actions':
torch.tensor([[0]], dtype=torch.long, device=cur_device),
torch.tensor([[0]], dtype=torch.long, device=device),
'prompt_len':
torch.tensor([1], device=cur_device),
torch.tensor([1], device=device),
'max_gen_len':
torch.tensor([1], device=cur_device),
torch.tensor([1], device=device),
'action_mask':
torch.tensor([[0]], dtype=torch.long, device=cur_device),
torch.tensor([[0]], dtype=torch.long, device=device),
}
model(dummy_batch)
start_time = time.time()
Expand Down
14 changes: 12 additions & 2 deletions compose_rl/algorithms/online/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,17 +100,22 @@ def eval_forward(self, batch: MutableMapping, outputs: MutableMapping):
)

def loss(self, outputs: MutableMapping, batch: MutableMapping):
# Get beta from config if available, otherwise use default
additional_kwargs = {}
if hasattr(self.config, 'beta'):
additional_kwargs['beta'] = self.config.beta

return_dict = online_rl_loss(
outputs=outputs,
batch=batch,
loss_type=OnPolicyEnum.PPO,
value_clip_range=self.config.value_clip_range,
value_loss_weight=self.config.value_loss_weight,
policy_clip_ratio=self.config.policy_clip_ratio,
beta=self.config.beta,
add_direct_kl_loss=self.config.compute_kl_loss,
kl_estimator=self.config.kl_estimator,
kl_clip_range=self.config.kl_clip_range,
**additional_kwargs,
)

self.policy_kl.append(return_dict['kl/policy_kl'])
Expand Down Expand Up @@ -217,17 +222,22 @@ def eval_forward(self, batch: MutableMapping, outputs: MutableMapping):
)

def loss(self, outputs: MutableMapping, batch: MutableMapping):
# Get beta from config if available, otherwise use default
additional_kwargs = {}
if hasattr(self.config, 'beta'):
additional_kwargs['beta'] = self.config.beta

return_dict = online_rl_loss(
outputs=outputs,
batch=batch,
loss_type=self.loss_type, # pyright: ignore
value_clip_range=self.config.value_clip_range,
value_loss_weight=self.config.value_loss_weight,
policy_clip_ratio=self.config.policy_clip_ratio,
beta = self.config.beta,
add_direct_kl_loss=self.config.compute_kl_loss,
kl_estimator=self.config.kl_estimator,
kl_clip_range=self.config.kl_clip_range,
**additional_kwargs,
)

self.policy_kl.append(return_dict['kl/policy_kl'])
Expand Down
58 changes: 58 additions & 0 deletions compose_rl/algorithms/online/single_controller_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright 2024 MosaicML ComposeRL authors
# SPDX-License-Identifier: Apache-2.0

"""Online On-Policy RL callback."""

from __future__ import annotations

import logging
from typing import Union

from composer.core import State
from composer.loggers import Logger
from composer.trainer.trainer import _get_initial_device_train_microbatch_size
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast

# Import the base class
from compose_rl.algorithms.online.callback import OnPolicyCallback
from compose_rl.algorithms.online.model import (
ComposerHFPolicyLM,
ComposerMPTPolicyLM,
)

Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
Policy = Union[ComposerHFPolicyLM, ComposerMPTPolicyLM]

__all__ = ['SingleControllerOnPolicyCallback']

log = logging.getLogger(__name__)


class SingleControllerOnPolicyCallback(OnPolicyCallback):
"""Callback for managing on-policy training in an RLHF loop.

Ideally all the overwritten methods below should be implemented in the
trainer actor instead of the callback, we kept them here for now to minimize
a drastic refactor to PPO Callback code
"""

def iteration_start(self, state: State, logger: Logger):
del logger # unused

self._get_reward(self.batch_rollouts) # type: ignore

# Reset and initialize state train dataloader
log.warning(
'trainer._train_data_spec should be updated whenever the dataloader is updated',
)
# Train Dataloader
state.set_dataloader(self.buffer, 'ep')
state.train_dataloader = state.dataloader
state.device_train_microbatch_size = _get_initial_device_train_microbatch_size(
state.device_train_microbatch_size,
state.auto_microbatching,
state.train_dataloader,
)

# Update IFT KL
self._update_ift_kl()
2 changes: 2 additions & 0 deletions tests/common/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright 2024 MosaicML ComposeRL authors
# SPDX-License-Identifier: Apache-2.0

from tests.common.actor import BaseDistributedGPUActor
from tests.common.datasets import (
FineGrainedPreference,
PairwisePreference,
Expand All @@ -11,6 +12,7 @@
from tests.common.markers import device, world_size

__all__ = [
'BaseDistributedGPUActor',
'PairwisePreference',
'FineGrainedPreference',
'PromptDataset',
Expand Down
101 changes: 101 additions & 0 deletions tests/common/actor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright 2024 MosaicML ComposeRL authors
# SPDX-License-Identifier: Apache-2.0

import os
from datetime import timedelta
from typing import Optional

import ray
import torch.distributed as dist

from compose_rl.algorithms.online.generation_utils import init_process_group
from compose_rl.utils.ray_utils import (
get_free_port,
get_node_ip,
is_cuda_visible_devices_set,
)


class BaseDistributedGPUActor:

def __init__(
self,
rank: int,
world_size: int,
master_addr: Optional[str] = None,
master_port: Optional[int] = None,
):
"""Initialize the distributed GPU actor for RAY.

Args:
rank: The rank of this process in the distributed group
world_size: Total number of processes in the distributed group
master_addr: Master node address. If None, will allocate dynamically for rank 0
master_port: Master node port. If None, will allocate dynamically for rank 0
"""
self.rank = rank
self.world_size = world_size
self.master_addr = master_addr
self.master_port = master_port

# Set up basic environment variables
# TODO: may need to handle LOCAL_WORLD_SIZE as used in callback.py
os.environ['WORLD_SIZE'] = str(world_size)
os.environ['RANK'] = str(rank)

# Set LOCAL_RANK based on Ray GPU allocation
os.environ['LOCAL_RANK'] = '0' if is_cuda_visible_devices_set(
) else str(ray.get_gpu_ids()[0])

# If this is rank 0 and no master_addr/master_port provided, allocate them
if rank == 0 and (master_addr is None or master_port is None):
self._allocate_master_address()

os.environ['MASTER_ADDR'] = self.master_addr # type: ignore
os.environ['MASTER_PORT'] = str(self.master_port) # type: ignore

self.model = None
self.model_update_group = None

def _allocate_master_address(self):
"""Allocate master address and port for rank 0."""
if self.master_addr is None:
# Get the local IP address
self.master_addr = get_node_ip()

if self.master_port is None:
# Allocate a free port
self.master_port = get_free_port()

def get_master_address(self) -> tuple[Optional[str], Optional[int]]:
"""Return the master address and port as a tuple."""
return (self.master_addr, self.master_port)

def get_free_port(self):
return get_free_port()

def init_train_process_group(self):
"""Initialize the distributed process group."""
# Initialize process group
dist.init_process_group(timeout=timedelta(seconds=30))

def add_process_group(
self,
backend: str,
master_addr: str,
master_port: int,
world_size: int,
rank: int,
group_name: str,
):
"""Initialize the process group on trainer rank 0 and vllm engines."""
# NOTE vLLM seems to have a safer implementation of init_process_group:
# https://github.com/vllm-project/vllm/blob/v0.9.1/examples/offline_inference/rlhf.py#L105
# we should look into using that instead
self.model_update_group = init_process_group(
backend=backend,
init_method=f'tcp://{master_addr}:{master_port}',
world_size=world_size,
rank=rank,
group_name=group_name,
)
2 changes: 2 additions & 0 deletions tests/common/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def __getitem__(self, index: int):
return {
'prompt': torch.ones((self.prompt_len,)).int(),
'prompt_len': torch.Tensor([self.prompt_len]).to(torch.int64),
'prompt_id': torch.Tensor([index]).int(),
}


Expand All @@ -87,6 +88,7 @@ def __getitem__(self, index: int):
return {
'prompt': torch.ones((self.prompt_len,)).int(),
'prompt_len': torch.Tensor([self.prompt_len]).to(torch.int64),
'prompt_id': torch.Tensor([index]).int(),
'verified_answer': '1',
}

Expand Down
Loading
Loading