Skip to content

alishafique3/KV-Caching-From-Scratch-Pytorch

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

9 Commits
 
 

Repository files navigation

KV Caching From Scratch Pytorch

Last week, I explored KV caching from scratch while working on a GPT-2 model. I ran my experiments on a Colab T4 GPU to better understand how caching improves inference speed in large language models.

𝗣𝗿𝗼𝗯𝗹𝗲𝗺?

In autoregressive generation, LLMs generate one token at a time, and each new token has to attend to all previous tokens. So if your model generated: "White → Fluffy → Cat" …the attention block still recomputes the Keys and Values for "White" and "Fluffy" every time.

That’s a lot of unnecessary computation, especially as the output grows longer.

𝗧𝗵𝗲 𝗞𝗩 𝗖𝗮𝗰𝗵𝗲 𝗦𝗼𝗹𝘂𝘁𝗶𝗼𝗻:

I implemented a caching mechanism where: • The model caches the Keys & Values for input tokens during prefill. • For each new token, it only computes the K/V for that token. • Previous tokens just pull from the cache, no recompute needed.

attention_cache

📊 𝗥𝗲𝘀𝘂𝗹𝘁𝘀:

Tested on a Colab T4 GPU for GPT-2 using different output lengths:

KV Cache

𝗢𝗻𝗲 𝗙𝘂𝗻 𝗜𝗻𝘀𝗶𝗴𝗵𝘁:

For shorter outputs, KV caching doesn't always help. In my tests, device communication overhead on CUDA sometimes outweighed the gains for small models like GPT-2.

Shoutout to Sebastian Raschka, PhD for his amazing blog post on KV Caching

References

  1. Attention in Transformers, Step-by-Step | Deep Learning Chapter 6
  2. Understanding and Coding the KV Cache in LLMs from Scratch
  3. Mastering Tensor Dimensions in Transformers
  4. The Illustrated Transformer – Jay Alammar
  5. Transformers KV Caching Explained | João Lages
  6. LLM Inference Series: 3. KV Caching Explained | Pierre Lienhart
  7. tanishqkumar/beyond-nanogpt: Minimal, annotated implementations

About

This project explores KV Caching in LLMs by implementing it from scratch in GPT-2 and benchmarking its impact on inference speed. It highlights how caching Key/Value pairs during decoding significantly reduces redundant computation and accelerates generation, achieving up to 7× speedup on long outputs.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors