Skip to content

Latest commit

 

History

History
1364 lines (1075 loc) · 44.3 KB

File metadata and controls

1364 lines (1075 loc) · 44.3 KB

LLM Response Time Optimizer - Execution Plan (2-3 Weeks)

🎯 CURRENT PROGRESS (Updated: 2026-02-18)


Phase 5: Mistral-7B Integration - IN PROGRESS 🔄

Status: Core generation working. lax.scan decode optimization in progress.

Phase 5 Results So Far (Mistral-7B-Instruct-v0.2, A100-40GB):

Metric Value
Average JAX latency 10.88s / 50 tokens
Average PyTorch latency 24.68s / 50 tokens
Average speedup 2.27x
Average tokens/sec 4.60 tok/s

Why these numbers differ from the GPT-2 16.32x result — IMPORTANT:

The GPT-2 16.32x speedup compared our uncached JAX vs our cached JAX — i.e., quadratic recomputation (O(n²)) vs linear KV-cache (O(n)). Eliminating ~99% of redundant computation naturally gives a large multiplier.

The Mistral 2.27x compares our cached JAX vs PyTorch generate(), which ALREADY uses KV cache internally. Both sides are O(n). We are competing on JIT compilation efficiency and memory layout, not on eliminating quadratic work. This is a fundamentally harder comparison.

Current bottleneck: Python-level decode loop. Each of 50 decode steps dispatches ~10 JAX ops per layer × 32 layers = ~16,000 Python→JAX round trips. Switching to jax.lax.scan traces the entire loop into one XLA program (1 dispatch), eliminating this overhead and enabling cross-step XLA optimization.


Phase 4: KV-Cache + JIT Optimization - COMPLETED ✅

Major Achievements:

  • Implemented KV-Cache utilities (src/kv_cache.py)

    • initialize_cache(): Create empty cache structure
    • update_cache(): Store new K,V at position
    • get_cached_kv(): Retrieve cached K,V values
  • Built complete manual transformer (src/cached_generation.py)

    • Multi-head attention with cache support
    • Full transformer layer (attention + MLP + layer norm)
    • Embedding layer and LM head
    • Autoregressive generation loop
  • Fixed critical cache accumulation bug

    • Bug: Non-cached mode only attended to current token
    • Fix: Always accumulate K,V regardless of use_cache flag
    • Result: Both modes now produce identical outputs
  • Implemented true non-cached mode with batch processing

    • batch_attention(): Process all tokens in parallel
    • Dual-mode transformer_layer(): Choose cached vs batch
    • Two prefill strategies: Token-by-token vs batch
    • Updated get_embeddings() for both single/batch positions
  • Fixed critical attention output projection bug

    • Bug: Missing transpose for attention c_proj weight
    • Impact: JAX model generated gibberish ("the the the...")
    • Investigation: Layer-by-layer comparison revealed 130+ divergence in first layer
    • Root cause: c_proj weight used without transpose (line 387)
    • Fix: Added c_proj_weight = c_proj_weight.T
    • Result: JAX now produces identical text to PyTorch!
  • Applied JIT compilation to core functions

    • Functions JIT-compiled: split_heads, merge_heads, compute_qkv, causal_mask, batch_attention, layer_norm, mlp, lm_head
    • Used functools.partial with static_argnums for compile-time constants
    • Static arguments: num_heads, seq_len, model_type
    • Technique: @partial(jax.jit, static_argnums=(n,)) decorator pattern
    • Result: KV-Cache + JIT combined for 16.32x speedup (cannot measure separately with decorator approach)

Final Phase 4 Results (GPT-2):

Metric Non-Cached Cached + JIT Improvement
Speed 1.50 tok/s 24.45 tok/s 16.32x
Time (15 tokens) 10.03s 0.63s 16.32x faster
Memory (INT8) 163MB 163MB + cache 2.00x reduction

Status:

  • ✅ Phase 4 COMPLETE for GPT-2!
  • ✅ Total speedup: 16.32x on GPT-2 (measured, far exceeds 2-3x target)
  • ✅ KV-Cache + JIT combined optimization (cannot measure separately with decorator approach)
  • ✅ Applied JIT decorators to 8 core functions: split_heads, merge_heads, compute_qkv, causal_mask, batch_attention, layer_norm, mlp, lm_head
  • ✅ Text quality: 100% match with PyTorch GPT-2
  • ✅ All generation tests passing on GPT-2
  • ✅ Fixed critical warmup issue: JAX JIT compiles separately for each sequence length
  • ⚠️ NOTE: All Phase 4 work done on GPT-2. Mistral-7B not yet implemented.
  • ⏭️ NEXT: Phase 5 - Implement Mistral-7B support & Benchmarking

Important Lesson Learned: JAX JIT compilation is shape-dependent. During autoregressive generation, each token produces a different sequence length, triggering separate compilations. Warmup must generate at least as many tokens as the actual benchmark to ensure all shapes are pre-compiled for accurate performance measurement.

Completed Phases:

Phase 1: Foundation & Environment Setup

  • Environment configured with JAX, Flax, Transformers
  • Project structure created
  • JAX basics understood

Phase 2: Model Conversion (PyTorch → JAX)

  • Manual conversion implementation complete
  • Weight transposition and PyTree structure working
  • Tested and validated on GPT-2
  • Note: Basic conversion tested on Mistral-7B, but full implementation incomplete

Phase 3: INT8 Quantization

  • Simple symmetric quantization implemented
  • Memory reduction: 2.00x (GPT-2: 326MB → 163MB)
  • Quantization working and tested on GPT-2
  • Note: Mistral quantization code exists but not fully tested

