Skip to content

minimal Pytorch implementation of DeepSeek's Multi Head Latent Attention + benchmarks

Notifications You must be signed in to change notification settings

sshkhr/mla-pytorch

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

10 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Multi-Head Latent Attention (MLA) in PyTorch

Companion code for the blog post Understanding Multi-Head Latent Attention, part of a series on DeepSeek's FlashMLA.

Attention Implementations

Clean, minimal PyTorch implementations of attention mechanisms discussed in the blog:

Module Class Description
mla/sdpa.py ScaledDotProductAttention Scaled dot-product attention (Vaswani et al., 2017)
mla/mha.py MultiHeadAttention Standard multi-head attention with optional KV cache
mla/gqa.py GroupedQueryAttention Grouped-query attention (Ainslie et al., 2023)
mla/mqa.py MultiQueryAttention Multi-query attention (Shazeer, 2019)
mla/mla.py MultiHeadLatentAttention MLA with low-rank KV compression (DeepSeek-AI, 2024)
mla/mla.py MultiHeadLatentAttentionAbsorbed MLA with weight absorption trick

Experiments

Scripts that reproduce the figures from the blog post:

Script Figure Description
experiments/plot_attention_memory_scaling.py attention_memory_scaling.png O(N^2) attention matrix memory scaling
experiments/plot_attention_compute_scaling.py attention_compute_scaling.png O(N^2) attention FLOPs and wall-clock time scaling
experiments/plot_kv_cache_benchmark.py kv_cache_benchmark.png Decoding speedup with KV caching
experiments/plot_attention_variants_kv_cache.py attention_variants_kv_cache.png KV cache size: MHA vs GQA vs MQA

Setup

pip install -r requirements.txt

Usage

from mla import MultiHeadAttention, MultiHeadLatentAttention

# Standard MHA
mha = MultiHeadAttention(d_model=512, n_heads=8)
output = mha(x)  # x: (batch, seq_len, 512)

# MLA with latent compression (KV + query compression)
mla = MultiHeadLatentAttention(d_model=512, n_heads=8, d_c=64, d_cq=96)
output = mla(x, use_cache=True)  # caches compressed latent

Run experiments:

python experiments/plot_attention_memory_scaling.py
python experiments/plot_attention_compute_scaling.py
python experiments/plot_kv_cache_benchmark.py
python experiments/plot_attention_variants_kv_cache.py

Figures are saved to figures/.

About

minimal Pytorch implementation of DeepSeek's Multi Head Latent Attention + benchmarks

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages