Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -150,3 +150,7 @@ notebooks/

# testing assets
**/tests/assets/*

# ides
.vscode/
.cursor/
5 changes: 3 additions & 2 deletions compose_rl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@
'When installing plugins, please use one of the extras depending on which version of llmfoundry you are using.',
)

from compose_rl import algorithms, data, metrics, utils
from compose_rl import algorithms, controllers, data, metrics, utils

__all__ = [
'algorithms',
'utils',
'controllers',
'data',
'metrics',
'utils',
]
56 changes: 2 additions & 54 deletions compose_rl/algorithms/online/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@
add_right_padding,
compute_advantages,
dist_compute_masked_mean_and_var,
flatten,
get_decoded_sequence,
get_entropies,
get_log_probs,
Expand All @@ -64,6 +63,7 @@
masked_sum,
switch_left_to_right_padding,
)
from compose_rl.algorithms.online.callback_utils import preprocess_batches

Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
Policy = Union[ComposerHFPolicyLM, ComposerMPTPolicyLM]
Expand Down Expand Up @@ -666,59 +666,7 @@ def _get_next_iter_prompts(self):
self._get_single_batch_prompts() for _ in range(n_unique_batches)
]

ret_batch = {}
for key in batches[0].keys():
curr_values = []

max_len = 0
if isinstance(batches[0][key], torch.Tensor):
max_len = max([batch[key].shape[-1] for batch in batches])

padding_key = None
for batch in batches:
# Explode the batch into multiple batches for each generation
for _ in range(self.generations_per_prompt):
# For keys that do not require additional processing
if key in [
'prompt_len',
'verified_answer',
'prompt_id',
'vstar',
'messages',
]:
curr_values.append(batch[key])
continue

bs, seq_len = batch[key].shape

if key == 'prompt':
padding_key = self.pad_token_idx
if (batch[key][:, -1] == padding_key).any():
raise ValueError(
'The last token in the prompt should not be the pad token. Please double '
+
'check the dataloader and prompt and dataloader.',
)
elif key == 'prompt_attention_mask':
padding_key = False

# Compute the required padding and concatenate with the batch tensor
pad = torch.ones(
(bs, max_len - seq_len),
dtype=batch[key].dtype,
) * padding_key # type: ignore
curr_values.append(torch.cat([pad, batch[key]], dim=-1))

# For tensor fields, use torch.cat to combine the values; for string fields, just use the list
if isinstance(curr_values[0], torch.Tensor):
ret_batch[key] = torch.cat(curr_values)
else:
if key in ['verified_answer', 'vstar']:
ret_batch[key] = list(flatten(curr_values))
else:
ret_batch[key] = curr_values

return ret_batch
return preprocess_batches(batches, self.generations_per_prompt, self.pad_token_idx)

def _get_single_batch_prompts(self):
"""Gets a single batch of prompts from the dataloader."""
Expand Down
60 changes: 60 additions & 0 deletions compose_rl/algorithms/online/callback_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright 2024 MosaicML ComposeRL authors
# SPDX-License-Identifier: Apache-2.0

import torch
from compose_rl.utils import flatten

def preprocess_batches(batches: list, generations_per_prompt: int, pad_token_idx: int):
ret_batch = {}
for key in batches[0].keys():
curr_values = []

max_len = 0
if isinstance(batches[0][key], torch.Tensor):
max_len = max([batch[key].shape[-1] for batch in batches])

padding_key = None
for batch in batches:
# Explode the batch into multiple batches for each generation
for _ in range(generations_per_prompt):
# For keys that do not require additional processing
if key in [
'prompt_len',
'verified_answer',
'prompt_id',
'vstar',
'messages',
]:
curr_values.append(batch[key])
continue

bs, seq_len = batch[key].shape

if key == 'prompt':
padding_key = pad_token_idx
if (batch[key][:, -1] == padding_key).any():
raise ValueError(
'The last token in the prompt should not be the pad token. Please double '
+
'check the dataloader and prompt and dataloader.',
)
elif key == 'prompt_attention_mask':
padding_key = False

# Compute the required padding and concatenate with the batch tensor
pad = torch.ones(
(bs, max_len - seq_len),
dtype=batch[key].dtype,
) * padding_key # type: ignore
curr_values.append(torch.cat([pad, batch[key]], dim=-1))

# For tensor fields, use torch.cat to combine the values; for string fields, just use the list
if isinstance(curr_values[0], torch.Tensor):
ret_batch[key] = torch.cat(curr_values)
else:
if key in ['verified_answer', 'vstar']:
ret_batch[key] = list(flatten(curr_values))
else:
ret_batch[key] = curr_values

return ret_batch
2 changes: 2 additions & 0 deletions compose_rl/algorithms/online/generation_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from compose_rl.algorithms.online.generation_utils.generation_utils import (
hf_generate,
vllm_generate,
_vllm_generate,
)
from compose_rl.algorithms.online.generation_utils.vllm_utils import (
broadcast_to_vllm,
Expand All @@ -17,4 +18,5 @@
'init_process_group',
'hf_generate',
'vllm_generate',
'_vllm_generate',
]
6 changes: 6 additions & 0 deletions compose_rl/algorithms/online/single_controller_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,9 @@ def iteration_start(self, state: State, logger: Logger):

# Update IFT KL
self._update_ift_kl()

def iteration_end(self, state: State, logger: Logger):
del logger # unused
self._log_generations_to_logger(state)
self._increment_rl_iter()
self.buffer.reset()
4 changes: 4 additions & 0 deletions compose_rl/controllers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from compose_rl.controllers.actor import BaseDistributedGPUActor, SPMDActorGroup
from compose_rl.controllers.buffer import Buffer

__all__ = ['BaseDistributedGPUActor', 'Buffer', 'SPMDActorGroup']
81 changes: 79 additions & 2 deletions tests/common/actor.py → compose_rl/controllers/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import os
from datetime import timedelta
from typing import Optional
from typing import Any, Callable, Optional

import ray
import torch.distributed as dist
Expand Down Expand Up @@ -39,8 +39,9 @@ def __init__(
self.master_port = master_port

# Set up basic environment variables
# TODO: may need to handle LOCAL_WORLD_SIZE as used in callback.py
os.environ['WORLD_SIZE'] = str(world_size)
# FIXME: handle LOCAL_WORLD_SIZE for multiple nodes
os.environ['LOCAL_WORLD_SIZE'] = str(world_size)
os.environ['RANK'] = str(rank)

# Set LOCAL_RANK based on Ray GPU allocation
Expand Down Expand Up @@ -99,3 +100,79 @@ def add_process_group(
rank=rank,
group_name=group_name,
)

def execute(self, func: Callable[['BaseDistributedGPUActor'], Any]):
"""Dispatch a serializable function to this actor."""
return func(self)


class SPMDActorGroup:
"""Group managers of SPMD actors."""

def __init__(self, num_train_actors: int, actor_class: type[BaseDistributedGPUActor], num_gpus_per_actor: int = 1):
self.num_train_actors = num_train_actors
self._train_actors: list[BaseDistributedGPUActor] = []
"""Create and initialize all training actors."""
print(f'\n=== STARTING DISTRIBUTED TRAINING WITH RAY ACTORS ===')

remote_actor_class = ray.remote(num_gpus=num_gpus_per_actor)(actor_class)
# Create master actor first
self._master_actor = remote_actor_class.remote(
0,
self.num_train_actors,
)
self._train_actors.append(self._master_actor)

# Get master address from rank 0 actor
master_addr, master_port = ray.get(
self._master_actor.get_master_address.remote(), # type: ignore
)
print(f'Master address allocated: {master_addr}:{master_port}')

# Create remaining actors with the master address/port
for i in range(1, self.num_train_actors):
actor = remote_actor_class.remote(
i,
self.num_train_actors,
master_addr, # type: ignore
master_port,
)
self._train_actors.append(actor)

@property
def train_actors(self):
return self._train_actors

@property
def master_actor(self):
return self._master_actor

@property
def collective_methods(self):
"""Property that provides easy access to method references.
"""
return _ActorMethodProxy(self)


class _ActorMethodProxy:
"""Proxy class that provides easy access to actor methods.
"""

def __init__(self, actor_group: SPMDActorGroup):
self._actor_group = actor_group

def __getattr__(self, name: str):
"""Get a method reference that will be called on all actors."""
if not hasattr(self._actor_group.master_actor, name):
raise AttributeError(
f"Method '{name}' not found on actor class: {self._actor_group.master_actor.__class__}"
)

# Return a callable that will execute the method on all actors
def method_wrapper(*args: Any, **kwargs: Any):
# Since all actors are the same class, we can get the same method from each actor
# and call it remotely. No validation needed since we validated above.
refs = [getattr(actor, name).remote(*args, **kwargs) for actor in self._actor_group.train_actors]
return ray.get(refs)

return method_wrapper
14 changes: 14 additions & 0 deletions compose_rl/controllers/buffer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from typing import Any

class Buffer:
"""Placeholder class for Async RL"""

def __init__(self, buffer_size: int = 1):
self.buffer_size = buffer_size
self.buffer = []

def put(self, struct: dict[str, Any]):
raise NotImplementedError

def get(self, struct: dict[str, Any]):
raise NotImplementedError
2 changes: 2 additions & 0 deletions compose_rl/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
split_text_to_sentences,
split_text_to_subsentences,
switch_left_to_right_padding,
print_batch_shapes,
)

__all__ = [
Expand Down Expand Up @@ -101,4 +102,5 @@
'prepare_math_prompt',
'remove_boxed',
'ray_utils',
'print_batch_shapes',
]
Loading
Loading