Upcoming Work:

Immediate (Current Session):

  1. Investigate text quality issues (repetitive generation)
  2. Compare with PyTorch baseline output
  3. Debug and fix if needed

Phase 4 Remaining:

  • Apply JIT compilation for further optimization
  • Benchmark JIT performance improvements
  • Achieve additional speedup beyond current 11.80x

Phase 5: Benchmarking & Analysis

  • Comprehensive benchmarks with Mistral-7B
  • Quality evaluation (ROUGE scores)
  • Performance visualizations
  • Final documentation and demo

Learning Philosophy

This is YOUR learning journey - ACCELERATED VERSION. This plan compresses the original 7-8 week timeline into 2-3 weeks through:

  • Parallel task execution where possible
  • Leveraging existing implementations (Hugging Face) as starting points
  • Focus on core optimizations over comprehensive testing
  • Claude provides more scaffolding to maintain pace

You will write the code. Claude Code will:

  • Provide more starter code/templates to save time
  • Help you implement critical sections faster
  • Review and debug more proactively
  • Still ensure you understand the concepts (but faster)

Prerequisites

Knowledge Requirements

  • Python: Intermediate level (OOP, decorators, type hints)
  • Machine Learning Basics: Understanding of transformers, attention mechanism, autoregressive generation
  • PyTorch Fundamentals: Model loading, inference, basic operations
  • Git: Basic version control (commit, branch, merge)

What You'll Learn

  • JAX/Flax functional programming paradigm
  • Model conversion between frameworks
  • Quantization techniques (INT8, calibration)
  • Performance optimization (JIT compilation, caching)
  • ML benchmarking and evaluation

Hardware Requirements

Local Development Machine:

  • GPU: NVIDIA GPU with 8GB+ VRAM (RTX 4070 or equivalent)
    • Use 4-bit/8-bit quantization to fit models in 8GB
  • RAM: 16GB+ recommended
  • Storage: 50GB+ free space (for model weights, datasets)

Cloud Resources (for final benchmarking):

  • Google Colab Pro (recommended): T4/V100 GPU with 16GB VRAM
    • Free tier works but may timeout on long benchmarks
  • University HPC (alternative): Submit as batch job for extended runs
  • Purpose: Run full Mistral-7B FP16 and comprehensive 1000-sample benchmarks

Development Strategy

Two-Tier Approach:

Tier 1: Local Development (Days 1-16)

  • Conversion (Days 3-5): Use GPT-2 (124M params, fits in 16GB RAM)
    • Learn transposition logic, PyTree structure
    • Fast iteration without VRAM/cloud dependencies
  • Optimization (Days 8-16): Use 4-bit quantized Mistral (if needed)
    • Quick iteration with 10-100 samples
    • Code, debug, and validate optimizations
    • Fast feedback loop (minutes, not hours)

Tier 2: Cloud Integration (Days 6, 17-19)

  • Day 6 (Conversion): Run Mistral conversion on Colab
    • Apply GPT-2 conversion logic to full Mistral-7B
    • Save converted params, download for local use
  • Days 17-19 (Benchmarking): Final evaluation on Colab
    • Full Mistral-7B FP16 on GPU/TPU
    • Complete 1000-sample Alpaca evaluation
    • Publication-quality results
    • Run overnight, analyze next day

Why this works:

  • ✅ Learn deeply with small model (GPT-2) locally
  • ✅ Same code scales to large model (Mistral) on cloud
  • ✅ Fast local development (no cloud costs during dev)
  • ✅ Professional-grade final results (full model, large dataset)
  • ✅ Best of both worlds (learning depth + production quality)

Project Structure (To Be Created)

LLM_Response_Time_Optimizer/
├── src/
│   ├── __init__.py
│   ├── model_conversion.py       # Phase 2: PyTorch → JAX conversion
│   ├── quantization.py            # Phase 3: INT8 quantization
│   ├── generation.py              # Phase 4: Optimized generation with KV-cache
│   ├── benchmarking.py            # Phase 5: Performance evaluation
│   └── utils.py                   # Helper functions (logging, config, etc.)
├── notebooks/
│   ├── 01_baseline_pytorch.ipynb  # Baseline PyTorch inference
│   ├── 02_jax_conversion.ipynb    # JAX conversion exploration
│   ├── 03_quantization.ipynb      # Quantization experiments
│   ├── 04_optimization.ipynb      # Optimization testing
│   └── 05_demo.ipynb              # Final demonstration
├── benchmarks/
│   ├── results/
│   │   ├── pytorch_baseline.json
│   │   ├── jax_optimized.json
│   │   └── plots/                 # Performance visualization plots
│   └── config.yaml                # Benchmark configuration
├── tests/
│   ├── test_conversion.py         # Numerical validation tests
│   ├── test_quantization.py
│   └── test_generation.py
├── requirements.txt               # Python dependencies
├── setup.py                       # Package setup (optional)
├── .gitignore
├── README.md
└── Execution_Plan.md              # This file

Phase 1: Foundation & Environment Setup (COMPRESSED)

Duration: 2 DAYS (Days 1-2) Goal: Quickly set up environment and grasp JAX/Flax essentials Strategy: Learn-by-doing with minimal theory

Learning Objectives (YOU study these)

  1. JAX Programming Model

    • Pure functions and functional programming in JAX
    • Array tracing and JIT compilation
    • Differences from PyTorch (eager vs. traced execution)
  2. Flax Neural Networks

    • flax.linen module system
    • Parameter management (PyTree structure)
    • Difference between model definition and parameters
  3. Resources to Study

