Skip to content

Product keys in FF#163

Open
wojsza05 wants to merge 8 commits intomainfrom
bgw/pk
Open

Product keys in FF#163
wojsza05 wants to merge 8 commits intomainfrom
bgw/pk

Conversation

@wojsza05
Copy link
Contributor

@wojsza05 wojsza05 commented Feb 4, 2026

No description provided.

seq_len,
rope_base,
rope_scale_freqs: bool,
causal=True,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do not add default values here, move it to defaults.yaml

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see a file with default values ​​for class arguments. Should I add an argument to the 17 configuration files where the RoPEAttention class is used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

However, I've removed this argument completely. I added it earlier because I thought it might be useful to someone in the future, but we don't use this class at the moment, so I won't bother.

Copy link
Member

@crewtool crewtool left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as in comment

)

# Weighted sum of retrieved values
out_heads = (values_selected * attn_weights.unsqueeze(-1)).sum(dim=2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can replace this with a single matrix multiplication. By doing so, no intermediate array is created, and memory usage is reduced

Comment on lines +404 to +408
out_heads = (values_selected * attn_weights.unsqueeze(-1)).sum(dim=2)

# 6. Aggregation
# Sum outputs across all heads
output = out_heads.sum(dim=1) # (BS, d_model)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe you can fuse these operations as well, reducing memory usage even further.

smth like this should work:
output = torch.einsum('b h k, b h k d -> b d', attn_weights, values_selected)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that torch.nn.functional.embedding_bag could be used here as well – that way the whole values matrix won't be instantiated.

# Flatten indices and weights to (Batch * Heads, K_neighbors)
# We treat (Batch * Seq * Heads) as the "bag" dimension
input_flat = memory_indices.view(-1, self.k)
weights_flat = attn_weights.view(-1, self.k)

# Fused Lookup + Weighted Sum
# Output shape: (BS * Seq * Heads, D)
# This avoids creating the (BS, H, K, D) tensor entirely
out_flat = F.embedding_bag(
    input_flat, 
    self.values.weight, 
    per_sample_weights=weights_flat, 
    mode='sum'
)

# Reshape and sum over heads
out_flat = out_flat.view(bs, seq_len, self.n_heads, d_model)
output = out_flat.sum(dim=2)  # (BS, Seq, D)

Copy link
Contributor

@mtboro mtboro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall the PR looks good to me. Take a look at suggestions that would improve memory usage, and resolve the suggestions made by @crewtool


# Calculate similarity between full Q and the reconstructed candidates
# q needs unsqueeze to broadcast: (B, H, S, 1, D) @ (B, H, S, K*K, D).T
# TODO

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this TODO done? if yes, remove this comment line

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I've addressed this on this branch: https://github.com/llm-random/nano/tree/bgw/pk_attn_update

The pr will follow after the current one is merged

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a suggestion from a previous PR that something could be done more optimally. It's not done yet.

Comment on lines +333 to +334
self.c1 = nn.Parameter(torch.randn(n_heads, n_sub_keys, query_dim // 2))
self.c2 = nn.Parameter(torch.randn(n_heads, n_sub_keys, query_dim // 2))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm think that adding std=d_model**-0.5 might slightly improve convergence

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants