diff --git a/.gitignore b/.gitignore index dd3b206d..c39bdb84 100644 --- a/.gitignore +++ b/.gitignore @@ -150,3 +150,7 @@ notebooks/ # testing assets **/tests/assets/* + +# ides +.vscode/ +.cursor/ diff --git a/compose_rl/__init__.py b/compose_rl/__init__.py index aa1ce185..0db9b2ae 100644 --- a/compose_rl/__init__.py +++ b/compose_rl/__init__.py @@ -10,11 +10,12 @@ 'When installing plugins, please use one of the extras depending on which version of llmfoundry you are using.', ) -from compose_rl import algorithms, data, metrics, utils +from compose_rl import algorithms, controllers, data, metrics, utils __all__ = [ 'algorithms', - 'utils', + 'controllers', 'data', 'metrics', + 'utils', ] diff --git a/compose_rl/algorithms/online/callback.py b/compose_rl/algorithms/online/callback.py index 41d96d59..8edef646 100644 --- a/compose_rl/algorithms/online/callback.py +++ b/compose_rl/algorithms/online/callback.py @@ -55,7 +55,6 @@ add_right_padding, compute_advantages, dist_compute_masked_mean_and_var, - flatten, get_decoded_sequence, get_entropies, get_log_probs, @@ -64,6 +63,7 @@ masked_sum, switch_left_to_right_padding, ) +from compose_rl.algorithms.online.callback_utils import preprocess_batches Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] Policy = Union[ComposerHFPolicyLM, ComposerMPTPolicyLM] @@ -666,59 +666,7 @@ def _get_next_iter_prompts(self): self._get_single_batch_prompts() for _ in range(n_unique_batches) ] - ret_batch = {} - for key in batches[0].keys(): - curr_values = [] - - max_len = 0 - if isinstance(batches[0][key], torch.Tensor): - max_len = max([batch[key].shape[-1] for batch in batches]) - - padding_key = None - for batch in batches: - # Explode the batch into multiple batches for each generation - for _ in range(self.generations_per_prompt): - # For keys that do not require additional processing - if key in [ - 'prompt_len', - 'verified_answer', - 'prompt_id', - 'vstar', - 'messages', - ]: - curr_values.append(batch[key]) - continue - - bs, seq_len = batch[key].shape - - if key == 'prompt': - padding_key = self.pad_token_idx - if (batch[key][:, -1] == padding_key).any(): - raise ValueError( - 'The last token in the prompt should not be the pad token. Please double ' - + - 'check the dataloader and prompt and dataloader.', - ) - elif key == 'prompt_attention_mask': - padding_key = False - - # Compute the required padding and concatenate with the batch tensor - pad = torch.ones( - (bs, max_len - seq_len), - dtype=batch[key].dtype, - ) * padding_key # type: ignore - curr_values.append(torch.cat([pad, batch[key]], dim=-1)) - - # For tensor fields, use torch.cat to combine the values; for string fields, just use the list - if isinstance(curr_values[0], torch.Tensor): - ret_batch[key] = torch.cat(curr_values) - else: - if key in ['verified_answer', 'vstar']: - ret_batch[key] = list(flatten(curr_values)) - else: - ret_batch[key] = curr_values - - return ret_batch + return preprocess_batches(batches, self.generations_per_prompt, self.pad_token_idx) def _get_single_batch_prompts(self): """Gets a single batch of prompts from the dataloader.""" diff --git a/compose_rl/algorithms/online/callback_utils.py b/compose_rl/algorithms/online/callback_utils.py new file mode 100644 index 00000000..982cccb9 --- /dev/null +++ b/compose_rl/algorithms/online/callback_utils.py @@ -0,0 +1,60 @@ +# Copyright 2024 MosaicML ComposeRL authors +# SPDX-License-Identifier: Apache-2.0 + +import torch +from compose_rl.utils import flatten + +def preprocess_batches(batches: list, generations_per_prompt: int, pad_token_idx: int): + ret_batch = {} + for key in batches[0].keys(): + curr_values = [] + + max_len = 0 + if isinstance(batches[0][key], torch.Tensor): + max_len = max([batch[key].shape[-1] for batch in batches]) + + padding_key = None + for batch in batches: + # Explode the batch into multiple batches for each generation + for _ in range(generations_per_prompt): + # For keys that do not require additional processing + if key in [ + 'prompt_len', + 'verified_answer', + 'prompt_id', + 'vstar', + 'messages', + ]: + curr_values.append(batch[key]) + continue + + bs, seq_len = batch[key].shape + + if key == 'prompt': + padding_key = pad_token_idx + if (batch[key][:, -1] == padding_key).any(): + raise ValueError( + 'The last token in the prompt should not be the pad token. Please double ' + + + 'check the dataloader and prompt and dataloader.', + ) + elif key == 'prompt_attention_mask': + padding_key = False + + # Compute the required padding and concatenate with the batch tensor + pad = torch.ones( + (bs, max_len - seq_len), + dtype=batch[key].dtype, + ) * padding_key # type: ignore + curr_values.append(torch.cat([pad, batch[key]], dim=-1)) + + # For tensor fields, use torch.cat to combine the values; for string fields, just use the list + if isinstance(curr_values[0], torch.Tensor): + ret_batch[key] = torch.cat(curr_values) + else: + if key in ['verified_answer', 'vstar']: + ret_batch[key] = list(flatten(curr_values)) + else: + ret_batch[key] = curr_values + + return ret_batch diff --git a/compose_rl/algorithms/online/generation_utils/__init__.py b/compose_rl/algorithms/online/generation_utils/__init__.py index 054ce72a..0b4e3891 100644 --- a/compose_rl/algorithms/online/generation_utils/__init__.py +++ b/compose_rl/algorithms/online/generation_utils/__init__.py @@ -4,6 +4,7 @@ from compose_rl.algorithms.online.generation_utils.generation_utils import ( hf_generate, vllm_generate, + _vllm_generate, ) from compose_rl.algorithms.online.generation_utils.vllm_utils import ( broadcast_to_vllm, @@ -17,4 +18,5 @@ 'init_process_group', 'hf_generate', 'vllm_generate', + '_vllm_generate', ] diff --git a/compose_rl/algorithms/online/single_controller_callback.py b/compose_rl/algorithms/online/single_controller_callback.py index b5b4cd6b..5452feb6 100644 --- a/compose_rl/algorithms/online/single_controller_callback.py +++ b/compose_rl/algorithms/online/single_controller_callback.py @@ -56,3 +56,9 @@ def iteration_start(self, state: State, logger: Logger): # Update IFT KL self._update_ift_kl() + + def iteration_end(self, state: State, logger: Logger): + del logger # unused + self._log_generations_to_logger(state) + self._increment_rl_iter() + self.buffer.reset() diff --git a/compose_rl/controllers/__init__.py b/compose_rl/controllers/__init__.py new file mode 100644 index 00000000..9bd150f4 --- /dev/null +++ b/compose_rl/controllers/__init__.py @@ -0,0 +1,4 @@ +from compose_rl.controllers.actor import BaseDistributedGPUActor, SPMDActorGroup +from compose_rl.controllers.buffer import Buffer + +__all__ = ['BaseDistributedGPUActor', 'Buffer', 'SPMDActorGroup'] diff --git a/compose_rl/controllers/actor.py b/compose_rl/controllers/actor.py new file mode 100644 index 00000000..be883bf4 --- /dev/null +++ b/compose_rl/controllers/actor.py @@ -0,0 +1,180 @@ +# Copyright 2024 MosaicML ComposeRL authors +# SPDX-License-Identifier: Apache-2.0 + +import os +from datetime import timedelta +from typing import Any, Callable, Optional + +import ray +import torch.distributed as dist + +from compose_rl.algorithms.online.generation_utils import init_process_group +from compose_rl.utils.ray_utils import ( + get_free_port, + get_node_ip, + is_cuda_visible_devices_set_by_ray, +) + + +class BaseDistributedGPUActor: + + def __init__( + self, + rank: int, + world_size: int, + master_addr: Optional[str] = None, + master_port: Optional[int] = None, + ): + """Initialize the distributed GPU actor for RAY. + + Args: + rank: The rank of this process in the distributed group + world_size: Total number of processes in the distributed group + master_addr: Master node address. If None, will allocate dynamically for rank 0 + master_port: Master node port. If None, will allocate dynamically for rank 0 + """ + self.rank = rank + self.world_size = world_size + self.master_addr = master_addr + self.master_port = master_port + + # Set up basic environment variables + os.environ['WORLD_SIZE'] = str(world_size) + # FIXME: handle LOCAL_WORLD_SIZE for multiple nodes + os.environ['LOCAL_WORLD_SIZE'] = str(world_size) + os.environ['RANK'] = str(rank) + + # Set LOCAL_RANK based on Ray GPU allocation + # ray.get_gpu_ids() is empty if ray is not used. + if len(ray.get_gpu_ids()) > 0: + os.environ['LOCAL_RANK'] = '0' if is_cuda_visible_devices_set_by_ray( + ) else str(ray.get_gpu_ids()[0]) + + # If this is rank 0 and no master_addr/master_port provided, allocate them + if rank == 0 and (master_addr is None or master_port is None): + self._allocate_master_address() + + os.environ['MASTER_ADDR'] = self.master_addr # type: ignore + os.environ['MASTER_PORT'] = str(self.master_port) # type: ignore + + self.model = None + self.model_update_group = None + + def _allocate_master_address(self): + """Allocate master address and port for rank 0.""" + if self.master_addr is None: + # Get the local IP address + self.master_addr = get_node_ip() + + if self.master_port is None: + # Allocate a free port + self.master_port = get_free_port() + + def get_master_address(self) -> tuple[Optional[str], Optional[int]]: + """Return the master address and port as a tuple.""" + return (self.master_addr, self.master_port) + + def get_free_port(self): + return get_free_port() + + def init_train_process_group(self): + """Initialize the distributed process group.""" + # Initialize process group + dist.init_process_group(timeout=timedelta(seconds=30)) + + def add_process_group( + self, + backend: str, + master_addr: str, + master_port: int, + world_size: int, + rank: int, + group_name: str, + ): + """Initialize the process group on trainer rank 0 and vllm engines.""" + # NOTE vLLM seems to have a safer implementation of init_process_group: + # https://github.com/vllm-project/vllm/blob/v0.9.1/examples/offline_inference/rlhf.py#L105 + # we should look into using that instead + self.model_update_group = init_process_group( + backend=backend, + init_method=f'tcp://{master_addr}:{master_port}', + world_size=world_size, + rank=rank, + group_name=group_name, + ) + + def execute(self, func: Callable[['BaseDistributedGPUActor'], Any]): + """Dispatch a serializable function to this actor.""" + return func(self) + + +class SPMDActorGroup: + """Group managers of SPMD actors.""" + + def __init__(self, num_train_actors: int, actor_class: type[BaseDistributedGPUActor], num_gpus_per_actor: int = 1): + self.num_train_actors = num_train_actors + self._train_actors: list[BaseDistributedGPUActor] = [] + """Create and initialize all training actors.""" + print(f'\n=== STARTING DISTRIBUTED TRAINING WITH RAY ACTORS ===') + + remote_actor_class = ray.remote(num_gpus=num_gpus_per_actor)(actor_class) + # Create master actor first + self._master_actor = remote_actor_class.remote( + 0, + self.num_train_actors, + ) + self._train_actors.append(self._master_actor) + + # Get master address from rank 0 actor + master_addr, master_port = ray.get( + self._master_actor.get_master_address.remote(), # type: ignore + ) + print(f'Master address allocated: {master_addr}:{master_port}') + + # Create remaining actors with the master address/port + for i in range(1, self.num_train_actors): + actor = remote_actor_class.remote( + i, + self.num_train_actors, + master_addr, # type: ignore + master_port, + ) + self._train_actors.append(actor) + + @property + def train_actors(self): + return self._train_actors + + @property + def master_actor(self): + return self._master_actor + + @property + def collective_methods(self): + """Property that provides easy access to method references. + """ + return _ActorMethodProxy(self) + + +class _ActorMethodProxy: + """Proxy class that provides easy access to actor methods. + """ + + def __init__(self, actor_group: SPMDActorGroup): + self._actor_group = actor_group + + def __getattr__(self, name: str): + """Get a method reference that will be called on all actors.""" + if not hasattr(self._actor_group.master_actor, name): + raise AttributeError( + f"Method '{name}' not found on actor class: {self._actor_group.master_actor.__class__}" + ) + + # Return a callable that will execute the method on all actors + def method_wrapper(*args: Any, **kwargs: Any): + # Since all actors are the same class, we can get the same method from each actor + # and call it remotely. No validation needed since we validated above. + refs = [getattr(actor, name).remote(*args, **kwargs) for actor in self._actor_group.train_actors] + return ray.get(refs) + + return method_wrapper diff --git a/compose_rl/controllers/buffer.py b/compose_rl/controllers/buffer.py new file mode 100644 index 00000000..98afea02 --- /dev/null +++ b/compose_rl/controllers/buffer.py @@ -0,0 +1,14 @@ +from typing import Any + +class Buffer: + """Placeholder class for Async RL""" + + def __init__(self, buffer_size: int = 1): + self.buffer_size = buffer_size + self.buffer = [] + + def put(self, struct: dict[str, Any]): + raise NotImplementedError + + def get(self, struct: dict[str, Any]): + raise NotImplementedError diff --git a/compose_rl/utils/__init__.py b/compose_rl/utils/__init__.py index 57010e57..8f30bb9a 100644 --- a/compose_rl/utils/__init__.py +++ b/compose_rl/utils/__init__.py @@ -51,6 +51,7 @@ split_text_to_sentences, split_text_to_subsentences, switch_left_to_right_padding, + print_batch_shapes, ) __all__ = [ @@ -101,4 +102,5 @@ 'prepare_math_prompt', 'remove_boxed', 'ray_utils', + 'print_batch_shapes', ] diff --git a/compose_rl/utils/ray_utils.py b/compose_rl/utils/ray_utils.py index a5e9a8da..908a12d5 100644 --- a/compose_rl/utils/ray_utils.py +++ b/compose_rl/utils/ray_utils.py @@ -1,6 +1,7 @@ # Copyright 2024 MosaicML ComposeRL authors # SPDX-License-Identifier: Apache-2.0 +from asyncio import Event import logging import os import socket @@ -80,6 +81,68 @@ def init_ray_with_torch_distributed(timeout_seconds: int = 30): return address +@ray.remote +class _Barrier: + """A barrier for synchronizing between multiple ray clients. + + NOTE: There is no timeout for this barrier. + """ + def __init__(self, num_parties: int): + """ + Initializes the barrier for a given number of parties. + """ + self._num_parties = num_parties + self._num_parties_arrived = 0 + self._event = Event() + + async def wait(self): + """ + await is blocked until all parties have called this method. + """ + print(f'Rank {ray.get_runtime_context().get_actor_id()} is waiting for the barrier') + self._num_parties_arrived += 1 + if self._num_parties_arrived == self._num_parties: + # All parties have arrived, notify them to proceed. + print(f'Rank {ray.get_runtime_context().get_actor_id()} is proceeding') + self._event.set() + else: + # Wait for the event to be set by the last arriving party. + await self._event.wait() + + def reset(self): + """ + Resets the barrier for reuse. + """ + self._num_parties_arrived = 0 + self._event.clear() + + +def _barrier(world_size: int, rank: int, name: str, namespace: str = '_synchronization'): + """ + A barrier for synchronizing between multiple ray clients with unlimited timeout. + + Args: + world_size (int): The number of parties to synchronize. + rank (int): The rank of the current process. + name (str): The agreed name of the barrier actor. + namespace (str): The agreed namespace of the barrier actor. + """ + if rank == 0: + # Create the barrier actor - Ray will handle any naming conflicts appropriately + barrier = _Barrier.options(name=name, namespace=namespace).remote(world_size) + else: + while True: + try: + barrier = ray.get_actor(name, namespace=namespace) + break + except ValueError: # Actor not found + time.sleep(1) # Retry after a short delay + ray.get(barrier.wait.remote()) + if rank == 0: + # close the ray actor + barrier.__ray_terminate__.remote() + + @contextmanager def start_ray_server(): """Context manager for Ray server in a torch distributed environment. @@ -112,9 +175,7 @@ def start_ray_server(): # NOTE we have to keep all the MCT orchestrator started processes alive with this barrier # until the ray cluster is stopped, otherwise the MCT orchestrator will reclaim the resources # once the processes on a node exit - # this may time out too quick for a real world run, if so we might need to reuse the original - # SyncActor based approach - dist.barrier() + _barrier(dist.get_world_size(), dist.get_rank(), 'mcloud_barrier') finally: if dist.get_rank() == 0: ray.shutdown() @@ -161,7 +222,7 @@ def get_free_port(): return sock.getsockname()[1] -def is_cuda_visible_devices_set(): +def is_cuda_visible_devices_set_by_ray(): """Check if CUDA_VISIBLE_DEVICES environment variable is being set by Ray. Ray can automatically set the CUDA_VISIBLE_DEVICES environment variable to @@ -171,7 +232,33 @@ def is_cuda_visible_devices_set(): Returns: bool: True if Ray is setting CUDA_VISIBLE_DEVICES, False otherwise """ - return os.environ.get( + return os.environ.get('CUDA_VISIBLE_DEVICES', None) is not None and os.environ.get( 'RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES', '0', ) == '0' + +# TODO: Since this uninstallation deals specifically with ray, +# added the function here instead of the regular utils.py file +# We need to investigate this further after the hackathon since +# this is a super hacky solution to support CPU workers +def uninstall_megablocks_if_exists(): + """ + Megablocks exists on the ray workers but is not supported on CPU. + We need to uninstall it to avoid errors. + + Note: Installing `llm-foundry[all-cpu]` (which doesn't have megablocks) + on the StreamingDatasetActor worker through ray runtime options + doesn't seem to actually resolve this issue even though it's supposed + to set up a new environment... + TODO: Figure out why that's the case and if there's a better way to + resolve this issue. + """ + import sys + import subprocess + + # First uninstall megablocks package (if it exists) + command = [sys.executable, "-m", "pip", "uninstall", "megablocks", "-y"] + subprocess.run(command, check=False, capture_output=True, text=True) + # Then remove from sys.modules if present + if 'megablocks' in sys.modules: + del sys.modules['megablocks'] diff --git a/compose_rl/utils/utils.py b/compose_rl/utils/utils.py index b51b7fc2..bfd172a4 100644 --- a/compose_rl/utils/utils.py +++ b/compose_rl/utils/utils.py @@ -1237,3 +1237,17 @@ def flatten(coll: Union[Iterable[Any], str]) -> Generator[Any, None, None]: yield subc else: yield i + +# TODO: Remove this function after the hackathon +def print_batch_shapes(batch: dict[str, Any]): + def get_shape(value: Any): + if isinstance(value, torch.Tensor): + return value.shape + elif isinstance(value, list): + return len(value) + else: + return f"{type(value)} isn't supported" + shape_dict = { + k: get_shape(v) for k, v in batch.items() + } + print(f'Batch shapes: {shape_dict}') diff --git a/test_single_controller_ppo.py b/test_single_controller_ppo.py new file mode 100644 index 00000000..1c63fce6 --- /dev/null +++ b/test_single_controller_ppo.py @@ -0,0 +1,787 @@ +# Copyright 2024 MosaicML ComposeRL authors +# SPDX-License-Identifier: Apache-2.0 + + +# Copy the test file in the root of the repo +# NOTE: This actually runs GRPO instead of PPO +# cd compose-rl +# run cmd: composer test_single_controller_ppo.py +# If I do ctrl+c to kill job +# Check with `ray status` to see if the actors are still running +# If they are, then run `ray stop` + +import argparse +from contextlib import contextmanager +import logging +import os +import time +import datetime +from functools import partial +from typing import Any, Optional + +from composer.loggers import MLFlowLogger +import ray +import torch +import torch.distributed as dist +from composer import Trainer +from composer.core import get_precision_context +from composer.optim import DecoupledAdamW +from composer.utils import dist as composer_dist +from llmfoundry.data import build_dataloader +from omegaconf import OmegaConf as om +from transformers import AutoTokenizer +from composer.callbacks import MemoryMonitor, SpeedMonitor, LRMonitor + +from compose_rl.algorithms.online import ( + ComposerHFPolicyLM, + ComposerHFCriticFreePolicyLM, + SingleControllerOnPolicyCallback, +) +from compose_rl.algorithms.online.generation_utils import ( + broadcast_to_vllm, + create_vllm_engines, + _vllm_generate, +) +from compose_rl.utils.ray_utils import start_ray_server, uninstall_megablocks_if_exists +from compose_rl.controllers import BaseDistributedGPUActor, SPMDActorGroup +from compose_rl.controllers.buffer import Buffer +from compose_rl.algorithms.online.callback_utils import preprocess_batches + +GLOBAL_TRAIN_BATCH_SIZE = 64 +GENERATIONS_PER_PROMPT = 8 +NUM_BATCHES_PER_UPDATE = 8 +NUM_TRAIN_ITERATIONS = 5 + +_MAX_SEQ_LEN = 6000 +_MAX_GEN_LEN = 4000 + + +@contextmanager +def time_it(name: str): + start_time = time.time() + print(f"[{name}] started at {time.strftime('%X')}") + yield + end_time = time.time() + print(f"[{name}] finished at {time.strftime('%X')}") + print(f"[{name}] took {end_time - start_time:.2f} seconds") + + +class DistributedGPUActor(BaseDistributedGPUActor): + """Distributed GPU actor for testing.""" + + def __init__( + self, + rank: int, + world_size: int, + master_addr: Optional[str] = None, + master_port: Optional[int] = None, + ): + super().__init__(rank, world_size, master_addr, master_port) + + # Configure Ray actor logging - this will go to Ray logs + self.logger = logging.getLogger(f"Actor-{rank}") + self.logger.setLevel(logging.INFO) + + # Create console handler that will be captured by Ray + handler = logging.StreamHandler() + formatter = logging.Formatter(f'[ACTOR-{rank}] %(asctime)s - %(levelname)s - %(message)s') + handler.setFormatter(formatter) + self.logger.addHandler(handler) + + self.model = None + self.model_update_group = None + self.ref_path = None + self._dataloader = None + self._tokenizer = None + self.ppo_callback = None + self.ppo_trainer: Trainer = None # type: ignore + + self.pretrain_model_name = None + self.device_train_batch_size = None + self.num_batches_per_update = None + self.max_seq_len = None + self.precision = None # type: ignore + self.train_config: dict = None # type: ignore + self.model_config = None + self.global_train_batch_size = None + self.max_gen_len = None + + def build_train_config(self, pretrain_model_name: str): + self.logger.info(f"Starting build_train_config with model: {pretrain_model_name}") + self.pretrain_model_name = pretrain_model_name + + self.model_config = { + 'tokenizer': self.tokenizer, + 'pretrained_model_name_or_path': self.pretrain_model_name, + 'pretrained': True, + 'use_flash_attention_2': True, + 'allow_embedding_resizing': True, + 'name': 'hf_critic_free_lm', + # 'init_device': 'mixed', + # This throws: [rank0]: ValueError: Detected mixed initialization where some ranks have model on cpu or gpu and some ranks are on meta. Either keep all ranks on the same device or set parallelism_config['fsdp']['sync_module_states'] = True. Otherwise, some weights may be randomly initialized when loading a checkpoint. + 'loss_type': 'grpo', + 'target_kl': 0.1, + 'kl_estimator': 'k3', + 'kl_clip_range': 40, + 'use_auth_token': True, + 'compute_kl_loss': False, + 'policy_clip_ratio': 0.2, + 'normalize_advantage': True, + 'length_normalize_policy_loss': True, + 'attn_implementation': 'flash_attention_2' + } + self.global_train_batch_size = GLOBAL_TRAIN_BATCH_SIZE + self.device_train_batch_size = self.global_train_batch_size // self.world_size + self.num_batches_per_update = NUM_BATCHES_PER_UPDATE + self.max_seq_len = _MAX_SEQ_LEN + self.max_gen_len = _MAX_GEN_LEN + self.precision = 'amp_bf16' + + ref_model_config = { + 'name': 'hf_causal_lm', + 'pretrained': self.model_config['pretrained'], + 'pretrained_model_name_or_path': self.pretrain_model_name, + 'use_auth_token': self.model_config['use_auth_token'], + 'use_flash_attention_2': self.model_config['use_flash_attention_2'], + } + + variables = { + 'gamma': 1, + 'lambda_gae': 1, + 'epoch_per_iteration': 1, + 'num_batches_per_update': self.num_batches_per_update, + 'generations_per_prompt': GENERATIONS_PER_PROMPT, + 'device_generate_batch_size': 1, + 'vllm_enable_prefix_caching': True, + 'generation_kwargs': { + 'top_p': 1.0, + 'use_cache': True, + 'do_sample': False, + 'temperature': 1.0, + }, + 'eos_token_ids': [ + 128001, + 128008, + 128009, + ], + 'buffer': { + 'name': 'MinibatchRolloutBuffer', + 'max_buffer_size': self.num_batches_per_update, + }, + 'max_gen_len': self.max_gen_len, + 'kl_controller': { + 'init_kl_coef': 0.0, # no KL penalty + 'kl_ctl_type': 'fixed', + }, + 'reference_model': { + 'model_config': ref_model_config, + 'precision': self.precision, + 'load_path': self.ref_path, + }, + 'non_train_fsdp_config': self.fsdp_config, + 'rewards': { + 'math_verifier': { + 'reward_type': 'math_verifier', + 'reward': 4, + }, + 'bad_generation_end': { + 'reward': -1, + 'eos_penalty': True, + 'reward_type': 'bad_generation_end' + }, + 'math_format_verifier': { + 'reward': 1, + 'reward_type': 'math_format_verifier' + }, + 'penalize_extra_short_responses': { + 'reward': -1, + 'reward_type': 'short_response_reward', + 'len_threshold': 10 + }, + } + } + algorithm_config = { + 'gradient_clipping': { + 'clipping_type': 'norm', + 'clipping_threshold': 1.0 + } + } + self.train_config = { + 'seed': 17, + 'model': self.model_config, + 'fsdp_config': self.fsdp_config, + 'precision': self.precision, + 'variables': variables, + 'algorithms': algorithm_config, + 'global_train_batch_size': self.device_train_batch_size * self.world_size, + 'device_train_batch_size': self.device_train_batch_size, + 'device_train_microbatch_size': self.device_train_batch_size, + 'save_folder': './checkpoints/grpo_single_controller', + 'log_config': True, + 'max_seq_len': self.max_seq_len, + 'python_log_level': 'debug', + 'console_log_interval': '1ba', + } + self.logger.info("Finished build_train_config") + + def build_tokenizer(self): + # TODO (algo): decide if we should use tokens or messages given + # we may need token level log prob + # TODO (infra): use the tokenizer/texts for prompt dataloader but + # token (ids) for the experience buffer/manager + kwargs = { + 'padding': 'longest', + 'pad_token': '<|finetune_right_pad_id|>', + 'truncation': True, + 'padding_side': 'left', + 'model_max_length': self.max_seq_len, + 'trust_remote_code': True, + } + tokenizer = AutoTokenizer.from_pretrained(self.pretrain_model_name, **kwargs) + return tokenizer + + @property + def tokenizer(self): + if self._tokenizer is None: + self._tokenizer = self.build_tokenizer() + return self._tokenizer + + @property + def fsdp_config(self): + # TODO (infra): use actual fsdp1 config + return {} + + def init_composer_dist(self): + composer_dist.initialize_dist('gpu') + + def build_ppo_trainer(self): + name = self.model_config.pop('name') + + self.logger.info(f"Model type: {name}") + if name == 'hf_ppo_lm': + self.logger.info("Creating ComposerHFPolicyLM") + model = ComposerHFPolicyLM(**self.model_config) + elif name == 'hf_critic_free_lm': + self.logger.info("Creating ComposerHFCriticFreePolicyLM") + model = ComposerHFCriticFreePolicyLM(**self.model_config) + self.logger.info("Model created successfully") + + optimizer = DecoupledAdamW(model.parameters(), lr=1e-6) + + # TODO (infra): pull the rest of the training logic from the callback + # to this class, e.g, how to interact with env, calculate rewards etc + # NOTE: SingleControllerOnPolicyCallback is currently over-writing the iteration_start method + self.ppo_callback = SingleControllerOnPolicyCallback( + train_config=self.train_config, + ) + + # Create a dummy dataloader to make sure trainer can call .fit() with + # the dataloader that exists at ITERATION_START. This dataloader + # will NOT be used for training. + dummy_dataset = torch.utils.data.TensorDataset(torch.randn(16, 1)) + dummy_distributed_sampler = torch.utils.data.distributed.DistributedSampler(dummy_dataset) + dummy_dataloader = torch.utils.data.DataLoader(dummy_dataset, sampler=dummy_distributed_sampler) + + mlflow_logger = MLFlowLogger( + experiment_name='test_single_controller_ppo', + run_name='test_single_controller_ppo', + tracking_uri='databricks', + ) + + + self.ppo_trainer = Trainer( + model=model, + optimizers=optimizer, + callbacks=[ + self.ppo_callback, + # callbacks for scheduled garbage collection + # this helps improve throughput by garbage collecting + # at regular intervals on all training processes + # ScheduledGarbageCollector( + # batch_interval='1000', + # ), # TODO: Add it back after we resolve some error because we are using a dummy dataloader + # callbacks for monitoring other metrics + LRMonitor(), + MemoryMonitor(), + SpeedMonitor(window_size=10), + ], + train_dataloader=dummy_dataloader, + precision=self.precision, + parallelism_config={'fsdp': self.fsdp_config}, + max_duration='5iter', + loggers=[mlflow_logger], + device_train_microbatch_size=1, + load_path=self.ref_path, + ) + + def close_trainer(self): + self.ppo_trainer.close() + + + def add_rollouts(self, current_rank_rollouts: dict[str, Any]): + """Adds the current rank's rollouts to the callback.""" + for k, v in current_rank_rollouts.items(): + assert isinstance(v, torch.Tensor) or isinstance(v, list), f"Expected a tensor or list, got {type(v)}" + if isinstance(v, torch.Tensor): + current_rank_rollouts[k] = v.to(torch.device('cuda')) + self.ppo_callback.batch_rollouts = current_rank_rollouts + + def train_1_iter(self): + # TODO (algo): implement the top level PPO algo here instead of the + # callback. Algorithmic researchers are expected to implement this + # function along with above policy/value/reward/ref trainers or models + # TODO (infra): try multiple fit to see if the (mlflow) logger, etc + # TODO (infra): fault tolerance at iteration level first + # TODO (infra): enable batch level control + + # NOTE: Trainer has a train microbatches function that should be used here to get low level control. + # fit() checks if there is existing checkpoint, make a full forward pass, it will run eval pass and save pass. + # We potentially want to run this https://github.com/mosaicml/composer/blob/dev/composer/trainer/trainer.py#L2826 + # fit() can also potentially overwrite the mlflow + self.ppo_trainer.fit(duration='1iter') + self.logger.info(f"#### Finished training 1 iter with loss: {self.ppo_trainer.state.loss}") + + +def setup_process_groups( + master_actor: Any, + vllm_engines: list[Any], + vllm_tensor_parallel_size: int, +): + """Initialize process groups for vLLM engines and master actor.""" + # Get a new port for the weight-update process group + master_addr, _ = ray.get( + master_actor.get_master_address.remote(), + ) # type: ignore + new_port = ray.get(master_actor.get_free_port.remote()) # type: ignore + print(f'new_port: {new_port}') + + world_size = dist.get_world_size() + + # Initialize process groups for vLLM engines + refs = [ + engine.init_process_group.remote( + master_addr, + new_port, + i * vllm_tensor_parallel_size + 1, + world_size // 2 + 1, + 'weight-update', + backend='nccl', + ) for i, engine in enumerate(vllm_engines) + ] + + # Add master actor to the process group + refs.append( + master_actor.add_process_group.remote( + backend='nccl', + master_addr=master_addr, + master_port=new_port, + world_size=world_size // 2 + 1, + rank=0, + group_name='weight-update', + ), + ) + + # Wait for all process groups to be initialized + print(ray.get(refs)) + + +class TrainActorGroup(SPMDActorGroup): + """Group of training actors for PPO.""" + + def __init__(self, *args: Any, **kwargs: Any): + super().__init__(*args, **kwargs) + + def build_models(self, pretrain_model_name: str): + """Build reference models and PPO trainers for all actors.""" + self.collective_methods.build_train_config(pretrain_model_name) + self.collective_methods.init_composer_dist() + + # Build PPO trainers + self.collective_methods.build_ppo_trainer() + print('build ppo trainer done') + + def _partition_rollouts_across_ranks(self, rollouts: dict[str, Any]): + """Partition the rollouts across all actors.""" + partitioned_rollouts = [] + per_rank_data_size = rollouts['prompt'].shape[0] // self.num_train_actors + for i in range(self.num_train_actors): + current_rank_start = i * per_rank_data_size + current_rank_end = (i + 1) * per_rank_data_size + current_rank_rollouts = {} + for k, v in rollouts.items(): + assert isinstance(v, torch.Tensor) or isinstance(v, list), f"Expected a tensor or list, got {type(v)}" + current_rank_rollouts[k] = v[current_rank_start:current_rank_end] + partitioned_rollouts.append(current_rank_rollouts) + return partitioned_rollouts + + def add_latest_rollouts_from_buffer(self, experience_buffer: "ExperienceBuffer"): + assert experience_buffer is not None, "Experience buffer is not set" + assert len(experience_buffer) > 0, "Experience buffer is empty" + latest_rollouts = experience_buffer.popleft() + partitioned_rollouts = self._partition_rollouts_across_ranks(latest_rollouts) + assert len(partitioned_rollouts) == self.num_train_actors, "Number of partitioned rollouts should be equal to the number of train actors" + ray.get([train_actor.add_rollouts.remote(partition) for train_actor, partition in zip(self.train_actors, partitioned_rollouts)]) + + def train_1_iter(self): + # added this method to time the collectivetraining time otherwise we can time each rank but the print/logging becomes messy to read + with time_it("training"): + self.collective_methods.train_1_iter() + + +class InferenceServer: + """Inference server with vLLM engines.""" + + def __init__(self, num_vllm_engines: int, vllm_tensor_parallel_size: int, pretrain_model_name: str): + self.num_vllm_engines = num_vllm_engines + self.vllm_tensor_parallel_size = vllm_tensor_parallel_size + self.vllm_engines = create_vllm_engines( + num_engines=num_vllm_engines, + tensor_parallel_size=vllm_tensor_parallel_size, + enforce_eager=True, + pretrain=pretrain_model_name, + revision=None, + seed=1, + enable_prefix_caching=False, + max_model_len=_MAX_GEN_LEN, + device_bundle={ + 'GPU': 1, + 'CPU': 1, + 'worker_node': 0, + }, + ) + + @property + def engines(self): + return self.vllm_engines + + +class RolloutAgent: + """Rollout agent for generating sequences from the inference server.""" + + def __init__( + self, + inference_server: InferenceServer, + streaming_dataset_actor: "StreamingDatasetActor", + ): + self.inference_server = inference_server + self.streaming_dataset_actor = streaming_dataset_actor + self.generation_kwargs = { + 'top_p': 1.0, + 'use_cache': True, + 'do_sample': False, + 'temperature': 1.0, + } + self.precision = 'amp_bf16' + self.tokenizer_pad_token_id = ray.get(self.streaming_dataset_actor.get_tokenizer_pad_token_id.remote()) + self.prompt_handler_config = ray.get(self.streaming_dataset_actor.get_prompt_handler_config.remote()) + self.max_gen_len = self.prompt_handler_config['max_gen_len'] + + def get_next_iter_rollouts(self): + """ + Gets the next rollouts from the inference server. + + Since all ranks should see different data, we need to get the rollouts for each rank. + """ + iter_data = ray.get(self.streaming_dataset_actor.get_next_iter_prompts.remote()) + all_prompts = iter_data['prompt'] + # TODO: Since this functionality is (somewhat) shared across the OnPolicyCallback and the RolloutAgent, + # we should move this to the separate util file. + with get_precision_context(self.precision), torch.no_grad(), time_it("batch_inference"): + sequences = _vllm_generate( + vllm_engines=self.inference_server.engines, + max_gen_len=self.max_gen_len, + generation_kwargs=self.generation_kwargs, + pad_token_id=self.tokenizer_pad_token_id, + all_prompts=all_prompts, + batch_sizes=[len(all_prompts)], + ) + + sequences = sequences[0] + max_vllm_generated_len = max([len(response) for response in sequences]) + padded_responses = [] + for sequence in sequences: + sequence = list(sequence) + if len(sequence) < max_vllm_generated_len: + sequence = sequence + [self.tokenizer_pad_token_id] * (max_vllm_generated_len - len(sequence)) + padded_responses.append(sequence) + + padded_responses = torch.tensor( + padded_responses, + dtype=all_prompts.dtype, + device=torch.device('cpu'), + ) + + processed_sequences = torch.cat([all_prompts, padded_responses], dim=-1) + iter_data['sequences'] = processed_sequences + + return iter_data + + +class ParameterBuffer(Buffer): + """Buffer for updating the inference model.""" + + def update_inference_model(self, actor: DistributedGPUActor, inference_server: InferenceServer): + start_time = time.time() + print('Before broadcast to vLLM') + # TODO (infra) instead of direcly broadcasting to vllm, we should + # push the model parameters to a parameter buffer manager and have + # the buffer manager initiate broadcast of parameters to vllm engines + broadcast_to_vllm( + actor.ppo_callback.actor_critic, + inference_server.engines, + actor.model_update_group, + device=torch.device('cuda'), + loss_type=actor.ppo_callback.actor_critic.loss_type, # type: ignore + ) + print('Finished broadcasting to vLLM') + print(f'Took: {time.time() - start_time} to broadcast to vllm.') + dist.barrier() + + def put(self, struct: dict[str, Any]): + # prefers to implement the model update logic in the Buffer class as the buffer is a bridge between the trainer actor and the inference server + # and knows the best way to transfer the model parameters. Trainer just needs to put necessary struct to this api + struct['actor_group'].collective_methods.execute(partial(self.update_inference_model, inference_server=struct['inference_server'])) + + +# TODO: Move this experience buffer earlier so that we can avoid +# using "ExperienceBuffer" (with quotes) as a type hint. +class ExperienceBuffer(Buffer): + """Buffer for storing experiences.""" + + def put(self, struct: dict[str, Any]): + self.buffer.append(struct) + + def get(self, struct: Optional[dict[str, Any]] = None): + return self.buffer[0] + + def popleft(self, struct: Optional[dict[str, Any]] = None): + return self.buffer.pop(0) + + def __len__(self): + return len(self.buffer) + + +class StreamingDatasetActor(BaseDistributedGPUActor): + """Streaming actor for loading prompts onto the experience buffer.""" + + def __init__(self): + # Setting up the distributed environment (WORLD_SIZE = 1) + super().__init__( + rank=0, + world_size=1, + master_addr=None, + master_port=None, + ) + + # Setting up all of the configs + # TODO: We should move these to dataclasses + # TODO: In a future PR, create all configs in the main function and populate + # the correct configs across all entities (e.g. DistributedGPUActor, StreamingDatasetActor, etc) + self.pretrain_model_name = 'meta-llama/Llama-3.1-8B-Instruct' + self.prompt_handler_config = { + "global_train_batch_size": GLOBAL_TRAIN_BATCH_SIZE, + "generations_per_prompt": GENERATIONS_PER_PROMPT, + "num_batches_per_update": NUM_BATCHES_PER_UPDATE, + "max_seq_len": _MAX_SEQ_LEN, + "max_gen_len": _MAX_GEN_LEN, + } + self.tokenizer_config = { + 'padding': 'longest', + 'pad_token': '<|finetune_right_pad_id|>', + 'truncation': True, + 'padding_side': 'left', + 'model_max_length': self.prompt_handler_config['max_seq_len'], + 'trust_remote_code': True, + } + timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + temp_dataset_dir = f"/tmp/dataset/prompt_{timestamp}/" + self.dataloader_config = { + 'name': 'prompt', + 'dataset': { + 'local': temp_dataset_dir, + 'split': 'train', + 'remote': 'dbfs:/Volumes/datasets/ashutoshbaheti/orl_data/math_lighteval/llama3_8b_math_prompts/', + 'shuffle': True, + 'max_gen_len': self.prompt_handler_config['max_gen_len'], + 'max_seq_len': self.prompt_handler_config['max_seq_len'], + 'shuffle_seed': 17, + 'download_timeout': 1800 + }, + 'drop_last': True, + 'num_workers': 1, + } + + # Key variables + global_train_batch_size = self.prompt_handler_config['global_train_batch_size'] + self.generations_per_prompt = self.prompt_handler_config['generations_per_prompt'] + num_batches_per_update = self.prompt_handler_config['num_batches_per_update'] + total_num_generations = global_train_batch_size * num_batches_per_update + self.num_prompts_per_iteration = total_num_generations // self.generations_per_prompt + + # Validate that the total number of generations is divisible by the number of generations per prompt + assert total_num_generations % self.generations_per_prompt == 0, "total_num_generations must be divisible by generations_per_prompt" + + # Creating main entities + self.tokenizer = self._build_tokenizer() + self.dataloader = self._build_dataloader() + self.dataloader_iter = iter(self.dataloader) + + def _build_dataloader(self): + foundry_dataspec = build_dataloader( + cfg = self.dataloader_config, + tokenizer = self.tokenizer, + device_batch_size = self.num_prompts_per_iteration, + ) + return foundry_dataspec.dataloader + + def _build_tokenizer(self): + tokenizer = AutoTokenizer.from_pretrained(self.pretrain_model_name, **self.tokenizer_config) + return tokenizer + + def get_prompt_handler_config(self): + return self.prompt_handler_config + + def get_tokenizer_pad_token_id(self): + return self.tokenizer.pad_token_id + + def _get_single_iter_prompts(self): + """Gets a single iteration's prompts from the dataloader.""" + try: + return next(self.dataloader_iter) + except StopIteration: + self.dataloader_iter = iter(self.dataloader) + return next(self.dataloader_iter) + + def get_next_iter_prompts(self): + """Gets the next iteration's prompts across all ranks and prepares them for the rollout agent.""" + batches = [self._get_single_iter_prompts()] + + return preprocess_batches(batches, self.generations_per_prompt, self.tokenizer.pad_token_id) + + +class PPOController: + """PPO controller for training the policy and value networks.""" + + def __init__( + self, + train_actor: TrainActorGroup, + inference_server: InferenceServer, + rollout_agent: RolloutAgent, + parameter_buffer: ParameterBuffer, + experience_buffer: ExperienceBuffer, + pretrain_model_name: str, + ): + self.train_actor = train_actor + self.inference_server = inference_server + self.rollout_agent = rollout_agent + self.parameter_buffer = parameter_buffer + self.experience_buffer = experience_buffer + self.train_actor.build_models(pretrain_model_name) + setup_process_groups( + self.train_actor.master_actor, + inference_server.engines, + inference_server.vllm_tensor_parallel_size, + ) + + def train(self): + for _ in range(NUM_TRAIN_ITERATIONS): # Example: train for 5 iterations + # NOTE: this loop is represents the logic happening in the current `iteration_start` of the OnPolicyCallback + self.parameter_buffer.put({'actor_group': self.train_actor, 'inference_server': self.inference_server}) + # Simple example of adding elements to the experience buffer + self.experience_buffer.put(self.rollout_agent.get_next_iter_rollouts()) + # Populate the train actor group with the rollouts and then train + self.train_actor.add_latest_rollouts_from_buffer(self.experience_buffer) + self.train_actor.train_1_iter() + + self.train_actor.collective_methods.close_trainer() + + + +def _run_single_controller_ppo( + config: Any, +): + """Shared function for running single controller PPO. + + Args: + config: OmegaConf configuration object containing all parameters + """ + # Set vLLM attention backend to FLASH_ATTN otherwise FlashInfer backend + # takes too long to jit compile + os.environ['VLLM_ATTENTION_BACKEND'] = 'FLASH_ATTN' + + with start_ray_server() as _address: + # only rank 0 is the master controller + if dist.get_rank() == 0: + world_size = getattr(config, "world_size", 0) + if world_size == 0: + world_size = dist.get_world_size() + + # Create buffers for the parameter and experience buffers + # first since they don't have external dependencies + parameter_buffer = ParameterBuffer() + experience_buffer = ExperienceBuffer() + + # create SPMD training actors of the system + num_train_actors = world_size // 2 + train_actor = TrainActorGroup(num_train_actors, DistributedGPUActor) + + # Create vLLM engines (or inference actors) + vllm_tensor_parallel_size = world_size - num_train_actors + num_vllm_engines = ( + world_size - num_train_actors + ) // vllm_tensor_parallel_size + # TODO: Encapsulate this into a inference server manager class + pretrain_model_name = config.pretrain_model_name + inference_server = InferenceServer( + num_vllm_engines=num_vllm_engines, + vllm_tensor_parallel_size=vllm_tensor_parallel_size, + pretrain_model_name=pretrain_model_name, + ) + + # We are using a CPU worker for the StreamingActor + # and this involves a super hacky workaround by + # uninstalling megablocks if it exists. Better solutions + # would include: + # 1) decouple StreamingActor from llm-foundry altogether + # 2) don't broadly import llm-foundry in compose-rl (only + # import it into codepaths/files that will only be used by + # GPUActors as opposed to CPUActors) + # 3) Setting up ray actors with correct environments (which + # would involve creating a BaseDistributedActor instead of a + # BaseDistributedGPUActor so that we can use CPUs) + # We uninstall megablocks after the Train Actors have been + # created so that those actors still have megablocks functionality. + uninstall_megablocks_if_exists() + streaming_dataset_actor = ray.remote(num_gpus=0)(StreamingDatasetActor).remote() + rollout_agent = RolloutAgent(inference_server, streaming_dataset_actor) + + ppo_controller = PPOController( + train_actor, + inference_server, + rollout_agent, + parameter_buffer, + experience_buffer, + pretrain_model_name, + ) + ppo_controller.train() + + +if __name__ == '__main__': + # Parse command line arguments + parser = argparse.ArgumentParser(description='Run single controller PPO with configuration file') + parser.add_argument('--file_path', type=str, required=False, default=None, + help='Path to the OmegaConf YAML configuration file') + args = parser.parse_args() + + # Load configuration using OmegaConf + if args.file_path is not None: + config = om.load(args.file_path) + else: + config = om.create({ + 'pretrain_model_name': 'meta-llama/Llama-3.1-8B-Instruct', + }) + + # This is an example of how to move the controller logic from PPO Callback + # to a separate trainer actor above and this main single controller + # function. + _run_single_controller_ppo(config) diff --git a/tests/common/__init__.py b/tests/common/__init__.py index 9d71832c..a815cb38 100644 --- a/tests/common/__init__.py +++ b/tests/common/__init__.py @@ -1,7 +1,6 @@ # Copyright 2024 MosaicML ComposeRL authors # SPDX-License-Identifier: Apache-2.0 -from tests.common.actor import BaseDistributedGPUActor from tests.common.datasets import ( FineGrainedPreference, PairwisePreference, @@ -12,7 +11,6 @@ from tests.common.markers import device, world_size __all__ = [ - 'BaseDistributedGPUActor', 'PairwisePreference', 'FineGrainedPreference', 'PromptDataset', diff --git a/tests/common/actor.py b/tests/common/actor.py deleted file mode 100644 index a2eab75f..00000000 --- a/tests/common/actor.py +++ /dev/null @@ -1,101 +0,0 @@ -# Copyright 2024 MosaicML ComposeRL authors -# SPDX-License-Identifier: Apache-2.0 - -import os -from datetime import timedelta -from typing import Optional - -import ray -import torch.distributed as dist - -from compose_rl.algorithms.online.generation_utils import init_process_group -from compose_rl.utils.ray_utils import ( - get_free_port, - get_node_ip, - is_cuda_visible_devices_set, -) - - -class BaseDistributedGPUActor: - - def __init__( - self, - rank: int, - world_size: int, - master_addr: Optional[str] = None, - master_port: Optional[int] = None, - ): - """Initialize the distributed GPU actor for RAY. - - Args: - rank: The rank of this process in the distributed group - world_size: Total number of processes in the distributed group - master_addr: Master node address. If None, will allocate dynamically for rank 0 - master_port: Master node port. If None, will allocate dynamically for rank 0 - """ - self.rank = rank - self.world_size = world_size - self.master_addr = master_addr - self.master_port = master_port - - # Set up basic environment variables - # TODO: may need to handle LOCAL_WORLD_SIZE as used in callback.py - os.environ['WORLD_SIZE'] = str(world_size) - os.environ['RANK'] = str(rank) - - # Set LOCAL_RANK based on Ray GPU allocation - os.environ['LOCAL_RANK'] = '0' if is_cuda_visible_devices_set( - ) else str(ray.get_gpu_ids()[0]) - - # If this is rank 0 and no master_addr/master_port provided, allocate them - if rank == 0 and (master_addr is None or master_port is None): - self._allocate_master_address() - - os.environ['MASTER_ADDR'] = self.master_addr # type: ignore - os.environ['MASTER_PORT'] = str(self.master_port) # type: ignore - - self.model = None - self.model_update_group = None - - def _allocate_master_address(self): - """Allocate master address and port for rank 0.""" - if self.master_addr is None: - # Get the local IP address - self.master_addr = get_node_ip() - - if self.master_port is None: - # Allocate a free port - self.master_port = get_free_port() - - def get_master_address(self) -> tuple[Optional[str], Optional[int]]: - """Return the master address and port as a tuple.""" - return (self.master_addr, self.master_port) - - def get_free_port(self): - return get_free_port() - - def init_train_process_group(self): - """Initialize the distributed process group.""" - # Initialize process group - dist.init_process_group(timeout=timedelta(seconds=30)) - - def add_process_group( - self, - backend: str, - master_addr: str, - master_port: int, - world_size: int, - rank: int, - group_name: str, - ): - """Initialize the process group on trainer rank 0 and vllm engines.""" - # NOTE vLLM seems to have a safer implementation of init_process_group: - # https://github.com/vllm-project/vllm/blob/v0.9.1/examples/offline_inference/rlhf.py#L105 - # we should look into using that instead - self.model_update_group = init_process_group( - backend=backend, - init_method=f'tcp://{master_addr}:{master_port}', - world_size=world_size, - rank=rank, - group_name=group_name, - ) diff --git a/tests/test_single_controller.py b/tests/test_single_controller.py index 1df09c74..156b4f29 100644 --- a/tests/test_single_controller.py +++ b/tests/test_single_controller.py @@ -18,8 +18,10 @@ from compose_rl.algorithms.online.generation_utils import ( create_vllm_engines, ) +from compose_rl.controllers import BaseDistributedGPUActor from compose_rl.utils.ray_utils import start_ray_server -from tests.common import BaseDistributedGPUActor, world_size +from tests.common import world_size + # Set up logging logger = logging.getLogger(__name__) diff --git a/tests/test_single_controller_ppo.py b/tests/test_single_controller_ppo.py deleted file mode 100644 index 401683cf..00000000 --- a/tests/test_single_controller_ppo.py +++ /dev/null @@ -1,613 +0,0 @@ -# Copyright 2024 MosaicML ComposeRL authors -# SPDX-License-Identifier: Apache-2.0 - -# run cmd: `cd compose-rl && cp tests/test_single_controller_ppo.py . -# && composer test_single_controller_ppo.py` - -import os -import pathlib -import time -from functools import partial -from typing import Any, Optional - -import pytest -import ray -import torch -import torch.distributed as dist -from composer import Trainer -from composer.core import get_precision_context -from composer.optim import DecoupledAdamW -from composer.utils import dist as composer_dist -from llmfoundry.models import ComposerHFCausalLM -from torch.utils.data import DataLoader -from transformers import ( - AutoTokenizer, - PreTrainedModel, - PreTrainedTokenizerBase, -) - -from compose_rl.algorithms.online import ( - ComposerHFPolicyLM, - SingleControllerOnPolicyCallback, -) -from compose_rl.algorithms.online.generation_utils import ( - broadcast_to_vllm, - create_vllm_engines, - vllm_generate, -) -from compose_rl.data import prompt_dataset_collate_fn -from compose_rl.utils.ray_utils import start_ray_server -from tests.common import ( - BaseDistributedGPUActor, - VerifiablePromptDataset, - world_size, -) - - -@ray.remote(num_gpus=1) -class DistributedGPUActor(BaseDistributedGPUActor): - """Distributed GPU actor for testing.""" - - def __init__( - self, - rank: int, - world_size: int, - master_addr: Optional[str] = None, - master_port: Optional[int] = None, - ): - super().__init__(rank, world_size, master_addr, master_port) - self.model = None - self.model_update_group = None - self.ref_path = None - self._dataloader = None - self._tokenizer = None - self.ppo_callback = None - self.ppo_trainer: Trainer = None # type: ignore - - self.pretrain_model_name = None - self.device_train_batch_size = None - self.num_batches_per_update = None - self.max_seq_len = None - self.precision: str = None # type: ignore - self.train_config: dict = None # type: ignore - - def build_train_config(self, pretrain_model_name: str): - self.pretrain_model_name = pretrain_model_name - self.device_train_batch_size = 4 - self.num_batches_per_update = 2 - self.max_seq_len = 32 - self.precision = 'amp_bf16' - - ref_model_config = {**self.model_config, 'name': 'hf_causal_lm'} - - variables = { - 'buffer': { - 'name': 'MinibatchRolloutBuffer', - 'max_buffer_size': self.num_batches_per_update, - }, - 'max_gen_len': 8, - 'gamma': 0.99, - 'lambda_gae': 0.95, - 'generation_kwargs': { - 'use_cache': True, - 'do_sample': False, - 'temperature': 1.0, - }, - 'kl_controller': { - 'init_kl_coef': 0.2, - 'target': 0.01, - 'horizon': 12800, - 'kl_ctl_type': 'adaptive', - }, - 'reference_model': { - 'model_config': ref_model_config, - 'precision': self.precision, - 'load_path': self.ref_path, - 'non_train_fsdp_config': self.fsdp_config, - }, - 'epoch_per_iteration': 1, - 'num_batches_per_update': self.num_batches_per_update, - 'rewards': { - 'output_length': { - 'reward_type': 'output_length', - 'max_gen_len': 10, - }, - }, - } - self.train_config = { - 'model': { - **self.model_config, - 'kl_estimator': 'k1', - 'kl_clip_range': 40.0, - }, - 'fsdp_config': - self.fsdp_config, - 'seed': - 17, - 'precision': - self.precision, - 'variables': - variables, - 'max_seq_len': - self.max_seq_len, - 'global_train_batch_size': - self.device_train_batch_size * self.world_size, - 'device_train_batch_size': - self.device_train_batch_size, - 'device_train_microbatch_size': - self.device_train_batch_size, - } - - def build_dataloader(self): - # TODO (infra): build prompt dataloader with rollout agent instead of - # trainer actor - max_seq_len = 32 - prompt_len = 10 - - dataset = VerifiablePromptDataset(prompt_len=prompt_len) - dataloader = DataLoader( - dataset, - collate_fn=partial( - prompt_dataset_collate_fn, - self.tokenizer, - max_seq_len, - ), - sampler=composer_dist.get_sampler(dataset), - batch_size=self.device_train_batch_size, - ) - # We need to mock this method, since our dataset isn't a - # StreamingDataset - dataloader.state_dict = lambda: {} - dataloader.load_state_dict = lambda x: None - return dataloader - - @property - def dataloader(self): - if self._dataloader is None: - self._dataloader = self.build_dataloader() - return self._dataloader - - def build_tokenizer(self): - # TODO (algo): decide if we should use tokens or messages given - # we may need token level log prob - # TODO (infra): use the tokenizer/texts for prompt dataloader but - # token (ids) for the experience buffer/manager - tokenizer = AutoTokenizer.from_pretrained(self.pretrain_model_name) - tokenizer.add_special_tokens({'pad_token': '[PAD]'}) - return tokenizer - - @property - def tokenizer(self): - if self._tokenizer is None: - self._tokenizer = self.build_tokenizer() - return self._tokenizer - - @property - def model_config(self): - return { - 'tokenizer': self.tokenizer, - 'pretrained_model_name_or_path': self.pretrain_model_name, - 'pretrained': True, - 'use_flash_attention_2': True, - 'allow_embedding_resizing': True, - } - - @property - def fsdp_config(self): - # TODO (infra): use actual fsdp1 config - return {} - - def init_composer_dist(self): - composer_dist.initialize_dist('gpu') - - def build_ref_model(self): - # pre-train a reference model for the PPO training - # The key observation here is that we should construct model - # training pipeline in the actor instead of the callback - # e.g., we can build ref/reward/policy/value model and create/colocate - # multiple trainers all in this class - tmp_ref_path = str('./ref_checkpoints') - ref_path = os.path.join(tmp_ref_path, 'latest-rank0.pt') - if os.path.exists(ref_path): - self.ref_path = ref_path - return - - tmp_model = ComposerHFCausalLM( - **self.model_config, - use_auth_token=True, - ) - - tmp_optimizer = DecoupledAdamW(tmp_model.parameters(), lr=1e-6) - - temp_dataloader = [{ - 'input_ids': torch.ones((2, 15)).to(dtype=torch.int64), - 'attention_mask': torch.ones((2, 15)), - 'labels': torch.ones((2, 15)).to(dtype=torch.int64), - }] - - temp_trainer = Trainer( - model=tmp_model, - train_dataloader=temp_dataloader, - optimizers=tmp_optimizer, - max_duration='1ba', - parallelism_config={'fsdp': self.fsdp_config}, - save_folder=tmp_ref_path, - save_weights_only=True, - device_train_microbatch_size=self. - device_train_microbatch_size, # type: ignore - ) - - temp_trainer.fit() - self.ref_path = ref_path - - def build_ppo_trainer(self): - composer_dist.initialize_dist('gpu') - - model = ComposerHFPolicyLM(**self.model_config, use_auth_token=True) - - optimizer = DecoupledAdamW(model.parameters(), lr=1e-8) - - # TODO (infra): pull the rest of the training logic from the callback - # to this class, e.g, how to interact with env, calculate rewards etc - self.ppo_callback = SingleControllerOnPolicyCallback( - train_config=self.train_config, - ) - self.ppo_trainer = Trainer( - model=model, - optimizers=optimizer, - callbacks=self.ppo_callback, - train_dataloader=self.dataloader, - precision=self.precision, - parallelism_config={'fsdp': self.fsdp_config}, - max_duration='3iter', - device_train_microbatch_size=1, - load_path=self.ref_path, - ) - - def train_1_iter(self): - # TODO (algo): implement the top level PPO algo here instead of the - # callback. Algorithmic researchers are expected to implement this - # function along with above policy/value/reward/ref trainers or models - # TODO (infra): try multiple fit to see if the (mlflow) logger, etc - # TODO (infra): fault tolerance at iteration level first - # TODO (infra): enable batch level control - self.ppo_trainer.fit(duration='1iter') - # This is the KL assert that must be true if we are truly loading - # from the same model. This is only true on the first iteration - assert torch.allclose( - self.ppo_trainer.state.loss['kl/ift_kl'], # pyright: ignore - torch.tensor(0.0), - atol=5e-5, - ) - - def update_inference_model(self, vllm_engines: list[Any]): - start_time = time.time() - print('Before broadcast to vLLM') - # TODO (infra) instead of direcly broadcasting to vllm, we should - # push the model parameters to a parameter buffer manager and have - # the buffer manager initiate broadcast of parameters to vllm engines - broadcast_to_vllm( - self.ppo_callback.actor_critic, - vllm_engines, - self.model_update_group, - device=torch.device('cuda'), - loss_type=self.ppo_callback.actor_critic.loss_type, # type: ignore - ) - print('Finished broadcasting to vLLM') - print(f'Took: {time.time() - start_time} to broadcast to vllm.') - dist.barrier() - - def query_inference_engines(self, vllm_engines: list[Any]): - """Round trip to inference engines. - - Args: - vllm_engines (list[Any]): The vllm engines to round trip to. - """ - # TODO (infra): we should use the rollout agent to generate sequences - # instead of the trainer actor, e.g,. reimplment _get_next_iter_prompts - # in the rollout agent - batch = self.ppo_trainer.state.device.batch_to_device( - self.ppo_callback._get_next_iter_prompts(), - ) - max_gen_len = self.train_config['variables']['max_gen_len'] - generation_kwargs = self.train_config['variables']['generation_kwargs'] - with get_precision_context(self.precision), torch.no_grad(): - # TODO (infra): refactor this code to isolate gather of - # prompts on the trainer actor and gather/scatter of sequences - # on the trainer actor, the first half is uesless while - # the second half should be managed throught a experience manager - sequences = vllm_generate( - vllm_engines=vllm_engines, - batch=batch, - max_gen_len=max_gen_len, - generation_kwargs=generation_kwargs, - tokenizer=self.tokenizer, # type: ignore - vllm_generate_function='generate', - ) - # Add the prepared sequences to the batch again - batch['sequences'] = sequences - self.ppo_callback.batch_rollouts = batch # type: ignore - - -def setup_process_groups( - master_actor: Any, - vllm_engines: list[Any], - vllm_tensor_parallel_size: int, -): - """Initialize process groups for vLLM engines and master actor.""" - # Get a new port for the weight-update process group - master_addr, _ = ray.get( - master_actor.get_master_address.remote(), - ) # type: ignore - new_port = ray.get(master_actor.get_free_port.remote()) # type: ignore - print(f'new_port: {new_port}') - - world_size = dist.get_world_size() - - # Initialize process groups for vLLM engines - refs = [ - engine.init_process_group.remote( - master_addr, - new_port, - i * vllm_tensor_parallel_size + 1, - world_size // 2 + 1, - 'weight-update', - backend='nccl', - ) for i, engine in enumerate(vllm_engines) - ] - - # Add master actor to the process group - refs.append( - master_actor.add_process_group.remote( - backend='nccl', - master_addr=master_addr, - master_port=new_port, - world_size=world_size // 2 + 1, - rank=0, - group_name='weight-update', - ), - ) - - # Wait for all process groups to be initialized - print(ray.get(refs)) - - -class SPMDActorGroup: - # TODO (infra): refactor this to a proper base class - - def __init__(self, num_train_actors: int): - self.num_train_actors = num_train_actors - - self._train_actors = [] - """Create and initialize all training actors.""" - print(f'\n=== STARTING DISTRIBUTED TRAINING WITH RAY ACTORS ===') - - # Create master actor first - self._master_actor = DistributedGPUActor.remote( - 0, - self.num_train_actors, - ) - self._train_actors.append(self._master_actor) - - # Get master address from rank 0 actor - master_addr, master_port = ray.get( - self._master_actor.get_master_address.remote(), # type: ignore - ) - print(f'Master address allocated: {master_addr}:{master_port}') - - # Create remaining actors with the master address/port - for i in range(1, self.num_train_actors): - actor = DistributedGPUActor.remote( - i, - self.num_train_actors, - master_addr, # type: ignore - master_port, - ) - self._train_actors.append(actor) - - @property - def train_actors(self): - return self._train_actors - - @property - def master_actor(self): - return self._master_actor - - -class TrainActorGroup(SPMDActorGroup): - # TODO: this class is mainly pass through gang scheduler, - # we should refactor this class to be more generic and reusable - - def build_models(self, pretrain_model_name: str): - """Build reference models and PPO trainers for all actors.""" - build_train_config_tasks = [ - actor.build_train_config.remote(pretrain_model_name) - for actor in self._train_actors - ] - ray.get(build_train_config_tasks) - - init_task = [ - actor.init_composer_dist.remote() for actor in self._train_actors - ] - ray.get(init_task) - - # Build reference models - build_ref_model_tasks = [ - actor.build_ref_model.remote() for actor in self._train_actors - ] - ray.get(build_ref_model_tasks) - print('build ref model done') - - # Build PPO trainers - build_ppo_trainer_tasks = [ - actor.build_ppo_trainer.remote() for actor in self._train_actors - ] - ray.get(build_ppo_trainer_tasks) - print('build ppo trainer done') - - def update_inference_model(self, vllm_engines: list[Any]): - refs = [ - actor.update_inference_model.remote(vllm_engines) - for actor in self._train_actors - ] - ray.get(refs) - print('update inference model done') - - def query_inference_engines(self, vllm_engines: list[Any]): - refs = [ - actor.query_inference_engines.remote(vllm_engines) - for actor in self._train_actors - ] - ray.get(refs) - print('query inference engines done') - - def train_iteration(self): - """Run one training iteration on all actors.""" - refs = [actor.train_1_iter.remote() for actor in self._train_actors] - ray.get(refs) - print('train 1 iter done') - - -class RolloutAgent: - - def __init__(self, vllm_engines: list, vllm_tensor_parallel_size: int): - self.vllm_engines = vllm_engines - self.vllm_tensor_parallel_size = vllm_tensor_parallel_size - - @property - def num_vllm_engines(self): - return len(self.vllm_engines) - - def generate(self, prompts: list[str]): - # TODO (infra): try integrate this with the multi-turn rollout - # repo - ref = self.vllm_engines[0].generate.remote(prompts) - gen_results = ray.get(ref) - for output in gen_results: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f'Prompt: {prompt!r}, Generated text: {generated_text!r}') - - -# TODO (infra): implement parameter buffer manager and experience manager -class PPOController: - - def __init__( - self, - train_actor: TrainActorGroup, - inference_client: RolloutAgent, - pretrain_model_name: str, - ): - self.train_actor = train_actor - self.inference_client = inference_client - - self.train_actor.build_models(pretrain_model_name) - setup_process_groups( - self.train_actor.master_actor, - self.inference_client.vllm_engines, - self.inference_client.vllm_tensor_parallel_size, - ) - - def train(self): - self.train_actor.update_inference_model( - self.inference_client.vllm_engines, - ) - self.train_actor.query_inference_engines( - self.inference_client.vllm_engines, - ) - self.train_actor.train_iteration() - - -def _run_single_controller_ppo( - pretrain_model_path: str, - world_size: int = 0, -): - """Shared function for running single controller PPO. - - Args: - pretrain_model_path: Path to the pretrained model - world_size: Number of distributed processes - prompts: List of prompts to test generation with - """ - # Set vLLM attention backend to FLASH_ATTN otherwise FlashInfer backend - # takes too long to jit compile - os.environ['VLLM_ATTENTION_BACKEND'] = 'FLASH_ATTN' - - prompts = [ - 'what is RAY?', - 'what is vLLM?', - ] - - with start_ray_server() as _address: - if dist.get_rank() == 0: - # only rank 0 is the master controller - - # create SPMD training actors of the system - if world_size == 0: - world_size = dist.get_world_size() - num_train_actors = world_size // 2 - train_actor = TrainActorGroup(num_train_actors) - - # Create vLLM engines (or inference actors) - vllm_tensor_parallel_size = world_size - num_train_actors - num_vllm_engines = ( - world_size - num_train_actors - ) // vllm_tensor_parallel_size - # TODO: Encapsulate this into a inference server manager class - vllm_engines = create_vllm_engines( - num_engines=num_vllm_engines, - tensor_parallel_size=vllm_tensor_parallel_size, - enforce_eager=True, - pretrain=pretrain_model_path, - revision=None, - seed=1, - enable_prefix_caching=False, - max_model_len=512, - device_bundle={ - 'GPU': 1, - 'CPU': 1, - 'worker_node': 0, - }, - ) - inference_client = RolloutAgent( - vllm_engines, - vllm_tensor_parallel_size, - ) - - ppo_controller = PPOController( - train_actor, - inference_client, - pretrain_model_path, - ) - ppo_controller.train() - - inference_client.generate(prompts) - - -@pytest.mark.gpu -@world_size(4) # TODO change this to 2 for CI testing (hit fatal python error) -def test_single_controller_ppo( - world_size: int, - tiny_llama_model: PreTrainedModel, - tiny_gpt2_tokenizer: PreTrainedTokenizerBase, - tmp_path: pathlib.Path, -): - """Test single controller PPO with Ray actors and vLLM engines.""" - # Save the model and tokenizer to a temporary directory - local_save_path = str(tmp_path / 'llama_model') - tiny_llama_model.save_pretrained(local_save_path) - tiny_gpt2_tokenizer.save_pretrained(local_save_path) - - _run_single_controller_ppo( - pretrain_model_path=local_save_path, - world_size=world_size, - ) - - -if __name__ == '__main__': - # This is an example of how to move the controller logic from PPO Callback - # to a separate trainer actor above and this main single controller - # function. - _run_single_controller_ppo( - pretrain_model_path='meta-llama/Llama-3.2-1B-Instruct', - ) diff --git a/yamls/distributed_ppo_test.yaml b/yamls/distributed_ppo_test.yaml new file mode 100644 index 00000000..c49706da --- /dev/null +++ b/yamls/distributed_ppo_test.yaml @@ -0,0 +1,24 @@ +name: compose-rl-distributed-ppo-test + +scheduling: + priority: low + preemptible: true + +compute: + gpus: 16 + cluster: r5z2p3 + +integrations: +- integration_type: git_repo + path: /workspace/compose-rl + git_repo: databricks/compose-rl + git_branch: single-controller-hackathon + +image: mosaicml/dle:nightly-latest + +command: |- + cd /workspace/compose-rl + composer test_single_controller_ppo.py --file_path /mnt/config/parameters.yaml + +parameters: + pretrain_model_name: Qwen/Qwen2.5-3B-Instruct