Implementation Tasks (YOU implement)

Task 1.1: Environment Setup

# YOU will run these commands
python -m venv venv
source venv/bin/activate  # On Windows: venv\Scripts\activate

# Create requirements.txt with specific versions
# Install dependencies
pip install -r requirements.txt

Your Action: Create requirements.txt with:

# Core frameworks
jax[cuda12]==0.4.23
flax==0.8.0
transformers==4.36.0
torch==2.1.0

# Optimization & evaluation
datasets==2.16.0
rouge-score==0.1.2
evaluate==0.4.1

# Utilities
numpy==1.24.0
matplotlib==3.8.0
seaborn==0.13.0
jupyter==1.0.0
pytest==7.4.0
pyyaml==6.0.1

# Development
black==23.12.0
isort==5.13.0

Claude's Role: Review your requirements.txt, suggest version compatibility fixes if needed.

Task 1.2: Create Project Structure

Your Action: Create all directories and __init__.py files as shown in project structure above.

Validation Checkpoint:

  • All directories created
  • Virtual environment activated
  • All dependencies installed without errors
  • Can import: import jax, import flax, import transformers

Task 1.3: JAX Crash Course (4 hours)

Your Action: Create notebooks/00_jax_basics.ipynb - MINIMAL version:

  1. Basic operations: arrays, random numbers (30 min)
  2. JIT compilation example (30 min)
  3. PyTree structure basics (30 min)
  4. Read existing Flax model code (2.5 hours)

SKIP: Gradient computation, detailed comparisons Focus: What you need for model conversion

Claude's Role: Provide working JAX/Flax code examples, answer quick questions.


Phase 2: Model Conversion (PyTorch → JAX) (COMPRESSED)

Duration: 5 DAYS (Days 3-7) Goal: Convert Mistral-7B using existing Flax implementation as reference Strategy: Leverage Hugging Face's FlaxMistral, adapt instead of building from scratch

Learning Objectives

  1. Transformer Architecture in Flax

    • Attention mechanism implementation
    • Layer normalization, feed-forward networks
    • Rotary position embeddings (RoPE)
  2. Parameter Conversion

    • PyTorch state dict structure
    • Flax PyTree parameter structure
    • Weight name mapping and reshaping
  3. Resources

Implementation Tasks

Task 2.1: Baseline PyTorch Inference

Your Action: Create notebooks/01_baseline_pytorch.ipynb

  1. Load Mistral-7B using transformers library
  2. Run inference on sample prompts
  3. Measure: latency, memory usage, tokens/sec
  4. Save outputs for comparison
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import time

# YOUR CODE:
# 1. Load model and tokenizer
# 2. Create benchmark function
# 3. Time inference on 10 sample prompts
# 4. Log memory usage (torch.cuda.max_memory_allocated())
# 5. Save outputs to JSON

Validation Checkpoint:

  • Model loads successfully
  • Can generate coherent text
  • Baseline metrics recorded: ~2.5s latency, ~28GB memory
  • Outputs saved for later comparison

Claude's Role: Help debug model loading issues, suggest memory profiling tools.

Task 2.2: Manual Conversion Implementation (LEARNING PATH)

Your Action: Manually convert PyTorch weights to JAX/Flax

Strategy: Two-tier development approach for learning + efficiency

  1. Local (Days 3-5): Develop conversion logic using GPT-2 (124M params, fits in 16GB RAM)
  2. Colab (Day 6): Apply same logic to Mistral-7B (requires 16GB VRAM)

Why This Approach:

  • ✅ Learn the actual conversion process (weight transposition, PyTree structure)
  • ✅ Fast iteration locally without VRAM limitations
  • ✅ Same code works for both models (architecture-agnostic)
  • ✅ Deep understanding of PyTorch ↔ JAX differences

Phase A: Develop with GPT-2 Locally (Days 3-5)

Create src/model_conversion.py with these functions:

Function 1: load_pytorch_model() (Day 3)

def load_pytorch_model(model_name: str, use_small_model: bool = False):
    """
    Load PyTorch model and extract state_dict.

    Args:
        model_name: HuggingFace model identifier
        use_small_model: If True, use GPT-2 for local testing

    Returns:
        state_dict: Dictionary of PyTorch tensors
        tokenizer: Loaded tokenizer
    """
    # Load GPT-2 for local development, Mistral for Colab
    # Extract state_dict (flat dictionary)
    # YOU IMPLEMENT
    pass

Function 2: convert_pytorch_to_jax() (Days 3-4)

def convert_pytorch_to_jax(pytorch_state_dict: Dict[str, torch.Tensor]) -> Dict[str, jnp.ndarray]:
    """
    Convert PyTorch state_dict to JAX arrays.

    Operations:
    1. torch.Tensor → numpy → jax.Array
    2. Transpose linear layer weights: [out, in] → [in, out]
    3. Rename: .weight → .kernel, embed_tokens.weight → .embedding

    YOU IMPLEMENT THIS - CORE LEARNING TASK
    """
    pass

Function 3: build_flax_pytree() (Day 4-5)

def build_flax_pytree(jax_state_dict: Dict[str, jnp.ndarray]) -> Dict[str, Any]:
    """
    Convert flat JAX state_dict to nested Flax PyTree structure.

    Example:
    Flat: {'model.layers.0.self_attn.q_proj.kernel': array(...)}
    Nested: {'model': {'layers': {'0': {'self_attn': {'q_proj': {'kernel': array(...)}}}}}}

    YOU IMPLEMENT THIS
    """
    pass

