diff --git a/compose_rl/algorithms/online/callback_utils.py b/compose_rl/algorithms/online/callback_utils.py index 982cccb9..e4d2ae9c 100644 --- a/compose_rl/algorithms/online/callback_utils.py +++ b/compose_rl/algorithms/online/callback_utils.py @@ -15,42 +15,45 @@ def preprocess_batches(batches: list, generations_per_prompt: int, pad_token_idx padding_key = None for batch in batches: - # Explode the batch into multiple batches for each generation - for _ in range(generations_per_prompt): - # For keys that do not require additional processing - if key in [ - 'prompt_len', - 'verified_answer', - 'prompt_id', - 'vstar', - 'messages', - ]: - curr_values.append(batch[key]) - continue - - bs, seq_len = batch[key].shape - - if key == 'prompt': - padding_key = pad_token_idx - if (batch[key][:, -1] == padding_key).any(): - raise ValueError( - 'The last token in the prompt should not be the pad token. Please double ' - + - 'check the dataloader and prompt and dataloader.', - ) - elif key == 'prompt_attention_mask': - padding_key = False - - # Compute the required padding and concatenate with the batch tensor - pad = torch.ones( - (bs, max_len - seq_len), - dtype=batch[key].dtype, - ) * padding_key # type: ignore - curr_values.append(torch.cat([pad, batch[key]], dim=-1)) + # inside the batch, it's a dictionary of tensors that have the batch dimension there, + # so we need to iterate through each element to explode it. + for item in batch[key]: + # Explode the batch into multiple batches for each generation + for _ in range(generations_per_prompt): + # For keys that do not require additional processing + if key in [ + 'prompt_len', + 'verified_answer', + 'prompt_id', + 'vstar', + 'messages', + ]: + curr_values.append(item) + continue + + seq_len, = item.shape # expect this to be a 1D tensor + + if key == 'prompt': + padding_key = pad_token_idx + if (item[-1] == padding_key).any(): + raise ValueError( + 'The last token in the prompt should not be the pad token. Please double ' + + + 'check the dataloader and prompt and dataloader.', + ) + elif key == 'prompt_attention_mask': + padding_key = False + + # Compute the required padding and concatenate with the batch tensor + pad = torch.ones( + max_len - seq_len, + dtype=item.dtype, + ) * padding_key # type: ignore + curr_values.append(torch.cat([pad, item], dim=-1)) # For tensor fields, use torch.cat to combine the values; for string fields, just use the list if isinstance(curr_values[0], torch.Tensor): - ret_batch[key] = torch.cat(curr_values) + ret_batch[key] = torch.stack(curr_values) else: if key in ['verified_answer', 'vstar']: ret_batch[key] = list(flatten(curr_values)) diff --git a/test_single_controller_ppo.py b/test_single_controller_ppo.py index ad756d4f..66d7305e 100644 --- a/test_single_controller_ppo.py +++ b/test_single_controller_ppo.py @@ -62,6 +62,17 @@ def time_it(name: str): print(f"[{name}] took {end_time - start_time:.2f} seconds") +def get_and_validate_num_prompts_per_iteration(config: Any): + generations_per_prompt = config.variables.generations_per_prompt + num_batches_per_update = config.variables.num_batches_per_update + total_num_generations = config.global_train_batch_size * num_batches_per_update + num_prompts_per_iteration = total_num_generations // generations_per_prompt + + assert total_num_generations % generations_per_prompt == 0, "total_num_generations must be divisible by generations_per_prompt" + + return num_prompts_per_iteration + + class DistributedGPUActor(BaseDistributedGPUActor): """Distributed GPU actor for testing.""" @@ -628,14 +639,8 @@ def __init__(self, config: Any): self.dataloader_config['dataset']['local'].format(timestamp=timestamp) # Key variables - global_train_batch_size = config.global_train_batch_size self.generations_per_prompt = config.variables.generations_per_prompt - num_batches_per_update = config.variables.num_batches_per_update - total_num_generations = global_train_batch_size * num_batches_per_update - self.num_prompts_per_iteration = total_num_generations // self.generations_per_prompt - - # Validate that the total number of generations is divisible by the number of generations per prompt - assert total_num_generations % self.generations_per_prompt == 0, "total_num_generations must be divisible by generations_per_prompt" + self.num_prompts_per_iteration = get_and_validate_num_prompts_per_iteration(config) # Creating main entities self.tokenizer = self._build_tokenizer() @@ -768,6 +773,9 @@ def _run_single_controller_ppo( config=config, ) + num_prompts_per_iteration = get_and_validate_num_prompts_per_iteration(config) + assert num_prompts_per_iteration % num_train_actors == 0, "Number of prompts per iteration must be divisible by number of train actors to ensure accurate advantage calculations." + # We are using a CPU worker for the StreamingActor # and this involves a super hacky workaround by # uninstalling megablocks if it exists. Better solutions @@ -785,6 +793,8 @@ def _run_single_controller_ppo( streaming_dataset_actor = ray.remote(num_gpus=0)(StreamingDatasetActor).remote(config) rollout_agent = RolloutAgent(inference_server, streaming_dataset_actor, config) + + # EvalAgent doesn't need to be a Ray actor since we don't need to # set a world_size or use GPUs for this process. eval_agent = EvalAgent(inference_server.engines, config)