Skip to content
Closed
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -150,3 +150,7 @@ notebooks/

# testing assets
**/tests/assets/*

# ides
.vscode/
.cursor/
5 changes: 3 additions & 2 deletions compose_rl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@
'When installing plugins, please use one of the extras depending on which version of llmfoundry you are using.',
)

from compose_rl import algorithms, data, metrics, utils
from compose_rl import algorithms, controllers, data, metrics, utils

__all__ = [
'algorithms',
'utils',
'controllers',
'data',
'metrics',
'utils',
]
56 changes: 2 additions & 54 deletions compose_rl/algorithms/online/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@
add_right_padding,
compute_advantages,
dist_compute_masked_mean_and_var,
flatten,
get_decoded_sequence,
get_entropies,
get_log_probs,
Expand All @@ -64,6 +63,7 @@
masked_sum,
switch_left_to_right_padding,
)
from compose_rl.algorithms.online.callback_utils import preprocess_batches

Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
Policy = Union[ComposerHFPolicyLM, ComposerMPTPolicyLM]
Expand Down Expand Up @@ -666,59 +666,7 @@ def _get_next_iter_prompts(self):
self._get_single_batch_prompts() for _ in range(n_unique_batches)
]

ret_batch = {}
for key in batches[0].keys():
curr_values = []

max_len = 0
if isinstance(batches[0][key], torch.Tensor):
max_len = max([batch[key].shape[-1] for batch in batches])

padding_key = None
for batch in batches:
# Explode the batch into multiple batches for each generation
for _ in range(self.generations_per_prompt):
# For keys that do not require additional processing
if key in [
'prompt_len',
'verified_answer',
'prompt_id',
'vstar',
'messages',
]:
curr_values.append(batch[key])
continue

bs, seq_len = batch[key].shape

if key == 'prompt':
padding_key = self.pad_token_idx
if (batch[key][:, -1] == padding_key).any():
raise ValueError(
'The last token in the prompt should not be the pad token. Please double '
+
'check the dataloader and prompt and dataloader.',
)
elif key == 'prompt_attention_mask':
padding_key = False

# Compute the required padding and concatenate with the batch tensor
pad = torch.ones(
(bs, max_len - seq_len),
dtype=batch[key].dtype,
) * padding_key # type: ignore
curr_values.append(torch.cat([pad, batch[key]], dim=-1))

# For tensor fields, use torch.cat to combine the values; for string fields, just use the list
if isinstance(curr_values[0], torch.Tensor):
ret_batch[key] = torch.cat(curr_values)
else:
if key in ['verified_answer', 'vstar']:
ret_batch[key] = list(flatten(curr_values))
else:
ret_batch[key] = curr_values

return ret_batch
return preprocess_batches(batches, self.generations_per_prompt, self.pad_token_idx)

def _get_single_batch_prompts(self):
"""Gets a single batch of prompts from the dataloader."""
Expand Down
60 changes: 60 additions & 0 deletions compose_rl/algorithms/online/callback_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright 2024 MosaicML ComposeRL authors
# SPDX-License-Identifier: Apache-2.0

import torch
from compose_rl.utils import flatten

def preprocess_batches(batches: list, generations_per_prompt: int, pad_token_idx: int):
ret_batch = {}
for key in batches[0].keys():
curr_values = []

max_len = 0
if isinstance(batches[0][key], torch.Tensor):
max_len = max([batch[key].shape[-1] for batch in batches])

padding_key = None
for batch in batches:
# Explode the batch into multiple batches for each generation
for _ in range(generations_per_prompt):
# For keys that do not require additional processing
if key in [
'prompt_len',
'verified_answer',
'prompt_id',
'vstar',
'messages',
]:
curr_values.append(batch[key])
continue

bs, seq_len = batch[key].shape

if key == 'prompt':
padding_key = pad_token_idx
if (batch[key][:, -1] == padding_key).any():
raise ValueError(
'The last token in the prompt should not be the pad token. Please double '
+
'check the dataloader and prompt and dataloader.',
)
elif key == 'prompt_attention_mask':
padding_key = False

# Compute the required padding and concatenate with the batch tensor
pad = torch.ones(
(bs, max_len - seq_len),
dtype=batch[key].dtype,
) * padding_key # type: ignore
curr_values.append(torch.cat([pad, batch[key]], dim=-1))

# For tensor fields, use torch.cat to combine the values; for string fields, just use the list
if isinstance(curr_values[0], torch.Tensor):
ret_batch[key] = torch.cat(curr_values)
else:
if key in ['verified_answer', 'vstar']:
ret_batch[key] = list(flatten(curr_values))
else:
ret_batch[key] = curr_values

return ret_batch
2 changes: 2 additions & 0 deletions compose_rl/algorithms/online/generation_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from compose_rl.algorithms.online.generation_utils.generation_utils import (
hf_generate,
vllm_generate,
_vllm_generate,
)
from compose_rl.algorithms.online.generation_utils.vllm_utils import (
broadcast_to_vllm,
Expand All @@ -17,4 +18,5 @@
'init_process_group',
'hf_generate',
'vllm_generate',
'_vllm_generate',
]
6 changes: 6 additions & 0 deletions compose_rl/algorithms/online/single_controller_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,9 @@ def iteration_start(self, state: State, logger: Logger):

# Update IFT KL
self._update_ift_kl()

def iteration_end(self, state: State, logger: Logger):
del logger # unused
self._log_generations_to_logger(state)
self._increment_rl_iter()
self.buffer.reset()
4 changes: 4 additions & 0 deletions compose_rl/controllers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from compose_rl.controllers.actor import BaseDistributedGPUActor, SPMDActorGroup
from compose_rl.controllers.buffer import Buffer

__all__ = ['BaseDistributedGPUActor', 'Buffer', 'SPMDActorGroup']
180 changes: 180 additions & 0 deletions compose_rl/controllers/actor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
# Copyright 2024 MosaicML ComposeRL authors
# SPDX-License-Identifier: Apache-2.0

import os
from datetime import timedelta
from typing import Any, Callable, Optional

import ray
import torch.distributed as dist

from compose_rl.algorithms.online.generation_utils import init_process_group
from compose_rl.utils.ray_utils import (
get_free_port,
get_node_ip,
is_cuda_visible_devices_set_by_ray,
)


class BaseDistributedGPUActor:

def __init__(
self,
rank: int,
world_size: int,
master_addr: Optional[str] = None,
master_port: Optional[int] = None,
):
"""Initialize the distributed GPU actor for RAY.

Args:
rank: The rank of this process in the distributed group
world_size: Total number of processes in the distributed group
master_addr: Master node address. If None, will allocate dynamically for rank 0
master_port: Master node port. If None, will allocate dynamically for rank 0
"""
self.rank = rank
self.world_size = world_size
self.master_addr = master_addr
self.master_port = master_port

# Set up basic environment variables
os.environ['WORLD_SIZE'] = str(world_size)
# FIXME: handle LOCAL_WORLD_SIZE for multiple nodes
os.environ['LOCAL_WORLD_SIZE'] = str(world_size)
os.environ['RANK'] = str(rank)

# Set LOCAL_RANK based on Ray GPU allocation
# ray.get_gpu_ids() is empty if ray is not used.
if len(ray.get_gpu_ids()) > 0:
os.environ['LOCAL_RANK'] = '0' if is_cuda_visible_devices_set_by_ray(
) else str(ray.get_gpu_ids()[0])

# If this is rank 0 and no master_addr/master_port provided, allocate them
if rank == 0 and (master_addr is None or master_port is None):
self._allocate_master_address()

os.environ['MASTER_ADDR'] = self.master_addr # type: ignore
os.environ['MASTER_PORT'] = str(self.master_port) # type: ignore

self.model = None
self.model_update_group = None

def _allocate_master_address(self):
"""Allocate master address and port for rank 0."""
if self.master_addr is None:
# Get the local IP address
self.master_addr = get_node_ip()

if self.master_port is None:
# Allocate a free port
self.master_port = get_free_port()

def get_master_address(self) -> tuple[Optional[str], Optional[int]]:
"""Return the master address and port as a tuple."""
return (self.master_addr, self.master_port)

def get_free_port(self):
return get_free_port()

def init_train_process_group(self):
"""Initialize the distributed process group."""
# Initialize process group
dist.init_process_group(timeout=timedelta(seconds=30))

def add_process_group(
self,
backend: str,
master_addr: str,
master_port: int,
world_size: int,
rank: int,
group_name: str,
):
"""Initialize the process group on trainer rank 0 and vllm engines."""
# NOTE vLLM seems to have a safer implementation of init_process_group:
# https://github.com/vllm-project/vllm/blob/v0.9.1/examples/offline_inference/rlhf.py#L105
# we should look into using that instead
self.model_update_group = init_process_group(
backend=backend,
init_method=f'tcp://{master_addr}:{master_port}',
world_size=world_size,
rank=rank,
group_name=group_name,
)

def execute(self, func: Callable[['BaseDistributedGPUActor'], Any]):
"""Dispatch a serializable function to this actor."""
return func(self)


class SPMDActorGroup:
"""Group managers of SPMD actors."""

def __init__(self, num_train_actors: int, actor_class: type[BaseDistributedGPUActor], num_gpus_per_actor: int = 1):
self.num_train_actors = num_train_actors
self._train_actors: list[BaseDistributedGPUActor] = []
"""Create and initialize all training actors."""
print(f'\n=== STARTING DISTRIBUTED TRAINING WITH RAY ACTORS ===')

remote_actor_class = ray.remote(num_gpus=num_gpus_per_actor)(actor_class)
# Create master actor first
self._master_actor = remote_actor_class.remote(
0,
self.num_train_actors,
)
self._train_actors.append(self._master_actor)

# Get master address from rank 0 actor
master_addr, master_port = ray.get(
self._master_actor.get_master_address.remote(), # type: ignore
)
print(f'Master address allocated: {master_addr}:{master_port}')

# Create remaining actors with the master address/port
for i in range(1, self.num_train_actors):
actor = remote_actor_class.remote(
i,
self.num_train_actors,
master_addr, # type: ignore
master_port,
)
self._train_actors.append(actor)

@property
def train_actors(self):
return self._train_actors

@property
def master_actor(self):
return self._master_actor

@property
def collective_methods(self):
"""Property that provides easy access to method references.
"""
return _ActorMethodProxy(self)


class _ActorMethodProxy:
"""Proxy class that provides easy access to actor methods.
"""

def __init__(self, actor_group: SPMDActorGroup):
self._actor_group = actor_group

def __getattr__(self, name: str):
"""Get a method reference that will be called on all actors."""
if not hasattr(self._actor_group.master_actor, name):
raise AttributeError(
f"Method '{name}' not found on actor class: {self._actor_group.master_actor.__class__}"
)

# Return a callable that will execute the method on all actors
def method_wrapper(*args: Any, **kwargs: Any):
# Since all actors are the same class, we can get the same method from each actor
# and call it remotely. No validation needed since we validated above.
refs = [getattr(actor, name).remote(*args, **kwargs) for actor in self._actor_group.train_actors]
return ray.get(refs)

return method_wrapper
14 changes: 14 additions & 0 deletions compose_rl/controllers/buffer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from typing import Any

class Buffer:
"""Placeholder class for Async RL"""

def __init__(self, buffer_size: int = 1):
self.buffer_size = buffer_size
self.buffer = []

def put(self, struct: dict[str, Any]):
raise NotImplementedError

def get(self, struct: dict[str, Any]):
raise NotImplementedError
2 changes: 2 additions & 0 deletions compose_rl/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
split_text_to_sentences,
split_text_to_subsentences,
switch_left_to_right_padding,
print_batch_shapes,
)

__all__ = [
Expand Down Expand Up @@ -101,4 +102,5 @@
'prepare_math_prompt',
'remove_boxed',
'ray_utils',
'print_batch_shapes',
]
Loading
Loading