Function 4: load_flax_model_with_params() (Day 5)

def load_flax_model_with_params(params: Dict[str, Any], model_name: str):
    """
    Initialize FlaxMistralForCausalLM with converted parameters.

    YOU IMPLEMENT THIS
    """
    pass

Testing Strategy:

# Test locally with GPT-2
python tests/test_conversion.py  # Uses GPT-2 by default

# Later on Colab with Mistral
python tests/test_conversion.py --use-mistral  # Full model

Phase B: Run on Mistral in Colab (Day 6)

Once your conversion code works with GPT-2:

  1. Create notebooks/conversion_colab.ipynb
  2. Upload your src/model_conversion.py
  3. Run conversion with use_small_model=False
  4. Save converted Mistral JAX params to Google Drive
  5. Download for local use in Phase 3-4

Validation Checkpoint (Phase A - Local):

  • load_pytorch_model() loads GPT-2 successfully
  • convert_pytorch_to_jax() transposes weights correctly
  • build_flax_pytree() creates nested structure
  • Can initialize Flax model with converted weights

Validation Checkpoint (Phase B - Colab):

  • Same code works with Mistral-7B on Colab
  • Converted Mistral params saved to Drive
  • Can load converted model locally

Claude's Role:

  • Explain weight transposition logic
  • Guide PyTree structure building
  • Debug conversion issues
  • Provide code structure/templates (you implement core logic)

Task 2.3: Numerical Validation

Your Action: Create tests/test_conversion.py

Implement tests:

def test_single_layer_output():
    """Compare single transformer layer output: PyTorch vs Flax"""
    # 1. Create same random input
    # 2. Run through PyTorch layer
    # 3. Run through Flax layer (with converted weights)
    # 4. Assert outputs are close (tolerance: 1e-5)
    # YOUR CODE
    pass

def test_full_model_output():
    """Compare full model logits: PyTorch vs Flax"""
    # Test on multiple inputs (short, long sequences)
    # YOUR CODE
    pass

def test_generation_equivalence():
    """Compare generated text: PyTorch vs Flax"""
    # Same prompt, same random seed → same output tokens
    # YOUR CODE
    pass

Validation Checkpoint:

  • All tests pass
  • Numerical difference < 1e-5 for all layers
  • Generated text is identical (or nearly identical)

Claude's Role: Help debug numerical issues, explain floating-point precision differences, suggest tolerance adjustments.

Task 2.5: Basic JAX Inference

Your Action: Create src/generation.py (basic version)

Implement basic generation (no optimization yet):

import jax
import jax.numpy as jnp
from flax import linen as nn

def generate_text(
    params: dict,
    input_ids: jnp.ndarray,
    max_length: int = 100,
    temperature: float = 1.0
) -> jnp.ndarray:
    """Generate text using JAX model (basic, no KV-cache yet)"""
    # YOUR CODE:
    # 1. Implement autoregressive loop
    # 2. Sample from logits
    # 3. Append token and continue
    pass

Validation Checkpoint:

  • Can generate coherent text with JAX model
  • Output quality matches PyTorch baseline
  • Basic timing measurements done (will optimize later)

Claude's Role: Review generation loop logic, help with JAX array manipulation.


Phase 3: INT8 Quantization (COMPRESSED)

Duration: 2 DAYS (Days 8-9) Goal: Implement basic INT8 quantization (simpler approach) Strategy: Use symmetric quantization only, skip calibration complexity

Learning Objectives

  1. Quantization Theory

    • Symmetric vs asymmetric quantization
    • Per-tensor vs per-channel quantization
    • Calibration and scale computation
  2. JAX Quantization Patterns

    • Simulated quantization (fake quantization)
    • Custom quantized operations
  3. Resources

Implementation Tasks

Task 3.1: Simple Quantization (SKIP CALIBRATION)

Your Action: SIMPLIFIED - Skip calibration initially

FAST APPROACH: Weight-only quantization (simpler than full calibration)

# Just quantize weights based on their min/max values
# Skip activation calibration to save time
# This still gives ~4x memory reduction

Validation Checkpoint:

  • Can quantize model weights
  • Memory reduced significantly

Claude's Role: Provide simple weight quantization code (5-10 lines).

Task 3.2: Implement Simple Quantization

Your Action: Create src/quantization.py - MINIMAL version

Claude provides starter code (you fill in gaps):

import jax.numpy as jnp

def quantize_weights(weights: jnp.ndarray) -> tuple[jnp.ndarray, float]:
    """Simple symmetric INT8 quantization"""
    scale = jnp.abs(weights).max() / 127.0
    quantized = jnp.round(weights / scale).clip(-128, 127).astype(jnp.int8)
    return quantized, scale

def dequantize_weights(quantized: jnp.ndarray, scale: float) -> jnp.ndarray:
    """Dequantize back to FP32"""
    return quantized.astype(jnp.float32) * scale

def quantize_model_params(params: dict) -> tuple[dict, dict]:
    """Quantize all model parameters"""
    # YOU implement: Loop through params PyTree
    # Claude provides tree traversal template if needed
    pass

Validation Checkpoint:

  • Can quantize/dequantize weights
  • Model still runs (quality checked later)
  • Tested on GPT-2: 326MB → 163MB (2.00x reduction)
  • Tested on Mistral-7B: 14.48GB → 7.24GB (2.00x reduction)

