From 15c6ae347c0a82e25bf4c2506d7f90eef1bdf30e Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Sat, 9 Aug 2025 10:59:41 -0700 Subject: [PATCH 001/109] adding logging to understand weight updates --- .../algorithms/online/generation_utils/vllm_actor.py | 1 + .../algorithms/online/generation_utils/vllm_utils.py | 11 +++++++---- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/compose_rl/algorithms/online/generation_utils/vllm_actor.py b/compose_rl/algorithms/online/generation_utils/vllm_actor.py index 46d84ede..dc5377a1 100644 --- a/compose_rl/algorithms/online/generation_utils/vllm_actor.py +++ b/compose_rl/algorithms/online/generation_utils/vllm_actor.py @@ -124,6 +124,7 @@ def update_weight( shape: Union[tuple[int, ...], list[int]], empty_cache: bool = False, ): + log.info(f'Updating weight {name} with shape {shape} and dtype {dtype}') return self.llm.collective_rpc( 'update_weight', args=(name, dtype, shape, empty_cache), diff --git a/compose_rl/algorithms/online/generation_utils/vllm_utils.py b/compose_rl/algorithms/online/generation_utils/vllm_utils.py index 4da62390..69620038 100644 --- a/compose_rl/algorithms/online/generation_utils/vllm_utils.py +++ b/compose_rl/algorithms/online/generation_utils/vllm_utils.py @@ -456,6 +456,7 @@ def broadcast_to_vllm( # This is needed otherwise FSDP will materialize parameters of size 0. # So just for the joint actor critic models we have to actually skip this module. if module_name == 'model' and loss_type == OnPolicyEnum.PPO: + log.info('Skipping model module') continue # Only update if we haven't updated this module before @@ -497,14 +498,16 @@ def broadcast_to_vllm( count += 1 shape = param.shape - refs = [ - engine.update_weight.remote( + refs = [] + for engine in vllm_engines: + log.info(f"Sending weight {parsed_name} to engine {engine} with dtype {param.dtype} and shape {shape}") + ref = engine.update_weight.remote( parsed_name, dtype=param.dtype, shape=shape, empty_cache=(count == num_params), - ) for engine in vllm_engines - ] + ) + refs.append(ref) refss.extend(refs) torch.distributed.broadcast( param.data, From 4ed2322d20dd7a988efcff14075a2a8d456fb15f Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Sat, 9 Aug 2025 11:08:45 -0700 Subject: [PATCH 002/109] assert false --- compose_rl/algorithms/online/generation_utils/vllm_actor.py | 1 + 1 file changed, 1 insertion(+) diff --git a/compose_rl/algorithms/online/generation_utils/vllm_actor.py b/compose_rl/algorithms/online/generation_utils/vllm_actor.py index dc5377a1..d6f57321 100644 --- a/compose_rl/algorithms/online/generation_utils/vllm_actor.py +++ b/compose_rl/algorithms/online/generation_utils/vllm_actor.py @@ -124,6 +124,7 @@ def update_weight( shape: Union[tuple[int, ...], list[int]], empty_cache: bool = False, ): + assert False, "weight update" log.info(f'Updating weight {name} with shape {shape} and dtype {dtype}') return self.llm.collective_rpc( 'update_weight', From 453a6103323711419ffb40bce322702d35374552 Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Sat, 9 Aug 2025 11:20:47 -0700 Subject: [PATCH 003/109] logging more updateS --- compose_rl/algorithms/online/generation_utils/vllm_actor.py | 5 ++--- compose_rl/algorithms/online/generation_utils/vllm_utils.py | 4 +++- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/compose_rl/algorithms/online/generation_utils/vllm_actor.py b/compose_rl/algorithms/online/generation_utils/vllm_actor.py index d6f57321..16be2463 100644 --- a/compose_rl/algorithms/online/generation_utils/vllm_actor.py +++ b/compose_rl/algorithms/online/generation_utils/vllm_actor.py @@ -124,9 +124,8 @@ def update_weight( shape: Union[tuple[int, ...], list[int]], empty_cache: bool = False, ): - assert False, "weight update" - log.info(f'Updating weight {name} with shape {shape} and dtype {dtype}') - return self.llm.collective_rpc( + update_str = f'Updating weight {name} with shape {shape} and dtype {dtype}' + return update_str, self.llm.collective_rpc( 'update_weight', args=(name, dtype, shape, empty_cache), ) diff --git a/compose_rl/algorithms/online/generation_utils/vllm_utils.py b/compose_rl/algorithms/online/generation_utils/vllm_utils.py index 69620038..d4ccfe47 100644 --- a/compose_rl/algorithms/online/generation_utils/vllm_utils.py +++ b/compose_rl/algorithms/online/generation_utils/vllm_utils.py @@ -527,7 +527,9 @@ def broadcast_to_vllm( log.info(f'for loop took: {time.time() - start_time}') start_time = time.time() - ray.get(refss) + results = ray.get(refss) + for result in results: + log.info(result[0]) if enable_prefix_caching: ray.get(cache_reset_refss) log.info(f'ray refs took: {time.time() - start_time}') From 5c15cb8c9b4c5704d109dbe68f9e4c0fd5b7e35f Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Sat, 9 Aug 2025 11:45:59 -0700 Subject: [PATCH 004/109] trying out llama 1b --- yamls/single-controller-grpo-workflow.yaml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/yamls/single-controller-grpo-workflow.yaml b/yamls/single-controller-grpo-workflow.yaml index 8227ae06..f901d08e 100644 --- a/yamls/single-controller-grpo-workflow.yaml +++ b/yamls/single-controller-grpo-workflow.yaml @@ -46,7 +46,7 @@ parameters: normalize_advantage: true use_flash_attention_2: true length_normalize_policy_loss: true - pretrained_model_name_or_path: deepseek-ai/DeepSeek-R1-Distill-Llama-8B + pretrained_model_name_or_path: meta-llama/Llama-3.2-1B-Instruct loggers: mlflow: tags: @@ -90,7 +90,7 @@ parameters: alpha: 1 t_warmup: 10iter tokenizer: - name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B + name: meta-llama/Llama-3.2-1B-Instruct kwargs: padding: longest pad_token: <|finetune_right_pad_id|> @@ -127,7 +127,7 @@ parameters: kl_controller: kl_ctl_type: fixed init_kl_coef: 0 - tokenizer_name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B + tokenizer_name: meta-llama/Llama-3.2-1B-Instruct num_train_nodes: 1 reference_model: precision: amp_bf16 @@ -137,7 +137,7 @@ parameters: pretrained: true use_auth_token: true use_flash_attention_2: true - pretrained_model_name_or_path: deepseek-ai/DeepSeek-R1-Distill-Llama-8B + pretrained_model_name_or_path: meta-llama/Llama-3.2-1B-Instruct generation_kwargs: top_p: 1 do_sample: true From 2bd7bc9d291f77566ad39effe4c14b66d777e6bf Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Sat, 9 Aug 2025 11:50:09 -0700 Subject: [PATCH 005/109] fix loading --- test_single_controller_ppo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test_single_controller_ppo.py b/test_single_controller_ppo.py index db0f473e..055eabdd 100644 --- a/test_single_controller_ppo.py +++ b/test_single_controller_ppo.py @@ -748,7 +748,7 @@ def _run_single_controller_ppo( # Load configuration using OmegaConf if args.file_path is None: - config = om.load("yamls/single-controller-grpo-workflow.yaml") + config = om.load("yamls/single-controller-grpo-workflow.yaml").parameters else: config = om.load(args.file_path) From a8d817bda6c207dbd6c064f8b91500a4390f1695 Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Sat, 9 Aug 2025 12:05:55 -0700 Subject: [PATCH 006/109] different dataset --- yamls/single-controller-grpo-workflow.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/yamls/single-controller-grpo-workflow.yaml b/yamls/single-controller-grpo-workflow.yaml index f901d08e..96fe5fdb 100644 --- a/yamls/single-controller-grpo-workflow.yaml +++ b/yamls/single-controller-grpo-workflow.yaml @@ -187,7 +187,7 @@ parameters: dataset: local: /tmp/dataset/prompt_{timestamp}/ split: train - remote: dbfs:/Volumes/datasets/ashutoshbaheti/orl_data/open_r1_filtered/dpsk_8b_open_r1_48k/ + remote: dbfs:/Volumes/datasets/ashutoshbaheti/orl_data/math_lighteval/llama3_1b_math_prompts/ shuffle: true max_gen_len: 8192 max_seq_len: 10240 From 5d642187326506aa99c78d1cb7e8612cbfa654c4 Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Sat, 9 Aug 2025 12:23:41 -0700 Subject: [PATCH 007/109] revert to r1 --- yamls/single-controller-grpo-workflow.yaml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/yamls/single-controller-grpo-workflow.yaml b/yamls/single-controller-grpo-workflow.yaml index 96fe5fdb..8227ae06 100644 --- a/yamls/single-controller-grpo-workflow.yaml +++ b/yamls/single-controller-grpo-workflow.yaml @@ -46,7 +46,7 @@ parameters: normalize_advantage: true use_flash_attention_2: true length_normalize_policy_loss: true - pretrained_model_name_or_path: meta-llama/Llama-3.2-1B-Instruct + pretrained_model_name_or_path: deepseek-ai/DeepSeek-R1-Distill-Llama-8B loggers: mlflow: tags: @@ -90,7 +90,7 @@ parameters: alpha: 1 t_warmup: 10iter tokenizer: - name: meta-llama/Llama-3.2-1B-Instruct + name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B kwargs: padding: longest pad_token: <|finetune_right_pad_id|> @@ -127,7 +127,7 @@ parameters: kl_controller: kl_ctl_type: fixed init_kl_coef: 0 - tokenizer_name: meta-llama/Llama-3.2-1B-Instruct + tokenizer_name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B num_train_nodes: 1 reference_model: precision: amp_bf16 @@ -137,7 +137,7 @@ parameters: pretrained: true use_auth_token: true use_flash_attention_2: true - pretrained_model_name_or_path: meta-llama/Llama-3.2-1B-Instruct + pretrained_model_name_or_path: deepseek-ai/DeepSeek-R1-Distill-Llama-8B generation_kwargs: top_p: 1 do_sample: true @@ -187,7 +187,7 @@ parameters: dataset: local: /tmp/dataset/prompt_{timestamp}/ split: train - remote: dbfs:/Volumes/datasets/ashutoshbaheti/orl_data/math_lighteval/llama3_1b_math_prompts/ + remote: dbfs:/Volumes/datasets/ashutoshbaheti/orl_data/open_r1_filtered/dpsk_8b_open_r1_48k/ shuffle: true max_gen_len: 8192 max_seq_len: 10240 From 3cf44321b824d9be1f8aeec549d064739aba977f Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Sat, 9 Aug 2025 14:09:53 -0700 Subject: [PATCH 008/109] trying out 2 nodes --- yamls/single-controller-grpo-workflow.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/yamls/single-controller-grpo-workflow.yaml b/yamls/single-controller-grpo-workflow.yaml index 8227ae06..25d9c990 100644 --- a/yamls/single-controller-grpo-workflow.yaml +++ b/yamls/single-controller-grpo-workflow.yaml @@ -6,8 +6,8 @@ scheduling: resumable: false preemptible: false compute: - gpus: 8 - cluster: r5z2p3 + gpus: 16 + cluster: r5z2p1 instance: oci.bm.gpu.h200.8.oke integrations: - integration_type: git_repo From f7184be6f7ca4770151efa7fa689bd2999fd776c Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Sat, 9 Aug 2025 14:28:48 -0700 Subject: [PATCH 009/109] test --- yamls/single-controller-grpo-workflow.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/yamls/single-controller-grpo-workflow.yaml b/yamls/single-controller-grpo-workflow.yaml index 25d9c990..50bd5081 100644 --- a/yamls/single-controller-grpo-workflow.yaml +++ b/yamls/single-controller-grpo-workflow.yaml @@ -14,7 +14,7 @@ integrations: path: /workspace/compose-rl git_repo: databricks/compose-rl ssh_clone: true - git_branch: single-controller-hackathon + git_branch: ethantang-db/fix_multi_node - integration_type: git_repo path: /workspace/research-universe git_repo: databricks-mosaic/research-universe From 7e66c923c403a23b8dbae671ef37a74f0d616009 Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Sat, 9 Aug 2025 16:31:58 -0700 Subject: [PATCH 010/109] log worker_wrap logic --- compose_rl/algorithms/online/generation_utils/vllm_actor.py | 3 +-- compose_rl/algorithms/online/generation_utils/vllm_utils.py | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/compose_rl/algorithms/online/generation_utils/vllm_actor.py b/compose_rl/algorithms/online/generation_utils/vllm_actor.py index 16be2463..46d84ede 100644 --- a/compose_rl/algorithms/online/generation_utils/vllm_actor.py +++ b/compose_rl/algorithms/online/generation_utils/vllm_actor.py @@ -124,8 +124,7 @@ def update_weight( shape: Union[tuple[int, ...], list[int]], empty_cache: bool = False, ): - update_str = f'Updating weight {name} with shape {shape} and dtype {dtype}' - return update_str, self.llm.collective_rpc( + return self.llm.collective_rpc( 'update_weight', args=(name, dtype, shape, empty_cache), ) diff --git a/compose_rl/algorithms/online/generation_utils/vllm_utils.py b/compose_rl/algorithms/online/generation_utils/vllm_utils.py index d4ccfe47..2029cff3 100644 --- a/compose_rl/algorithms/online/generation_utils/vllm_utils.py +++ b/compose_rl/algorithms/online/generation_utils/vllm_utils.py @@ -157,6 +157,7 @@ def update_weight( shape (Union[Tuple[int, ...], List[int], torch.Size]): Shape of the weight empty_cache (bool): Whether to empty cache after updating weights """ + log.info(f"Updating weight {name} in worker_wrap with shape {shape} and dtype {dtype}") weight = torch.empty(shape, dtype=dtype, device='cuda') torch.distributed.broadcast( weight, @@ -528,8 +529,7 @@ def broadcast_to_vllm( log.info(f'for loop took: {time.time() - start_time}') start_time = time.time() results = ray.get(refss) - for result in results: - log.info(result[0]) + log.info(f'results: {results}') if enable_prefix_caching: ray.get(cache_reset_refss) log.info(f'ray refs took: {time.time() - start_time}') From 7e1d36564c61eadcf52b8aff5a154b5e7755fcde Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Sat, 9 Aug 2025 16:37:40 -0700 Subject: [PATCH 011/109] force crash --- compose_rl/algorithms/online/generation_utils/vllm_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compose_rl/algorithms/online/generation_utils/vllm_utils.py b/compose_rl/algorithms/online/generation_utils/vllm_utils.py index 2029cff3..adc179df 100644 --- a/compose_rl/algorithms/online/generation_utils/vllm_utils.py +++ b/compose_rl/algorithms/online/generation_utils/vllm_utils.py @@ -157,7 +157,7 @@ def update_weight( shape (Union[Tuple[int, ...], List[int], torch.Size]): Shape of the weight empty_cache (bool): Whether to empty cache after updating weights """ - log.info(f"Updating weight {name} in worker_wrap with shape {shape} and dtype {dtype}") + assert False, "update_weight in worker wrap called" weight = torch.empty(shape, dtype=dtype, device='cuda') torch.distributed.broadcast( weight, From 7c1067025216c8c26f5f8f224fe942f2c3a57864 Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Sat, 9 Aug 2025 16:54:47 -0700 Subject: [PATCH 012/109] removing assert --- compose_rl/algorithms/online/generation_utils/vllm_utils.py | 1 - yamls/single-controller-grpo-workflow.yaml | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/compose_rl/algorithms/online/generation_utils/vllm_utils.py b/compose_rl/algorithms/online/generation_utils/vllm_utils.py index adc179df..e208dcfd 100644 --- a/compose_rl/algorithms/online/generation_utils/vllm_utils.py +++ b/compose_rl/algorithms/online/generation_utils/vllm_utils.py @@ -157,7 +157,6 @@ def update_weight( shape (Union[Tuple[int, ...], List[int], torch.Size]): Shape of the weight empty_cache (bool): Whether to empty cache after updating weights """ - assert False, "update_weight in worker wrap called" weight = torch.empty(shape, dtype=dtype, device='cuda') torch.distributed.broadcast( weight, diff --git a/yamls/single-controller-grpo-workflow.yaml b/yamls/single-controller-grpo-workflow.yaml index 50bd5081..a19dff40 100644 --- a/yamls/single-controller-grpo-workflow.yaml +++ b/yamls/single-controller-grpo-workflow.yaml @@ -6,7 +6,7 @@ scheduling: resumable: false preemptible: false compute: - gpus: 16 + gpus: 8 cluster: r5z2p1 instance: oci.bm.gpu.h200.8.oke integrations: From b0a3467f35ce9c4ecc1c07d4d0604fe40665ed93 Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Sat, 9 Aug 2025 17:10:57 -0700 Subject: [PATCH 013/109] jank logging --- .../algorithms/online/generation_utils/vllm_utils.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/compose_rl/algorithms/online/generation_utils/vllm_utils.py b/compose_rl/algorithms/online/generation_utils/vllm_utils.py index e208dcfd..9e3b4790 100644 --- a/compose_rl/algorithms/online/generation_utils/vllm_utils.py +++ b/compose_rl/algorithms/online/generation_utils/vllm_utils.py @@ -129,6 +129,17 @@ def init_process_group( assert group_name != '', 'group name must not be empty' rank = torch.distributed.get_rank() + rank_offset + + with open(f'/tmp/compose-rl-init_process_group_log_{rank}.txt', 'w') as f: + f.write(f'torch_rank: {torch.distributed.get_rank()}\n') + f.write(f'offset: {rank_offset}\n') + f.write(f'rank: {rank}\n') + f.write(f'world_size: {world_size}\n') + f.write(f'group_name: {group_name}\n') + f.write(f'backend: {backend}\n') + f.write(f'master_address: {master_address}\n') + f.write(f'master_port: {master_port}\n') + self._model_update_group = init_process_group( # type: ignore backend=backend, init_method=f'tcp://{master_address}:{master_port}', From ccd2e4cf1eaa64119d3f45d7968ba8b371156de3 Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Sat, 9 Aug 2025 17:27:16 -0700 Subject: [PATCH 014/109] try gloo? --- test_single_controller_ppo.py | 4 ++-- yamls/single-controller-grpo-workflow.yaml | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/test_single_controller_ppo.py b/test_single_controller_ppo.py index 055eabdd..2d0a9bcf 100644 --- a/test_single_controller_ppo.py +++ b/test_single_controller_ppo.py @@ -299,14 +299,14 @@ def setup_process_groups( i * vllm_tensor_parallel_size + 1, world_size // 2 + 1, 'weight-update', - backend='nccl', + backend='gloo', ) for i, engine in enumerate(vllm_engines) ] # Add master actor to the process group refs.append( master_actor.add_process_group.remote( - backend='nccl', + backend='gloo', master_addr=master_addr, master_port=new_port, world_size=world_size // 2 + 1, diff --git a/yamls/single-controller-grpo-workflow.yaml b/yamls/single-controller-grpo-workflow.yaml index a19dff40..50bd5081 100644 --- a/yamls/single-controller-grpo-workflow.yaml +++ b/yamls/single-controller-grpo-workflow.yaml @@ -6,7 +6,7 @@ scheduling: resumable: false preemptible: false compute: - gpus: 8 + gpus: 16 cluster: r5z2p1 instance: oci.bm.gpu.h200.8.oke integrations: From 34ce4fb1577784817b7c85195926d952ec0d3ded Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Sat, 9 Aug 2025 19:11:18 -0700 Subject: [PATCH 015/109] revert back to nccl --- test_single_controller_ppo.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test_single_controller_ppo.py b/test_single_controller_ppo.py index 2d0a9bcf..055eabdd 100644 --- a/test_single_controller_ppo.py +++ b/test_single_controller_ppo.py @@ -299,14 +299,14 @@ def setup_process_groups( i * vllm_tensor_parallel_size + 1, world_size // 2 + 1, 'weight-update', - backend='gloo', + backend='nccl', ) for i, engine in enumerate(vllm_engines) ] # Add master actor to the process group refs.append( master_actor.add_process_group.remote( - backend='gloo', + backend='nccl', master_addr=master_addr, master_port=new_port, world_size=world_size // 2 + 1, From 536e5133a97de7f0fe257045b695d25d56cd065d Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Mon, 11 Aug 2025 11:40:13 -0700 Subject: [PATCH 016/109] try out cpu and gloo --- test_single_controller_ppo.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test_single_controller_ppo.py b/test_single_controller_ppo.py index 055eabdd..c9e50d3e 100644 --- a/test_single_controller_ppo.py +++ b/test_single_controller_ppo.py @@ -299,14 +299,14 @@ def setup_process_groups( i * vllm_tensor_parallel_size + 1, world_size // 2 + 1, 'weight-update', - backend='nccl', + backend='gloo', ) for i, engine in enumerate(vllm_engines) ] # Add master actor to the process group refs.append( master_actor.add_process_group.remote( - backend='nccl', + backend='gloo', master_addr=master_addr, master_port=new_port, world_size=world_size // 2 + 1, @@ -513,7 +513,7 @@ def update_inference_model(self, actor: DistributedGPUActor, inference_server: I actor.ppo_callback.actor_critic, inference_server.engines, actor.model_update_group, - device=torch.device('cuda'), + device=torch.device('cpu'), loss_type=actor.ppo_callback.actor_critic.loss_type, # type: ignore ) print('Finished broadcasting to vLLM') From 2f4de01d74f14eee60af6e40b7116018a6d34e39 Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Mon, 11 Aug 2025 12:58:31 -0700 Subject: [PATCH 017/109] log tensors to file --- compose_rl/algorithms/online/generation_utils/vllm_utils.py | 4 ++++ test_single_controller_ppo.py | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/compose_rl/algorithms/online/generation_utils/vllm_utils.py b/compose_rl/algorithms/online/generation_utils/vllm_utils.py index 9e3b4790..75bab013 100644 --- a/compose_rl/algorithms/online/generation_utils/vllm_utils.py +++ b/compose_rl/algorithms/online/generation_utils/vllm_utils.py @@ -174,6 +174,8 @@ def update_weight( 0, group=self._model_update_group, ) + with open(f"/tmp/compose-rl-worker-{torch.distributed.get_rank(self._model_update_group)}.txt", "a") as f: + f.write(f"Received weight {name} with shape {shape} and dtype {dtype} with data {weight[..., :3]}\n") # Because FSDP keeps master weights in FP32 and vLLM typically doesn't do this # We will need to cast the weight type to the model_config type @@ -520,6 +522,8 @@ def broadcast_to_vllm( ) refs.append(ref) refss.extend(refs) + with open(f"/tmp/compose-rl-master.txt", "a") as f: + f.write(f"Sending weight {parsed_name} to engine {engine} with dtype {param.dtype} and shape {shape} with data {param.data[..., :3]}\n") torch.distributed.broadcast( param.data, 0, diff --git a/test_single_controller_ppo.py b/test_single_controller_ppo.py index c9e50d3e..2d0a9bcf 100644 --- a/test_single_controller_ppo.py +++ b/test_single_controller_ppo.py @@ -513,7 +513,7 @@ def update_inference_model(self, actor: DistributedGPUActor, inference_server: I actor.ppo_callback.actor_critic, inference_server.engines, actor.model_update_group, - device=torch.device('cpu'), + device=torch.device('cuda'), loss_type=actor.ppo_callback.actor_critic.loss_type, # type: ignore ) print('Finished broadcasting to vLLM') From c96a8d15aa7bb98a60042387fa04d883bc8c8c8a Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Mon, 11 Aug 2025 13:04:52 -0700 Subject: [PATCH 018/109] better logging --- .../algorithms/online/generation_utils/vllm_utils.py | 11 ++++++----- test_single_controller_ppo.py | 4 ++-- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/compose_rl/algorithms/online/generation_utils/vllm_utils.py b/compose_rl/algorithms/online/generation_utils/vllm_utils.py index 75bab013..3fbbccac 100644 --- a/compose_rl/algorithms/online/generation_utils/vllm_utils.py +++ b/compose_rl/algorithms/online/generation_utils/vllm_utils.py @@ -174,8 +174,9 @@ def update_weight( 0, group=self._model_update_group, ) - with open(f"/tmp/compose-rl-worker-{torch.distributed.get_rank(self._model_update_group)}.txt", "a") as f: - f.write(f"Received weight {name} with shape {shape} and dtype {dtype} with data {weight[..., :3]}\n") + if "5" in name: + with open(f"/tmp/compose-rl-worker-{torch.distributed.get_rank(self._model_update_group)}.txt", "a") as f: + f.write(f"Received weight {name} with shape {shape} and dtype {dtype} with data {weight[..., :3]}\n") # Because FSDP keeps master weights in FP32 and vLLM typically doesn't do this # We will need to cast the weight type to the model_config type @@ -513,7 +514,6 @@ def broadcast_to_vllm( shape = param.shape refs = [] for engine in vllm_engines: - log.info(f"Sending weight {parsed_name} to engine {engine} with dtype {param.dtype} and shape {shape}") ref = engine.update_weight.remote( parsed_name, dtype=param.dtype, @@ -522,8 +522,9 @@ def broadcast_to_vllm( ) refs.append(ref) refss.extend(refs) - with open(f"/tmp/compose-rl-master.txt", "a") as f: - f.write(f"Sending weight {parsed_name} to engine {engine} with dtype {param.dtype} and shape {shape} with data {param.data[..., :3]}\n") + if "5" in parsed_name: + with open(f"/tmp/compose-rl-master.txt", "a") as f: + f.write(f"Sending weight {parsed_name} to engine {engine} with dtype {param.dtype} and shape {shape} with data {param.data[..., :3]}\n") torch.distributed.broadcast( param.data, 0, diff --git a/test_single_controller_ppo.py b/test_single_controller_ppo.py index 2d0a9bcf..055eabdd 100644 --- a/test_single_controller_ppo.py +++ b/test_single_controller_ppo.py @@ -299,14 +299,14 @@ def setup_process_groups( i * vllm_tensor_parallel_size + 1, world_size // 2 + 1, 'weight-update', - backend='gloo', + backend='nccl', ) for i, engine in enumerate(vllm_engines) ] # Add master actor to the process group refs.append( master_actor.add_process_group.remote( - backend='gloo', + backend='nccl', master_addr=master_addr, master_port=new_port, world_size=world_size // 2 + 1, From ac6b8d927ee0858a0c9f33b365ec315ae2e7db9f Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Mon, 11 Aug 2025 13:17:26 -0700 Subject: [PATCH 019/109] removed redundent debugging --- compose_rl/algorithms/online/generation_utils/vllm_utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/compose_rl/algorithms/online/generation_utils/vllm_utils.py b/compose_rl/algorithms/online/generation_utils/vllm_utils.py index 3fbbccac..7e960c85 100644 --- a/compose_rl/algorithms/online/generation_utils/vllm_utils.py +++ b/compose_rl/algorithms/online/generation_utils/vllm_utils.py @@ -543,8 +543,7 @@ def broadcast_to_vllm( log.info(f'for loop took: {time.time() - start_time}') start_time = time.time() - results = ray.get(refss) - log.info(f'results: {results}') + ray.get(refss) if enable_prefix_caching: ray.get(cache_reset_refss) log.info(f'ray refs took: {time.time() - start_time}') From 24b0d93a4d531b53dfdf48b12a764eba56c0daa5 Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Mon, 11 Aug 2025 14:14:40 -0700 Subject: [PATCH 020/109] rank --- .../algorithms/online/generation_utils/vllm_utils.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/compose_rl/algorithms/online/generation_utils/vllm_utils.py b/compose_rl/algorithms/online/generation_utils/vllm_utils.py index 7e960c85..4deadcbc 100644 --- a/compose_rl/algorithms/online/generation_utils/vllm_utils.py +++ b/compose_rl/algorithms/online/generation_utils/vllm_utils.py @@ -140,6 +140,7 @@ def init_process_group( f.write(f'master_address: {master_address}\n') f.write(f'master_port: {master_port}\n') + self._rank = rank self._model_update_group = init_process_group( # type: ignore backend=backend, init_method=f'tcp://{master_address}:{master_port}', @@ -174,8 +175,8 @@ def update_weight( 0, group=self._model_update_group, ) - if "5" in name: - with open(f"/tmp/compose-rl-worker-{torch.distributed.get_rank(self._model_update_group)}.txt", "a") as f: + if ".5." in name: + with open(f"/tmp/compose-rl-worker-{self._rank}.txt", "a") as f: f.write(f"Received weight {name} with shape {shape} and dtype {dtype} with data {weight[..., :3]}\n") # Because FSDP keeps master weights in FP32 and vLLM typically doesn't do this @@ -522,7 +523,7 @@ def broadcast_to_vllm( ) refs.append(ref) refss.extend(refs) - if "5" in parsed_name: + if ".5." in parsed_name: with open(f"/tmp/compose-rl-master.txt", "a") as f: f.write(f"Sending weight {parsed_name} to engine {engine} with dtype {param.dtype} and shape {shape} with data {param.data[..., :3]}\n") torch.distributed.broadcast( From f95a801adbe2982f94c9cbc9a7fbe27847998ecb Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Mon, 11 Aug 2025 15:41:26 -0700 Subject: [PATCH 021/109] try env vars --- compose_rl/algorithms/online/generation_utils/vllm_actor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compose_rl/algorithms/online/generation_utils/vllm_actor.py b/compose_rl/algorithms/online/generation_utils/vllm_actor.py index 46d84ede..d000e46b 100644 --- a/compose_rl/algorithms/online/generation_utils/vllm_actor.py +++ b/compose_rl/algorithms/online/generation_utils/vllm_actor.py @@ -33,7 +33,7 @@ log = logging.getLogger(__name__) -@ray.remote +@ray.remote(runtime_env={"env_vars": {"NCCL_CUMEM_ENABLE": "0"}}) class LLMRayActor: def __init__( From b0a34419c195c7e250c5ed70893ef748d99f7c9b Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Mon, 11 Aug 2025 16:06:59 -0700 Subject: [PATCH 022/109] try out other place for nccl --- compose_rl/algorithms/online/generation_utils/vllm_actor.py | 2 +- compose_rl/algorithms/online/generation_utils/vllm_utils.py | 2 ++ test_single_controller_ppo.py | 6 ++++++ 3 files changed, 9 insertions(+), 1 deletion(-) diff --git a/compose_rl/algorithms/online/generation_utils/vllm_actor.py b/compose_rl/algorithms/online/generation_utils/vllm_actor.py index d000e46b..1272d683 100644 --- a/compose_rl/algorithms/online/generation_utils/vllm_actor.py +++ b/compose_rl/algorithms/online/generation_utils/vllm_actor.py @@ -33,7 +33,7 @@ log = logging.getLogger(__name__) -@ray.remote(runtime_env={"env_vars": {"NCCL_CUMEM_ENABLE": "0"}}) +@ray.remote() class LLMRayActor: def __init__( diff --git a/compose_rl/algorithms/online/generation_utils/vllm_utils.py b/compose_rl/algorithms/online/generation_utils/vllm_utils.py index 4deadcbc..fd9bb993 100644 --- a/compose_rl/algorithms/online/generation_utils/vllm_utils.py +++ b/compose_rl/algorithms/online/generation_utils/vllm_utils.py @@ -178,6 +178,8 @@ def update_weight( if ".5." in name: with open(f"/tmp/compose-rl-worker-{self._rank}.txt", "a") as f: f.write(f"Received weight {name} with shape {shape} and dtype {dtype} with data {weight[..., :3]}\n") + import os + f.write(f"os.environ['NCCL_CUMEM_ENABLE'] = {os.environ['NCCL_CUMEM_ENABLE']}\n") # Because FSDP keeps master weights in FP32 and vLLM typically doesn't do this # We will need to cast the weight type to the model_config type diff --git a/test_single_controller_ppo.py b/test_single_controller_ppo.py index 055eabdd..7fc07bb3 100644 --- a/test_single_controller_ppo.py +++ b/test_single_controller_ppo.py @@ -375,6 +375,12 @@ class InferenceServer: """Inference server with vLLM engines.""" def __init__(self, num_vllm_engines: int, pretrain_model_name: str, config: Any): + import os + if os.getenv('NODE_RANK', None) == '0' and os.getenv('LOCAL_RANK', None) == '0': + os.environ['NCCL_CUMEM_ENABLE'] = '0' + os.environ['RAY_BACKEND_LOG_LEVEL'] = 'DEBUG' + os.environ['RAY_DEBUG_LOGS'] = '1' + self.num_vllm_engines = num_vllm_engines self.vllm_tensor_parallel_size = config.vllm_tensor_parallel_size self.vllm_engines = create_vllm_engines( From 55989427d6d4e49fd8c09de17fb94710ab6274ae Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Mon, 11 Aug 2025 16:09:21 -0700 Subject: [PATCH 023/109] f... --- compose_rl/algorithms/online/generation_utils/vllm_actor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compose_rl/algorithms/online/generation_utils/vllm_actor.py b/compose_rl/algorithms/online/generation_utils/vllm_actor.py index 1272d683..46d84ede 100644 --- a/compose_rl/algorithms/online/generation_utils/vllm_actor.py +++ b/compose_rl/algorithms/online/generation_utils/vllm_actor.py @@ -33,7 +33,7 @@ log = logging.getLogger(__name__) -@ray.remote() +@ray.remote class LLMRayActor: def __init__( From febdc6971dc6f738e1db853a1c97b83401334f79 Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Mon, 11 Aug 2025 16:53:32 -0700 Subject: [PATCH 024/109] further debugging --- compose_rl/algorithms/online/generation_utils/vllm_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/compose_rl/algorithms/online/generation_utils/vllm_utils.py b/compose_rl/algorithms/online/generation_utils/vllm_utils.py index fd9bb993..080258f5 100644 --- a/compose_rl/algorithms/online/generation_utils/vllm_utils.py +++ b/compose_rl/algorithms/online/generation_utils/vllm_utils.py @@ -180,6 +180,8 @@ def update_weight( f.write(f"Received weight {name} with shape {shape} and dtype {dtype} with data {weight[..., :3]}\n") import os f.write(f"os.environ['NCCL_CUMEM_ENABLE'] = {os.environ['NCCL_CUMEM_ENABLE']}\n") + f.write(f"model_type = {type(self.model_runner.model)}\n") + f.write(f"model_methods = {dir(self.model_runner.model)}\n") # Because FSDP keeps master weights in FP32 and vLLM typically doesn't do this # We will need to cast the weight type to the model_config type From 05085f37babfd2f0db80116f6f137a9e1a8a9603 Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Mon, 11 Aug 2025 17:00:12 -0700 Subject: [PATCH 025/109] log what weights are updated --- .../online/generation_utils/vllm_utils.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/compose_rl/algorithms/online/generation_utils/vllm_utils.py b/compose_rl/algorithms/online/generation_utils/vllm_utils.py index 080258f5..5ced8b61 100644 --- a/compose_rl/algorithms/online/generation_utils/vllm_utils.py +++ b/compose_rl/algorithms/online/generation_utils/vllm_utils.py @@ -175,23 +175,26 @@ def update_weight( 0, group=self._model_update_group, ) - if ".5." in name: - with open(f"/tmp/compose-rl-worker-{self._rank}.txt", "a") as f: - f.write(f"Received weight {name} with shape {shape} and dtype {dtype} with data {weight[..., :3]}\n") - import os - f.write(f"os.environ['NCCL_CUMEM_ENABLE'] = {os.environ['NCCL_CUMEM_ENABLE']}\n") - f.write(f"model_type = {type(self.model_runner.model)}\n") - f.write(f"model_methods = {dir(self.model_runner.model)}\n") # Because FSDP keeps master weights in FP32 and vLLM typically doesn't do this # We will need to cast the weight type to the model_config type if weight.dtype != self.model_config.dtype: # type: ignore weight = weight.to(self.model_config.dtype) # type: ignore - self.model_runner.model.load_weights( # type: ignore + updated_weights = self.model_runner.model.load_weights( # type: ignore weights=[(name, weight)], ) # type: ignore + if ".5." in name: + with open(f"/tmp/compose-rl-worker-{self._rank}.txt", "a") as f: + f.write(f"Received weight {name} with shape {shape} and dtype {dtype} with data {weight[..., :3]}\n") + f.write(f"updated_weights = {updated_weights}\n") + import os + f.write(f"os.environ['NCCL_CUMEM_ENABLE'] = {os.environ['NCCL_CUMEM_ENABLE']}\n") + f.write(f"model_type = {type(self.model_runner.model)}\n") + f.write(f"model_methods = {dir(self.model_runner.model)}\n") + f.write(f"updated_weights = {updated_weights}\n") + del weight if empty_cache: From 0ebb94ba6ce2d88fe37451e206de25d763cb236d Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Mon, 11 Aug 2025 17:06:53 -0700 Subject: [PATCH 026/109] log weight updates --- .../algorithms/online/generation_utils/vllm_utils.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/compose_rl/algorithms/online/generation_utils/vllm_utils.py b/compose_rl/algorithms/online/generation_utils/vllm_utils.py index 5ced8b61..4987adb2 100644 --- a/compose_rl/algorithms/online/generation_utils/vllm_utils.py +++ b/compose_rl/algorithms/online/generation_utils/vllm_utils.py @@ -185,15 +185,14 @@ def update_weight( weights=[(name, weight)], ) # type: ignore + updated_weight_tensor = [weight_param.data for weight_name, weight_param in self.model_runner.model.named_parameters() if weight_name == name][0] + if ".5." in name: with open(f"/tmp/compose-rl-worker-{self._rank}.txt", "a") as f: f.write(f"Received weight {name} with shape {shape} and dtype {dtype} with data {weight[..., :3]}\n") - f.write(f"updated_weights = {updated_weights}\n") - import os - f.write(f"os.environ['NCCL_CUMEM_ENABLE'] = {os.environ['NCCL_CUMEM_ENABLE']}\n") f.write(f"model_type = {type(self.model_runner.model)}\n") - f.write(f"model_methods = {dir(self.model_runner.model)}\n") f.write(f"updated_weights = {updated_weights}\n") + f.write(f"updated_weight_tensor = {updated_weight_tensor[..., :3]}\n") del weight From 2673c669e5dd05d67a101e5c9de125fff3fdf51a Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Mon, 11 Aug 2025 17:13:15 -0700 Subject: [PATCH 027/109] update weights --- compose_rl/algorithms/online/generation_utils/vllm_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compose_rl/algorithms/online/generation_utils/vllm_utils.py b/compose_rl/algorithms/online/generation_utils/vllm_utils.py index 4987adb2..9c2a1ce3 100644 --- a/compose_rl/algorithms/online/generation_utils/vllm_utils.py +++ b/compose_rl/algorithms/online/generation_utils/vllm_utils.py @@ -185,7 +185,7 @@ def update_weight( weights=[(name, weight)], ) # type: ignore - updated_weight_tensor = [weight_param.data for weight_name, weight_param in self.model_runner.model.named_parameters() if weight_name == name][0] + updated_weight_tensor = [weight_param.data for weight_name, weight_param in self.model_runner.model.named_parameters() if weight_name == updated_weights[0][0]][0] if ".5." in name: with open(f"/tmp/compose-rl-worker-{self._rank}.txt", "a") as f: From d3c1d2054d4c950a748d0354130f5c1361e135d3 Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Mon, 11 Aug 2025 17:25:30 -0700 Subject: [PATCH 028/109] this is trippin --- .../online/generation_utils/vllm_utils.py | 29 ++++++++++++++----- 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/compose_rl/algorithms/online/generation_utils/vllm_utils.py b/compose_rl/algorithms/online/generation_utils/vllm_utils.py index 9c2a1ce3..eac57451 100644 --- a/compose_rl/algorithms/online/generation_utils/vllm_utils.py +++ b/compose_rl/algorithms/online/generation_utils/vllm_utils.py @@ -185,14 +185,29 @@ def update_weight( weights=[(name, weight)], ) # type: ignore - updated_weight_tensor = [weight_param.data for weight_name, weight_param in self.model_runner.model.named_parameters() if weight_name == updated_weights[0][0]][0] - - if ".5." in name: + if len(updated_weights) == 0: with open(f"/tmp/compose-rl-worker-{self._rank}.txt", "a") as f: - f.write(f"Received weight {name} with shape {shape} and dtype {dtype} with data {weight[..., :3]}\n") - f.write(f"model_type = {type(self.model_runner.model)}\n") - f.write(f"updated_weights = {updated_weights}\n") - f.write(f"updated_weight_tensor = {updated_weight_tensor[..., :3]}\n") + f.write(f"Weight {name} not found in model\n") + else: + write = True + if isinstance(updated_weights, list): + updated_weights = updated_weights[0][0] + elif isinstance(updated_weights, set): + updated_weights = updated_weights.pop() + else: + with open(f"/tmp/compose-rl-worker-{self._rank}.txt", "a") as f: + f.write(f"updated_weights = {updated_weights}\n") + f.write(f"type(updated_weights) = {type(updated_weights)}\n") + write = False + + if write: + updated_weight_tensor = [weight_param.data for weight_name, weight_param in self.model_runner.model.named_parameters() if weight_name == updated_weights][0] + + if ".5." in name: + with open(f"/tmp/compose-rl-worker-{self._rank}.txt", "a") as f: + f.write(f"Received weight {name} with shape {shape} and dtype {dtype} with data {weight[..., :3]}\n") + f.write(f"updated_weights = {updated_weights}\n") + f.write(f"updated_weight_tensor = {updated_weight_tensor[..., :3]}\n") del weight From 840e0d89cb2f42cb56c530d1c763b414905a06ca Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Mon, 11 Aug 2025 18:27:34 -0700 Subject: [PATCH 029/109] better weight logging --- .../online/generation_utils/vllm_utils.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/compose_rl/algorithms/online/generation_utils/vllm_utils.py b/compose_rl/algorithms/online/generation_utils/vllm_utils.py index eac57451..2bb528ce 100644 --- a/compose_rl/algorithms/online/generation_utils/vllm_utils.py +++ b/compose_rl/algorithms/online/generation_utils/vllm_utils.py @@ -205,9 +205,19 @@ def update_weight( if ".5." in name: with open(f"/tmp/compose-rl-worker-{self._rank}.txt", "a") as f: - f.write(f"Received weight {name} with shape {shape} and dtype {dtype} with data {weight[..., :3]}\n") + if len(shape) == 2: + weight_str = f"{weight[0, :10]}, ... {weight[-1, -10:]}" + updated_weight_tensor_str = f"{updated_weight_tensor[0, :10]}, ... {updated_weight_tensor[-1, -10:]}" + elif len(shape) == 1: + weight_str = f"{weight[0, :10]}, ... {weight[-1, -10:]}" + updated_weight_tensor_str = f"{updated_weight_tensor[0, :10]}, ... {updated_weight_tensor[-1, -10:]}" + else: + weight_str = f"{weight[0, :10]}, ... {weight[-1, -10:]}" + updated_weight_tensor_str = f"{updated_weight_tensor[0, :10]}, ... {updated_weight_tensor[-1, -10:]}" + + f.write(f"Received weight {name} with shape {shape} and dtype {dtype} with data {weight_str}\n") f.write(f"updated_weights = {updated_weights}\n") - f.write(f"updated_weight_tensor = {updated_weight_tensor[..., :3]}\n") + f.write(f"updated_weight_tensor = {updated_weight_tensor_str}\n") del weight From b9dfda7cdd15439947200729328ba6cb8f2c24ba Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Mon, 11 Aug 2025 18:34:13 -0700 Subject: [PATCH 030/109] like cursor bruh? --- .../algorithms/online/generation_utils/vllm_utils.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/compose_rl/algorithms/online/generation_utils/vllm_utils.py b/compose_rl/algorithms/online/generation_utils/vllm_utils.py index 2bb528ce..82788690 100644 --- a/compose_rl/algorithms/online/generation_utils/vllm_utils.py +++ b/compose_rl/algorithms/online/generation_utils/vllm_utils.py @@ -209,11 +209,12 @@ def update_weight( weight_str = f"{weight[0, :10]}, ... {weight[-1, -10:]}" updated_weight_tensor_str = f"{updated_weight_tensor[0, :10]}, ... {updated_weight_tensor[-1, -10:]}" elif len(shape) == 1: - weight_str = f"{weight[0, :10]}, ... {weight[-1, -10:]}" - updated_weight_tensor_str = f"{updated_weight_tensor[0, :10]}, ... {updated_weight_tensor[-1, -10:]}" + weight_str = f"{weight[:10]}, ... {weight[-10:]}" + updated_weight_tensor_str = f"{updated_weight_tensor[:10]}, ... {updated_weight_tensor[-10:]}" else: - weight_str = f"{weight[0, :10]}, ... {weight[-1, -10:]}" - updated_weight_tensor_str = f"{updated_weight_tensor[0, :10]}, ... {updated_weight_tensor[-1, -10:]}" + weight_str = f"{weight[..., :10]}, ... {weight[..., -10:]}" + updated_weight_tensor_str = f"{updated_weight_tensor[..., :10]}, ... {updated_weight_tensor[..., -10:]}" + f.write(f"Received weight {name} with shape {shape} and dtype {dtype} with data {weight_str}\n") f.write(f"Received weight {name} with shape {shape} and dtype {dtype} with data {weight_str}\n") f.write(f"updated_weights = {updated_weights}\n") From 50475d5a721975fc23f416266155c37901cd1396 Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Mon, 11 Aug 2025 18:43:29 -0700 Subject: [PATCH 031/109] better logs --- .../algorithms/online/generation_utils/vllm_utils.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/compose_rl/algorithms/online/generation_utils/vllm_utils.py b/compose_rl/algorithms/online/generation_utils/vllm_utils.py index 82788690..83fe01d1 100644 --- a/compose_rl/algorithms/online/generation_utils/vllm_utils.py +++ b/compose_rl/algorithms/online/generation_utils/vllm_utils.py @@ -214,11 +214,10 @@ def update_weight( else: weight_str = f"{weight[..., :10]}, ... {weight[..., -10:]}" updated_weight_tensor_str = f"{updated_weight_tensor[..., :10]}, ... {updated_weight_tensor[..., -10:]}" - f.write(f"Received weight {name} with shape {shape} and dtype {dtype} with data {weight_str}\n") - - f.write(f"Received weight {name} with shape {shape} and dtype {dtype} with data {weight_str}\n") + f.write(f"Received weight {name} with shape {shape} and dtype {dtype}\n") + f.write(f"size = {weight.size()}, weight = {weight_str}\n") f.write(f"updated_weights = {updated_weights}\n") - f.write(f"updated_weight_tensor = {updated_weight_tensor_str}\n") + f.write(f"size = {updated_weight_tensor.size()}, weight = {updated_weight_tensor_str}\n") del weight From c80e25f525a960fbbcdec90fe0619ac93b14f4ef Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Mon, 11 Aug 2025 18:56:28 -0700 Subject: [PATCH 032/109] ??? --- compose_rl/algorithms/online/generation_utils/vllm_utils.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/compose_rl/algorithms/online/generation_utils/vllm_utils.py b/compose_rl/algorithms/online/generation_utils/vllm_utils.py index 83fe01d1..6724a675 100644 --- a/compose_rl/algorithms/online/generation_utils/vllm_utils.py +++ b/compose_rl/algorithms/online/generation_utils/vllm_utils.py @@ -190,6 +190,12 @@ def update_weight( f.write(f"Weight {name} not found in model\n") else: write = True + if len(updated_weights) > 1: + with open(f"/tmp/compose-rl-worker-{self._rank}.txt", "a") as f: + f.write(f"multiple updated_weights = {updated_weights}\n") + f.write(f"type(updated_weights) = {type(updated_weights)}\n") + write = False + if isinstance(updated_weights, list): updated_weights = updated_weights[0][0] elif isinstance(updated_weights, set): From 2c8d11fa57b444fda917ffd3480507ae998b5ea1 Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Mon, 11 Aug 2025 19:48:09 -0700 Subject: [PATCH 033/109] ??? --- .../online/generation_utils/vllm_utils.py | 20 +++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/compose_rl/algorithms/online/generation_utils/vllm_utils.py b/compose_rl/algorithms/online/generation_utils/vllm_utils.py index 6724a675..e5be8695 100644 --- a/compose_rl/algorithms/online/generation_utils/vllm_utils.py +++ b/compose_rl/algorithms/online/generation_utils/vllm_utils.py @@ -176,6 +176,14 @@ def update_weight( group=self._model_update_group, ) + if ".5." in name: + if len(shape) == 2: + weight_str = f"{weight[0, :10]}, ... {weight[-1, -10:]}" + elif len(shape) == 1: + weight_str = f"{weight[:10]}, ... {weight[-10:]}" + else: + weight_str = f"{weight[..., :10]}, ... {weight[..., -10:]}" + # Because FSDP keeps master weights in FP32 and vLLM typically doesn't do this # We will need to cast the weight type to the model_config type if weight.dtype != self.model_config.dtype: # type: ignore @@ -212,13 +220,10 @@ def update_weight( if ".5." in name: with open(f"/tmp/compose-rl-worker-{self._rank}.txt", "a") as f: if len(shape) == 2: - weight_str = f"{weight[0, :10]}, ... {weight[-1, -10:]}" updated_weight_tensor_str = f"{updated_weight_tensor[0, :10]}, ... {updated_weight_tensor[-1, -10:]}" elif len(shape) == 1: - weight_str = f"{weight[:10]}, ... {weight[-10:]}" updated_weight_tensor_str = f"{updated_weight_tensor[:10]}, ... {updated_weight_tensor[-10:]}" else: - weight_str = f"{weight[..., :10]}, ... {weight[..., -10:]}" updated_weight_tensor_str = f"{updated_weight_tensor[..., :10]}, ... {updated_weight_tensor[..., -10:]}" f.write(f"Received weight {name} with shape {shape} and dtype {dtype}\n") f.write(f"size = {weight.size()}, weight = {weight_str}\n") @@ -562,7 +567,14 @@ def broadcast_to_vllm( refss.extend(refs) if ".5." in parsed_name: with open(f"/tmp/compose-rl-master.txt", "a") as f: - f.write(f"Sending weight {parsed_name} to engine {engine} with dtype {param.dtype} and shape {shape} with data {param.data[..., :3]}\n") + if len(shape) == 2: + weight_str = f"{param.data[0, :10]}, ... {param.data[-1, -10:]}" + elif len(shape) == 1: + weight_str = f"{param.data[:10]}, ... {param.data[-10:]}" + else: + weight_str = f"{param.data[..., :10]}, ... {param.data[..., -10:]}" + f.write(f"Sending weight {parsed_name} to engine {engine} with dtype {param.dtype}\n") + f.write(f"size = {shape}, weight = {weight_str}\n") torch.distributed.broadcast( param.data, 0, From 1de2fff51515585b5d99f32ae18b866aa86ae398 Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Mon, 11 Aug 2025 20:06:48 -0700 Subject: [PATCH 034/109] try layer 25 --- compose_rl/algorithms/online/generation_utils/vllm_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/compose_rl/algorithms/online/generation_utils/vllm_utils.py b/compose_rl/algorithms/online/generation_utils/vllm_utils.py index e5be8695..ac849495 100644 --- a/compose_rl/algorithms/online/generation_utils/vllm_utils.py +++ b/compose_rl/algorithms/online/generation_utils/vllm_utils.py @@ -176,7 +176,7 @@ def update_weight( group=self._model_update_group, ) - if ".5." in name: + if ".25." in name: if len(shape) == 2: weight_str = f"{weight[0, :10]}, ... {weight[-1, -10:]}" elif len(shape) == 1: @@ -217,7 +217,7 @@ def update_weight( if write: updated_weight_tensor = [weight_param.data for weight_name, weight_param in self.model_runner.model.named_parameters() if weight_name == updated_weights][0] - if ".5." in name: + if ".25." in name: with open(f"/tmp/compose-rl-worker-{self._rank}.txt", "a") as f: if len(shape) == 2: updated_weight_tensor_str = f"{updated_weight_tensor[0, :10]}, ... {updated_weight_tensor[-1, -10:]}" @@ -565,7 +565,7 @@ def broadcast_to_vllm( ) refs.append(ref) refss.extend(refs) - if ".5." in parsed_name: + if ".25." in parsed_name: with open(f"/tmp/compose-rl-master.txt", "a") as f: if len(shape) == 2: weight_str = f"{param.data[0, :10]}, ... {param.data[-1, -10:]}" From 8089c523fb071dce706cfc609825088164bba931 Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Mon, 11 Aug 2025 20:44:55 -0700 Subject: [PATCH 035/109] cranking up the learning rate --- yamls/single-controller-grpo-workflow.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/yamls/single-controller-grpo-workflow.yaml b/yamls/single-controller-grpo-workflow.yaml index 50bd5081..88b5841b 100644 --- a/yamls/single-controller-grpo-workflow.yaml +++ b/yamls/single-controller-grpo-workflow.yaml @@ -78,7 +78,7 @@ parameters: save_interval: 1dur runtime_estimator: {} optimizer: - lr: 1.0e-06 + lr: 1.0e-03 name: decoupled_adamw betas: - 0.9 From f38d7dbfdd8185af83a26ad8071c7ec83b087ef7 Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Mon, 11 Aug 2025 20:46:46 -0700 Subject: [PATCH 036/109] new lines --- compose_rl/algorithms/online/generation_utils/vllm_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/compose_rl/algorithms/online/generation_utils/vllm_utils.py b/compose_rl/algorithms/online/generation_utils/vllm_utils.py index ac849495..3ed56a8b 100644 --- a/compose_rl/algorithms/online/generation_utils/vllm_utils.py +++ b/compose_rl/algorithms/online/generation_utils/vllm_utils.py @@ -229,7 +229,7 @@ def update_weight( f.write(f"size = {weight.size()}, weight = {weight_str}\n") f.write(f"updated_weights = {updated_weights}\n") f.write(f"size = {updated_weight_tensor.size()}, weight = {updated_weight_tensor_str}\n") - + f.write("\n\n") del weight if empty_cache: @@ -575,6 +575,7 @@ def broadcast_to_vllm( weight_str = f"{param.data[..., :10]}, ... {param.data[..., -10:]}" f.write(f"Sending weight {parsed_name} to engine {engine} with dtype {param.dtype}\n") f.write(f"size = {shape}, weight = {weight_str}\n") + f.write("\n\n") torch.distributed.broadcast( param.data, 0, From 963a66f3efbdcb7d705e0acca6e4268c906a5aa4 Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Mon, 11 Aug 2025 20:55:48 -0700 Subject: [PATCH 037/109] remove new lines --- compose_rl/algorithms/online/generation_utils/vllm_utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/compose_rl/algorithms/online/generation_utils/vllm_utils.py b/compose_rl/algorithms/online/generation_utils/vllm_utils.py index 3ed56a8b..c66e68b5 100644 --- a/compose_rl/algorithms/online/generation_utils/vllm_utils.py +++ b/compose_rl/algorithms/online/generation_utils/vllm_utils.py @@ -229,7 +229,6 @@ def update_weight( f.write(f"size = {weight.size()}, weight = {weight_str}\n") f.write(f"updated_weights = {updated_weights}\n") f.write(f"size = {updated_weight_tensor.size()}, weight = {updated_weight_tensor_str}\n") - f.write("\n\n") del weight if empty_cache: @@ -575,7 +574,6 @@ def broadcast_to_vllm( weight_str = f"{param.data[..., :10]}, ... {param.data[..., -10:]}" f.write(f"Sending weight {parsed_name} to engine {engine} with dtype {param.dtype}\n") f.write(f"size = {shape}, weight = {weight_str}\n") - f.write("\n\n") torch.distributed.broadcast( param.data, 0, From 37403ef2df7aefdd5d3c4e5b0f9ae37b3e4c231c Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Mon, 11 Aug 2025 21:40:40 -0700 Subject: [PATCH 038/109] even more ridiculous LR --- yamls/single-controller-grpo-workflow.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/yamls/single-controller-grpo-workflow.yaml b/yamls/single-controller-grpo-workflow.yaml index 88b5841b..6196eb3a 100644 --- a/yamls/single-controller-grpo-workflow.yaml +++ b/yamls/single-controller-grpo-workflow.yaml @@ -78,7 +78,7 @@ parameters: save_interval: 1dur runtime_estimator: {} optimizer: - lr: 1.0e-03 + lr: 0.5 name: decoupled_adamw betas: - 0.9 From 6fbdfc3834bf62aa29e36790d8d61b704b459aca Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Mon, 11 Aug 2025 23:06:26 -0700 Subject: [PATCH 039/109] are you fking kidding me --- test_single_controller_ppo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test_single_controller_ppo.py b/test_single_controller_ppo.py index 7fc07bb3..9022fc6e 100644 --- a/test_single_controller_ppo.py +++ b/test_single_controller_ppo.py @@ -185,7 +185,7 @@ def build_ppo_trainer(self): model = ComposerHFCriticFreePolicyLM(**self.model_config) self.logger.info("Model created successfully") - optimizer = DecoupledAdamW(model.parameters(), lr=1e-6) + optimizer = DecoupledAdamW(model.parameters(), lr=1e-2) # TODO (infra): pull the rest of the training logic from the callback # to this class, e.g, how to interact with env, calculate rewards etc From 7c699d68c6d5d5dd6536767f9f125e2bef17a0e4 Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Mon, 11 Aug 2025 23:27:55 -0700 Subject: [PATCH 040/109] new opt --- test_single_controller_ppo.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test_single_controller_ppo.py b/test_single_controller_ppo.py index 9022fc6e..e56762d7 100644 --- a/test_single_controller_ppo.py +++ b/test_single_controller_ppo.py @@ -27,7 +27,7 @@ import torch.distributed as dist from composer import Trainer from composer.core import get_precision_context -from composer.optim import DecoupledAdamW +from composer.optim import DecoupledSGDW from composer.utils import dist as composer_dist from llmfoundry.data import build_dataloader from omegaconf import OmegaConf as om @@ -185,7 +185,7 @@ def build_ppo_trainer(self): model = ComposerHFCriticFreePolicyLM(**self.model_config) self.logger.info("Model created successfully") - optimizer = DecoupledAdamW(model.parameters(), lr=1e-2) + optimizer = DecoupledSGDW(model.parameters(), lr=1e-2) # TODO (infra): pull the rest of the training logic from the callback # to this class, e.g, how to interact with env, calculate rewards etc From 81b9c46d49601015b5b8084648e7170186c3e189 Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Tue, 12 Aug 2025 11:30:00 -0700 Subject: [PATCH 041/109] chaos updates --- compose_rl/algorithms/online/generation_utils/vllm_utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/compose_rl/algorithms/online/generation_utils/vllm_utils.py b/compose_rl/algorithms/online/generation_utils/vllm_utils.py index c66e68b5..413f76fc 100644 --- a/compose_rl/algorithms/online/generation_utils/vllm_utils.py +++ b/compose_rl/algorithms/online/generation_utils/vllm_utils.py @@ -148,6 +148,7 @@ def init_process_group( rank=rank, group_name=group_name, ) + self._chaos_updates = [1, 3, 5, 7, 11] log.info(f'init process group for: {torch.distributed.get_rank()}') log.info( f'init_process_group: master_address={master_address}, master_port={master_port}, ' @@ -176,6 +177,9 @@ def update_weight( group=self._model_update_group, ) + weight = weight * self._chaos_updates[0] + self._chaos_updates = self._chaos_updates[1:] + [self._chaos_updates[0]] + if ".25." in name: if len(shape) == 2: weight_str = f"{weight[0, :10]}, ... {weight[-1, -10:]}" From 00920bf6a24fa54ac762474af3a2cab2a03f8b3e Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Tue, 12 Aug 2025 11:52:49 -0700 Subject: [PATCH 042/109] try new chaos values --- compose_rl/algorithms/online/generation_utils/vllm_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compose_rl/algorithms/online/generation_utils/vllm_utils.py b/compose_rl/algorithms/online/generation_utils/vllm_utils.py index 413f76fc..5ea74258 100644 --- a/compose_rl/algorithms/online/generation_utils/vllm_utils.py +++ b/compose_rl/algorithms/online/generation_utils/vllm_utils.py @@ -148,7 +148,7 @@ def init_process_group( rank=rank, group_name=group_name, ) - self._chaos_updates = [1, 3, 5, 7, 11] + self._chaos_updates = [0.95, 1.05, 1.1, 0.9, 1] log.info(f'init process group for: {torch.distributed.get_rank()}') log.info( f'init_process_group: master_address={master_address}, master_port={master_port}, ' From a1c3594887df2975e6bd8811bad582e65bad9510 Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Tue, 12 Aug 2025 13:16:04 -0700 Subject: [PATCH 043/109] logging weight updateS --- .../online/generation_utils/vllm_utils.py | 6 +- test_single_controller_ppo.py | 82 +++++++++++++++++++ 2 files changed, 85 insertions(+), 3 deletions(-) diff --git a/compose_rl/algorithms/online/generation_utils/vllm_utils.py b/compose_rl/algorithms/online/generation_utils/vllm_utils.py index 5ea74258..d5c4684a 100644 --- a/compose_rl/algorithms/online/generation_utils/vllm_utils.py +++ b/compose_rl/algorithms/online/generation_utils/vllm_utils.py @@ -148,7 +148,7 @@ def init_process_group( rank=rank, group_name=group_name, ) - self._chaos_updates = [0.95, 1.05, 1.1, 0.9, 1] + # self._chaos_updates = [0.95, 1.05, 1.1, 0.9, 1] log.info(f'init process group for: {torch.distributed.get_rank()}') log.info( f'init_process_group: master_address={master_address}, master_port={master_port}, ' @@ -177,8 +177,8 @@ def update_weight( group=self._model_update_group, ) - weight = weight * self._chaos_updates[0] - self._chaos_updates = self._chaos_updates[1:] + [self._chaos_updates[0]] + # weight = weight * self._chaos_updates[0] + # self._chaos_updates = self._chaos_updates[1:] + [self._chaos_updates[0]] if ".25." in name: if len(shape) == 2: diff --git a/test_single_controller_ppo.py b/test_single_controller_ppo.py index ec12d144..b24dee4e 100644 --- a/test_single_controller_ppo.py +++ b/test_single_controller_ppo.py @@ -25,6 +25,7 @@ import ray import torch import torch.distributed as dist +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP # type: ignore from composer import Trainer from composer.core import get_precision_context from composer.optim import DecoupledSGDW @@ -248,6 +249,32 @@ def train_1_iter(self): # fit() can also potentially overwrite the mlflow self.ppo_trainer.fit(duration='1iter') self.logger.info(f"#### Finished training 1 iter with loss: {self.ppo_trainer.state.loss}") + model = self.ppo_trainer.state.model + + param2fullname = build_param_fullnames(model) + + for _, module in model.named_modules(): + if isinstance(module, FSDP): + with FSDP.summon_full_params( + module, + writeback=False, + rank0_only=True, + recurse=False, + ): + for _, param in module.named_parameters(recurse=True): + full_name = param2fullname[param] + parsed_name = simplify_param_path(full_name) + shape = param.shape + if ".25." in parsed_name: + with open(f"/tmp/compose-rl-master.txt", "a") as f: + if len(shape) == 2: + weight_str = f"{param.data[0, :10]}, ... {param.data[-1, -10:]}" + elif len(shape) == 1: + weight_str = f"{param.data[:10]}, ... {param.data[-10:]}" + else: + weight_str = f"{param.data[..., :10]}, ... {param.data[..., -10:]}" + f.write(f"Weight {parsed_name}\n") + f.write(f"size = {shape}, weight = {weight_str}\n") def setup_process_groups( @@ -807,6 +834,61 @@ def _run_single_controller_ppo( asyncio.run(ppo_controller.train_async(config.max_duration)) +def simplify_param_path(path: str) -> str: + """Simplifies the parameter path by removing unnecessary parts. + + Args: + path (str): The original parameter path. + """ + # Parts we want to remove + remove_parts = [ + '_fsdp_wrapped_module', + '_checkpoint_wrapped_module', + 'lm_backbone', + 'model', + ] + + # Split the path into parts + parts = path.split('.') + + # Keep only parts that don't contain any of the remove_parts + clean_parts = [] + if 'lm_head' not in path: + clean_parts = ['model'] + for part in parts: + if not any(remove in part for remove in remove_parts): + clean_parts.append(part) + + return '.'.join(clean_parts) + + +def build_param_fullnames(top_module: torch.nn.Module) -> dict: + """Builds a mapping of parameter objects to their fully-qualified names. + + Traverses the entire model from the top level and map each parameter + object to its fully-qualified name (e.g., + "lm_backbone.layer1.mlp.down_proj.weight"). + + Args: + top_module (torch.nn.Module): The top-level module to traverse. + """ + param2fullname = {} + + def _dfs(current_module: torch.nn.Module, prefix: str = ''): + # Get local parameters (without recursing into children). + for local_name, param in current_module.named_parameters(recurse=False): + full_name = f'{prefix}.{local_name}' if prefix else local_name + param2fullname[param] = full_name + + # Recurse on child modules. + for child_name, child_module in current_module.named_children(): + child_prefix = f'{prefix}.{child_name}' if prefix else child_name + _dfs(child_module, prefix=child_prefix) + + _dfs(top_module) + return param2fullname + + if __name__ == '__main__': # Parse command line arguments parser = argparse.ArgumentParser(description='Run single controller PPO with configuration file') From 594903997c436127df506fd5960c17a7150ac27a Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Tue, 12 Aug 2025 13:19:41 -0700 Subject: [PATCH 044/109] lr is 1 --- test_single_controller_ppo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test_single_controller_ppo.py b/test_single_controller_ppo.py index b24dee4e..8dca7796 100644 --- a/test_single_controller_ppo.py +++ b/test_single_controller_ppo.py @@ -171,7 +171,7 @@ def build_ppo_trainer(self): model = ComposerHFCriticFreePolicyLM(**self.model_config) self.logger.info("Model created successfully") - optimizer = DecoupledSGDW(model.parameters(), lr=1e-2) + optimizer = DecoupledSGDW(model.parameters(), lr=1) # TODO (infra): pull the rest of the training logic from the callback # to this class, e.g, how to interact with env, calculate rewards etc From abbfa1dfaa3d1d608ff5ff72b9b388d5b1fd2a5b Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Tue, 12 Aug 2025 13:33:27 -0700 Subject: [PATCH 045/109] god dammit cursor --- test_single_controller_ppo.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/test_single_controller_ppo.py b/test_single_controller_ppo.py index 8dca7796..10e970bd 100644 --- a/test_single_controller_ppo.py +++ b/test_single_controller_ppo.py @@ -249,6 +249,10 @@ def train_1_iter(self): # fit() can also potentially overwrite the mlflow self.ppo_trainer.fit(duration='1iter') self.logger.info(f"#### Finished training 1 iter with loss: {self.ppo_trainer.state.loss}") + + if dist.get_global_rank() != 0: + return + model = self.ppo_trainer.state.model param2fullname = build_param_fullnames(model) @@ -266,7 +270,7 @@ def train_1_iter(self): parsed_name = simplify_param_path(full_name) shape = param.shape if ".25." in parsed_name: - with open(f"/tmp/compose-rl-master.txt", "a") as f: + with open(f"/tmp/compose-rl-train.txt", "a") as f: if len(shape) == 2: weight_str = f"{param.data[0, :10]}, ... {param.data[-1, -10:]}" elif len(shape) == 1: From fca0c37c3c010f22603fa81edd7c0d627d10dfce Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Tue, 12 Aug 2025 13:46:53 -0700 Subject: [PATCH 046/109] correct rank --- test_single_controller_ppo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test_single_controller_ppo.py b/test_single_controller_ppo.py index 10e970bd..18d322a5 100644 --- a/test_single_controller_ppo.py +++ b/test_single_controller_ppo.py @@ -250,7 +250,7 @@ def train_1_iter(self): self.ppo_trainer.fit(duration='1iter') self.logger.info(f"#### Finished training 1 iter with loss: {self.ppo_trainer.state.loss}") - if dist.get_global_rank() != 0: + if self.rank != 0: return model = self.ppo_trainer.state.model From 85cfd6b116da087de70b3940bd81c63e5c1dac63 Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Tue, 12 Aug 2025 13:54:53 -0700 Subject: [PATCH 047/109] try other chaos update vlaues --- compose_rl/algorithms/online/generation_utils/vllm_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/compose_rl/algorithms/online/generation_utils/vllm_utils.py b/compose_rl/algorithms/online/generation_utils/vllm_utils.py index d5c4684a..4dc89f70 100644 --- a/compose_rl/algorithms/online/generation_utils/vllm_utils.py +++ b/compose_rl/algorithms/online/generation_utils/vllm_utils.py @@ -148,7 +148,7 @@ def init_process_group( rank=rank, group_name=group_name, ) - # self._chaos_updates = [0.95, 1.05, 1.1, 0.9, 1] + self._chaos_updates = [0.99, 1.01, 1.02, 0.98, 1] log.info(f'init process group for: {torch.distributed.get_rank()}') log.info( f'init_process_group: master_address={master_address}, master_port={master_port}, ' @@ -177,8 +177,8 @@ def update_weight( group=self._model_update_group, ) - # weight = weight * self._chaos_updates[0] - # self._chaos_updates = self._chaos_updates[1:] + [self._chaos_updates[0]] + weight = weight * self._chaos_updates[0] + self._chaos_updates = self._chaos_updates[1:] + [self._chaos_updates[0]] if ".25." in name: if len(shape) == 2: From ba7fb546edb3b507bfaed298ef6741bcdb8fc7aa Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Tue, 12 Aug 2025 14:41:09 -0700 Subject: [PATCH 048/109] fix crash? --- test_single_controller_ppo.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/test_single_controller_ppo.py b/test_single_controller_ppo.py index 18d322a5..d5c7e676 100644 --- a/test_single_controller_ppo.py +++ b/test_single_controller_ppo.py @@ -250,9 +250,6 @@ def train_1_iter(self): self.ppo_trainer.fit(duration='1iter') self.logger.info(f"#### Finished training 1 iter with loss: {self.ppo_trainer.state.loss}") - if self.rank != 0: - return - model = self.ppo_trainer.state.model param2fullname = build_param_fullnames(model) @@ -270,7 +267,7 @@ def train_1_iter(self): parsed_name = simplify_param_path(full_name) shape = param.shape if ".25." in parsed_name: - with open(f"/tmp/compose-rl-train.txt", "a") as f: + with open(f"/tmp/compose-rl-train-{self.rank}.txt", "a") as f: if len(shape) == 2: weight_str = f"{param.data[0, :10]}, ... {param.data[-1, -10:]}" elif len(shape) == 1: From be9d152b8c87f4d88198e4ec7f49e25b1a167362 Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Tue, 12 Aug 2025 14:59:38 -0700 Subject: [PATCH 049/109] remove chaos --- compose_rl/algorithms/online/generation_utils/vllm_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/compose_rl/algorithms/online/generation_utils/vllm_utils.py b/compose_rl/algorithms/online/generation_utils/vllm_utils.py index 4dc89f70..b44d2434 100644 --- a/compose_rl/algorithms/online/generation_utils/vllm_utils.py +++ b/compose_rl/algorithms/online/generation_utils/vllm_utils.py @@ -177,8 +177,8 @@ def update_weight( group=self._model_update_group, ) - weight = weight * self._chaos_updates[0] - self._chaos_updates = self._chaos_updates[1:] + [self._chaos_updates[0]] + # weight = weight * self._chaos_updates[0] + # self._chaos_updates = self._chaos_updates[1:] + [self._chaos_updates[0]] if ".25." in name: if len(shape) == 2: From 513ff595284d93f5bfc6dfc70943849f0eca134d Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Tue, 12 Aug 2025 15:25:46 -0700 Subject: [PATCH 050/109] fp32 everything --- .../algorithms/online/generation_utils/vllm_utils.py | 2 +- yamls/single-controller-grpo-workflow.yaml | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/compose_rl/algorithms/online/generation_utils/vllm_utils.py b/compose_rl/algorithms/online/generation_utils/vllm_utils.py index b44d2434..f962702d 100644 --- a/compose_rl/algorithms/online/generation_utils/vllm_utils.py +++ b/compose_rl/algorithms/online/generation_utils/vllm_utils.py @@ -318,7 +318,7 @@ def create_vllm_engines( 'compose_rl.algorithms.online.generation_utils.vllm_utils.WorkerWrap', tensor_parallel_size=tensor_parallel_size, # type: ignore enforce_eager=enforce_eager, # type: ignore - dtype='bfloat16', # type: ignore + dtype='float32', # type: ignore seed=seed + i, # type: ignore distributed_executor_backend= # type: ignore distributed_executor_backend, diff --git a/yamls/single-controller-grpo-workflow.yaml b/yamls/single-controller-grpo-workflow.yaml index b8bb5260..d965fa87 100644 --- a/yamls/single-controller-grpo-workflow.yaml +++ b/yamls/single-controller-grpo-workflow.yaml @@ -72,7 +72,7 @@ parameters: memory_monitor: {} hf_checkpointer: overwrite: true - precision: bfloat16 + precision: float32 save_folder: /tmp/hf_checkpoints/ save_interval: 1dur runtime_estimator: {} @@ -83,7 +83,7 @@ parameters: - 0.9 - 0.95 weight_decay: 0 - precision: amp_bf16 + precision: float32 scheduler: name: constant_with_warmup alpha: 1 @@ -129,7 +129,7 @@ parameters: tokenizer_name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B num_train_nodes: 1 reference_model: - precision: amp_bf16 + precision: float32 pretrained: true model_config: name: hf_causal_lm @@ -156,7 +156,7 @@ parameters: sync_module_states: true verbose: false cpu_offload: false - mixed_precision: PURE + mixed_precision: FULL state_dict_type: sharded use_orig_params: true forward_prefetch: true From 6a81fd1f1457a6d968aa35e818df8010603ce53e Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Tue, 12 Aug 2025 15:26:03 -0700 Subject: [PATCH 051/109] enable chaos --- compose_rl/algorithms/online/generation_utils/vllm_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/compose_rl/algorithms/online/generation_utils/vllm_utils.py b/compose_rl/algorithms/online/generation_utils/vllm_utils.py index f962702d..e318eded 100644 --- a/compose_rl/algorithms/online/generation_utils/vllm_utils.py +++ b/compose_rl/algorithms/online/generation_utils/vllm_utils.py @@ -177,8 +177,8 @@ def update_weight( group=self._model_update_group, ) - # weight = weight * self._chaos_updates[0] - # self._chaos_updates = self._chaos_updates[1:] + [self._chaos_updates[0]] + weight = weight * self._chaos_updates[0] + self._chaos_updates = self._chaos_updates[1:] + [self._chaos_updates[0]] if ".25." in name: if len(shape) == 2: From 2822d71aee5fa4e5dc42d1bddf32fcc55577cf1a Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Tue, 12 Aug 2025 15:28:53 -0700 Subject: [PATCH 052/109] god dammit flash attention --- yamls/single-controller-grpo-workflow.yaml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/yamls/single-controller-grpo-workflow.yaml b/yamls/single-controller-grpo-workflow.yaml index d965fa87..b8bb5260 100644 --- a/yamls/single-controller-grpo-workflow.yaml +++ b/yamls/single-controller-grpo-workflow.yaml @@ -72,7 +72,7 @@ parameters: memory_monitor: {} hf_checkpointer: overwrite: true - precision: float32 + precision: bfloat16 save_folder: /tmp/hf_checkpoints/ save_interval: 1dur runtime_estimator: {} @@ -83,7 +83,7 @@ parameters: - 0.9 - 0.95 weight_decay: 0 - precision: float32 + precision: amp_bf16 scheduler: name: constant_with_warmup alpha: 1 @@ -129,7 +129,7 @@ parameters: tokenizer_name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B num_train_nodes: 1 reference_model: - precision: float32 + precision: amp_bf16 pretrained: true model_config: name: hf_causal_lm @@ -156,7 +156,7 @@ parameters: sync_module_states: true verbose: false cpu_offload: false - mixed_precision: FULL + mixed_precision: PURE state_dict_type: sharded use_orig_params: true forward_prefetch: true From 801099c2462ca188caf8620faa739e5b3027a25e Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Tue, 12 Aug 2025 15:34:40 -0700 Subject: [PATCH 053/109] disable flash attention on vllm --- test_single_controller_ppo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test_single_controller_ppo.py b/test_single_controller_ppo.py index d5c7e676..ae0ab19e 100644 --- a/test_single_controller_ppo.py +++ b/test_single_controller_ppo.py @@ -769,7 +769,7 @@ def _run_single_controller_ppo( """ # Set vLLM attention backend to FLASH_ATTN otherwise FlashInfer backend # takes too long to jit compile - os.environ['VLLM_ATTENTION_BACKEND'] = 'FLASH_ATTN' + # os.environ['VLLM_ATTENTION_BACKEND'] = 'FLASH_ATTN' # Disable setting CUDA_VISIBLE_DEVICES by ray, we will set it manually os.environ['RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES'] = '1' From 80154765b08a972ca2348db0602e26b83e7d242f Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Tue, 12 Aug 2025 15:47:06 -0700 Subject: [PATCH 054/109] try out torch SDPA --- test_single_controller_ppo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test_single_controller_ppo.py b/test_single_controller_ppo.py index ae0ab19e..18b9da23 100644 --- a/test_single_controller_ppo.py +++ b/test_single_controller_ppo.py @@ -769,7 +769,7 @@ def _run_single_controller_ppo( """ # Set vLLM attention backend to FLASH_ATTN otherwise FlashInfer backend # takes too long to jit compile - # os.environ['VLLM_ATTENTION_BACKEND'] = 'FLASH_ATTN' + os.environ['VLLM_ATTENTION_BACKEND'] = 'TORCH_SDPA' # Disable setting CUDA_VISIBLE_DEVICES by ray, we will set it manually os.environ['RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES'] = '1' From fe7f76cb0bdef468739ef9a44a2183e6f7ab0263 Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Tue, 12 Aug 2025 15:52:29 -0700 Subject: [PATCH 055/109] :/ --- compose_rl/algorithms/online/generation_utils/vllm_utils.py | 2 +- test_single_controller_ppo.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/compose_rl/algorithms/online/generation_utils/vllm_utils.py b/compose_rl/algorithms/online/generation_utils/vllm_utils.py index e318eded..5451c613 100644 --- a/compose_rl/algorithms/online/generation_utils/vllm_utils.py +++ b/compose_rl/algorithms/online/generation_utils/vllm_utils.py @@ -318,7 +318,7 @@ def create_vllm_engines( 'compose_rl.algorithms.online.generation_utils.vllm_utils.WorkerWrap', tensor_parallel_size=tensor_parallel_size, # type: ignore enforce_eager=enforce_eager, # type: ignore - dtype='float32', # type: ignore + dtype='bf16', # type: ignore seed=seed + i, # type: ignore distributed_executor_backend= # type: ignore distributed_executor_backend, diff --git a/test_single_controller_ppo.py b/test_single_controller_ppo.py index 18b9da23..d5c7e676 100644 --- a/test_single_controller_ppo.py +++ b/test_single_controller_ppo.py @@ -769,7 +769,7 @@ def _run_single_controller_ppo( """ # Set vLLM attention backend to FLASH_ATTN otherwise FlashInfer backend # takes too long to jit compile - os.environ['VLLM_ATTENTION_BACKEND'] = 'TORCH_SDPA' + os.environ['VLLM_ATTENTION_BACKEND'] = 'FLASH_ATTN' # Disable setting CUDA_VISIBLE_DEVICES by ray, we will set it manually os.environ['RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES'] = '1' From a48416f2ade4e2cf8bea5b0e330ba4766ec89829 Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Tue, 12 Aug 2025 15:56:19 -0700 Subject: [PATCH 056/109] proper fsdp summon --- test_single_controller_ppo.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test_single_controller_ppo.py b/test_single_controller_ppo.py index d5c7e676..f43775d7 100644 --- a/test_single_controller_ppo.py +++ b/test_single_controller_ppo.py @@ -259,8 +259,8 @@ def train_1_iter(self): with FSDP.summon_full_params( module, writeback=False, - rank0_only=True, - recurse=False, + rank0_only=False, + recurse=True, ): for _, param in module.named_parameters(recurse=True): full_name = param2fullname[param] From 21fafa1c50ecf2a34bc2af7b32504e1c56812f5e Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Tue, 12 Aug 2025 15:57:25 -0700 Subject: [PATCH 057/109] bfloat16 --- compose_rl/algorithms/online/generation_utils/vllm_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compose_rl/algorithms/online/generation_utils/vllm_utils.py b/compose_rl/algorithms/online/generation_utils/vllm_utils.py index 5451c613..4dc89f70 100644 --- a/compose_rl/algorithms/online/generation_utils/vllm_utils.py +++ b/compose_rl/algorithms/online/generation_utils/vllm_utils.py @@ -318,7 +318,7 @@ def create_vllm_engines( 'compose_rl.algorithms.online.generation_utils.vllm_utils.WorkerWrap', tensor_parallel_size=tensor_parallel_size, # type: ignore enforce_eager=enforce_eager, # type: ignore - dtype='bf16', # type: ignore + dtype='bfloat16', # type: ignore seed=seed + i, # type: ignore distributed_executor_backend= # type: ignore distributed_executor_backend, From 35721b86e6d0889260589a35a14c6a3812d1efac Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Tue, 12 Aug 2025 15:58:18 -0700 Subject: [PATCH 058/109] disable chaos --- compose_rl/algorithms/online/generation_utils/vllm_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/compose_rl/algorithms/online/generation_utils/vllm_utils.py b/compose_rl/algorithms/online/generation_utils/vllm_utils.py index 4dc89f70..b44d2434 100644 --- a/compose_rl/algorithms/online/generation_utils/vllm_utils.py +++ b/compose_rl/algorithms/online/generation_utils/vllm_utils.py @@ -177,8 +177,8 @@ def update_weight( group=self._model_update_group, ) - weight = weight * self._chaos_updates[0] - self._chaos_updates = self._chaos_updates[1:] + [self._chaos_updates[0]] + # weight = weight * self._chaos_updates[0] + # self._chaos_updates = self._chaos_updates[1:] + [self._chaos_updates[0]] if ".25." in name: if len(shape) == 2: From 7cb3c3b572d54f3877040263de1fd6af22b1000d Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Tue, 12 Aug 2025 17:08:51 -0700 Subject: [PATCH 059/109] better logging --- .../online/generation_utils/vllm_utils.py | 49 +++---------------- test_single_controller_ppo.py | 46 +++++++++-------- 2 files changed, 28 insertions(+), 67 deletions(-) diff --git a/compose_rl/algorithms/online/generation_utils/vllm_utils.py b/compose_rl/algorithms/online/generation_utils/vllm_utils.py index b44d2434..a9a91d9b 100644 --- a/compose_rl/algorithms/online/generation_utils/vllm_utils.py +++ b/compose_rl/algorithms/online/generation_utils/vllm_utils.py @@ -130,16 +130,6 @@ def init_process_group( rank = torch.distributed.get_rank() + rank_offset - with open(f'/tmp/compose-rl-init_process_group_log_{rank}.txt', 'w') as f: - f.write(f'torch_rank: {torch.distributed.get_rank()}\n') - f.write(f'offset: {rank_offset}\n') - f.write(f'rank: {rank}\n') - f.write(f'world_size: {world_size}\n') - f.write(f'group_name: {group_name}\n') - f.write(f'backend: {backend}\n') - f.write(f'master_address: {master_address}\n') - f.write(f'master_port: {master_port}\n') - self._rank = rank self._model_update_group = init_process_group( # type: ignore backend=backend, @@ -148,7 +138,6 @@ def init_process_group( rank=rank, group_name=group_name, ) - self._chaos_updates = [0.99, 1.01, 1.02, 0.98, 1] log.info(f'init process group for: {torch.distributed.get_rank()}') log.info( f'init_process_group: master_address={master_address}, master_port={master_port}, ' @@ -177,16 +166,8 @@ def update_weight( group=self._model_update_group, ) - # weight = weight * self._chaos_updates[0] - # self._chaos_updates = self._chaos_updates[1:] + [self._chaos_updates[0]] - - if ".25." in name: - if len(shape) == 2: - weight_str = f"{weight[0, :10]}, ... {weight[-1, -10:]}" - elif len(shape) == 1: - weight_str = f"{weight[:10]}, ... {weight[-10:]}" - else: - weight_str = f"{weight[..., :10]}, ... {weight[..., -10:]}" + with open(f"/tmp/compose-rl-worker-{self._rank}.txt", "a") as f: + f.write(f"Weight {name} at receiving weight, size = {shape}, weight_sum = {torch.sum(weight)}\n") # Because FSDP keeps master weights in FP32 and vLLM typically doesn't do this # We will need to cast the weight type to the model_config type @@ -220,19 +201,9 @@ def update_weight( if write: updated_weight_tensor = [weight_param.data for weight_name, weight_param in self.model_runner.model.named_parameters() if weight_name == updated_weights][0] + with open(f"/tmp/compose-rl-worker-{self._rank}.txt", "a") as f: + f.write(f"Weight {updated_weights} at updating weight, size = {updated_weight_tensor.shape}, weight_sum = {torch.sum(updated_weight_tensor)}\n") - if ".25." in name: - with open(f"/tmp/compose-rl-worker-{self._rank}.txt", "a") as f: - if len(shape) == 2: - updated_weight_tensor_str = f"{updated_weight_tensor[0, :10]}, ... {updated_weight_tensor[-1, -10:]}" - elif len(shape) == 1: - updated_weight_tensor_str = f"{updated_weight_tensor[:10]}, ... {updated_weight_tensor[-10:]}" - else: - updated_weight_tensor_str = f"{updated_weight_tensor[..., :10]}, ... {updated_weight_tensor[..., -10:]}" - f.write(f"Received weight {name} with shape {shape} and dtype {dtype}\n") - f.write(f"size = {weight.size()}, weight = {weight_str}\n") - f.write(f"updated_weights = {updated_weights}\n") - f.write(f"size = {updated_weight_tensor.size()}, weight = {updated_weight_tensor_str}\n") del weight if empty_cache: @@ -568,16 +539,8 @@ def broadcast_to_vllm( ) refs.append(ref) refss.extend(refs) - if ".25." in parsed_name: - with open(f"/tmp/compose-rl-master.txt", "a") as f: - if len(shape) == 2: - weight_str = f"{param.data[0, :10]}, ... {param.data[-1, -10:]}" - elif len(shape) == 1: - weight_str = f"{param.data[:10]}, ... {param.data[-10:]}" - else: - weight_str = f"{param.data[..., :10]}, ... {param.data[..., -10:]}" - f.write(f"Sending weight {parsed_name} to engine {engine} with dtype {param.dtype}\n") - f.write(f"size = {shape}, weight = {weight_str}\n") + with open(f"/tmp/compose-rl-master.txt", "a") as f: + f.write(f"Weight {parsed_name} at sending weight, size = {shape}, weight_sum = {torch.sum(param.data)}\n") torch.distributed.broadcast( param.data, 0, diff --git a/test_single_controller_ppo.py b/test_single_controller_ppo.py index f43775d7..50d89eb5 100644 --- a/test_single_controller_ppo.py +++ b/test_single_controller_ppo.py @@ -247,35 +247,33 @@ def train_1_iter(self): # fit() checks if there is existing checkpoint, make a full forward pass, it will run eval pass and save pass. # We potentially want to run this https://github.com/mosaicml/composer/blob/dev/composer/trainer/trainer.py#L2826 # fit() can also potentially overwrite the mlflow - self.ppo_trainer.fit(duration='1iter') - self.logger.info(f"#### Finished training 1 iter with loss: {self.ppo_trainer.state.loss}") + def write_params(model: torch.nn.Module, param2fullname: dict, update_stamp: str): + for _, module in model.named_modules(): + if isinstance(module, FSDP): + with FSDP.summon_full_params( + module, + writeback=False, + rank0_only=False, + recurse=True, + ): + for _, param in module.named_parameters(recurse=True): + full_name = param2fullname[param] + parsed_name = simplify_param_path(full_name) + shape = param.shape + with open(f"/tmp/compose-rl-train-{self.rank}.txt", "a") as f: + f.write(f"Weight {parsed_name} at {update_stamp}, size = {shape}, weight_sum = {torch.sum(param.data)}\n") + model = self.ppo_trainer.state.model param2fullname = build_param_fullnames(model) - for _, module in model.named_modules(): - if isinstance(module, FSDP): - with FSDP.summon_full_params( - module, - writeback=False, - rank0_only=False, - recurse=True, - ): - for _, param in module.named_parameters(recurse=True): - full_name = param2fullname[param] - parsed_name = simplify_param_path(full_name) - shape = param.shape - if ".25." in parsed_name: - with open(f"/tmp/compose-rl-train-{self.rank}.txt", "a") as f: - if len(shape) == 2: - weight_str = f"{param.data[0, :10]}, ... {param.data[-1, -10:]}" - elif len(shape) == 1: - weight_str = f"{param.data[:10]}, ... {param.data[-10:]}" - else: - weight_str = f"{param.data[..., :10]}, ... {param.data[..., -10:]}" - f.write(f"Weight {parsed_name}\n") - f.write(f"size = {shape}, weight = {weight_str}\n") + write_params(model, param2fullname, 'before train iter') + + self.ppo_trainer.fit(duration='1iter') + self.logger.info(f"#### Finished training 1 iter with loss: {self.ppo_trainer.state.loss}") + + write_params(model, param2fullname, 'after train iter') def setup_process_groups( From f92628332fd9542ee05ee15756222f464429288c Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Tue, 12 Aug 2025 18:13:24 -0700 Subject: [PATCH 060/109] logging how many examples being trained --- compose_rl/algorithms/online/single_controller_callback.py | 1 + test_single_controller_ppo.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/compose_rl/algorithms/online/single_controller_callback.py b/compose_rl/algorithms/online/single_controller_callback.py index 5452feb6..a4ddc207 100644 --- a/compose_rl/algorithms/online/single_controller_callback.py +++ b/compose_rl/algorithms/online/single_controller_callback.py @@ -53,6 +53,7 @@ def iteration_start(self, state: State, logger: Logger): state.auto_microbatching, state.train_dataloader, ) + print(f"Training on {len(self.buffer)} examples") # Update IFT KL self._update_ift_kl() diff --git a/test_single_controller_ppo.py b/test_single_controller_ppo.py index 50d89eb5..90a0a1c0 100644 --- a/test_single_controller_ppo.py +++ b/test_single_controller_ppo.py @@ -263,7 +263,6 @@ def write_params(model: torch.nn.Module, param2fullname: dict, update_stamp: str with open(f"/tmp/compose-rl-train-{self.rank}.txt", "a") as f: f.write(f"Weight {parsed_name} at {update_stamp}, size = {shape}, weight_sum = {torch.sum(param.data)}\n") - model = self.ppo_trainer.state.model param2fullname = build_param_fullnames(model) @@ -365,6 +364,7 @@ async def run(self, num_iterations: int, experience_buffer: 'ExperienceBuffer', # Simple example of adding elements to the experience buffer # Populate the train actor group with the rollouts and then train latest_rollouts = await experience_buffer.get() + print(f"Obtained {len(latest_rollouts['verified_answer'])} examples") self._add_latest_rollouts(latest_rollouts) await asyncio.to_thread(self.train_1_iter) # TODO decide where should we use the lock and the semaphore From b133fb601c43e8f8145fc5cf6fdeb302f5300d39 Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Tue, 12 Aug 2025 18:38:02 -0700 Subject: [PATCH 061/109] more logging --- compose_rl/algorithms/online/callback.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/compose_rl/algorithms/online/callback.py b/compose_rl/algorithms/online/callback.py index 8ed96ff9..785174e9 100644 --- a/compose_rl/algorithms/online/callback.py +++ b/compose_rl/algorithms/online/callback.py @@ -1125,3 +1125,6 @@ def state_dict(self): def load_state_dict(self, state_dict: dict[str, Any]): self.kl_ctl.load_state_dict(state_dict['KL_ctl_state_dict']) self.iter_num = state_dict['iter_num'] + + def before_forward(self, state: State, logger: Logger): + print(f"Before forward, training on {state.batch["prompt_id"]}") From 816a527d1d6dd24c2ca0e92a14867049dc1595dd Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Tue, 12 Aug 2025 18:53:53 -0700 Subject: [PATCH 062/109] log loss --- compose_rl/algorithms/online/callback.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/compose_rl/algorithms/online/callback.py b/compose_rl/algorithms/online/callback.py index 785174e9..2c65de74 100644 --- a/compose_rl/algorithms/online/callback.py +++ b/compose_rl/algorithms/online/callback.py @@ -1128,3 +1128,6 @@ def load_state_dict(self, state_dict: dict[str, Any]): def before_forward(self, state: State, logger: Logger): print(f"Before forward, training on {state.batch["prompt_id"]}") + + def after_loss(self, state: State, logger: Logger): + print(f"After loss, {state.loss}") From 5cbb05a85c5d3e0671641e720fae83d40eb94ee4 Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Tue, 12 Aug 2025 19:01:33 -0700 Subject: [PATCH 063/109] debugging batch better --- compose_rl/algorithms/online/callback.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/compose_rl/algorithms/online/callback.py b/compose_rl/algorithms/online/callback.py index 2c65de74..c4c6f694 100644 --- a/compose_rl/algorithms/online/callback.py +++ b/compose_rl/algorithms/online/callback.py @@ -1127,7 +1127,10 @@ def load_state_dict(self, state_dict: dict[str, Any]): self.iter_num = state_dict['iter_num'] def before_forward(self, state: State, logger: Logger): - print(f"Before forward, training on {state.batch["prompt_id"]}") + print(f"Before forward, training on {state.batch}") def after_loss(self, state: State, logger: Logger): print(f"After loss, {state.loss}") + + def batch_start(self, state: State, logger: Logger): + print("Batch start, training on samples ", state.batch["prompt_id"]) \ No newline at end of file From c301055621c4f512f0728760972746e22d163143 Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Tue, 12 Aug 2025 19:54:49 -0700 Subject: [PATCH 064/109] shorter for better debugging --- yamls/single-controller-grpo-workflow.yaml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/yamls/single-controller-grpo-workflow.yaml b/yamls/single-controller-grpo-workflow.yaml index b8bb5260..1f5f115a 100644 --- a/yamls/single-controller-grpo-workflow.yaml +++ b/yamls/single-controller-grpo-workflow.yaml @@ -95,7 +95,7 @@ parameters: pad_token: <|finetune_right_pad_id|> truncation: true padding_side: left - model_max_length: 10240 + model_max_length: 1024 trust_remote_code: true variables: gamma: 1 @@ -118,7 +118,7 @@ parameters: len_threshold: 10 lambda_gae: 1 global_seed: 17 - max_gen_len: 8192 + max_gen_len: 1024 eos_token_ids: - 128001 - 128008 @@ -165,7 +165,7 @@ parameters: activation_cpu_offload: false activation_checkpointing: true activation_checkpointing_reentrant: false - max_seq_len: 10240 + max_seq_len: 1024 save_folder: /tmp/checkpoints dist_timeout: 1800 max_duration: 10iter @@ -177,8 +177,8 @@ parameters: split: train remote: dbfs:/Volumes/datasets/ashutoshbaheti/orl_data/open_r1_filtered/dpsk_8b_open_r1_48k/ shuffle: true - max_gen_len: 8192 - max_seq_len: 10240 + max_gen_len: 1024 + max_seq_len: 1024 shuffle_seed: 17 download_timeout: 1800 drop_last: true From b5fc1f082852bb26bb68030247b6e06662358a0d Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Tue, 12 Aug 2025 19:58:33 -0700 Subject: [PATCH 065/109] more len --- yamls/single-controller-grpo-workflow.yaml | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/yamls/single-controller-grpo-workflow.yaml b/yamls/single-controller-grpo-workflow.yaml index 1f5f115a..9ebac343 100644 --- a/yamls/single-controller-grpo-workflow.yaml +++ b/yamls/single-controller-grpo-workflow.yaml @@ -63,7 +63,7 @@ parameters: - name: math_hard eval_overrides: generation_params: - max_tokens: 8192 + max_tokens: 4000 lr_monitor: {} scheduled_gc: batch_interval: 1000 @@ -95,7 +95,7 @@ parameters: pad_token: <|finetune_right_pad_id|> truncation: true padding_side: left - model_max_length: 1024 + model_max_length: 6400 trust_remote_code: true variables: gamma: 1 @@ -118,7 +118,7 @@ parameters: len_threshold: 10 lambda_gae: 1 global_seed: 17 - max_gen_len: 1024 + max_gen_len: 4000 eos_token_ids: - 128001 - 128008 @@ -165,7 +165,7 @@ parameters: activation_cpu_offload: false activation_checkpointing: true activation_checkpointing_reentrant: false - max_seq_len: 1024 + max_seq_len: 6400 save_folder: /tmp/checkpoints dist_timeout: 1800 max_duration: 10iter @@ -177,8 +177,8 @@ parameters: split: train remote: dbfs:/Volumes/datasets/ashutoshbaheti/orl_data/open_r1_filtered/dpsk_8b_open_r1_48k/ shuffle: true - max_gen_len: 1024 - max_seq_len: 1024 + max_gen_len: 4000 + max_seq_len: 6400 shuffle_seed: 17 download_timeout: 1800 drop_last: true From aefd7372d5fb44918efce73efac0c9c5e680a938 Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Tue, 12 Aug 2025 20:27:22 -0700 Subject: [PATCH 066/109] test --- compose_rl/algorithms/online/callback.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/compose_rl/algorithms/online/callback.py b/compose_rl/algorithms/online/callback.py index c4c6f694..ecf6b95d 100644 --- a/compose_rl/algorithms/online/callback.py +++ b/compose_rl/algorithms/online/callback.py @@ -1127,10 +1127,11 @@ def load_state_dict(self, state_dict: dict[str, Any]): self.iter_num = state_dict['iter_num'] def before_forward(self, state: State, logger: Logger): - print(f"Before forward, training on {state.batch}") + if dist.get_global_rank() == 0: + skip_keys = ['prompt_id', 'prompt', 'prompt_len', 'verified_answer', 'prompt_attention_mask', 'sequences'] + to_log = {k: v for k, v in state.batch.items() if k not in skip_keys} + print(f"Before forward, training on {to_log}") def after_loss(self, state: State, logger: Logger): - print(f"After loss, {state.loss}") - - def batch_start(self, state: State, logger: Logger): - print("Batch start, training on samples ", state.batch["prompt_id"]) \ No newline at end of file + if dist.get_global_rank() == 0: + print(f"After loss, {state.loss}") From fd47fc0be59d49689f60c3c3f274c193957ccd62 Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Tue, 12 Aug 2025 20:33:03 -0700 Subject: [PATCH 067/109] test --- compose_rl/algorithms/online/callback.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/compose_rl/algorithms/online/callback.py b/compose_rl/algorithms/online/callback.py index ecf6b95d..0df57360 100644 --- a/compose_rl/algorithms/online/callback.py +++ b/compose_rl/algorithms/online/callback.py @@ -1132,6 +1132,10 @@ def before_forward(self, state: State, logger: Logger): to_log = {k: v for k, v in state.batch.items() if k not in skip_keys} print(f"Before forward, training on {to_log}") + def after_forward(self, state: State, logger: Logger): + if dist.get_global_rank() == 0: + print(f"After forward, outputs: {state.outputs}") + def after_loss(self, state: State, logger: Logger): if dist.get_global_rank() == 0: print(f"After loss, {state.loss}") From 2a4567507de8483f821ea3b8a07ce6d633013e93 Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Tue, 12 Aug 2025 21:01:28 -0700 Subject: [PATCH 068/109] not cpu --- compose_rl/algorithms/online/callback.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/compose_rl/algorithms/online/callback.py b/compose_rl/algorithms/online/callback.py index 0df57360..90d091b5 100644 --- a/compose_rl/algorithms/online/callback.py +++ b/compose_rl/algorithms/online/callback.py @@ -947,8 +947,8 @@ def _resolve_outputs( bs = iter_batch['prompt_id'].shape[0] iter_batch.update({ - 'adv_masked_mean': torch.ones(bs) * batch_adv_mean.cpu(), - 'adv_masked_var': torch.ones(bs) * batch_adv_var.cpu(), + 'adv_masked_mean': torch.ones(bs) * batch_adv_mean, + 'adv_masked_var': torch.ones(bs) * batch_adv_var, }) mean_ift = masked_mean( From 09a0c114504a15c29801a757d4db49f6869060f8 Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Tue, 12 Aug 2025 21:10:40 -0700 Subject: [PATCH 069/109] gpu --- compose_rl/algorithms/online/callback.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/compose_rl/algorithms/online/callback.py b/compose_rl/algorithms/online/callback.py index 90d091b5..09472ae1 100644 --- a/compose_rl/algorithms/online/callback.py +++ b/compose_rl/algorithms/online/callback.py @@ -947,8 +947,8 @@ def _resolve_outputs( bs = iter_batch['prompt_id'].shape[0] iter_batch.update({ - 'adv_masked_mean': torch.ones(bs) * batch_adv_mean, - 'adv_masked_var': torch.ones(bs) * batch_adv_var, + 'adv_masked_mean': torch.ones(bs, device=batch_adv_mean.device) * batch_adv_mean, + 'adv_masked_var': torch.ones(bs, device=batch_adv_var.device) * batch_adv_var, }) mean_ift = masked_mean( From 787248800b380f48f4c4021143b04d13fc1f096b Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Tue, 12 Aug 2025 21:23:50 -0700 Subject: [PATCH 070/109] debug more --- compose_rl/algorithms/online/callback.py | 2 ++ compose_rl/utils/utils.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/compose_rl/algorithms/online/callback.py b/compose_rl/algorithms/online/callback.py index 09472ae1..364a1050 100644 --- a/compose_rl/algorithms/online/callback.py +++ b/compose_rl/algorithms/online/callback.py @@ -945,6 +945,8 @@ def _resolve_outputs( env_outs['action_mask'], ) + print(f"batch_adv_mean: {batch_adv_mean}, batch_adv_var: {batch_adv_var}") + bs = iter_batch['prompt_id'].shape[0] iter_batch.update({ 'adv_masked_mean': torch.ones(bs, device=batch_adv_mean.device) * batch_adv_mean, diff --git a/compose_rl/utils/utils.py b/compose_rl/utils/utils.py index bfd172a4..fea0e16c 100644 --- a/compose_rl/utils/utils.py +++ b/compose_rl/utils/utils.py @@ -432,6 +432,8 @@ def dist_compute_masked_mean_and_var( centered_values = centered_values.sum() dist.all_reduce(centered_values) + + print(f"masked_tensor_sum: {masked_tensor_sum}, centered_values: {centered_values}, num_unmasked_elements: {num_unmasked_elements}") global_variance = centered_values / num_unmasked_elements if unbiased: bessel_correction = num_unmasked_elements / (num_unmasked_elements - 1) From 3484cf83125630ce88e77b91332afb352885e991 Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Tue, 12 Aug 2025 21:47:14 -0700 Subject: [PATCH 071/109] disable logging --- compose_rl/algorithms/online/callback.py | 2 -- compose_rl/utils/utils.py | 1 - 2 files changed, 3 deletions(-) diff --git a/compose_rl/algorithms/online/callback.py b/compose_rl/algorithms/online/callback.py index 364a1050..09472ae1 100644 --- a/compose_rl/algorithms/online/callback.py +++ b/compose_rl/algorithms/online/callback.py @@ -945,8 +945,6 @@ def _resolve_outputs( env_outs['action_mask'], ) - print(f"batch_adv_mean: {batch_adv_mean}, batch_adv_var: {batch_adv_var}") - bs = iter_batch['prompt_id'].shape[0] iter_batch.update({ 'adv_masked_mean': torch.ones(bs, device=batch_adv_mean.device) * batch_adv_mean, diff --git a/compose_rl/utils/utils.py b/compose_rl/utils/utils.py index fea0e16c..e5812a0e 100644 --- a/compose_rl/utils/utils.py +++ b/compose_rl/utils/utils.py @@ -433,7 +433,6 @@ def dist_compute_masked_mean_and_var( dist.all_reduce(centered_values) - print(f"masked_tensor_sum: {masked_tensor_sum}, centered_values: {centered_values}, num_unmasked_elements: {num_unmasked_elements}") global_variance = centered_values / num_unmasked_elements if unbiased: bessel_correction = num_unmasked_elements / (num_unmasked_elements - 1) From d3b16bee421f105ac3d8826725ed108fe0f711f4 Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Tue, 12 Aug 2025 21:50:41 -0700 Subject: [PATCH 072/109] change gen len --- yamls/single-controller-grpo-workflow.yaml | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/yamls/single-controller-grpo-workflow.yaml b/yamls/single-controller-grpo-workflow.yaml index 9ebac343..b8bb5260 100644 --- a/yamls/single-controller-grpo-workflow.yaml +++ b/yamls/single-controller-grpo-workflow.yaml @@ -63,7 +63,7 @@ parameters: - name: math_hard eval_overrides: generation_params: - max_tokens: 4000 + max_tokens: 8192 lr_monitor: {} scheduled_gc: batch_interval: 1000 @@ -95,7 +95,7 @@ parameters: pad_token: <|finetune_right_pad_id|> truncation: true padding_side: left - model_max_length: 6400 + model_max_length: 10240 trust_remote_code: true variables: gamma: 1 @@ -118,7 +118,7 @@ parameters: len_threshold: 10 lambda_gae: 1 global_seed: 17 - max_gen_len: 4000 + max_gen_len: 8192 eos_token_ids: - 128001 - 128008 @@ -165,7 +165,7 @@ parameters: activation_cpu_offload: false activation_checkpointing: true activation_checkpointing_reentrant: false - max_seq_len: 6400 + max_seq_len: 10240 save_folder: /tmp/checkpoints dist_timeout: 1800 max_duration: 10iter @@ -177,8 +177,8 @@ parameters: split: train remote: dbfs:/Volumes/datasets/ashutoshbaheti/orl_data/open_r1_filtered/dpsk_8b_open_r1_48k/ shuffle: true - max_gen_len: 4000 - max_seq_len: 6400 + max_gen_len: 8192 + max_seq_len: 10240 shuffle_seed: 17 download_timeout: 1800 drop_last: true From 6cd39806ce7b4b246de5f308b52c37ad475c128e Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Tue, 12 Aug 2025 21:51:19 -0700 Subject: [PATCH 073/109] disable logging --- compose_rl/algorithms/online/callback.py | 24 +++++++++++++----------- compose_rl/utils/utils.py | 1 + 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/compose_rl/algorithms/online/callback.py b/compose_rl/algorithms/online/callback.py index 09472ae1..1fec957e 100644 --- a/compose_rl/algorithms/online/callback.py +++ b/compose_rl/algorithms/online/callback.py @@ -945,6 +945,8 @@ def _resolve_outputs( env_outs['action_mask'], ) + print(f"batch_adv_mean: {batch_adv_mean}, batch_adv_var: {batch_adv_var}") + bs = iter_batch['prompt_id'].shape[0] iter_batch.update({ 'adv_masked_mean': torch.ones(bs, device=batch_adv_mean.device) * batch_adv_mean, @@ -1126,16 +1128,16 @@ def load_state_dict(self, state_dict: dict[str, Any]): self.kl_ctl.load_state_dict(state_dict['KL_ctl_state_dict']) self.iter_num = state_dict['iter_num'] - def before_forward(self, state: State, logger: Logger): - if dist.get_global_rank() == 0: - skip_keys = ['prompt_id', 'prompt', 'prompt_len', 'verified_answer', 'prompt_attention_mask', 'sequences'] - to_log = {k: v for k, v in state.batch.items() if k not in skip_keys} - print(f"Before forward, training on {to_log}") + # def before_forward(self, state: State, logger: Logger): + # if dist.get_global_rank() == 0: + # skip_keys = ['prompt_id', 'prompt', 'prompt_len', 'verified_answer', 'prompt_attention_mask', 'sequences'] + # to_log = {k: v for k, v in state.batch.items() if k not in skip_keys} + # print(f"Before forward, training on {to_log}") - def after_forward(self, state: State, logger: Logger): - if dist.get_global_rank() == 0: - print(f"After forward, outputs: {state.outputs}") + # def after_forward(self, state: State, logger: Logger): + # if dist.get_global_rank() == 0: + # print(f"After forward, outputs: {state.outputs}") - def after_loss(self, state: State, logger: Logger): - if dist.get_global_rank() == 0: - print(f"After loss, {state.loss}") + # def after_loss(self, state: State, logger: Logger): + # if dist.get_global_rank() == 0: + # print(f"After loss, {state.loss}") diff --git a/compose_rl/utils/utils.py b/compose_rl/utils/utils.py index e5812a0e..fea0e16c 100644 --- a/compose_rl/utils/utils.py +++ b/compose_rl/utils/utils.py @@ -433,6 +433,7 @@ def dist_compute_masked_mean_and_var( dist.all_reduce(centered_values) + print(f"masked_tensor_sum: {masked_tensor_sum}, centered_values: {centered_values}, num_unmasked_elements: {num_unmasked_elements}") global_variance = centered_values / num_unmasked_elements if unbiased: bessel_correction = num_unmasked_elements / (num_unmasked_elements - 1) From 43e3df96f856410078e5f3b1b3ba02b9db2a0133 Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Tue, 12 Aug 2025 22:09:21 -0700 Subject: [PATCH 074/109] logging before --- compose_rl/utils/utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/compose_rl/utils/utils.py b/compose_rl/utils/utils.py index fea0e16c..986a3bdf 100644 --- a/compose_rl/utils/utils.py +++ b/compose_rl/utils/utils.py @@ -422,6 +422,8 @@ def dist_compute_masked_mean_and_var( # Get the masked tensor sum masked_tensor_sum = (tensor * mask).sum() + print(f"before dist, num_unmasked_elements: {num_unmasked_elements}, masked_tensor_sum: {masked_tensor_sum}") + dist.all_reduce(num_unmasked_elements) dist.all_reduce(masked_tensor_sum) From 96c98fe78246f850c1d6bf1f67411964565109bc Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Tue, 12 Aug 2025 22:13:40 -0700 Subject: [PATCH 075/109] increase global batch size --- yamls/single-controller-grpo-workflow.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/yamls/single-controller-grpo-workflow.yaml b/yamls/single-controller-grpo-workflow.yaml index b8bb5260..5cb3789f 100644 --- a/yamls/single-controller-grpo-workflow.yaml +++ b/yamls/single-controller-grpo-workflow.yaml @@ -191,7 +191,7 @@ parameters: console_log_interval: 1ba device_eval_batch_size: 1 eval_subset_num_batches: -1 - global_train_batch_size: 64 + global_train_batch_size: 128 device_train_microbatch_size: 1 vllm_tensor_parallel_size: 1 vllm_enable_prefix_caching: false From 33766c87b11e2b0af69999cb0b21c73e83260797 Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Tue, 12 Aug 2025 22:24:46 -0700 Subject: [PATCH 076/109] log rewards --- compose_rl/utils/utils.py | 2 ++ yamls/single-controller-grpo-workflow.yaml | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/compose_rl/utils/utils.py b/compose_rl/utils/utils.py index 986a3bdf..bc328802 100644 --- a/compose_rl/utils/utils.py +++ b/compose_rl/utils/utils.py @@ -406,6 +406,8 @@ def compute_advantages( for t in reversed(range(deltas.size(1) - 1)): advantages[:, t] = deltas[:, t] + discount * advantages[:, t + 1] + print(f"rewards: {rewards}, values: {values}, advantages: {advantages}") + return advantages diff --git a/yamls/single-controller-grpo-workflow.yaml b/yamls/single-controller-grpo-workflow.yaml index 5cb3789f..b8bb5260 100644 --- a/yamls/single-controller-grpo-workflow.yaml +++ b/yamls/single-controller-grpo-workflow.yaml @@ -191,7 +191,7 @@ parameters: console_log_interval: 1ba device_eval_batch_size: 1 eval_subset_num_batches: -1 - global_train_batch_size: 128 + global_train_batch_size: 64 device_train_microbatch_size: 1 vllm_tensor_parallel_size: 1 vllm_enable_prefix_caching: false From 930ed2d3a284152dfc8ad7062420b52b5d1089c0 Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Tue, 12 Aug 2025 22:38:06 -0700 Subject: [PATCH 077/109] more logs --- compose_rl/algorithms/online/callback.py | 3 +++ compose_rl/utils/utils.py | 4 +--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/compose_rl/algorithms/online/callback.py b/compose_rl/algorithms/online/callback.py index 1fec957e..21cb85c7 100644 --- a/compose_rl/algorithms/online/callback.py +++ b/compose_rl/algorithms/online/callback.py @@ -933,6 +933,9 @@ def _resolve_outputs( expanded_advantages, advantages, ) + + reward_sq = rewards * rewards + print(f"reward_sq: {reward_sq.sum()}, advantages: {advantages.sum()}") env_outs['advantages'] = advantages else: raise ValueError( diff --git a/compose_rl/utils/utils.py b/compose_rl/utils/utils.py index bc328802..5b8d4c47 100644 --- a/compose_rl/utils/utils.py +++ b/compose_rl/utils/utils.py @@ -406,8 +406,6 @@ def compute_advantages( for t in reversed(range(deltas.size(1) - 1)): advantages[:, t] = deltas[:, t] + discount * advantages[:, t + 1] - print(f"rewards: {rewards}, values: {values}, advantages: {advantages}") - return advantages @@ -424,7 +422,7 @@ def dist_compute_masked_mean_and_var( # Get the masked tensor sum masked_tensor_sum = (tensor * mask).sum() - print(f"before dist, num_unmasked_elements: {num_unmasked_elements}, masked_tensor_sum: {masked_tensor_sum}") + print(f"before dist, num_unmasked_elements: {num_unmasked_elements}, masked_tensor_sum: {masked_tensor_sum}, tensor: {tensor.sum()}") dist.all_reduce(num_unmasked_elements) dist.all_reduce(masked_tensor_sum) From 35024c8448823b85b598f355e134720a1bea2913 Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Tue, 12 Aug 2025 22:48:19 -0700 Subject: [PATCH 078/109] probe more --- compose_rl/algorithms/online/callback.py | 1 + 1 file changed, 1 insertion(+) diff --git a/compose_rl/algorithms/online/callback.py b/compose_rl/algorithms/online/callback.py index 21cb85c7..df0f3c39 100644 --- a/compose_rl/algorithms/online/callback.py +++ b/compose_rl/algorithms/online/callback.py @@ -935,6 +935,7 @@ def _resolve_outputs( ) reward_sq = rewards * rewards + print(f"flat_rewards: {flat_rewards.sum()}, expanded_advantages: {expanded_advantages.sum()}") print(f"reward_sq: {reward_sq.sum()}, advantages: {advantages.sum()}") env_outs['advantages'] = advantages else: From d634b861f4bad7595d9605c722e80de4b5e61306 Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Tue, 12 Aug 2025 23:00:49 -0700 Subject: [PATCH 079/109] log expanded advantages better --- compose_rl/algorithms/online/callback.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/compose_rl/algorithms/online/callback.py b/compose_rl/algorithms/online/callback.py index df0f3c39..e5413c08 100644 --- a/compose_rl/algorithms/online/callback.py +++ b/compose_rl/algorithms/online/callback.py @@ -935,7 +935,8 @@ def _resolve_outputs( ) reward_sq = rewards * rewards - print(f"flat_rewards: {flat_rewards.sum()}, expanded_advantages: {expanded_advantages.sum()}") + expanded_advantages_sq = expanded_advantages * expanded_advantages + print(f"flat_rewards: {flat_rewards.sum()}, expanded_advantages_sq: {expanded_advantages_sq.sum()}") print(f"reward_sq: {reward_sq.sum()}, advantages: {advantages.sum()}") env_outs['advantages'] = advantages else: From c0e95e1f510631b7b0df4f05fb0c4aa4fc9986e3 Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Tue, 12 Aug 2025 23:07:39 -0700 Subject: [PATCH 080/109] more log --- compose_rl/algorithms/online/callback.py | 1 + 1 file changed, 1 insertion(+) diff --git a/compose_rl/algorithms/online/callback.py b/compose_rl/algorithms/online/callback.py index e5413c08..21956ded 100644 --- a/compose_rl/algorithms/online/callback.py +++ b/compose_rl/algorithms/online/callback.py @@ -936,6 +936,7 @@ def _resolve_outputs( reward_sq = rewards * rewards expanded_advantages_sq = expanded_advantages * expanded_advantages + print(f"mean_rewards: {mean_rewards.sum()}, std_rewards: {std_rewards.sum()}") print(f"flat_rewards: {flat_rewards.sum()}, expanded_advantages_sq: {expanded_advantages_sq.sum()}") print(f"reward_sq: {reward_sq.sum()}, advantages: {advantages.sum()}") env_outs['advantages'] = advantages From 05c90d2c684da49d67b5ceb15359ff839b476326 Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Tue, 12 Aug 2025 23:14:59 -0700 Subject: [PATCH 081/109] log ids --- compose_rl/algorithms/online/callback.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/compose_rl/algorithms/online/callback.py b/compose_rl/algorithms/online/callback.py index 21956ded..c9847fd1 100644 --- a/compose_rl/algorithms/online/callback.py +++ b/compose_rl/algorithms/online/callback.py @@ -936,8 +936,9 @@ def _resolve_outputs( reward_sq = rewards * rewards expanded_advantages_sq = expanded_advantages * expanded_advantages + print(f"prompt_ids: {prompt_id}") print(f"mean_rewards: {mean_rewards.sum()}, std_rewards: {std_rewards.sum()}") - print(f"flat_rewards: {flat_rewards.sum()}, expanded_advantages_sq: {expanded_advantages_sq.sum()}") + print(f"flat_rewards: {flat_rewards.sum()}, expanded_advantages_sq: {expanded_advantages_sq.sum()}, action_masks: {env_outs['action_mask'].sum()}") print(f"reward_sq: {reward_sq.sum()}, advantages: {advantages.sum()}") env_outs['advantages'] = advantages else: From 46728494bfcc22ca5d815858e8dc0f98d0b977d1 Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Tue, 12 Aug 2025 23:27:02 -0700 Subject: [PATCH 082/109] wtf? --- compose_rl/algorithms/online/callback.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/compose_rl/algorithms/online/callback.py b/compose_rl/algorithms/online/callback.py index c9847fd1..60dca328 100644 --- a/compose_rl/algorithms/online/callback.py +++ b/compose_rl/algorithms/online/callback.py @@ -907,6 +907,12 @@ def _resolve_outputs( sums.scatter_add_(0, inverse_indices, flat_rewards) sum_squares.scatter_add_(0, inverse_indices, flat_rewards**2) + if dist.get_global_rank() == 0: + print(f"sums: {sums}") + print(f"sum_squares: {sum_squares}") + print(f"counts: {counts}") + print(f"prompt_ids: {prompt_id}") + # Compute means and standard deviations means = sums / counts variances = (sum_squares / counts) - (means**2) @@ -936,7 +942,6 @@ def _resolve_outputs( reward_sq = rewards * rewards expanded_advantages_sq = expanded_advantages * expanded_advantages - print(f"prompt_ids: {prompt_id}") print(f"mean_rewards: {mean_rewards.sum()}, std_rewards: {std_rewards.sum()}") print(f"flat_rewards: {flat_rewards.sum()}, expanded_advantages_sq: {expanded_advantages_sq.sum()}, action_masks: {env_outs['action_mask'].sum()}") print(f"reward_sq: {reward_sq.sum()}, advantages: {advantages.sum()}") From 26ff5a965958f6618c69344a4e4575fa59e25218 Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Tue, 12 Aug 2025 23:37:57 -0700 Subject: [PATCH 083/109] ??? --- compose_rl/algorithms/online/callback.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/compose_rl/algorithms/online/callback.py b/compose_rl/algorithms/online/callback.py index 60dca328..4397338c 100644 --- a/compose_rl/algorithms/online/callback.py +++ b/compose_rl/algorithms/online/callback.py @@ -908,6 +908,8 @@ def _resolve_outputs( sum_squares.scatter_add_(0, inverse_indices, flat_rewards**2) if dist.get_global_rank() == 0: + print(f"rewards: {rewards}") + print(f"flat_rewards: {flat_rewards}") print(f"sums: {sums}") print(f"sum_squares: {sum_squares}") print(f"counts: {counts}") From 92f109059638819b4346a1bd3fcb7459af1a8043 Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Wed, 13 Aug 2025 00:01:57 -0700 Subject: [PATCH 084/109] inverse indices --- compose_rl/algorithms/online/callback.py | 1 + 1 file changed, 1 insertion(+) diff --git a/compose_rl/algorithms/online/callback.py b/compose_rl/algorithms/online/callback.py index 4397338c..de9bf37e 100644 --- a/compose_rl/algorithms/online/callback.py +++ b/compose_rl/algorithms/online/callback.py @@ -909,6 +909,7 @@ def _resolve_outputs( if dist.get_global_rank() == 0: print(f"rewards: {rewards}") + print(f"inverse_indices: {inverse_indices}") print(f"flat_rewards: {flat_rewards}") print(f"sums: {sums}") print(f"sum_squares: {sum_squares}") From 171fcf9ed9678043e4bb894940632af201a7891c Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Wed, 13 Aug 2025 00:40:59 -0700 Subject: [PATCH 085/109] test wild theory --- yamls/single-controller-grpo-workflow.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/yamls/single-controller-grpo-workflow.yaml b/yamls/single-controller-grpo-workflow.yaml index b8bb5260..aa03581c 100644 --- a/yamls/single-controller-grpo-workflow.yaml +++ b/yamls/single-controller-grpo-workflow.yaml @@ -143,8 +143,8 @@ parameters: use_cache: true temperature: 1 epoch_per_iteration: 1 - generations_per_prompt: 8 - num_batches_per_update: 8 + generations_per_prompt: 4 + num_batches_per_update: 4 device_generate_batch_size: 1 algorithms: gradient_clipping: From cdce7f04431926c3ace8b56ef11924b6dcde3634 Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Wed, 13 Aug 2025 12:40:36 -0700 Subject: [PATCH 086/109] validate 32 bs also --- yamls/single-controller-grpo-workflow.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/yamls/single-controller-grpo-workflow.yaml b/yamls/single-controller-grpo-workflow.yaml index aa03581c..f60d7130 100644 --- a/yamls/single-controller-grpo-workflow.yaml +++ b/yamls/single-controller-grpo-workflow.yaml @@ -191,7 +191,7 @@ parameters: console_log_interval: 1ba device_eval_batch_size: 1 eval_subset_num_batches: -1 - global_train_batch_size: 64 + global_train_batch_size: 32 device_train_microbatch_size: 1 vllm_tensor_parallel_size: 1 vllm_enable_prefix_caching: false From 594f87ee0767780e853b5ce92e67522b821fa7f6 Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Wed, 13 Aug 2025 12:52:22 -0700 Subject: [PATCH 087/109] isolate which var it is --- yamls/single-controller-grpo-workflow.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/yamls/single-controller-grpo-workflow.yaml b/yamls/single-controller-grpo-workflow.yaml index f60d7130..e7a105c1 100644 --- a/yamls/single-controller-grpo-workflow.yaml +++ b/yamls/single-controller-grpo-workflow.yaml @@ -143,7 +143,7 @@ parameters: use_cache: true temperature: 1 epoch_per_iteration: 1 - generations_per_prompt: 4 + generations_per_prompt: 8 num_batches_per_update: 4 device_generate_batch_size: 1 algorithms: From dde0142e99f324191e20cd658fcd5a802290f4ea Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Wed, 13 Aug 2025 20:45:07 -0700 Subject: [PATCH 088/109] trying out 16 samples --- yamls/single-controller-grpo-workflow.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/yamls/single-controller-grpo-workflow.yaml b/yamls/single-controller-grpo-workflow.yaml index e7a105c1..4af074c2 100644 --- a/yamls/single-controller-grpo-workflow.yaml +++ b/yamls/single-controller-grpo-workflow.yaml @@ -143,8 +143,8 @@ parameters: use_cache: true temperature: 1 epoch_per_iteration: 1 - generations_per_prompt: 8 - num_batches_per_update: 4 + generations_per_prompt: 4 + num_batches_per_update: 8 device_generate_batch_size: 1 algorithms: gradient_clipping: From 26b6cc49019cf541724d2d7d07f14af98a24775c Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Wed, 13 Aug 2025 20:48:24 -0700 Subject: [PATCH 089/109] more debug --- test_single_controller_ppo.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test_single_controller_ppo.py b/test_single_controller_ppo.py index 90a0a1c0..e2f5a906 100644 --- a/test_single_controller_ppo.py +++ b/test_single_controller_ppo.py @@ -348,6 +348,7 @@ def _partition_rollouts_across_ranks(self, rollouts: dict[str, Any]) -> list[dic return partitioned_rollouts def _add_latest_rollouts(self, rollouts: dict[str, Any]): + print(f"Rollout ids: {rollouts['prompt_id']}") partitioned_rollouts = self._partition_rollouts_across_ranks(rollouts) assert len(partitioned_rollouts) == self.num_train_actors, "Number of partitioned rollouts should be equal to the number of train actors" ray.get([train_actor.add_rollouts.remote(partition) for train_actor, partition in zip(self.train_actors, partitioned_rollouts)]) From 2e1b64a23c4526241e1001bdb79648189dd7ed95 Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Wed, 13 Aug 2025 21:16:39 -0700 Subject: [PATCH 090/109] trying out something else --- test_single_controller_ppo.py | 2 ++ yamls/single-controller-grpo-workflow.yaml | 4 ++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/test_single_controller_ppo.py b/test_single_controller_ppo.py index e2f5a906..ca3c86c1 100644 --- a/test_single_controller_ppo.py +++ b/test_single_controller_ppo.py @@ -705,6 +705,8 @@ def get_next_iter_prompts(self): """Gets the next iteration's prompts across all ranks and prepares them for the rollout agent.""" batches = [self._get_single_iter_prompts()] + print(f"Batches: {batches}") + return preprocess_batches(batches, self.generations_per_prompt, self.tokenizer.pad_token_id) def get_dataloader_state_dict(self): diff --git a/yamls/single-controller-grpo-workflow.yaml b/yamls/single-controller-grpo-workflow.yaml index 4af074c2..e7a105c1 100644 --- a/yamls/single-controller-grpo-workflow.yaml +++ b/yamls/single-controller-grpo-workflow.yaml @@ -143,8 +143,8 @@ parameters: use_cache: true temperature: 1 epoch_per_iteration: 1 - generations_per_prompt: 4 - num_batches_per_update: 8 + generations_per_prompt: 8 + num_batches_per_update: 4 device_generate_batch_size: 1 algorithms: gradient_clipping: From 7042619695cbba00ecfe66172e5ffe1ccee71586 Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Wed, 13 Aug 2025 21:20:41 -0700 Subject: [PATCH 091/109] further debugging --- compose_rl/algorithms/online/callback_utils.py | 2 ++ test_single_controller_ppo.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/compose_rl/algorithms/online/callback_utils.py b/compose_rl/algorithms/online/callback_utils.py index 982cccb9..2fd98d7f 100644 --- a/compose_rl/algorithms/online/callback_utils.py +++ b/compose_rl/algorithms/online/callback_utils.py @@ -57,4 +57,6 @@ def preprocess_batches(batches: list, generations_per_prompt: int, pad_token_idx else: ret_batch[key] = curr_values + print(f"Ret batch: {ret_batch['prompt_id']}") + return ret_batch diff --git a/test_single_controller_ppo.py b/test_single_controller_ppo.py index ca3c86c1..712a6f19 100644 --- a/test_single_controller_ppo.py +++ b/test_single_controller_ppo.py @@ -705,7 +705,7 @@ def get_next_iter_prompts(self): """Gets the next iteration's prompts across all ranks and prepares them for the rollout agent.""" batches = [self._get_single_iter_prompts()] - print(f"Batches: {batches}") + print(f"Batches: {batches[0]['prompt_id']}") return preprocess_batches(batches, self.generations_per_prompt, self.tokenizer.pad_token_id) From a768a0ed763770b207a66641af342b31474aaa62 Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Wed, 13 Aug 2025 21:25:10 -0700 Subject: [PATCH 092/109] debug --- compose_rl/algorithms/online/callback_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/compose_rl/algorithms/online/callback_utils.py b/compose_rl/algorithms/online/callback_utils.py index 2fd98d7f..d26b47ab 100644 --- a/compose_rl/algorithms/online/callback_utils.py +++ b/compose_rl/algorithms/online/callback_utils.py @@ -6,6 +6,8 @@ def preprocess_batches(batches: list, generations_per_prompt: int, pad_token_idx: int): ret_batch = {} + print(f"Length of batches: {len(batches)}") + print(f"Length of prompts in first batch: {len(batches[0]['prompt_id'])}") for key in batches[0].keys(): curr_values = [] @@ -57,6 +59,6 @@ def preprocess_batches(batches: list, generations_per_prompt: int, pad_token_idx else: ret_batch[key] = curr_values - print(f"Ret batch: {ret_batch['prompt_id']}") + print(f"Ret batch: {ret_batch['prompt_id']}") return ret_batch From 136d68750b3c4a5a431a9c4c313a54c666cf5fef Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Wed, 13 Aug 2025 21:31:17 -0700 Subject: [PATCH 093/109] double checking type --- compose_rl/algorithms/online/callback_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/compose_rl/algorithms/online/callback_utils.py b/compose_rl/algorithms/online/callback_utils.py index d26b47ab..8174cae6 100644 --- a/compose_rl/algorithms/online/callback_utils.py +++ b/compose_rl/algorithms/online/callback_utils.py @@ -51,6 +51,7 @@ def preprocess_batches(batches: list, generations_per_prompt: int, pad_token_idx curr_values.append(torch.cat([pad, batch[key]], dim=-1)) # For tensor fields, use torch.cat to combine the values; for string fields, just use the list + print(f"Curr values type: {type(curr_values[0])}") if isinstance(curr_values[0], torch.Tensor): ret_batch[key] = torch.cat(curr_values) else: From d46ae8d3c232ca69c476feb34b5611365ee432f7 Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Wed, 13 Aug 2025 21:31:36 -0700 Subject: [PATCH 094/109] more check --- compose_rl/algorithms/online/callback_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compose_rl/algorithms/online/callback_utils.py b/compose_rl/algorithms/online/callback_utils.py index 8174cae6..5a033e51 100644 --- a/compose_rl/algorithms/online/callback_utils.py +++ b/compose_rl/algorithms/online/callback_utils.py @@ -51,7 +51,7 @@ def preprocess_batches(batches: list, generations_per_prompt: int, pad_token_idx curr_values.append(torch.cat([pad, batch[key]], dim=-1)) # For tensor fields, use torch.cat to combine the values; for string fields, just use the list - print(f"Curr values type: {type(curr_values[0])}") + print(f"{key}'s curr values type: {type(curr_values[0])}") if isinstance(curr_values[0], torch.Tensor): ret_batch[key] = torch.cat(curr_values) else: From 2289ee766081046c2199ad6c99331f85b9d04a8e Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Wed, 13 Aug 2025 21:47:32 -0700 Subject: [PATCH 095/109] flipping this --- .../algorithms/online/callback_utils.py | 60 ++++++++++--------- 1 file changed, 31 insertions(+), 29 deletions(-) diff --git a/compose_rl/algorithms/online/callback_utils.py b/compose_rl/algorithms/online/callback_utils.py index 5a033e51..1df14530 100644 --- a/compose_rl/algorithms/online/callback_utils.py +++ b/compose_rl/algorithms/online/callback_utils.py @@ -17,38 +17,40 @@ def preprocess_batches(batches: list, generations_per_prompt: int, pad_token_idx padding_key = None for batch in batches: - # Explode the batch into multiple batches for each generation - for _ in range(generations_per_prompt): - # For keys that do not require additional processing - if key in [ - 'prompt_len', - 'verified_answer', - 'prompt_id', - 'vstar', - 'messages', - ]: - curr_values.append(batch[key]) - continue - bs, seq_len = batch[key].shape + for item in batch[key]: + # Explode the batch into multiple batches for each generation + for _ in range(generations_per_prompt): + # For keys that do not require additional processing + if key in [ + 'prompt_len', + 'verified_answer', + 'prompt_id', + 'vstar', + 'messages', + ]: + curr_values.append(item) + continue - if key == 'prompt': - padding_key = pad_token_idx - if (batch[key][:, -1] == padding_key).any(): - raise ValueError( - 'The last token in the prompt should not be the pad token. Please double ' - + - 'check the dataloader and prompt and dataloader.', - ) - elif key == 'prompt_attention_mask': - padding_key = False + bs, seq_len = batch[key].shape - # Compute the required padding and concatenate with the batch tensor - pad = torch.ones( - (bs, max_len - seq_len), - dtype=batch[key].dtype, - ) * padding_key # type: ignore - curr_values.append(torch.cat([pad, batch[key]], dim=-1)) + if key == 'prompt': + padding_key = pad_token_idx + if (batch[key][:, -1] == padding_key).any(): + raise ValueError( + 'The last token in the prompt should not be the pad token. Please double ' + + + 'check the dataloader and prompt and dataloader.', + ) + elif key == 'prompt_attention_mask': + padding_key = False + + # Compute the required padding and concatenate with the batch tensor + pad = torch.ones( + (bs, max_len - seq_len), + dtype=batch[key].dtype, + ) * padding_key # type: ignore + curr_values.append(torch.cat([pad, batch[key]], dim=-1)) # For tensor fields, use torch.cat to combine the values; for string fields, just use the list print(f"{key}'s curr values type: {type(curr_values[0])}") From 46e72b62f3a3f55b6d2122b917259016478be9c8 Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Wed, 13 Aug 2025 21:47:59 -0700 Subject: [PATCH 096/109] more corrections --- compose_rl/algorithms/online/callback_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/compose_rl/algorithms/online/callback_utils.py b/compose_rl/algorithms/online/callback_utils.py index 1df14530..75e82bc5 100644 --- a/compose_rl/algorithms/online/callback_utils.py +++ b/compose_rl/algorithms/online/callback_utils.py @@ -32,11 +32,11 @@ def preprocess_batches(batches: list, generations_per_prompt: int, pad_token_idx curr_values.append(item) continue - bs, seq_len = batch[key].shape + bs, seq_len = item.shape if key == 'prompt': padding_key = pad_token_idx - if (batch[key][:, -1] == padding_key).any(): + if (item[:, -1] == padding_key).any(): raise ValueError( 'The last token in the prompt should not be the pad token. Please double ' + @@ -48,9 +48,9 @@ def preprocess_batches(batches: list, generations_per_prompt: int, pad_token_idx # Compute the required padding and concatenate with the batch tensor pad = torch.ones( (bs, max_len - seq_len), - dtype=batch[key].dtype, + dtype=item.dtype, ) * padding_key # type: ignore - curr_values.append(torch.cat([pad, batch[key]], dim=-1)) + curr_values.append(torch.cat([pad, item], dim=-1)) # For tensor fields, use torch.cat to combine the values; for string fields, just use the list print(f"{key}'s curr values type: {type(curr_values[0])}") From 27a52f21bdcbb86620bdd6095e1c729a5349cf8f Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Wed, 13 Aug 2025 21:53:12 -0700 Subject: [PATCH 097/109] more debugging --- compose_rl/algorithms/online/callback_utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/compose_rl/algorithms/online/callback_utils.py b/compose_rl/algorithms/online/callback_utils.py index 75e82bc5..f8c171a6 100644 --- a/compose_rl/algorithms/online/callback_utils.py +++ b/compose_rl/algorithms/online/callback_utils.py @@ -18,6 +18,9 @@ def preprocess_batches(batches: list, generations_per_prompt: int, pad_token_idx padding_key = None for batch in batches: + if isinstance(batch[key], torch.Tensor): + print(f"shape of {key}: {batch[key].shape}") + for item in batch[key]: # Explode the batch into multiple batches for each generation for _ in range(generations_per_prompt): From a394e79955e600d1193db64d72b9def8ef521152 Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Wed, 13 Aug 2025 21:58:35 -0700 Subject: [PATCH 098/109] stack --- compose_rl/algorithms/online/callback_utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/compose_rl/algorithms/online/callback_utils.py b/compose_rl/algorithms/online/callback_utils.py index f8c171a6..5eab56a9 100644 --- a/compose_rl/algorithms/online/callback_utils.py +++ b/compose_rl/algorithms/online/callback_utils.py @@ -58,7 +58,10 @@ def preprocess_batches(batches: list, generations_per_prompt: int, pad_token_idx # For tensor fields, use torch.cat to combine the values; for string fields, just use the list print(f"{key}'s curr values type: {type(curr_values[0])}") if isinstance(curr_values[0], torch.Tensor): - ret_batch[key] = torch.cat(curr_values) + if len(curr_values[0].shape) == 0: + ret_batch[key] = torch.stack(curr_values) + else: + ret_batch[key] = torch.cat(curr_values) else: if key in ['verified_answer', 'vstar']: ret_batch[key] = list(flatten(curr_values)) From bdeb027b78ba6f5d6004af6a07944b79bffb1f3d Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Wed, 13 Aug 2025 22:02:59 -0700 Subject: [PATCH 099/109] fix bs --- compose_rl/algorithms/online/callback_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/compose_rl/algorithms/online/callback_utils.py b/compose_rl/algorithms/online/callback_utils.py index 5eab56a9..877fa7e9 100644 --- a/compose_rl/algorithms/online/callback_utils.py +++ b/compose_rl/algorithms/online/callback_utils.py @@ -35,7 +35,7 @@ def preprocess_batches(batches: list, generations_per_prompt: int, pad_token_idx curr_values.append(item) continue - bs, seq_len = item.shape + seq_len = item.shape if key == 'prompt': padding_key = pad_token_idx @@ -50,7 +50,7 @@ def preprocess_batches(batches: list, generations_per_prompt: int, pad_token_idx # Compute the required padding and concatenate with the batch tensor pad = torch.ones( - (bs, max_len - seq_len), + max_len - seq_len, dtype=item.dtype, ) * padding_key # type: ignore curr_values.append(torch.cat([pad, item], dim=-1)) From 2647386df4592bbfd0f127b21a4c6ed6c519d2e5 Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Wed, 13 Aug 2025 22:06:06 -0700 Subject: [PATCH 100/109] more fix --- compose_rl/algorithms/online/callback_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compose_rl/algorithms/online/callback_utils.py b/compose_rl/algorithms/online/callback_utils.py index 877fa7e9..87d6b5c2 100644 --- a/compose_rl/algorithms/online/callback_utils.py +++ b/compose_rl/algorithms/online/callback_utils.py @@ -39,7 +39,7 @@ def preprocess_batches(batches: list, generations_per_prompt: int, pad_token_idx if key == 'prompt': padding_key = pad_token_idx - if (item[:, -1] == padding_key).any(): + if (item[-1] == padding_key).any(): raise ValueError( 'The last token in the prompt should not be the pad token. Please double ' + From c0c3dae373ced8325dd490a68e4cc0aa28173319 Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Wed, 13 Aug 2025 22:10:25 -0700 Subject: [PATCH 101/109] most random comma --- compose_rl/algorithms/online/callback_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compose_rl/algorithms/online/callback_utils.py b/compose_rl/algorithms/online/callback_utils.py index 87d6b5c2..121f2d16 100644 --- a/compose_rl/algorithms/online/callback_utils.py +++ b/compose_rl/algorithms/online/callback_utils.py @@ -35,7 +35,7 @@ def preprocess_batches(batches: list, generations_per_prompt: int, pad_token_idx curr_values.append(item) continue - seq_len = item.shape + seq_len, = item.shape # expect this to be a 1D tensor if key == 'prompt': padding_key = pad_token_idx From 45b08feac6233ee0c5291dad1a59d0327cbc7961 Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Wed, 13 Aug 2025 22:18:32 -0700 Subject: [PATCH 102/109] double checking shapes --- compose_rl/algorithms/online/callback_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/compose_rl/algorithms/online/callback_utils.py b/compose_rl/algorithms/online/callback_utils.py index 121f2d16..6a016dea 100644 --- a/compose_rl/algorithms/online/callback_utils.py +++ b/compose_rl/algorithms/online/callback_utils.py @@ -62,6 +62,7 @@ def preprocess_batches(batches: list, generations_per_prompt: int, pad_token_idx ret_batch[key] = torch.stack(curr_values) else: ret_batch[key] = torch.cat(curr_values) + print(f"Ret batch: {ret_batch[key].shape}") else: if key in ['verified_answer', 'vstar']: ret_batch[key] = list(flatten(curr_values)) From f8bfccee591e532c5ef7233c886e2d949f36006a Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Wed, 13 Aug 2025 22:18:58 -0700 Subject: [PATCH 103/109] should be stack --- compose_rl/algorithms/online/callback_utils.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/compose_rl/algorithms/online/callback_utils.py b/compose_rl/algorithms/online/callback_utils.py index 6a016dea..28a37932 100644 --- a/compose_rl/algorithms/online/callback_utils.py +++ b/compose_rl/algorithms/online/callback_utils.py @@ -58,10 +58,7 @@ def preprocess_batches(batches: list, generations_per_prompt: int, pad_token_idx # For tensor fields, use torch.cat to combine the values; for string fields, just use the list print(f"{key}'s curr values type: {type(curr_values[0])}") if isinstance(curr_values[0], torch.Tensor): - if len(curr_values[0].shape) == 0: - ret_batch[key] = torch.stack(curr_values) - else: - ret_batch[key] = torch.cat(curr_values) + ret_batch[key] = torch.stack(curr_values) print(f"Ret batch: {ret_batch[key].shape}") else: if key in ['verified_answer', 'vstar']: From 4075679fecd6e694079f92f44188d626ea1023a8 Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Wed, 13 Aug 2025 22:25:13 -0700 Subject: [PATCH 104/109] cleaning up stuff --- .../online/generation_utils/vllm_utils.py | 45 +-------- test_single_controller_ppo.py | 94 +------------------ yamls/single-controller-grpo-workflow.yaml | 12 +-- 3 files changed, 13 insertions(+), 138 deletions(-) diff --git a/compose_rl/algorithms/online/generation_utils/vllm_utils.py b/compose_rl/algorithms/online/generation_utils/vllm_utils.py index a9a91d9b..4da62390 100644 --- a/compose_rl/algorithms/online/generation_utils/vllm_utils.py +++ b/compose_rl/algorithms/online/generation_utils/vllm_utils.py @@ -129,8 +129,6 @@ def init_process_group( assert group_name != '', 'group name must not be empty' rank = torch.distributed.get_rank() + rank_offset - - self._rank = rank self._model_update_group = init_process_group( # type: ignore backend=backend, init_method=f'tcp://{master_address}:{master_port}', @@ -166,44 +164,15 @@ def update_weight( group=self._model_update_group, ) - with open(f"/tmp/compose-rl-worker-{self._rank}.txt", "a") as f: - f.write(f"Weight {name} at receiving weight, size = {shape}, weight_sum = {torch.sum(weight)}\n") - # Because FSDP keeps master weights in FP32 and vLLM typically doesn't do this # We will need to cast the weight type to the model_config type if weight.dtype != self.model_config.dtype: # type: ignore weight = weight.to(self.model_config.dtype) # type: ignore - updated_weights = self.model_runner.model.load_weights( # type: ignore + self.model_runner.model.load_weights( # type: ignore weights=[(name, weight)], ) # type: ignore - if len(updated_weights) == 0: - with open(f"/tmp/compose-rl-worker-{self._rank}.txt", "a") as f: - f.write(f"Weight {name} not found in model\n") - else: - write = True - if len(updated_weights) > 1: - with open(f"/tmp/compose-rl-worker-{self._rank}.txt", "a") as f: - f.write(f"multiple updated_weights = {updated_weights}\n") - f.write(f"type(updated_weights) = {type(updated_weights)}\n") - write = False - - if isinstance(updated_weights, list): - updated_weights = updated_weights[0][0] - elif isinstance(updated_weights, set): - updated_weights = updated_weights.pop() - else: - with open(f"/tmp/compose-rl-worker-{self._rank}.txt", "a") as f: - f.write(f"updated_weights = {updated_weights}\n") - f.write(f"type(updated_weights) = {type(updated_weights)}\n") - write = False - - if write: - updated_weight_tensor = [weight_param.data for weight_name, weight_param in self.model_runner.model.named_parameters() if weight_name == updated_weights][0] - with open(f"/tmp/compose-rl-worker-{self._rank}.txt", "a") as f: - f.write(f"Weight {updated_weights} at updating weight, size = {updated_weight_tensor.shape}, weight_sum = {torch.sum(updated_weight_tensor)}\n") - del weight if empty_cache: @@ -487,7 +456,6 @@ def broadcast_to_vllm( # This is needed otherwise FSDP will materialize parameters of size 0. # So just for the joint actor critic models we have to actually skip this module. if module_name == 'model' and loss_type == OnPolicyEnum.PPO: - log.info('Skipping model module') continue # Only update if we haven't updated this module before @@ -529,18 +497,15 @@ def broadcast_to_vllm( count += 1 shape = param.shape - refs = [] - for engine in vllm_engines: - ref = engine.update_weight.remote( + refs = [ + engine.update_weight.remote( parsed_name, dtype=param.dtype, shape=shape, empty_cache=(count == num_params), - ) - refs.append(ref) + ) for engine in vllm_engines + ] refss.extend(refs) - with open(f"/tmp/compose-rl-master.txt", "a") as f: - f.write(f"Weight {parsed_name} at sending weight, size = {shape}, weight_sum = {torch.sum(param.data)}\n") torch.distributed.broadcast( param.data, 0, diff --git a/test_single_controller_ppo.py b/test_single_controller_ppo.py index 712a6f19..ad756d4f 100644 --- a/test_single_controller_ppo.py +++ b/test_single_controller_ppo.py @@ -25,10 +25,9 @@ import ray import torch import torch.distributed as dist -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP # type: ignore from composer import Trainer from composer.core import get_precision_context -from composer.optim import DecoupledSGDW +from composer.optim import DecoupledAdamW from composer.utils import dist as composer_dist from llmfoundry.data import build_dataloader from omegaconf import OmegaConf as om @@ -171,7 +170,7 @@ def build_ppo_trainer(self): model = ComposerHFCriticFreePolicyLM(**self.model_config) self.logger.info("Model created successfully") - optimizer = DecoupledSGDW(model.parameters(), lr=1) + optimizer = DecoupledAdamW(model.parameters(), lr=1e-6) # TODO (infra): pull the rest of the training logic from the callback # to this class, e.g, how to interact with env, calculate rewards etc @@ -247,33 +246,9 @@ def train_1_iter(self): # fit() checks if there is existing checkpoint, make a full forward pass, it will run eval pass and save pass. # We potentially want to run this https://github.com/mosaicml/composer/blob/dev/composer/trainer/trainer.py#L2826 # fit() can also potentially overwrite the mlflow - def write_params(model: torch.nn.Module, param2fullname: dict, update_stamp: str): - for _, module in model.named_modules(): - if isinstance(module, FSDP): - with FSDP.summon_full_params( - module, - writeback=False, - rank0_only=False, - recurse=True, - ): - for _, param in module.named_parameters(recurse=True): - full_name = param2fullname[param] - parsed_name = simplify_param_path(full_name) - shape = param.shape - with open(f"/tmp/compose-rl-train-{self.rank}.txt", "a") as f: - f.write(f"Weight {parsed_name} at {update_stamp}, size = {shape}, weight_sum = {torch.sum(param.data)}\n") - - model = self.ppo_trainer.state.model - - param2fullname = build_param_fullnames(model) - - write_params(model, param2fullname, 'before train iter') - self.ppo_trainer.fit(duration='1iter') self.logger.info(f"#### Finished training 1 iter with loss: {self.ppo_trainer.state.loss}") - write_params(model, param2fullname, 'after train iter') - def setup_process_groups( master_actor: Any, @@ -348,7 +323,6 @@ def _partition_rollouts_across_ranks(self, rollouts: dict[str, Any]) -> list[dic return partitioned_rollouts def _add_latest_rollouts(self, rollouts: dict[str, Any]): - print(f"Rollout ids: {rollouts['prompt_id']}") partitioned_rollouts = self._partition_rollouts_across_ranks(rollouts) assert len(partitioned_rollouts) == self.num_train_actors, "Number of partitioned rollouts should be equal to the number of train actors" ray.get([train_actor.add_rollouts.remote(partition) for train_actor, partition in zip(self.train_actors, partitioned_rollouts)]) @@ -365,7 +339,6 @@ async def run(self, num_iterations: int, experience_buffer: 'ExperienceBuffer', # Simple example of adding elements to the experience buffer # Populate the train actor group with the rollouts and then train latest_rollouts = await experience_buffer.get() - print(f"Obtained {len(latest_rollouts['verified_answer'])} examples") self._add_latest_rollouts(latest_rollouts) await asyncio.to_thread(self.train_1_iter) # TODO decide where should we use the lock and the semaphore @@ -376,12 +349,6 @@ class InferenceServer: """Inference server with vLLM engines.""" def __init__(self, num_vllm_engines: int, pretrain_model_name: str, config: Any): - import os - if os.getenv('NODE_RANK', None) == '0' and os.getenv('LOCAL_RANK', None) == '0': - os.environ['NCCL_CUMEM_ENABLE'] = '0' - os.environ['RAY_BACKEND_LOG_LEVEL'] = 'DEBUG' - os.environ['RAY_DEBUG_LOGS'] = '1' - self.num_vllm_engines = num_vllm_engines self.vllm_tensor_parallel_size = config.vllm_tensor_parallel_size self.vllm_engines = create_vllm_engines( @@ -705,8 +672,6 @@ def get_next_iter_prompts(self): """Gets the next iteration's prompts across all ranks and prepares them for the rollout agent.""" batches = [self._get_single_iter_prompts()] - print(f"Batches: {batches[0]['prompt_id']}") - return preprocess_batches(batches, self.generations_per_prompt, self.tokenizer.pad_token_id) def get_dataloader_state_dict(self): @@ -836,61 +801,6 @@ def _run_single_controller_ppo( asyncio.run(ppo_controller.train_async(config.max_duration)) -def simplify_param_path(path: str) -> str: - """Simplifies the parameter path by removing unnecessary parts. - - Args: - path (str): The original parameter path. - """ - # Parts we want to remove - remove_parts = [ - '_fsdp_wrapped_module', - '_checkpoint_wrapped_module', - 'lm_backbone', - 'model', - ] - - # Split the path into parts - parts = path.split('.') - - # Keep only parts that don't contain any of the remove_parts - clean_parts = [] - if 'lm_head' not in path: - clean_parts = ['model'] - for part in parts: - if not any(remove in part for remove in remove_parts): - clean_parts.append(part) - - return '.'.join(clean_parts) - - -def build_param_fullnames(top_module: torch.nn.Module) -> dict: - """Builds a mapping of parameter objects to their fully-qualified names. - - Traverses the entire model from the top level and map each parameter - object to its fully-qualified name (e.g., - "lm_backbone.layer1.mlp.down_proj.weight"). - - Args: - top_module (torch.nn.Module): The top-level module to traverse. - """ - param2fullname = {} - - def _dfs(current_module: torch.nn.Module, prefix: str = ''): - # Get local parameters (without recursing into children). - for local_name, param in current_module.named_parameters(recurse=False): - full_name = f'{prefix}.{local_name}' if prefix else local_name - param2fullname[param] = full_name - - # Recurse on child modules. - for child_name, child_module in current_module.named_children(): - child_prefix = f'{prefix}.{child_name}' if prefix else child_name - _dfs(child_module, prefix=child_prefix) - - _dfs(top_module) - return param2fullname - - if __name__ == '__main__': # Parse command line arguments parser = argparse.ArgumentParser(description='Run single controller PPO with configuration file') diff --git a/yamls/single-controller-grpo-workflow.yaml b/yamls/single-controller-grpo-workflow.yaml index e7a105c1..699ba5a8 100644 --- a/yamls/single-controller-grpo-workflow.yaml +++ b/yamls/single-controller-grpo-workflow.yaml @@ -6,15 +6,15 @@ scheduling: resumable: false preemptible: false compute: - gpus: 16 - cluster: r5z2p1 + gpus: 8 + cluster: r5z2p3 instance: oci.bm.gpu.h200.8.oke integrations: - integration_type: git_repo path: /workspace/compose-rl git_repo: databricks/compose-rl ssh_clone: true - git_branch: ethantang-db/fix_multi_node + git_branch: single-controller-hackathon - integration_type: git_repo path: /workspace/research-universe git_repo: databricks-mosaic/research-universe @@ -77,7 +77,7 @@ parameters: save_interval: 1dur runtime_estimator: {} optimizer: - lr: 0.5 + lr: 1.0e-06 name: decoupled_adamw betas: - 0.9 @@ -144,7 +144,7 @@ parameters: temperature: 1 epoch_per_iteration: 1 generations_per_prompt: 8 - num_batches_per_update: 4 + num_batches_per_update: 8 device_generate_batch_size: 1 algorithms: gradient_clipping: @@ -191,7 +191,7 @@ parameters: console_log_interval: 1ba device_eval_batch_size: 1 eval_subset_num_batches: -1 - global_train_batch_size: 32 + global_train_batch_size: 64 device_train_microbatch_size: 1 vllm_tensor_parallel_size: 1 vllm_enable_prefix_caching: false From 7b6226708d54febe44c3967f7fa97090073a87ba Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Wed, 13 Aug 2025 22:25:42 -0700 Subject: [PATCH 105/109] let's do a run --- yamls/single-controller-grpo-workflow.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/yamls/single-controller-grpo-workflow.yaml b/yamls/single-controller-grpo-workflow.yaml index 699ba5a8..d01a5231 100644 --- a/yamls/single-controller-grpo-workflow.yaml +++ b/yamls/single-controller-grpo-workflow.yaml @@ -14,7 +14,7 @@ integrations: path: /workspace/compose-rl git_repo: databricks/compose-rl ssh_clone: true - git_branch: single-controller-hackathon + git_branch: ethantang-db/fix_multi_node - integration_type: git_repo path: /workspace/research-universe git_repo: databricks-mosaic/research-universe From 748168ec7e8441119724f4e4b7dfd38d33469b31 Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Wed, 13 Aug 2025 22:26:14 -0700 Subject: [PATCH 106/109] change cluster --- yamls/single-controller-grpo-workflow.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/yamls/single-controller-grpo-workflow.yaml b/yamls/single-controller-grpo-workflow.yaml index d01a5231..a709e741 100644 --- a/yamls/single-controller-grpo-workflow.yaml +++ b/yamls/single-controller-grpo-workflow.yaml @@ -7,7 +7,7 @@ scheduling: preemptible: false compute: gpus: 8 - cluster: r5z2p3 + cluster: r5z2p1 instance: oci.bm.gpu.h200.8.oke integrations: - integration_type: git_repo From 06248b7ea1f97a98c184f3e1ad979aa1eb8ab07b Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Wed, 13 Aug 2025 22:36:36 -0700 Subject: [PATCH 107/109] the fix --- compose_rl/algorithms/online/callback.py | 35 ++----------------- .../algorithms/online/callback_utils.py | 14 +++----- .../online/single_controller_callback.py | 1 - compose_rl/utils/utils.py | 4 --- yamls/single-controller-grpo-workflow.yaml | 4 +-- 5 files changed, 8 insertions(+), 50 deletions(-) diff --git a/compose_rl/algorithms/online/callback.py b/compose_rl/algorithms/online/callback.py index de9bf37e..8ed96ff9 100644 --- a/compose_rl/algorithms/online/callback.py +++ b/compose_rl/algorithms/online/callback.py @@ -907,15 +907,6 @@ def _resolve_outputs( sums.scatter_add_(0, inverse_indices, flat_rewards) sum_squares.scatter_add_(0, inverse_indices, flat_rewards**2) - if dist.get_global_rank() == 0: - print(f"rewards: {rewards}") - print(f"inverse_indices: {inverse_indices}") - print(f"flat_rewards: {flat_rewards}") - print(f"sums: {sums}") - print(f"sum_squares: {sum_squares}") - print(f"counts: {counts}") - print(f"prompt_ids: {prompt_id}") - # Compute means and standard deviations means = sums / counts variances = (sum_squares / counts) - (means**2) @@ -942,12 +933,6 @@ def _resolve_outputs( expanded_advantages, advantages, ) - - reward_sq = rewards * rewards - expanded_advantages_sq = expanded_advantages * expanded_advantages - print(f"mean_rewards: {mean_rewards.sum()}, std_rewards: {std_rewards.sum()}") - print(f"flat_rewards: {flat_rewards.sum()}, expanded_advantages_sq: {expanded_advantages_sq.sum()}, action_masks: {env_outs['action_mask'].sum()}") - print(f"reward_sq: {reward_sq.sum()}, advantages: {advantages.sum()}") env_outs['advantages'] = advantages else: raise ValueError( @@ -960,12 +945,10 @@ def _resolve_outputs( env_outs['action_mask'], ) - print(f"batch_adv_mean: {batch_adv_mean}, batch_adv_var: {batch_adv_var}") - bs = iter_batch['prompt_id'].shape[0] iter_batch.update({ - 'adv_masked_mean': torch.ones(bs, device=batch_adv_mean.device) * batch_adv_mean, - 'adv_masked_var': torch.ones(bs, device=batch_adv_var.device) * batch_adv_var, + 'adv_masked_mean': torch.ones(bs) * batch_adv_mean.cpu(), + 'adv_masked_var': torch.ones(bs) * batch_adv_var.cpu(), }) mean_ift = masked_mean( @@ -1142,17 +1125,3 @@ def state_dict(self): def load_state_dict(self, state_dict: dict[str, Any]): self.kl_ctl.load_state_dict(state_dict['KL_ctl_state_dict']) self.iter_num = state_dict['iter_num'] - - # def before_forward(self, state: State, logger: Logger): - # if dist.get_global_rank() == 0: - # skip_keys = ['prompt_id', 'prompt', 'prompt_len', 'verified_answer', 'prompt_attention_mask', 'sequences'] - # to_log = {k: v for k, v in state.batch.items() if k not in skip_keys} - # print(f"Before forward, training on {to_log}") - - # def after_forward(self, state: State, logger: Logger): - # if dist.get_global_rank() == 0: - # print(f"After forward, outputs: {state.outputs}") - - # def after_loss(self, state: State, logger: Logger): - # if dist.get_global_rank() == 0: - # print(f"After loss, {state.loss}") diff --git a/compose_rl/algorithms/online/callback_utils.py b/compose_rl/algorithms/online/callback_utils.py index 28a37932..329ce409 100644 --- a/compose_rl/algorithms/online/callback_utils.py +++ b/compose_rl/algorithms/online/callback_utils.py @@ -6,8 +6,7 @@ def preprocess_batches(batches: list, generations_per_prompt: int, pad_token_idx: int): ret_batch = {} - print(f"Length of batches: {len(batches)}") - print(f"Length of prompts in first batch: {len(batches[0]['prompt_id'])}") + for key in batches[0].keys(): curr_values = [] @@ -16,11 +15,10 @@ def preprocess_batches(batches: list, generations_per_prompt: int, pad_token_idx max_len = max([batch[key].shape[-1] for batch in batches]) padding_key = None - for batch in batches: - - if isinstance(batch[key], torch.Tensor): - print(f"shape of {key}: {batch[key].shape}") + for batch in batches: + # inside the batch, it's a dictionary of tensors that have the batch dimension there, + # so we need to iterate through each element to explode it. for item in batch[key]: # Explode the batch into multiple batches for each generation for _ in range(generations_per_prompt): @@ -56,16 +54,12 @@ def preprocess_batches(batches: list, generations_per_prompt: int, pad_token_idx curr_values.append(torch.cat([pad, item], dim=-1)) # For tensor fields, use torch.cat to combine the values; for string fields, just use the list - print(f"{key}'s curr values type: {type(curr_values[0])}") if isinstance(curr_values[0], torch.Tensor): ret_batch[key] = torch.stack(curr_values) - print(f"Ret batch: {ret_batch[key].shape}") else: if key in ['verified_answer', 'vstar']: ret_batch[key] = list(flatten(curr_values)) else: ret_batch[key] = curr_values - print(f"Ret batch: {ret_batch['prompt_id']}") - return ret_batch diff --git a/compose_rl/algorithms/online/single_controller_callback.py b/compose_rl/algorithms/online/single_controller_callback.py index a4ddc207..5452feb6 100644 --- a/compose_rl/algorithms/online/single_controller_callback.py +++ b/compose_rl/algorithms/online/single_controller_callback.py @@ -53,7 +53,6 @@ def iteration_start(self, state: State, logger: Logger): state.auto_microbatching, state.train_dataloader, ) - print(f"Training on {len(self.buffer)} examples") # Update IFT KL self._update_ift_kl() diff --git a/compose_rl/utils/utils.py b/compose_rl/utils/utils.py index 5b8d4c47..bfd172a4 100644 --- a/compose_rl/utils/utils.py +++ b/compose_rl/utils/utils.py @@ -422,8 +422,6 @@ def dist_compute_masked_mean_and_var( # Get the masked tensor sum masked_tensor_sum = (tensor * mask).sum() - print(f"before dist, num_unmasked_elements: {num_unmasked_elements}, masked_tensor_sum: {masked_tensor_sum}, tensor: {tensor.sum()}") - dist.all_reduce(num_unmasked_elements) dist.all_reduce(masked_tensor_sum) @@ -434,8 +432,6 @@ def dist_compute_masked_mean_and_var( centered_values = centered_values.sum() dist.all_reduce(centered_values) - - print(f"masked_tensor_sum: {masked_tensor_sum}, centered_values: {centered_values}, num_unmasked_elements: {num_unmasked_elements}") global_variance = centered_values / num_unmasked_elements if unbiased: bessel_correction = num_unmasked_elements / (num_unmasked_elements - 1) diff --git a/yamls/single-controller-grpo-workflow.yaml b/yamls/single-controller-grpo-workflow.yaml index a709e741..699ba5a8 100644 --- a/yamls/single-controller-grpo-workflow.yaml +++ b/yamls/single-controller-grpo-workflow.yaml @@ -7,14 +7,14 @@ scheduling: preemptible: false compute: gpus: 8 - cluster: r5z2p1 + cluster: r5z2p3 instance: oci.bm.gpu.h200.8.oke integrations: - integration_type: git_repo path: /workspace/compose-rl git_repo: databricks/compose-rl ssh_clone: true - git_branch: ethantang-db/fix_multi_node + git_branch: single-controller-hackathon - integration_type: git_repo path: /workspace/research-universe git_repo: databricks-mosaic/research-universe From d1bf2331e0b63ee457bb8d2f0860c69529db7afa Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Wed, 13 Aug 2025 22:47:12 -0700 Subject: [PATCH 108/109] white space --- compose_rl/algorithms/online/callback_utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/compose_rl/algorithms/online/callback_utils.py b/compose_rl/algorithms/online/callback_utils.py index 329ce409..e4d2ae9c 100644 --- a/compose_rl/algorithms/online/callback_utils.py +++ b/compose_rl/algorithms/online/callback_utils.py @@ -6,7 +6,6 @@ def preprocess_batches(batches: list, generations_per_prompt: int, pad_token_idx: int): ret_batch = {} - for key in batches[0].keys(): curr_values = [] @@ -15,7 +14,6 @@ def preprocess_batches(batches: list, generations_per_prompt: int, pad_token_idx max_len = max([batch[key].shape[-1] for batch in batches]) padding_key = None - for batch in batches: # inside the batch, it's a dictionary of tensors that have the batch dimension there, # so we need to iterate through each element to explode it. From 9af0e4d6c45aa78be9088eb1d25825985d3b5dec Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Thu, 14 Aug 2025 14:22:18 -0700 Subject: [PATCH 109/109] added checks for proper configs --- test_single_controller_ppo.py | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/test_single_controller_ppo.py b/test_single_controller_ppo.py index ad756d4f..66d7305e 100644 --- a/test_single_controller_ppo.py +++ b/test_single_controller_ppo.py @@ -62,6 +62,17 @@ def time_it(name: str): print(f"[{name}] took {end_time - start_time:.2f} seconds") +def get_and_validate_num_prompts_per_iteration(config: Any): + generations_per_prompt = config.variables.generations_per_prompt + num_batches_per_update = config.variables.num_batches_per_update + total_num_generations = config.global_train_batch_size * num_batches_per_update + num_prompts_per_iteration = total_num_generations // generations_per_prompt + + assert total_num_generations % generations_per_prompt == 0, "total_num_generations must be divisible by generations_per_prompt" + + return num_prompts_per_iteration + + class DistributedGPUActor(BaseDistributedGPUActor): """Distributed GPU actor for testing.""" @@ -628,14 +639,8 @@ def __init__(self, config: Any): self.dataloader_config['dataset']['local'].format(timestamp=timestamp) # Key variables - global_train_batch_size = config.global_train_batch_size self.generations_per_prompt = config.variables.generations_per_prompt - num_batches_per_update = config.variables.num_batches_per_update - total_num_generations = global_train_batch_size * num_batches_per_update - self.num_prompts_per_iteration = total_num_generations // self.generations_per_prompt - - # Validate that the total number of generations is divisible by the number of generations per prompt - assert total_num_generations % self.generations_per_prompt == 0, "total_num_generations must be divisible by generations_per_prompt" + self.num_prompts_per_iteration = get_and_validate_num_prompts_per_iteration(config) # Creating main entities self.tokenizer = self._build_tokenizer() @@ -768,6 +773,9 @@ def _run_single_controller_ppo( config=config, ) + num_prompts_per_iteration = get_and_validate_num_prompts_per_iteration(config) + assert num_prompts_per_iteration % num_train_actors == 0, "Number of prompts per iteration must be divisible by number of train actors to ensure accurate advantage calculations." + # We are using a CPU worker for the StreamingActor # and this involves a super hacky workaround by # uninstalling megablocks if it exists. Better solutions @@ -785,6 +793,8 @@ def _run_single_controller_ppo( streaming_dataset_actor = ray.remote(num_gpus=0)(StreamingDatasetActor).remote(config) rollout_agent = RolloutAgent(inference_server, streaming_dataset_actor, config) + + # EvalAgent doesn't need to be a Ray actor since we don't need to # set a world_size or use GPUs for this process. eval_agent = EvalAgent(inference_server.engines, config)