Conversation
src/core/model.py
Outdated
| seq_len, | ||
| rope_base, | ||
| rope_scale_freqs: bool, | ||
| causal=True, |
There was a problem hiding this comment.
We do not add default values here, move it to defaults.yaml
There was a problem hiding this comment.
I don't see a file with default values for class arguments. Should I add an argument to the 17 configuration files where the RoPEAttention class is used?
There was a problem hiding this comment.
However, I've removed this argument completely. I added it earlier because I thought it might be useful to someone in the future, but we don't use this class at the moment, so I won't bother.
src/product_keys/model.py
Outdated
| ) | ||
|
|
||
| # Weighted sum of retrieved values | ||
| out_heads = (values_selected * attn_weights.unsqueeze(-1)).sum(dim=2) |
There was a problem hiding this comment.
you can replace this with a single matrix multiplication. By doing so, no intermediate array is created, and memory usage is reduced
src/product_keys/model.py
Outdated
| out_heads = (values_selected * attn_weights.unsqueeze(-1)).sum(dim=2) | ||
|
|
||
| # 6. Aggregation | ||
| # Sum outputs across all heads | ||
| output = out_heads.sum(dim=1) # (BS, d_model) |
There was a problem hiding this comment.
I believe you can fuse these operations as well, reducing memory usage even further.
smth like this should work:
output = torch.einsum('b h k, b h k d -> b d', attn_weights, values_selected)
There was a problem hiding this comment.
I think that torch.nn.functional.embedding_bag could be used here as well – that way the whole values matrix won't be instantiated.
# Flatten indices and weights to (Batch * Heads, K_neighbors)
# We treat (Batch * Seq * Heads) as the "bag" dimension
input_flat = memory_indices.view(-1, self.k)
weights_flat = attn_weights.view(-1, self.k)
# Fused Lookup + Weighted Sum
# Output shape: (BS * Seq * Heads, D)
# This avoids creating the (BS, H, K, D) tensor entirely
out_flat = F.embedding_bag(
input_flat,
self.values.weight,
per_sample_weights=weights_flat,
mode='sum'
)
# Reshape and sum over heads
out_flat = out_flat.view(bs, seq_len, self.n_heads, d_model)
output = out_flat.sum(dim=2) # (BS, Seq, D)
|
|
||
| # Calculate similarity between full Q and the reconstructed candidates | ||
| # q needs unsqueeze to broadcast: (B, H, S, 1, D) @ (B, H, S, K*K, D).T | ||
| # TODO |
There was a problem hiding this comment.
is this TODO done? if yes, remove this comment line
There was a problem hiding this comment.
I think I've addressed this on this branch: https://github.com/llm-random/nano/tree/bgw/pk_attn_update
The pr will follow after the current one is merged
There was a problem hiding this comment.
This is a suggestion from a previous PR that something could be done more optimally. It's not done yet.
src/product_keys/model.py
Outdated
| self.c1 = nn.Parameter(torch.randn(n_heads, n_sub_keys, query_dim // 2)) | ||
| self.c2 = nn.Parameter(torch.randn(n_heads, n_sub_keys, query_dim // 2)) |
There was a problem hiding this comment.
I'm think that adding std=d_model**-0.5 might slightly improve convergence
No description provided.