TODO - FUTURE WORK:

  • Implement checkpoint saving for quantized models
    • Must save both quantized_params AND scales dictionary
    • Save scales as JSON alongside checkpoint
    • Add validation tests for save/load round-trip
    • Consider pickle vs Orbax trade-offs

Claude's Role: Provide nearly complete code, you adapt for model structure.

Task 3.3: Quantization-Aware Inference

Your Action: Modify src/generation.py

Add quantized inference path:

def generate_text_quantized(
    quantized_params: dict,
    scales: dict,
    zero_points: dict,
    input_ids: jnp.ndarray,
    max_length: int = 100
) -> jnp.ndarray:
    """Generate with quantized weights (dequantized on-the-fly)"""
    # YOUR CODE
    pass

Validation Checkpoint:

  • Quantized inference works
  • Memory usage reduced (measure with JAX memory profiler)
  • Output quality degradation measured (ROUGE scores)

Claude's Role: Help profile memory usage, suggest optimization strategies.

Task 3.4: Quality Evaluation

Your Action: Create notebooks/03_quantization_analysis.ipynb

Compare outputs:

# YOUR CODE:
# 1. Run FP32 model on 100 test prompts
# 2. Run INT8 model on same prompts
# 3. Compute ROUGE-L scores between outputs
# 4. Identify: accuracy loss, memory savings
# 5. Visualize: quality vs compression trade-off

Target Metrics:

  • Memory: 28GB → ~7GB (4x reduction)
  • Quality: ROUGE-L > 0.95 (95%+ similarity)

Validation Checkpoint:

  • Quality metrics computed
  • Memory savings verified
  • Decision made: acceptable accuracy loss or need mixed precision?

Claude's Role: Help interpret ROUGE scores, suggest mixed-precision strategies if needed.


Phase 4: Optimization (KV-Cache & JIT) (COMPRESSED)

Duration: 5 DAYS (Days 10-14) Goal: Implement KV-cache and JIT (these are the CRITICAL optimizations) Strategy: Focus on working implementation, optimize later if needed

Learning Objectives

  1. KV-Cache Mechanism

    • Why autoregressive generation is slow without caching
    • How to structure cache as PyTree
    • Cache initialization and updates
  2. JIT Compilation

    • What operations benefit from JIT
    • How to write JIT-friendly code (pure functions)
    • Debugging JIT compilation issues
  3. Resources

Implementation Tasks

Task 4.1: Implement KV-Cache Structure

Your Action: Modify src/generation.py

Implement cache:

from typing import NamedTuple

class KVCache(NamedTuple):
    """KV-cache for single layer"""
    k: jnp.ndarray  # [batch, num_heads, seq_len, head_dim]
    v: jnp.ndarray

def init_kv_cache(
    batch_size: int,
    num_layers: int,
    num_kv_heads: int,
    head_dim: int,
    max_seq_len: int
) -> dict:
    """Initialize empty KV-cache for all layers"""
    # YOUR CODE
    pass

def update_kv_cache(
    cache: dict,
    layer_idx: int,
    new_k: jnp.ndarray,
    new_v: jnp.ndarray,
    position: int
) -> dict:
    """Update cache with new key/value at position"""
    # YOUR CODE
    # Use JAX array updates: cache['k'].at[position].set(new_k)
    pass

Validation Checkpoint:

  • Cache initialization works
  • Cache updates correctly
  • Cache structure is JAX PyTree (can use with jax.tree_map)

Claude's Role: Review cache structure, help with JAX array update syntax.

Task 4.2: Modify Attention for KV-Cache

Your Action: Modify attention mechanism

Implement cached attention:

def attention_with_cache(
    query: jnp.ndarray,
    key: jnp.ndarray,
    value: jnp.ndarray,
    cache: KVCache,
    position: int,
    use_cache: bool = True
) -> tuple[jnp.ndarray, KVCache]:
    """
    Attention with KV-cache support

    Args:
        query: Current query [batch, num_heads, 1, head_dim]
        key: Current key [batch, num_heads, 1, head_dim]
        value: Current value [batch, num_heads, 1, head_dim]
        cache: Previous K/V values
        position: Current position in sequence

    Returns:
        attention_output, updated_cache
    """
    # YOUR CODE:
    # 1. Concatenate cached_k + new_k
    # 2. Concatenate cached_v + new_v
    # 3. Compute attention(query, full_k, full_v)
    # 4. Update cache
    # 5. Return output + new_cache
    pass

Validation Checkpoint:

  • Cached attention produces same output as non-cached
  • Cache grows correctly with each token
  • No memory leaks (cache size bounded)

Claude's Role: Debug attention logic, help with cache concatenation.

Task 4.3: Implement Cached Generation Loop

Your Action: Create optimized generation function

def generate_with_cache(
    params: dict,
    input_ids: jnp.ndarray,
    max_length: int = 100,
    temperature: float = 1.0
) -> jnp.ndarray:
    """Optimized generation with KV-cache"""
    # YOUR CODE:
    # 1. Initialize cache
    # 2. Process prompt tokens (prefill phase)
    # 3. Generate new tokens (decode phase with cache)
    # 4. Return generated sequence
    pass

Validation Checkpoint:

  • Cached generation produces same output as non-cached
  • Speedup measured: should be ~5-10x faster
  • Memory usage increased but acceptable (cache overhead)

Claude's Role: Help optimize generation loop, suggest performance improvements.

Task 4.4: Apply JIT Compilation (PRAGMATIC)

Your Action: JIT-ify critical functions - START SIMPLE

