Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,15 @@ __pycache__/
.git/*
*.pyc
push.sh
birdie_rl/birdie_reward_model/example_usage/compress.py
birdie_rl/birdie_reward_model/example_usage/compress.py

# Build artifacts
build/
dist/
*.egg-info/
*.so
*.pyd
*.dll
*.dylib
*.o
*.obj
264 changes: 141 additions & 123 deletions README.md

Large diffs are not rendered by default.

28 changes: 27 additions & 1 deletion birdie_rl/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,27 @@
from .birdie_reward_model import *
"""
Top-level package exports.

Birdie depends on PyTorch; on some systems a broken/mismatched CUDA install can
make `import torch` fail (e.g. missing shared libs). We keep `birdie_rl`
importable in that case so users can still work with the text/objective/pipeline
utilities, and raise a helpful error only when `Birdie` is accessed.
"""

from __future__ import annotations

from typing import Any

__all__ = ["Birdie"]

def __getattr__(name: str) -> Any: # pragma: no cover
if name != "Birdie":
raise AttributeError(name)
try:
from .birdie_reward_model import Birdie
except Exception as exc: # noqa: BLE001 - want to catch CUDA/torch import failures too
raise ImportError(
"Failed to import `Birdie` (PyTorch/CUDA runtime issue). "
"If you recently changed GPUs/drivers, reinstall a matching PyTorch build, "
"or install a CPU-only torch wheel for data-pipeline debugging."
) from exc
return Birdie
18 changes: 15 additions & 3 deletions birdie_rl/birdie_reward_model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,20 @@
model = train_step(model, batch)
"""

# Import the Birdie class from birdie.py so it is accessible at the package level.
from .birdie import Birdie
from __future__ import annotations

from typing import Any

# Restricts __all__ variable to simplify a code running "from birdie_reward_model import *"
__all__ = ["Birdie"]

def __getattr__(name: str) -> Any: # pragma: no cover
if name != "Birdie":
raise AttributeError(name)
try:
from .birdie import Birdie
except Exception as exc: # noqa: BLE001 - want to catch torch/datasets import failures too
raise ImportError(
"Failed to import `Birdie` (optional dependencies/runtime issue). "
"This does not affect importing submodules like `agent_bird` for synthetic runs."
) from exc
return Birdie
80 changes: 61 additions & 19 deletions birdie_rl/birdie_reward_model/agent_bird.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,9 +392,31 @@ def __init__(self, **kwargs):
# Merge any user-provided kwargs with the class defaults
self.__dict__.update(kwargs)

# Default device
if self.device is None:
self.device = "cuda" if torch.cuda.is_available() else "cpu"

# Reward model sequence length cap (affects memory/throughput).
self.model_max_seq_len = int(kwargs.get("model_max_seq_len", gated_ssm_agent.default_max_seq_len))

# Optional toggles (useful for fast synthetic runs / testing).
self.use_torch_compile = bool(kwargs.get("use_torch_compile", True))
self.disable_tqdm = bool(kwargs.get("disable_tqdm", False))

# Ensure num_objectives is an int
self.num_objectives = int(self.num_objectives)

# Default explore classes: treat each objective as its own "class".
if self.explore_classes is None:
self.explore_classes = np.arange(self.num_objectives, dtype=np.int32)
else:
self.explore_classes = np.asarray(self.explore_classes, dtype=np.int32)
if self.explore_classes.shape[0] != self.num_objectives:
raise ValueError(
f"explore_classes must have length num_objectives={self.num_objectives}, "
f"got shape={self.explore_classes.shape}"
)

# If reward_signal_dims not specified, set it to num_objectives
if self.reward_signal_dims is None:
self.reward_signal_dims = self.num_objectives
Expand Down Expand Up @@ -477,8 +499,15 @@ def __init__(self, **kwargs):
output_dim=self.num_objectives,
hidden_dims=self.hidden_dims,
dropout_rate=self.dropout_rate,
max_seq_len=self.model_max_seq_len,
device=self.device,
)
self.model = torch.compile(self.model).to(self.device)
self.model = self.model.to(self.device)
if self.use_torch_compile and hasattr(torch, "compile"):
try:
self.model = torch.compile(self.model)
except Exception:
pass

# Initialize optimizer (AdamW) for the model
self.optimizer = AdamW(self.model.parameters(), lr=self.lr, weight_decay=0.1)
Expand Down Expand Up @@ -597,8 +626,8 @@ def predict_rewards(self, test_input):
self.model.eval()
test_input = test_input.to(self.device)

# By default, we pass the entire sequence. If we want a single step, we might clamp it.
current_seq_len = test_input.shape[1]
seq_len_limit = self.model_max_seq_len
current_seq_len = min(test_input.shape[1], seq_len_limit)

# Optionally apply a 'conversion_model' if set (unused by default).
if self.conversion_model is not None:
Expand All @@ -608,19 +637,23 @@ def predict_rewards(self, test_input):
test_input[chunk_idx:chunk_idx + chunk_size, :, -self.num_objectives:] = \
self.conversion_model(test_input[chunk_idx:chunk_idx + chunk_size, :, -self.num_objectives:])

# We pad the input if shorter than the default.
# In practice, the MLP might not need strict padding, but this is an example from the code.
# "gated_ssm_agent.default_max_seq_len" is used to ensure shape consistency.
test_input = torch.cat([
test_input,
torch.zeros(
(
test_input.shape[0],
gated_ssm_agent.default_max_seq_len - current_seq_len,
*test_input.shape[2:],
), device=test_input.device,
),
], dim=1)
# Keep a bounded context window and pad to the model's expected sequence length.
if test_input.shape[1] > seq_len_limit:
test_input = test_input[:, -seq_len_limit:]
else:
pad_length = seq_len_limit - test_input.shape[1]
if pad_length > 0:
test_input = torch.cat([
test_input,
torch.zeros(
(
test_input.shape[0],
pad_length,
*test_input.shape[2:],
),
device=test_input.device,
),
], dim=1)

# Run inference in chunks to avoid OOM if large
chunks = []
Expand All @@ -630,13 +663,15 @@ def predict_rewards(self, test_input):
disable = (not self.accelerator.is_main_process)
else:
disable = False
disable = disable or self.disable_tqdm

with torch.no_grad():
for idx in tqdm(range(0, test_input.shape[0], chunk_size), desc="Predicting rewards... (done in chunks to save VRAM)", disable=disable):
chunk = test_input[idx:idx + chunk_size]
# The model returns shape [batch_size, seq_len, output_dim].
# We pick the last time step or a specific one (just an example).
chunk_output = self.model(chunk, current_seq_len=current_seq_len)[..., current_seq_len, :]
target_idx = max(0, current_seq_len - 1)
chunk_output = self.model(chunk, current_seq_len=current_seq_len)[..., target_idx, :]
chunks.append(chunk_output)
sampled_preds = torch.cat(chunks, dim=0)
self.model.train()
Expand Down Expand Up @@ -840,6 +875,13 @@ def append_to_history(self, action_taken, old_loss_vector, new_loss_vector, obse
else:
self.train_y = torch.cat([self.train_y, observed_reward], dim=-2).float()

# Keep a bounded context window (the reward model is limited to this anyway).
seq_len_limit = self.model_max_seq_len
if self.X is not None and self.X.shape[1] > seq_len_limit:
self.X = self.X[:, -seq_len_limit:]
if self.train_y is not None and self.train_y.shape[1] > seq_len_limit:
self.train_y = self.train_y[:, -seq_len_limit:]

return observed_reward

def update(
Expand Down Expand Up @@ -897,7 +939,7 @@ def update(
y = self.train_y.to(self.device).float()

# We limit the sequence length if it is bigger than the MLP can handle
seq_len_limit = gated_ssm_agent.default_max_seq_len
seq_len_limit = self.model_max_seq_len
x = x[:, -seq_len_limit:]
y = y[:, -seq_len_limit:]

Expand Down Expand Up @@ -936,7 +978,7 @@ def update(

num_iterations = int(num_iterations)

progress_bar = tqdm(range(num_iterations), desc="Training Agent Bird...", leave=False)
progress_bar = tqdm(range(num_iterations), desc="Training Agent Bird...", leave=False, disable=self.disable_tqdm)
loss_val = 999
# We'll do a small training loop
with torch.enable_grad():
Expand Down
106 changes: 85 additions & 21 deletions birdie_rl/birdie_reward_model/gated_ssm_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,23 @@

# Rotary embeddings for Q,K
from birdie_rl.birdie_reward_model import rotary
from torch.nn.attention.flex_attention import create_block_mask
create_block_mask = torch.compile(create_block_mask)
try:
from torch.nn.attention.flex_attention import create_block_mask as _create_block_mask
from torch.nn.attention.flex_attention import flex_attention as _flex_attention
_HAS_FLEX_ATTENTION = True
except Exception:
_create_block_mask = None
_flex_attention = None
_HAS_FLEX_ATTENTION = False

if _HAS_FLEX_ATTENTION and hasattr(torch, "compile"):
try:
create_block_mask = torch.compile(_create_block_mask)
except Exception:
create_block_mask = _create_block_mask
else:
create_block_mask = _create_block_mask

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_float32_matmul_precision("medium")
Expand Down Expand Up @@ -167,6 +182,8 @@ def __init__(
# For demonstration: the number of heads for Q can differ from K/V if GQA is used
self.num_heads = dims // head_dims
self.gqa_num_heads = 2 # e.g., grouping for keys/values
if self.num_heads < self.gqa_num_heads or (self.num_heads % self.gqa_num_heads) != 0:
self.gqa_num_heads = 1

# Dimensions for Q, K, V
q_dims = head_dims * self.num_heads
Expand All @@ -183,14 +200,16 @@ def __init__(
# RMS-split for input norm
self.norm = RMS_split(dims)

def forward(self, x, block_mask=None):
def forward(self, x, block_mask=None, attn_mask=None):
"""
Forward pass for MHA with optional block_mask.

Args:
x (Tensor): shape [batch, seq_len, dims]
block_mask (Tensor or None): shape e.g. [1,1,Q_LEN,KV_LEN],
for controlling which positions can attend.
attn_mask (Tensor or None): bool or float attention mask broadcastable to
[batch, heads, Q_LEN, KV_LEN] (used for SDPA fallback).
Returns:
Tensor: shape [batch, seq_len, dims]
"""
Expand All @@ -212,14 +231,34 @@ def forward(self, x, block_mask=None):
# Project V similarly, shape [b, gqa_num_heads, seq, dims]
v = einops.rearrange(self.v_proj(x), 'b s (h d) -> b h s d', h=self.gqa_num_heads)

# Use the custom "flex_attention" function from PyTorch's experimental module
mha_out = torch.nn.attention.flex_attention.flex_attention(
query=q,
key=k,
value=v,
block_mask=block_mask,
enable_gqa=True, # if we have gqa_num_heads < num_heads for Q
)
use_flex_attention = _HAS_FLEX_ATTENTION and (block_mask is not None) and (_flex_attention is not None)
if use_flex_attention:
# Use the custom "flex_attention" function from PyTorch's experimental module
mha_out = _flex_attention(
query=q,
key=k,
value=v,
block_mask=block_mask,
enable_gqa=(self.gqa_num_heads != self.num_heads),
)
else:
# Fallback to SDPA (CPU-safe, available in stable PyTorch).
if k.shape[1] != q.shape[1]:
if (q.shape[1] % k.shape[1]) != 0:
raise ValueError(
f"GQA head mismatch: q_heads={q.shape[1]} not divisible by kv_heads={k.shape[1]}"
)
repeat_factor = q.shape[1] // k.shape[1]
k = k.repeat_interleave(repeat_factor, dim=1)
v = v.repeat_interleave(repeat_factor, dim=1)
mha_out = F.scaled_dot_product_attention(
q,
k,
v,
attn_mask=attn_mask,
dropout_p=0.0,
is_causal=False,
)

# Rearrange back
mha_out = einops.rearrange(mha_out, 'b h s d -> b s (h d)')
Expand Down Expand Up @@ -389,20 +428,45 @@ def mask_mod(b, h, q_idx, kv_idx):
causal_mask = (q_idx >= kv_idx)
return causal_mask & is_valid_loc

# Build the block_mask using create_block_mask if needed
block_mask = create_block_mask(
mask_mod,
B=x.shape[0], # or 1 if we want same mask for entire batch
H=1, # or num_heads
Q_LEN=self.sequence_length,
KV_LEN=self.sequence_length,
device=x.device,
)
block_mask = None
attn_mask = None

if current_seq_len is None:
current_seq_len = x.shape[1]
current_seq_len = int(current_seq_len)

if create_block_mask is not None:
# Build the block_mask using create_block_mask if available.
block_mask = create_block_mask(
mask_mod,
B=x.shape[0], # or 1 if we want same mask for entire batch
H=1, # or num_heads
Q_LEN=x.shape[1],
KV_LEN=x.shape[1],
device=x.device,
)
else:
# SDPA fallback mask:
# - causal for tokens < current_seq_len
# - ensure rows for padded queries still have at least one valid entry
# to avoid NaNs (we allow self-attend on padded positions).
seq_len = x.shape[1]
q_idx = torch.arange(seq_len, device=x.device)[:, None]
kv_idx = torch.arange(seq_len, device=x.device)[None, :]
causal = q_idx >= kv_idx

if current_seq_len >= seq_len:
attn_mask = causal
else:
valid_q = q_idx < current_seq_len
valid_kv = kv_idx < current_seq_len
diag = q_idx == kv_idx
attn_mask = (valid_q & valid_kv & causal) | (~valid_q & diag)

# Pass through each layer
for layer in self.layers:
if isinstance(layer, MHA):
x = layer(x, block_mask=block_mask)
x = layer(x, block_mask=block_mask, attn_mask=attn_mask)
else:
x = layer(x)

Expand Down
Loading