diff --git a/compose_rl/algorithms/online/__init__.py b/compose_rl/algorithms/online/__init__.py index 84efe4ce..3857e92f 100644 --- a/compose_rl/algorithms/online/__init__.py +++ b/compose_rl/algorithms/online/__init__.py @@ -19,6 +19,8 @@ HFPolicyConfig, MPTPolicyConfig, ) +from compose_rl.algorithms.online.single_controller_callback import \ + SingleControllerOnPolicyCallback from compose_rl.registry import kl_controllers kl_controllers.register('adaptive', func=AdaptiveKLController) @@ -28,6 +30,7 @@ __all__ = [ 'OnPolicyCallback', + 'SingleControllerOnPolicyCallback', 'ComposerMPTPolicyLM', 'ComposerHFPolicyLM', 'ComposerHFCriticFreePolicyLM', diff --git a/compose_rl/algorithms/online/callback.py b/compose_rl/algorithms/online/callback.py index d32b5795..41d96d59 100644 --- a/compose_rl/algorithms/online/callback.py +++ b/compose_rl/algorithms/online/callback.py @@ -764,7 +764,6 @@ def _interact_with_env(self, batch: dict[str, torch.Tensor]): # When we hit this function, we should already have all the prompts we need per iteration. num_gen_calls = bs // self.device_generate_batch_size - gen_batch_partial_outputs = [] all_sequences = [] for i in range(num_gen_calls): gen_batch = self._extract_minibatch( @@ -796,6 +795,15 @@ def _interact_with_env(self, batch: dict[str, torch.Tensor]): # Add the prepared sequences to the batch again batch['sequences'] = sequences + # Compute rewards and populate buffer + self._get_reward(batch) + + def _get_reward(self, batch: dict[str, torch.Tensor]): + """Compute rewards for a batch of generated sequences. + + Args: + batch (dict): The batch containing generated sequences to compute rewards for. + """ env_outputs, prompts_and_gens, ref_outputs, all_rewards_dict = env_reward( actor_critic=self.actor_critic, # pyright: ignore reward_manager=self.reward_manager, @@ -825,7 +833,9 @@ def _interact_with_env(self, batch: dict[str, torch.Tensor]): del resolved_outputs[key] # We need to split the resolved outputs into minibatches - for idx in range(bs // self.device_train_batch_size): + for idx in range( + batch['prompt_id'].shape[0] // self.device_train_batch_size, + ): minibatch = self._extract_minibatch( resolved_outputs, idx, @@ -834,7 +844,9 @@ def _interact_with_env(self, batch: dict[str, torch.Tensor]): self.buffer.add(minibatch) # Making sure we correctly parsed the minibatches - assert len(self.buffer) == self.num_batches_per_update + assert len( + self.buffer, + ) == self.num_batches_per_update, f'{len(self.buffer)} != {self.num_batches_per_update}' self.actor_critic.train() @@ -1149,7 +1161,7 @@ def _update_inference_model(self, batch: dict[str, torch.Tensor]): model=self.actor_critic, vllm_engines=self.vllm_engines, model_update_group=self.model_update_group, - batch=batch, + device=batch['prompt'].device, loss_type=self.actor_critic.loss_type, # type: ignore enable_prefix_caching=self.vllm_enable_prefix_caching, ) diff --git a/compose_rl/algorithms/online/generation_utils/vllm_utils.py b/compose_rl/algorithms/online/generation_utils/vllm_utils.py index 62c86d41..4da62390 100644 --- a/compose_rl/algorithms/online/generation_utils/vllm_utils.py +++ b/compose_rl/algorithms/online/generation_utils/vllm_utils.py @@ -381,7 +381,7 @@ def broadcast_to_vllm( model: nn.Module, vllm_engines: list, model_update_group: Optional[torch.distributed.ProcessGroup], - batch: dict[str, torch.Tensor], + device: torch.device, loss_type: OnPolicyEnum = OnPolicyEnum.PPO, enable_prefix_caching: bool = False, ): @@ -391,7 +391,7 @@ def broadcast_to_vllm( model (nn.Module): The model to broadcast vllm_engines (list): List of vllm engines model_update_group (torch.distributed.ProcessGroup): The process group for model updates - batch (dict[str, torch.Tensor]): The batch to use for the forward pass + device (torch.device): The device to use for the forward pass loss_type (str): The loss type which decides whether to use critic-free or not. Defaults to `ppo`. enable_prefix_caching (bool): Whether to enable prefix caching. Defaults to `False`. """ @@ -419,9 +419,6 @@ def broadcast_to_vllm( engine.reset_prefix_cache.remote() for engine in vllm_engines ] - # This is needed to get the correct model device - cur_device = batch['prompt'].device - # These apply to llama modules, it might change for other modules valid_non_leaf_module_names = [ 'model.embed_tokens.weight', @@ -438,17 +435,17 @@ def broadcast_to_vllm( # We need this otherwise FSDP throws an error during a standard forward pass. dummy_batch = { 'obs': - torch.tensor([[0]], dtype=torch.long, device=cur_device), + torch.tensor([[0]], dtype=torch.long, device=device), 'right_padded_attn_mask': - torch.tensor([[1]], dtype=torch.bool, device=cur_device), + torch.tensor([[1]], dtype=torch.bool, device=device), 'actions': - torch.tensor([[0]], dtype=torch.long, device=cur_device), + torch.tensor([[0]], dtype=torch.long, device=device), 'prompt_len': - torch.tensor([1], device=cur_device), + torch.tensor([1], device=device), 'max_gen_len': - torch.tensor([1], device=cur_device), + torch.tensor([1], device=device), 'action_mask': - torch.tensor([[0]], dtype=torch.long, device=cur_device), + torch.tensor([[0]], dtype=torch.long, device=device), } model(dummy_batch) start_time = time.time() diff --git a/compose_rl/algorithms/online/model.py b/compose_rl/algorithms/online/model.py index efe5eeda..842fc36d 100644 --- a/compose_rl/algorithms/online/model.py +++ b/compose_rl/algorithms/online/model.py @@ -100,6 +100,11 @@ def eval_forward(self, batch: MutableMapping, outputs: MutableMapping): ) def loss(self, outputs: MutableMapping, batch: MutableMapping): + # Get beta from config if available, otherwise use default + additional_kwargs = {} + if hasattr(self.config, 'beta'): + additional_kwargs['beta'] = self.config.beta + return_dict = online_rl_loss( outputs=outputs, batch=batch, @@ -107,10 +112,10 @@ def loss(self, outputs: MutableMapping, batch: MutableMapping): value_clip_range=self.config.value_clip_range, value_loss_weight=self.config.value_loss_weight, policy_clip_ratio=self.config.policy_clip_ratio, - beta=self.config.beta, add_direct_kl_loss=self.config.compute_kl_loss, kl_estimator=self.config.kl_estimator, kl_clip_range=self.config.kl_clip_range, + **additional_kwargs, ) self.policy_kl.append(return_dict['kl/policy_kl']) @@ -217,6 +222,11 @@ def eval_forward(self, batch: MutableMapping, outputs: MutableMapping): ) def loss(self, outputs: MutableMapping, batch: MutableMapping): + # Get beta from config if available, otherwise use default + additional_kwargs = {} + if hasattr(self.config, 'beta'): + additional_kwargs['beta'] = self.config.beta + return_dict = online_rl_loss( outputs=outputs, batch=batch, @@ -224,10 +234,10 @@ def loss(self, outputs: MutableMapping, batch: MutableMapping): value_clip_range=self.config.value_clip_range, value_loss_weight=self.config.value_loss_weight, policy_clip_ratio=self.config.policy_clip_ratio, - beta = self.config.beta, add_direct_kl_loss=self.config.compute_kl_loss, kl_estimator=self.config.kl_estimator, kl_clip_range=self.config.kl_clip_range, + **additional_kwargs, ) self.policy_kl.append(return_dict['kl/policy_kl']) diff --git a/compose_rl/algorithms/online/single_controller_callback.py b/compose_rl/algorithms/online/single_controller_callback.py new file mode 100644 index 00000000..b5b4cd6b --- /dev/null +++ b/compose_rl/algorithms/online/single_controller_callback.py @@ -0,0 +1,58 @@ +# Copyright 2024 MosaicML ComposeRL authors +# SPDX-License-Identifier: Apache-2.0 + +"""Online On-Policy RL callback.""" + +from __future__ import annotations + +import logging +from typing import Union + +from composer.core import State +from composer.loggers import Logger +from composer.trainer.trainer import _get_initial_device_train_microbatch_size +from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast + +# Import the base class +from compose_rl.algorithms.online.callback import OnPolicyCallback +from compose_rl.algorithms.online.model import ( + ComposerHFPolicyLM, + ComposerMPTPolicyLM, +) + +Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] +Policy = Union[ComposerHFPolicyLM, ComposerMPTPolicyLM] + +__all__ = ['SingleControllerOnPolicyCallback'] + +log = logging.getLogger(__name__) + + +class SingleControllerOnPolicyCallback(OnPolicyCallback): + """Callback for managing on-policy training in an RLHF loop. + + Ideally all the overwritten methods below should be implemented in the + trainer actor instead of the callback, we kept them here for now to minimize + a drastic refactor to PPO Callback code + """ + + def iteration_start(self, state: State, logger: Logger): + del logger # unused + + self._get_reward(self.batch_rollouts) # type: ignore + + # Reset and initialize state train dataloader + log.warning( + 'trainer._train_data_spec should be updated whenever the dataloader is updated', + ) + # Train Dataloader + state.set_dataloader(self.buffer, 'ep') + state.train_dataloader = state.dataloader + state.device_train_microbatch_size = _get_initial_device_train_microbatch_size( + state.device_train_microbatch_size, + state.auto_microbatching, + state.train_dataloader, + ) + + # Update IFT KL + self._update_ift_kl() diff --git a/tests/common/__init__.py b/tests/common/__init__.py index a815cb38..9d71832c 100644 --- a/tests/common/__init__.py +++ b/tests/common/__init__.py @@ -1,6 +1,7 @@ # Copyright 2024 MosaicML ComposeRL authors # SPDX-License-Identifier: Apache-2.0 +from tests.common.actor import BaseDistributedGPUActor from tests.common.datasets import ( FineGrainedPreference, PairwisePreference, @@ -11,6 +12,7 @@ from tests.common.markers import device, world_size __all__ = [ + 'BaseDistributedGPUActor', 'PairwisePreference', 'FineGrainedPreference', 'PromptDataset', diff --git a/tests/common/actor.py b/tests/common/actor.py new file mode 100644 index 00000000..a2eab75f --- /dev/null +++ b/tests/common/actor.py @@ -0,0 +1,101 @@ +# Copyright 2024 MosaicML ComposeRL authors +# SPDX-License-Identifier: Apache-2.0 + +import os +from datetime import timedelta +from typing import Optional + +import ray +import torch.distributed as dist + +from compose_rl.algorithms.online.generation_utils import init_process_group +from compose_rl.utils.ray_utils import ( + get_free_port, + get_node_ip, + is_cuda_visible_devices_set, +) + + +class BaseDistributedGPUActor: + + def __init__( + self, + rank: int, + world_size: int, + master_addr: Optional[str] = None, + master_port: Optional[int] = None, + ): + """Initialize the distributed GPU actor for RAY. + + Args: + rank: The rank of this process in the distributed group + world_size: Total number of processes in the distributed group + master_addr: Master node address. If None, will allocate dynamically for rank 0 + master_port: Master node port. If None, will allocate dynamically for rank 0 + """ + self.rank = rank + self.world_size = world_size + self.master_addr = master_addr + self.master_port = master_port + + # Set up basic environment variables + # TODO: may need to handle LOCAL_WORLD_SIZE as used in callback.py + os.environ['WORLD_SIZE'] = str(world_size) + os.environ['RANK'] = str(rank) + + # Set LOCAL_RANK based on Ray GPU allocation + os.environ['LOCAL_RANK'] = '0' if is_cuda_visible_devices_set( + ) else str(ray.get_gpu_ids()[0]) + + # If this is rank 0 and no master_addr/master_port provided, allocate them + if rank == 0 and (master_addr is None or master_port is None): + self._allocate_master_address() + + os.environ['MASTER_ADDR'] = self.master_addr # type: ignore + os.environ['MASTER_PORT'] = str(self.master_port) # type: ignore + + self.model = None + self.model_update_group = None + + def _allocate_master_address(self): + """Allocate master address and port for rank 0.""" + if self.master_addr is None: + # Get the local IP address + self.master_addr = get_node_ip() + + if self.master_port is None: + # Allocate a free port + self.master_port = get_free_port() + + def get_master_address(self) -> tuple[Optional[str], Optional[int]]: + """Return the master address and port as a tuple.""" + return (self.master_addr, self.master_port) + + def get_free_port(self): + return get_free_port() + + def init_train_process_group(self): + """Initialize the distributed process group.""" + # Initialize process group + dist.init_process_group(timeout=timedelta(seconds=30)) + + def add_process_group( + self, + backend: str, + master_addr: str, + master_port: int, + world_size: int, + rank: int, + group_name: str, + ): + """Initialize the process group on trainer rank 0 and vllm engines.""" + # NOTE vLLM seems to have a safer implementation of init_process_group: + # https://github.com/vllm-project/vllm/blob/v0.9.1/examples/offline_inference/rlhf.py#L105 + # we should look into using that instead + self.model_update_group = init_process_group( + backend=backend, + init_method=f'tcp://{master_addr}:{master_port}', + world_size=world_size, + rank=rank, + group_name=group_name, + ) diff --git a/tests/common/datasets.py b/tests/common/datasets.py index 09a49804..795d264c 100644 --- a/tests/common/datasets.py +++ b/tests/common/datasets.py @@ -71,6 +71,7 @@ def __getitem__(self, index: int): return { 'prompt': torch.ones((self.prompt_len,)).int(), 'prompt_len': torch.Tensor([self.prompt_len]).to(torch.int64), + 'prompt_id': torch.Tensor([index]).int(), } @@ -87,6 +88,7 @@ def __getitem__(self, index: int): return { 'prompt': torch.ones((self.prompt_len,)).int(), 'prompt_len': torch.Tensor([self.prompt_len]).to(torch.int64), + 'prompt_id': torch.Tensor([index]).int(), 'verified_answer': '1', } diff --git a/tests/fixtures/fixtures.py b/tests/fixtures/fixtures.py index 6271d071..409203cf 100644 --- a/tests/fixtures/fixtures.py +++ b/tests/fixtures/fixtures.py @@ -58,6 +58,33 @@ def tiny_gpt2_config_helper(): return config_object +def tiny_llama_config_helper(): + pytest.importorskip('transformers') + from transformers.models.llama.configuration_llama import LlamaConfig + config_dict = { + 'architectures': ['LlamaForCausalLM'], + 'bos_token_id': 1, + 'eos_token_id': 2, + 'hidden_act': 'silu', + 'hidden_size': 128, + 'intermediate_size': 256, + 'max_position_embeddings': 2048, + 'model_type': 'llama', + 'num_attention_heads': 4, + 'num_hidden_layers': 2, + 'num_key_value_heads': 4, + 'rms_norm_eps': 1e-06, + 'rope_theta': 10000.0, + 'use_cache': True, + 'vocab_size': 50258, # Match GPT-2 tokenizer vocabulary size + } + + config_object = LlamaConfig( + **config_dict, + ) + return config_object + + def assets_path(): rank = os.environ.get('RANK', '0') folder_name = 'tokenizers' + (f'_{rank}' if rank != '0' else '') @@ -144,12 +171,22 @@ def _session_tiny_gpt2_model(_session_tiny_gpt2_config): # type: ignore return causal_lm_model_helper(_session_tiny_gpt2_config) +@pytest.fixture(scope='session') +def _session_tiny_llama_model(_session_tiny_llama_config): # type: ignore + return causal_lm_model_helper(_session_tiny_llama_config) + + ## SESSION CONFIGS ## @pytest.fixture(scope='session') def _session_tiny_gpt2_config(): # type: ignore return tiny_gpt2_config_helper() +@pytest.fixture(scope='session') +def _session_tiny_llama_config(): # type: ignore + return tiny_llama_config_helper() + + ## SESSION TOKENIZERS ## @pytest.fixture(scope='session') def _session_tiny_gpt2_tokenizer(tokenizers_assets): # type: ignore @@ -164,6 +201,11 @@ def tiny_gpt2_model(_session_tiny_gpt2_model): # type: ignore return copy.deepcopy(_session_tiny_gpt2_model) +@pytest.fixture +def tiny_llama_model(_session_tiny_llama_model): # type: ignore + return copy.deepcopy(_session_tiny_llama_model) + + ## TOKENIZER FIXTURES ## @pytest.fixture def tiny_gpt2_tokenizer(_session_tiny_gpt2_tokenizer): # type: ignore diff --git a/tests/test_single_controller.py b/tests/test_single_controller.py index efb2d41e..1df09c74 100644 --- a/tests/test_single_controller.py +++ b/tests/test_single_controller.py @@ -4,8 +4,6 @@ import logging import os import pathlib -from datetime import timedelta -from typing import Optional import pytest import ray @@ -19,93 +17,17 @@ from compose_rl.algorithms.online.generation_utils import ( create_vllm_engines, - init_process_group, ) -from compose_rl.utils.ray_utils import ( - get_free_port, - get_node_ip, - is_cuda_visible_devices_set, - start_ray_server, -) -from tests.common import world_size +from compose_rl.utils.ray_utils import start_ray_server +from tests.common import BaseDistributedGPUActor, world_size # Set up logging logger = logging.getLogger(__name__) @ray.remote(num_gpus=1) -class DistributedGPUActor: - - def __init__( - self, - rank: int, - world_size: int, - master_addr: Optional[str] = None, - master_port: Optional[int] = None, - ): - """Initialize the distributed GPU actor. - - Args: - rank: The rank of this process in the distributed group - world_size: Total number of processes in the distributed group - master_addr: Master node address. If None, will allocate dynamically for rank 0 - master_port: Master node port. If None, will allocate dynamically for rank 0 - """ - self.rank = rank - self.world_size = world_size - self.master_addr = master_addr - self.master_port = master_port - - # Set up basic environment variables - os.environ['WORLD_SIZE'] = str(world_size) - os.environ['RANK'] = str(rank) - - # Set LOCAL_RANK based on Ray GPU allocation - os.environ['LOCAL_RANK'] = '0' if is_cuda_visible_devices_set( - ) else str(ray.get_gpu_ids()[0]) - - # If this is rank 0 and no master_addr/master_port provided, allocate them - if rank == 0 and (master_addr is None or master_port is None): - self._allocate_master_address() - - os.environ['MASTER_ADDR'] = self.master_addr # type: ignore - os.environ['MASTER_PORT'] = str(self.master_port) # type: ignore - - self.model = None - self.model_update_group = None - - def _allocate_master_address(self): - """Allocate master address and port for rank 0.""" - if self.master_addr is None: - # Get the local IP address - self.master_addr = get_node_ip() - - if self.master_port is None: - # Allocate a free port - self.master_port = get_free_port() - - def get_master_address(self) -> tuple[Optional[str], Optional[int]]: - """Return the master address and port as a tuple.""" - return (self.master_addr, self.master_port) - - def get_free_port(self): - return get_free_port() - - def init_train_process_group(self): - """Initialize the distributed process group.""" - # Initialize process group - dist.init_process_group(timeout=timedelta(seconds=30)) - logger.info(f'is distributed initialized: {dist.is_initialized()}') - # Print debug information - num_visible_devices = torch.cuda.device_count() - logger.info(f'num_visible_devices: {num_visible_devices}') - logger.info('Ray actor init envs:') - logger.info(f'rank: {dist.get_rank()}') - logger.info(f'node_rank: {dist.get_rank() // 8}') - logger.info(f'world_size: {dist.get_world_size()}') - logger.info(f'local_rank: {dist.get_rank() % 8}') - logger.info(f'master_addr: {self.master_addr}') - logger.info(f'master_port: {self.master_port}') +class DistributedGPUActor(BaseDistributedGPUActor): + """Distributed GPU actor for testing.""" def init_model(self, model_name: str): """Initialize the model.""" @@ -132,7 +54,7 @@ def sync_weights(self, vllm_engines: list): dist.broadcast(p, src=0, group=self.model_update_group) ray.get(refs) - def tensor_all_reduce(self) -> float: + def test_tensor_all_reduce(self) -> float: """Perform a simple tensor all_reduce operation.""" # Create a tensor on the GPU and perform all_reduce device = torch.device('cuda') @@ -141,27 +63,6 @@ def tensor_all_reduce(self) -> float: return x.item() - def init_vllm_process_group( - self, - backend: str, - master_addr: str, - master_port: int, - world_size: int, - rank: int, - group_name: str, - ): - """Initialize the process group on trainer rank 0 and vllm engines.""" - # NOTE vLLM seems to have a safer implementation of init_process_group: - # https://github.com/vllm-project/vllm/blob/v0.9.1/examples/offline_inference/rlhf.py#L105 - # we should look into using that instead - self.model_update_group = init_process_group( - backend=backend, - init_method=f'tcp://{master_addr}:{master_port}', - world_size=world_size, - rank=rank, - group_name=group_name, - ) - @pytest.mark.gpu @world_size(4) @@ -228,7 +129,7 @@ def test_distributed_ray_actors( # Perform tensor all_reduce on all actors reduce_tasks = [ - actor.tensor_all_reduce.remote() # type: ignore + actor.test_tensor_all_reduce.remote() # type: ignore for actor in train_actors ] results = ray.get(reduce_tasks) @@ -275,7 +176,7 @@ def test_distributed_ray_actors( ) for i, engine in enumerate(vllm_engines) ] refs.append( - master_actor.init_vllm_process_group.remote( # type: ignore + master_actor.add_process_group.remote( # type: ignore backend='nccl', master_addr=master_addr, master_port=new_port, diff --git a/tests/test_single_controller_ppo.py b/tests/test_single_controller_ppo.py new file mode 100644 index 00000000..401683cf --- /dev/null +++ b/tests/test_single_controller_ppo.py @@ -0,0 +1,613 @@ +# Copyright 2024 MosaicML ComposeRL authors +# SPDX-License-Identifier: Apache-2.0 + +# run cmd: `cd compose-rl && cp tests/test_single_controller_ppo.py . +# && composer test_single_controller_ppo.py` + +import os +import pathlib +import time +from functools import partial +from typing import Any, Optional + +import pytest +import ray +import torch +import torch.distributed as dist +from composer import Trainer +from composer.core import get_precision_context +from composer.optim import DecoupledAdamW +from composer.utils import dist as composer_dist +from llmfoundry.models import ComposerHFCausalLM +from torch.utils.data import DataLoader +from transformers import ( + AutoTokenizer, + PreTrainedModel, + PreTrainedTokenizerBase, +) + +from compose_rl.algorithms.online import ( + ComposerHFPolicyLM, + SingleControllerOnPolicyCallback, +) +from compose_rl.algorithms.online.generation_utils import ( + broadcast_to_vllm, + create_vllm_engines, + vllm_generate, +) +from compose_rl.data import prompt_dataset_collate_fn +from compose_rl.utils.ray_utils import start_ray_server +from tests.common import ( + BaseDistributedGPUActor, + VerifiablePromptDataset, + world_size, +) + + +@ray.remote(num_gpus=1) +class DistributedGPUActor(BaseDistributedGPUActor): + """Distributed GPU actor for testing.""" + + def __init__( + self, + rank: int, + world_size: int, + master_addr: Optional[str] = None, + master_port: Optional[int] = None, + ): + super().__init__(rank, world_size, master_addr, master_port) + self.model = None + self.model_update_group = None + self.ref_path = None + self._dataloader = None + self._tokenizer = None + self.ppo_callback = None + self.ppo_trainer: Trainer = None # type: ignore + + self.pretrain_model_name = None + self.device_train_batch_size = None + self.num_batches_per_update = None + self.max_seq_len = None + self.precision: str = None # type: ignore + self.train_config: dict = None # type: ignore + + def build_train_config(self, pretrain_model_name: str): + self.pretrain_model_name = pretrain_model_name + self.device_train_batch_size = 4 + self.num_batches_per_update = 2 + self.max_seq_len = 32 + self.precision = 'amp_bf16' + + ref_model_config = {**self.model_config, 'name': 'hf_causal_lm'} + + variables = { + 'buffer': { + 'name': 'MinibatchRolloutBuffer', + 'max_buffer_size': self.num_batches_per_update, + }, + 'max_gen_len': 8, + 'gamma': 0.99, + 'lambda_gae': 0.95, + 'generation_kwargs': { + 'use_cache': True, + 'do_sample': False, + 'temperature': 1.0, + }, + 'kl_controller': { + 'init_kl_coef': 0.2, + 'target': 0.01, + 'horizon': 12800, + 'kl_ctl_type': 'adaptive', + }, + 'reference_model': { + 'model_config': ref_model_config, + 'precision': self.precision, + 'load_path': self.ref_path, + 'non_train_fsdp_config': self.fsdp_config, + }, + 'epoch_per_iteration': 1, + 'num_batches_per_update': self.num_batches_per_update, + 'rewards': { + 'output_length': { + 'reward_type': 'output_length', + 'max_gen_len': 10, + }, + }, + } + self.train_config = { + 'model': { + **self.model_config, + 'kl_estimator': 'k1', + 'kl_clip_range': 40.0, + }, + 'fsdp_config': + self.fsdp_config, + 'seed': + 17, + 'precision': + self.precision, + 'variables': + variables, + 'max_seq_len': + self.max_seq_len, + 'global_train_batch_size': + self.device_train_batch_size * self.world_size, + 'device_train_batch_size': + self.device_train_batch_size, + 'device_train_microbatch_size': + self.device_train_batch_size, + } + + def build_dataloader(self): + # TODO (infra): build prompt dataloader with rollout agent instead of + # trainer actor + max_seq_len = 32 + prompt_len = 10 + + dataset = VerifiablePromptDataset(prompt_len=prompt_len) + dataloader = DataLoader( + dataset, + collate_fn=partial( + prompt_dataset_collate_fn, + self.tokenizer, + max_seq_len, + ), + sampler=composer_dist.get_sampler(dataset), + batch_size=self.device_train_batch_size, + ) + # We need to mock this method, since our dataset isn't a + # StreamingDataset + dataloader.state_dict = lambda: {} + dataloader.load_state_dict = lambda x: None + return dataloader + + @property + def dataloader(self): + if self._dataloader is None: + self._dataloader = self.build_dataloader() + return self._dataloader + + def build_tokenizer(self): + # TODO (algo): decide if we should use tokens or messages given + # we may need token level log prob + # TODO (infra): use the tokenizer/texts for prompt dataloader but + # token (ids) for the experience buffer/manager + tokenizer = AutoTokenizer.from_pretrained(self.pretrain_model_name) + tokenizer.add_special_tokens({'pad_token': '[PAD]'}) + return tokenizer + + @property + def tokenizer(self): + if self._tokenizer is None: + self._tokenizer = self.build_tokenizer() + return self._tokenizer + + @property + def model_config(self): + return { + 'tokenizer': self.tokenizer, + 'pretrained_model_name_or_path': self.pretrain_model_name, + 'pretrained': True, + 'use_flash_attention_2': True, + 'allow_embedding_resizing': True, + } + + @property + def fsdp_config(self): + # TODO (infra): use actual fsdp1 config + return {} + + def init_composer_dist(self): + composer_dist.initialize_dist('gpu') + + def build_ref_model(self): + # pre-train a reference model for the PPO training + # The key observation here is that we should construct model + # training pipeline in the actor instead of the callback + # e.g., we can build ref/reward/policy/value model and create/colocate + # multiple trainers all in this class + tmp_ref_path = str('./ref_checkpoints') + ref_path = os.path.join(tmp_ref_path, 'latest-rank0.pt') + if os.path.exists(ref_path): + self.ref_path = ref_path + return + + tmp_model = ComposerHFCausalLM( + **self.model_config, + use_auth_token=True, + ) + + tmp_optimizer = DecoupledAdamW(tmp_model.parameters(), lr=1e-6) + + temp_dataloader = [{ + 'input_ids': torch.ones((2, 15)).to(dtype=torch.int64), + 'attention_mask': torch.ones((2, 15)), + 'labels': torch.ones((2, 15)).to(dtype=torch.int64), + }] + + temp_trainer = Trainer( + model=tmp_model, + train_dataloader=temp_dataloader, + optimizers=tmp_optimizer, + max_duration='1ba', + parallelism_config={'fsdp': self.fsdp_config}, + save_folder=tmp_ref_path, + save_weights_only=True, + device_train_microbatch_size=self. + device_train_microbatch_size, # type: ignore + ) + + temp_trainer.fit() + self.ref_path = ref_path + + def build_ppo_trainer(self): + composer_dist.initialize_dist('gpu') + + model = ComposerHFPolicyLM(**self.model_config, use_auth_token=True) + + optimizer = DecoupledAdamW(model.parameters(), lr=1e-8) + + # TODO (infra): pull the rest of the training logic from the callback + # to this class, e.g, how to interact with env, calculate rewards etc + self.ppo_callback = SingleControllerOnPolicyCallback( + train_config=self.train_config, + ) + self.ppo_trainer = Trainer( + model=model, + optimizers=optimizer, + callbacks=self.ppo_callback, + train_dataloader=self.dataloader, + precision=self.precision, + parallelism_config={'fsdp': self.fsdp_config}, + max_duration='3iter', + device_train_microbatch_size=1, + load_path=self.ref_path, + ) + + def train_1_iter(self): + # TODO (algo): implement the top level PPO algo here instead of the + # callback. Algorithmic researchers are expected to implement this + # function along with above policy/value/reward/ref trainers or models + # TODO (infra): try multiple fit to see if the (mlflow) logger, etc + # TODO (infra): fault tolerance at iteration level first + # TODO (infra): enable batch level control + self.ppo_trainer.fit(duration='1iter') + # This is the KL assert that must be true if we are truly loading + # from the same model. This is only true on the first iteration + assert torch.allclose( + self.ppo_trainer.state.loss['kl/ift_kl'], # pyright: ignore + torch.tensor(0.0), + atol=5e-5, + ) + + def update_inference_model(self, vllm_engines: list[Any]): + start_time = time.time() + print('Before broadcast to vLLM') + # TODO (infra) instead of direcly broadcasting to vllm, we should + # push the model parameters to a parameter buffer manager and have + # the buffer manager initiate broadcast of parameters to vllm engines + broadcast_to_vllm( + self.ppo_callback.actor_critic, + vllm_engines, + self.model_update_group, + device=torch.device('cuda'), + loss_type=self.ppo_callback.actor_critic.loss_type, # type: ignore + ) + print('Finished broadcasting to vLLM') + print(f'Took: {time.time() - start_time} to broadcast to vllm.') + dist.barrier() + + def query_inference_engines(self, vllm_engines: list[Any]): + """Round trip to inference engines. + + Args: + vllm_engines (list[Any]): The vllm engines to round trip to. + """ + # TODO (infra): we should use the rollout agent to generate sequences + # instead of the trainer actor, e.g,. reimplment _get_next_iter_prompts + # in the rollout agent + batch = self.ppo_trainer.state.device.batch_to_device( + self.ppo_callback._get_next_iter_prompts(), + ) + max_gen_len = self.train_config['variables']['max_gen_len'] + generation_kwargs = self.train_config['variables']['generation_kwargs'] + with get_precision_context(self.precision), torch.no_grad(): + # TODO (infra): refactor this code to isolate gather of + # prompts on the trainer actor and gather/scatter of sequences + # on the trainer actor, the first half is uesless while + # the second half should be managed throught a experience manager + sequences = vllm_generate( + vllm_engines=vllm_engines, + batch=batch, + max_gen_len=max_gen_len, + generation_kwargs=generation_kwargs, + tokenizer=self.tokenizer, # type: ignore + vllm_generate_function='generate', + ) + # Add the prepared sequences to the batch again + batch['sequences'] = sequences + self.ppo_callback.batch_rollouts = batch # type: ignore + + +def setup_process_groups( + master_actor: Any, + vllm_engines: list[Any], + vllm_tensor_parallel_size: int, +): + """Initialize process groups for vLLM engines and master actor.""" + # Get a new port for the weight-update process group + master_addr, _ = ray.get( + master_actor.get_master_address.remote(), + ) # type: ignore + new_port = ray.get(master_actor.get_free_port.remote()) # type: ignore + print(f'new_port: {new_port}') + + world_size = dist.get_world_size() + + # Initialize process groups for vLLM engines + refs = [ + engine.init_process_group.remote( + master_addr, + new_port, + i * vllm_tensor_parallel_size + 1, + world_size // 2 + 1, + 'weight-update', + backend='nccl', + ) for i, engine in enumerate(vllm_engines) + ] + + # Add master actor to the process group + refs.append( + master_actor.add_process_group.remote( + backend='nccl', + master_addr=master_addr, + master_port=new_port, + world_size=world_size // 2 + 1, + rank=0, + group_name='weight-update', + ), + ) + + # Wait for all process groups to be initialized + print(ray.get(refs)) + + +class SPMDActorGroup: + # TODO (infra): refactor this to a proper base class + + def __init__(self, num_train_actors: int): + self.num_train_actors = num_train_actors + + self._train_actors = [] + """Create and initialize all training actors.""" + print(f'\n=== STARTING DISTRIBUTED TRAINING WITH RAY ACTORS ===') + + # Create master actor first + self._master_actor = DistributedGPUActor.remote( + 0, + self.num_train_actors, + ) + self._train_actors.append(self._master_actor) + + # Get master address from rank 0 actor + master_addr, master_port = ray.get( + self._master_actor.get_master_address.remote(), # type: ignore + ) + print(f'Master address allocated: {master_addr}:{master_port}') + + # Create remaining actors with the master address/port + for i in range(1, self.num_train_actors): + actor = DistributedGPUActor.remote( + i, + self.num_train_actors, + master_addr, # type: ignore + master_port, + ) + self._train_actors.append(actor) + + @property + def train_actors(self): + return self._train_actors + + @property + def master_actor(self): + return self._master_actor + + +class TrainActorGroup(SPMDActorGroup): + # TODO: this class is mainly pass through gang scheduler, + # we should refactor this class to be more generic and reusable + + def build_models(self, pretrain_model_name: str): + """Build reference models and PPO trainers for all actors.""" + build_train_config_tasks = [ + actor.build_train_config.remote(pretrain_model_name) + for actor in self._train_actors + ] + ray.get(build_train_config_tasks) + + init_task = [ + actor.init_composer_dist.remote() for actor in self._train_actors + ] + ray.get(init_task) + + # Build reference models + build_ref_model_tasks = [ + actor.build_ref_model.remote() for actor in self._train_actors + ] + ray.get(build_ref_model_tasks) + print('build ref model done') + + # Build PPO trainers + build_ppo_trainer_tasks = [ + actor.build_ppo_trainer.remote() for actor in self._train_actors + ] + ray.get(build_ppo_trainer_tasks) + print('build ppo trainer done') + + def update_inference_model(self, vllm_engines: list[Any]): + refs = [ + actor.update_inference_model.remote(vllm_engines) + for actor in self._train_actors + ] + ray.get(refs) + print('update inference model done') + + def query_inference_engines(self, vllm_engines: list[Any]): + refs = [ + actor.query_inference_engines.remote(vllm_engines) + for actor in self._train_actors + ] + ray.get(refs) + print('query inference engines done') + + def train_iteration(self): + """Run one training iteration on all actors.""" + refs = [actor.train_1_iter.remote() for actor in self._train_actors] + ray.get(refs) + print('train 1 iter done') + + +class RolloutAgent: + + def __init__(self, vllm_engines: list, vllm_tensor_parallel_size: int): + self.vllm_engines = vllm_engines + self.vllm_tensor_parallel_size = vllm_tensor_parallel_size + + @property + def num_vllm_engines(self): + return len(self.vllm_engines) + + def generate(self, prompts: list[str]): + # TODO (infra): try integrate this with the multi-turn rollout + # repo + ref = self.vllm_engines[0].generate.remote(prompts) + gen_results = ray.get(ref) + for output in gen_results: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f'Prompt: {prompt!r}, Generated text: {generated_text!r}') + + +# TODO (infra): implement parameter buffer manager and experience manager +class PPOController: + + def __init__( + self, + train_actor: TrainActorGroup, + inference_client: RolloutAgent, + pretrain_model_name: str, + ): + self.train_actor = train_actor + self.inference_client = inference_client + + self.train_actor.build_models(pretrain_model_name) + setup_process_groups( + self.train_actor.master_actor, + self.inference_client.vllm_engines, + self.inference_client.vllm_tensor_parallel_size, + ) + + def train(self): + self.train_actor.update_inference_model( + self.inference_client.vllm_engines, + ) + self.train_actor.query_inference_engines( + self.inference_client.vllm_engines, + ) + self.train_actor.train_iteration() + + +def _run_single_controller_ppo( + pretrain_model_path: str, + world_size: int = 0, +): + """Shared function for running single controller PPO. + + Args: + pretrain_model_path: Path to the pretrained model + world_size: Number of distributed processes + prompts: List of prompts to test generation with + """ + # Set vLLM attention backend to FLASH_ATTN otherwise FlashInfer backend + # takes too long to jit compile + os.environ['VLLM_ATTENTION_BACKEND'] = 'FLASH_ATTN' + + prompts = [ + 'what is RAY?', + 'what is vLLM?', + ] + + with start_ray_server() as _address: + if dist.get_rank() == 0: + # only rank 0 is the master controller + + # create SPMD training actors of the system + if world_size == 0: + world_size = dist.get_world_size() + num_train_actors = world_size // 2 + train_actor = TrainActorGroup(num_train_actors) + + # Create vLLM engines (or inference actors) + vllm_tensor_parallel_size = world_size - num_train_actors + num_vllm_engines = ( + world_size - num_train_actors + ) // vllm_tensor_parallel_size + # TODO: Encapsulate this into a inference server manager class + vllm_engines = create_vllm_engines( + num_engines=num_vllm_engines, + tensor_parallel_size=vllm_tensor_parallel_size, + enforce_eager=True, + pretrain=pretrain_model_path, + revision=None, + seed=1, + enable_prefix_caching=False, + max_model_len=512, + device_bundle={ + 'GPU': 1, + 'CPU': 1, + 'worker_node': 0, + }, + ) + inference_client = RolloutAgent( + vllm_engines, + vllm_tensor_parallel_size, + ) + + ppo_controller = PPOController( + train_actor, + inference_client, + pretrain_model_path, + ) + ppo_controller.train() + + inference_client.generate(prompts) + + +@pytest.mark.gpu +@world_size(4) # TODO change this to 2 for CI testing (hit fatal python error) +def test_single_controller_ppo( + world_size: int, + tiny_llama_model: PreTrainedModel, + tiny_gpt2_tokenizer: PreTrainedTokenizerBase, + tmp_path: pathlib.Path, +): + """Test single controller PPO with Ray actors and vLLM engines.""" + # Save the model and tokenizer to a temporary directory + local_save_path = str(tmp_path / 'llama_model') + tiny_llama_model.save_pretrained(local_save_path) + tiny_gpt2_tokenizer.save_pretrained(local_save_path) + + _run_single_controller_ppo( + pretrain_model_path=local_save_path, + world_size=world_size, + ) + + +if __name__ == '__main__': + # This is an example of how to move the controller logic from PPO Callback + # to a separate trainer actor above and this main single controller + # function. + _run_single_controller_ppo( + pretrain_model_path='meta-llama/Llama-3.2-1B-Instruct', + )