# Start with simple JIT wrapper
@jax.jit
def forward_pass(params, input_ids, cache):
    """JIT the forward pass"""
    # YOUR CODE - keep it simple first
    pass

# If scan is too complex, use regular loop initially
# Optimize with scan only if needed for speed

PRAGMATIC APPROACH:

  • JIT the model forward pass first (easiest)
  • Use regular Python loop for generation initially
  • Convert to jax.lax.scan ONLY if speed insufficient

Validation Checkpoint:

  • JIT compilation works (no errors)
  • Generation is faster than without JIT
  • Speed targets met (if not, optimize further)

Claude's Role: Provide JIT-ready code structure, help debug tracing errors quickly.

Task 4.5: Benchmark Optimized Model

Your Action: Create notebooks/04_optimization_results.ipynb

Compare:

  1. PyTorch baseline (no optimization)
  2. JAX + INT8 (no cache, no JIT)
  3. JAX + INT8 + KV-cache (no JIT)
  4. JAX + INT8 + KV-cache + JIT (fully optimized)

Measure for each:

  • Tokens/sec
  • Latency (p50, p95, p99)
  • Memory usage
  • Time-to-first-token

Expected Results:

  • 2-3x speedup overall
  • <1s latency for typical queries
  • ~7GB memory (vs 28GB baseline)

Validation Checkpoint:

  • All configurations benchmarked
  • Speedup targets achieved (2-3x)
  • Memory targets achieved (4x reduction)
  • Results documented

Claude's Role: Help interpret benchmark results, suggest further optimizations if targets not met.


Phase 5: Benchmarking & Analysis (COMPRESSED)

Duration: 3-5 DAYS (Days 15-19, buffer 20-21) Goal: Essential benchmarks + good visualizations Strategy: Focus on key metrics, skip exhaustive testing

Learning Objectives

  1. Benchmarking Methodology

    • Statistical rigor (mean, std, percentiles)
    • Fair comparison practices
    • Reproducibility
  2. ML Evaluation Metrics

    • ROUGE scores, BERTScore
    • Perplexity (optional)
    • Human evaluation considerations

Implementation Tasks

Task 5.1: Simple Benchmark Script (MINIMAL)

Your Action: Create src/benchmarking.py - ESSENTIAL ONLY

MINIMAL VERSION (Claude provides template):

import time
import json

def benchmark_latency(model_fn, prompts, num_runs=50):  # Reduced from 100
    """Basic latency measurement"""
    times = []
    for prompt in prompts:
        start = time.time()
        output = model_fn(prompt)
        times.append(time.time() - start)
    return {
        "mean": np.mean(times),
        "p50": np.percentile(times, 50),
        "p95": np.percentile(times, 95)
    }

# YOU add: memory profiling, ROUGE scoring (basic)
# SKIP: Exhaustive stats, t-tests, multiple batch sizes

Validation Checkpoint:

  • Can measure latency and memory
  • Can compute basic ROUGE scores
  • Results saved to JSON

Claude's Role: Provide working benchmark template (80% complete).

Task 5.2: Run Benchmarks (TWO-STAGE APPROACH)

Stage 1: Local Development Benchmarks (Day 16) Your Action: Quick validation with small sample on local machine

# Quick validation on 4-bit quantized model (local machine)
python -m src.benchmarking \
    --num-samples 100 \
    --model-type quantized \
    --output benchmarks/results/dev_results.json

Purpose: Validate your code works, see preliminary results


Stage 2: Final Colab Benchmarks (Day 17 - OVERNIGHT RUN) Your Action: Create notebooks/06_colab_final_benchmark.ipynb

This notebook will:

  1. Load full Mistral-7B FP16 (no quantization for baseline)
  2. Load your optimized JAX model
  3. Run on 1000 Alpaca samples (takes 2-4 hours)
  4. Save comprehensive results
# In Colab notebook (with T4/V100 GPU)
# Set to use full dataset
NUM_SAMPLES = 1000  # Full Alpaca evaluation

# Run both models
pytorch_results = benchmark_pytorch(model_pt, alpaca_dataset, NUM_SAMPLES)
jax_results = benchmark_jax(model_jax, alpaca_dataset, NUM_SAMPLES)

# Save results
save_results("pytorch_baseline_full.json", pytorch_results)
save_results("jax_optimized_full.json", jax_results)

Strategy:

  • Start the Colab run in the evening (Day 17)
  • Let it run overnight (~3-4 hours)
  • Download results next morning (Day 18)
  • No time wasted waiting!

Colab Setup Tips:

  • Use Colab Pro if available (no timeouts)
  • Or split into batches if using free tier
  • Mount Google Drive to save results automatically
  • Use TPU for JAX model (even faster!)

Validation Checkpoint:

  • Local 100-sample benchmark works (Day 16)
  • Colab notebook created and tested (Day 17)
  • Full 1000-sample results obtained (Day 18 morning)
  • Results show 2-3x speedup, 4x memory reduction

Claude's Role:

  • Provide Colab notebook template with GPU/TPU setup
  • Help debug any Colab-specific issues
  • Guide on how to download and analyze results

Task 5.3: Create Visualizations

Your Action: Create notebooks/05_results_visualization.ipynb

Create plots:

# YOUR CODE:
# 1. Load benchmark results (PyTorch vs JAX optimized)
# 2. Create comparison plots:
#    - Bar chart: Tokens/sec comparison
#    - Bar chart: Memory usage comparison
#    - Line plot: Latency distribution (CDF)
#    - Heatmap: Quality metrics (ROUGE scores by prompt length)
#    - Cost analysis: $ per 1M tokens
# 3. Save publication-quality figures

