Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
110 commits
Select commit Hold shift + click to select a range
15c6ae3
adding logging to understand weight updates
ethantang-db Aug 9, 2025
4ed2322
assert false
ethantang-db Aug 9, 2025
453a610
logging more updateS
ethantang-db Aug 9, 2025
5c15cb8
trying out llama 1b
ethantang-db Aug 9, 2025
2bd7bc9
fix loading
ethantang-db Aug 9, 2025
a8d817b
different dataset
ethantang-db Aug 9, 2025
5d64218
revert to r1
ethantang-db Aug 9, 2025
3cf4432
trying out 2 nodes
ethantang-db Aug 9, 2025
f7184be
test
ethantang-db Aug 9, 2025
7e66c92
log worker_wrap logic
ethantang-db Aug 9, 2025
7e1d365
force crash
ethantang-db Aug 9, 2025
7c10670
removing assert
ethantang-db Aug 9, 2025
b0a3467
jank logging
ethantang-db Aug 10, 2025
ccd2e4c
try gloo?
ethantang-db Aug 10, 2025
34ce4fb
revert back to nccl
ethantang-db Aug 10, 2025
536e513
try out cpu and gloo
ethantang-db Aug 11, 2025
2f4de01
log tensors to file
ethantang-db Aug 11, 2025
c96a8d1
better logging
ethantang-db Aug 11, 2025
ac6b8d9
removed redundent debugging
ethantang-db Aug 11, 2025
24b0d93
rank
ethantang-db Aug 11, 2025
f95a801
try env vars
ethantang-db Aug 11, 2025
b0a3441
try out other place for nccl
ethantang-db Aug 11, 2025
5598942
f...
ethantang-db Aug 11, 2025
febdc69
further debugging
ethantang-db Aug 11, 2025
05085f3
log what weights are updated
ethantang-db Aug 12, 2025
0ebb94b
log weight updates
ethantang-db Aug 12, 2025
2673c66
update weights
ethantang-db Aug 12, 2025
d3c1d20
this is trippin
ethantang-db Aug 12, 2025
840e0d8
better weight logging
ethantang-db Aug 12, 2025
b9dfda7
like cursor bruh?
ethantang-db Aug 12, 2025
50475d5
better logs
ethantang-db Aug 12, 2025
c80e25f
???
ethantang-db Aug 12, 2025
2c8d11f
???
ethantang-db Aug 12, 2025
1de2fff
try layer 25
ethantang-db Aug 12, 2025
8089c52
cranking up the learning rate
ethantang-db Aug 12, 2025
f38d7db
new lines
ethantang-db Aug 12, 2025
963a66f
remove new lines
ethantang-db Aug 12, 2025
37403ef
even more ridiculous LR
ethantang-db Aug 12, 2025
6fbdfc3
are you fking kidding me
ethantang-db Aug 12, 2025
7c699d6
new opt
ethantang-db Aug 12, 2025
5672519
Merge branch 'single-controller-hackathon' into ethantang-db/fix_mult…
ethantang-db Aug 12, 2025
81b9c46
chaos updates
ethantang-db Aug 12, 2025
00920bf
try new chaos values
ethantang-db Aug 12, 2025
a1c3594
logging weight updateS
ethantang-db Aug 12, 2025
5949039
lr is 1
ethantang-db Aug 12, 2025
abbfa1d
god dammit cursor
ethantang-db Aug 12, 2025
fca0c37
correct rank
ethantang-db Aug 12, 2025
85cfd6b
try other chaos update vlaues
ethantang-db Aug 12, 2025
ba7fb54
fix crash?
ethantang-db Aug 12, 2025
be9d152
remove chaos
ethantang-db Aug 12, 2025
513ff59
fp32 everything
ethantang-db Aug 12, 2025
6a81fd1
enable chaos
ethantang-db Aug 12, 2025
2822d71
god dammit flash attention
ethantang-db Aug 12, 2025
801099c
disable flash attention on vllm
ethantang-db Aug 12, 2025
8015476
try out torch SDPA
ethantang-db Aug 12, 2025
fe7f76c
:/
ethantang-db Aug 12, 2025
a48416f
proper fsdp summon
ethantang-db Aug 12, 2025
21fafa1
bfloat16
ethantang-db Aug 12, 2025
35721b8
disable chaos
ethantang-db Aug 12, 2025
7cb3c3b
better logging
ethantang-db Aug 13, 2025
f926283
logging how many examples being trained
ethantang-db Aug 13, 2025
b133fb6
more logging
ethantang-db Aug 13, 2025
816a527
log loss
ethantang-db Aug 13, 2025
5cbb05a
debugging batch better
ethantang-db Aug 13, 2025
c301055
shorter for better debugging
ethantang-db Aug 13, 2025
b5fc1f0
more len
ethantang-db Aug 13, 2025
aefd737
test
ethantang-db Aug 13, 2025
fd47fc0
test
ethantang-db Aug 13, 2025
2a45675
not cpu
ethantang-db Aug 13, 2025
09a0c11
gpu
ethantang-db Aug 13, 2025
7872488
debug more
ethantang-db Aug 13, 2025
3484cf8
disable logging
ethantang-db Aug 13, 2025
d3b16be
change gen len
ethantang-db Aug 13, 2025
6cd3980
disable logging
ethantang-db Aug 13, 2025
43e3df9
logging before
ethantang-db Aug 13, 2025
96c98fe
increase global batch size
ethantang-db Aug 13, 2025
33766c8
log rewards
ethantang-db Aug 13, 2025
930ed2d
more logs
ethantang-db Aug 13, 2025
35024c8
probe more
ethantang-db Aug 13, 2025
d634b86
log expanded advantages better
ethantang-db Aug 13, 2025
c0e95e1
more log
ethantang-db Aug 13, 2025
05c90d2
log ids
ethantang-db Aug 13, 2025
4672849
wtf?
ethantang-db Aug 13, 2025
26ff5a9
???
ethantang-db Aug 13, 2025
92f1090
inverse indices
ethantang-db Aug 13, 2025
171fcf9
test wild theory
ethantang-db Aug 13, 2025
cdce7f0
validate 32 bs also
ethantang-db Aug 13, 2025
594f87e
isolate which var it is
ethantang-db Aug 13, 2025
dde0142
trying out 16 samples
ethantang-db Aug 14, 2025
26b6cc4
more debug
ethantang-db Aug 14, 2025
2e1b64a
trying out something else
ethantang-db Aug 14, 2025
7042619
further debugging
ethantang-db Aug 14, 2025
a768a0e
debug
ethantang-db Aug 14, 2025
136d687
double checking type
ethantang-db Aug 14, 2025
d46ae8d
more check
ethantang-db Aug 14, 2025
2289ee7
flipping this
ethantang-db Aug 14, 2025
46e72b6
more corrections
ethantang-db Aug 14, 2025
27a52f2
more debugging
ethantang-db Aug 14, 2025
a394e79
stack
ethantang-db Aug 14, 2025
bdeb027
fix bs
ethantang-db Aug 14, 2025
2647386
more fix
ethantang-db Aug 14, 2025
c0c3dae
most random comma
ethantang-db Aug 14, 2025
45b08fe
double checking shapes
ethantang-db Aug 14, 2025
f8bfcce
should be stack
ethantang-db Aug 14, 2025
4075679
cleaning up stuff
ethantang-db Aug 14, 2025
7b62267
let's do a run
ethantang-db Aug 14, 2025
748168e
change cluster
ethantang-db Aug 14, 2025
06248b7
the fix
ethantang-db Aug 14, 2025
d1bf233
white space
ethantang-db Aug 14, 2025
9af0e4d
added checks for proper configs
ethantang-db Aug 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 36 additions & 33 deletions compose_rl/algorithms/online/callback_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,42 +15,45 @@ def preprocess_batches(batches: list, generations_per_prompt: int, pad_token_idx

padding_key = None
for batch in batches:
# Explode the batch into multiple batches for each generation
for _ in range(generations_per_prompt):
# For keys that do not require additional processing
if key in [
'prompt_len',
'verified_answer',
'prompt_id',
'vstar',
'messages',
]:
curr_values.append(batch[key])
continue

bs, seq_len = batch[key].shape

if key == 'prompt':
padding_key = pad_token_idx
if (batch[key][:, -1] == padding_key).any():
raise ValueError(
'The last token in the prompt should not be the pad token. Please double '
+
'check the dataloader and prompt and dataloader.',
)
elif key == 'prompt_attention_mask':
padding_key = False

# Compute the required padding and concatenate with the batch tensor
pad = torch.ones(
(bs, max_len - seq_len),
dtype=batch[key].dtype,
) * padding_key # type: ignore
curr_values.append(torch.cat([pad, batch[key]], dim=-1))
# inside the batch, it's a dictionary of tensors that have the batch dimension there,
# so we need to iterate through each element to explode it.
for item in batch[key]:
# Explode the batch into multiple batches for each generation
for _ in range(generations_per_prompt):
# For keys that do not require additional processing
if key in [
'prompt_len',
'verified_answer',
'prompt_id',
'vstar',
'messages',
]:
curr_values.append(item)
continue

seq_len, = item.shape # expect this to be a 1D tensor

if key == 'prompt':
padding_key = pad_token_idx
if (item[-1] == padding_key).any():
raise ValueError(
'The last token in the prompt should not be the pad token. Please double '
+
'check the dataloader and prompt and dataloader.',
)
elif key == 'prompt_attention_mask':
padding_key = False

# Compute the required padding and concatenate with the batch tensor
pad = torch.ones(
max_len - seq_len,
dtype=item.dtype,
) * padding_key # type: ignore
curr_values.append(torch.cat([pad, item], dim=-1))

# For tensor fields, use torch.cat to combine the values; for string fields, just use the list
if isinstance(curr_values[0], torch.Tensor):
ret_batch[key] = torch.cat(curr_values)
ret_batch[key] = torch.stack(curr_values)
else:
if key in ['verified_answer', 'vstar']:
ret_batch[key] = list(flatten(curr_values))
Expand Down
24 changes: 17 additions & 7 deletions test_single_controller_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,17 @@ def time_it(name: str):
print(f"[{name}] took {end_time - start_time:.2f} seconds")


def get_and_validate_num_prompts_per_iteration(config: Any):
generations_per_prompt = config.variables.generations_per_prompt
num_batches_per_update = config.variables.num_batches_per_update
total_num_generations = config.global_train_batch_size * num_batches_per_update
num_prompts_per_iteration = total_num_generations // generations_per_prompt

assert total_num_generations % generations_per_prompt == 0, "total_num_generations must be divisible by generations_per_prompt"

return num_prompts_per_iteration


class DistributedGPUActor(BaseDistributedGPUActor):
"""Distributed GPU actor for testing."""

Expand Down Expand Up @@ -628,14 +639,8 @@ def __init__(self, config: Any):
self.dataloader_config['dataset']['local'].format(timestamp=timestamp)

# Key variables
global_train_batch_size = config.global_train_batch_size
self.generations_per_prompt = config.variables.generations_per_prompt
num_batches_per_update = config.variables.num_batches_per_update
total_num_generations = global_train_batch_size * num_batches_per_update
self.num_prompts_per_iteration = total_num_generations // self.generations_per_prompt

# Validate that the total number of generations is divisible by the number of generations per prompt
assert total_num_generations % self.generations_per_prompt == 0, "total_num_generations must be divisible by generations_per_prompt"
self.num_prompts_per_iteration = get_and_validate_num_prompts_per_iteration(config)

# Creating main entities
self.tokenizer = self._build_tokenizer()
Expand Down Expand Up @@ -768,6 +773,9 @@ def _run_single_controller_ppo(
config=config,
)

num_prompts_per_iteration = get_and_validate_num_prompts_per_iteration(config)
assert num_prompts_per_iteration % num_train_actors == 0, "Number of prompts per iteration must be divisible by number of train actors to ensure accurate advantage calculations."

# We are using a CPU worker for the StreamingActor
# and this involves a super hacky workaround by
# uninstalling megablocks if it exists. Better solutions
Expand All @@ -785,6 +793,8 @@ def _run_single_controller_ppo(
streaming_dataset_actor = ray.remote(num_gpus=0)(StreamingDatasetActor).remote(config)
rollout_agent = RolloutAgent(inference_server, streaming_dataset_actor, config)



# EvalAgent doesn't need to be a Ray actor since we don't need to
# set a world_size or use GPUs for this process.
eval_agent = EvalAgent(inference_server.engines, config)
Expand Down