Validation Checkpoint:

  • All visualizations created
  • Figures saved to benchmarks/results/plots/
  • Results clearly show 2-3x speedup, 4x memory reduction

Claude's Role: Suggest visualization improvements, help with matplotlib/seaborn syntax.

Task 5.4: Write Final Demo Notebook

Your Action: Create notebooks/06_demo.ipynb

Create interactive demo:

# Demo should include:
# 1. Side-by-side comparison (PyTorch vs JAX)
# 2. Interactive prompt input
# 3. Real-time timing display
# 4. Memory usage monitoring
# 5. Output quality comparison
# 6. Summary statistics

Validation Checkpoint:

  • Demo notebook runs end-to-end
  • Results are impressive and clearly visualized
  • Notebook is well-documented (ready to share)

Claude's Role: Review notebook, suggest presentation improvements.

Task 5.5: Update README with Results

Your Action: Update README.md

Add:

  • Actual Results section (replace "Expected Results")
  • Screenshots/plots from benchmarks
  • Usage instructions
  • Installation verification steps
  • Link to demo notebook

Validation Checkpoint:

  • README reflects actual implementation
  • Results section shows real benchmarks
  • Project is presentation-ready

Claude's Role: Review README, suggest clarity improvements.


Success Criteria

Phase-by-Phase Checkpoints

Phase 1 - Foundation:

  • Environment setup complete, all dependencies installed
  • Can run JAX code, understand JIT basics
  • Project structure created

Phase 2 - Conversion:

  • PyTorch model converted to JAX
  • Numerical validation passes (outputs match within 1e-5)
  • Basic JAX inference works

Phase 3 - Quantization:

  • INT8 quantization implemented
  • Memory reduced from 28GB → ~7GB
  • Quality degradation < 5% (ROUGE-L > 0.95)

Phase 4 - Optimization:

  • KV-cache implemented correctly
  • JIT compilation applied
  • 2-3x speedup achieved (2.5s → 0.9s)

Phase 5 - Benchmarking:

  • Comprehensive benchmarks completed
  • Visualizations created
  • Project is presentation-ready

Final Project Goals (from README)

Metric Target Status
Tokens/sec 20-25 (2.5-3x) ⬜ To be measured
Latency <1s ⬜ To be measured
Memory ~7.5GB (4x reduction) ⬜ To be measured
Quality 98%+ (ROUGE-L) ⬜ To be measured
Cost 65% reduction ⬜ To be calculated

Claude Code's Role Throughout

What Claude WILL Do (ACCELERATED MODE):

✅ Provide substantial starter code/templates (60-80% complete) ✅ Write boilerplate and setup code to save time ✅ Help you find and use existing implementations ✅ Debug proactively and suggest fast solutions ✅ Provide working examples for complex patterns ✅ Review and optimize your code quickly ✅ Keep you moving toward the goal

What Claude WILL NOT Do:

❌ Write 100% of the code (you still implement key logic) ❌ Skip teaching the "why" behind optimizations ❌ Let you skip validation checkpoints ❌ Let you merge broken code

How to Work with Claude:

When Starting a New Phase:

  1. YOU: "I'm starting Phase X. Can you explain [concept] before I implement?"
  2. Claude: Provides explanation and resources
  3. YOU: Implement the code
  4. YOU: "Can you review my implementation of [function]?"
  5. Claude: Reviews, suggests improvements

When Stuck:

  1. YOU: "I'm getting error X when implementing Y. Here's my code..."
  2. Claude: Analyzes error, explains issue, suggests fix
  3. YOU: Apply fix and test
  4. YOU: "It works! Why did that fix it?"
  5. Claude: Explains the underlying reason

When Uncertain:

  1. YOU: "I'm not sure if I should use approach A or B for [task]"
  2. Claude: Explains trade-offs of each approach
  3. YOU: Make decision based on your goals
  4. Claude: Supports your choice with implementation guidance

Troubleshooting Guide

Common Issues & Solutions

Issue: GPU out of memory during model loading Solution:

  • Use smaller batch size
  • Try model sharding (load layers one at a time)
  • Use gradient checkpointing
  • Try on cloud GPU with more VRAM

Issue: JAX JIT compilation is extremely slow Solution:

  • Check for dynamic shapes (use static batch size)
  • Avoid Python loops inside JIT functions
  • Profile with jax.profiler to find bottleneck

Issue: Converted model outputs don't match PyTorch Solution:

  • Check weight transpose (PyTorch vs Flax conventions)
  • Verify attention mask is applied correctly
  • Check for missing bias terms
  • Test layer-by-layer to isolate issue

Issue: Quantized model has poor quality Solution:

  • Use per-channel instead of per-tensor quantization
  • Try mixed precision (keep attention in FP32)
  • Increase calibration dataset size
  • Consider quantization-aware training (advanced)

Issue: No speedup from KV-cache Solution:

  • Verify cache is actually being used (add logging)
  • Check cache concatenation is efficient (no copies)
  • Profile to find bottleneck (may be elsewhere)

Additional Resources

Documentation

Papers

Tutorials

Tools


Timeline & Milestones (COMPRESSED)

Week 1: Setup + Conversion Foundation

Days 1-2: Environment setup + JAX basics (Phase 1)

  • Quick JAX tutorial (4 hours)
  • Environment setup and structure creation (2 hours)
  • Basic JAX operations and JIT understanding (2 hours)

Days 3-5: Model conversion development with GPT-2 (Phase 2 begins)

  • Baseline PyTorch benchmarking + load_pytorch_model() (Day 3)
  • Implement convert_pytorch_to_jax() - weight transposition logic (Day 4)
  • Implement build_flax_pytree() + load_flax_model_with_params() (Day 5)
  • Test all functions locally with GPT-2

Days 6-7: Mistral conversion on Colab + validation

  • Run conversion on Mistral-7B in Colab (Day 6 morning)
  • Save and download converted Mistral params (Day 6 afternoon)
  • Numerical validation tests (Day 7)
  • Basic JAX inference with Mistral working (Day 7)

Checkpoint Week 1: JAX model running, outputs match PyTorch


Week 2: Quantization + Optimization

Days 8-9: Quantization (Phase 3 - compressed)

  • Calibration data prep (Day 8 morning)
  • Implement quantization functions (Day 8 afternoon)
  • Quantized inference + evaluation (Day 9)

Days 10-12: KV-Cache implementation (Phase 4 begins)

  • Design and implement KV-cache structure (Day 10)
  • Modify attention mechanism for caching (Day 11)
  • Test cached generation (Day 12)

Days 13-14: JIT compilation

  • Apply JIT to generation loop (Day 13)
  • Debug and optimize JIT performance (Day 14)
  • Measure all optimizations combined

Checkpoint Week 2: INT8 + KV-cache + JIT working, hitting speed targets


Week 3: Benchmarking + Polish

Days 15-17: Comprehensive evaluation (Phase 5)

  • Build benchmark suite with configurable sample size (Day 15)
  • Local validation: 100-sample benchmark on laptop (Day 16)
  • Colab setup: Launch 1000-sample overnight run (Day 17 evening)

Days 18-19: Analysis + Documentation

  • Download Colab results, create visualizations (Day 18 morning)
  • Final demo notebook with impressive results (Day 18 afternoon)
  • Update README with actual metrics (Day 19)
  • Polish and prepare for presentation

Days 20-21: BUFFER for issues/refinements

Final Checkpoint: Project complete, presentation-ready


Total Duration: 2-3 weeks (aggressive but achievable) Key Success Factor: Work 3-4 hours/day consistently


Next Steps

Ready to start? Here's what to do now:

  1. Review this plan - Make sure you understand the overall structure
  2. Check prerequisites - Ensure you have the required knowledge
  3. Set up environment - Start with Phase 1, Task 1.1
  4. Ask questions - If anything is unclear, ask Claude before proceeding

When ready to begin Phase 1:

YOU: "I'm ready to start Phase 1. Can you help me create the requirements.txt file?"

Project Philosophy

"The goal is not just to build a faster model, but to deeply understand WHY these optimizations work. You'll learn by doing, make mistakes, debug issues, and gain intuition that will serve you in future ML projects."

Good luck on your learning journey! 🚀



Key Differences from 7-8 Week Plan

Aspect Original Plan Compressed Plan
Timeline 7-8 weeks 2-3 weeks
JAX Learning Deep dive (1 week) Crash course (2 days)
Conversion Manual implementation Manual (GPT-2 locally → Mistral on Colab)
Quantization Full calibration Weight-only (simpler)
Testing Comprehensive Essential only
Dataset Size 1000 samples 100 dev + 1000 final (Colab)
Claude's Role Reviewer/guide Active contributor + explainer
Daily Time 1-2 hours 3-4 hours required

Success Tips for Fast Execution

  1. Work consistently: 3-4 hours/day, no skipping days
  2. Don't overthink: Use existing implementations when possible
  3. Ask Claude early: Don't spend 2 hours stuck on something
  4. Parallelize: While code runs, work on documentation
  5. Use cloud GPU: Don't wait for local setup issues
  6. Skip perfection: Working > perfect for first iteration
  7. Leverage templates: Claude provides more scaffolding here
  8. Overnight runs: Use Colab for long benchmarks while you sleep

Benchmarking Strategy Summary

Development Phase (Days 1-16): Local + Small Samples

  • Hardware: Your RTX 4070 (8GB VRAM) + 4-bit quantization
  • Dataset: 10-100 samples for quick iteration
  • Purpose: Fast development, debugging, validation
  • Time per run: 5-10 minutes
  • Cost: $0 (local machine)

Final Evaluation (Day 17-18): Colab + Full Dataset

  • Hardware: Colab T4/V100 GPU (16GB VRAM) or TPU
  • Dataset: Full 1000 Alpaca samples
  • Purpose: Publication-quality results for portfolio
  • Time: 3-4 hours (overnight run)
  • Cost: Free (Colab free tier) or $10/month (Colab Pro - recommended)

Why This Works

✅ Fast local iteration (don't wait hours for results during dev) ✅ Professional-grade final metrics (1000 samples, full model) ✅ Cost-effective (only use Colab when needed) ✅ Timeline-friendly (overnight run doesn't block your work) ✅ Best results (can use TPU for JAX, which is faster than GPU!)

Colab TPU Advantage for JAX

JAX models can run on Colab TPU (not available locally):

  • GPU: Good performance (~20-30 tokens/sec)
  • TPU: Excellent performance (~40-60 tokens/sec) - JAX is optimized for TPU!
  • This makes your final JAX results even more impressive

Final deliverable: Side-by-side comparison showing massive speedup on real hardware with comprehensive 1000-sample evaluation!


Document Version: 2.0 (COMPRESSED TIMELINE) Last Updated: 2025-10-29 Status: Ready for FAST Phase 1 (2-3 week version)