Optimizing Mistral-7B inference for 3x faster response times using JAX.
This project demonstrates production-level optimization of large language models by converting Mistral-7B from PyTorch to JAX and applying quantization, efficient caching, and JIT compilation. The goal is to achieve 2-3x speedup while maintaining output quality.
- Speed: Reduce response time from ~2.5s to <1s per query (2-3x speedup target, achieved 16x on GPT-2!)
- Memory: Decrease GPU memory usage from ~14GB (FP16) to ~7GB (INT8) via quantization
- Quality: Maintain 98%+ output similarity to original model
- Cost: Lower inference costs through faster generation and reduced memory
- JAX: For automatic differentiation and XLA compilation
- Flax: Neural network library built on JAX
- Transformers: Hugging Face library for model loading
- jax.jit: Just-in-time compilation for performance
- Optax: For any additional fine-tuning (optional)
- datasets: For loading evaluation datasets (Alpaca)
- rouge-score: For output quality evaluation
- matplotlib/seaborn: For performance visualizations
- Model Conversion: PyTorch → JAX/Flax for XLA optimization
- INT8 Quantization: 2x memory reduction on model weights (FP32 → INT8)
- KV-Cache: Avoid redundant computation during generation
- JIT Compilation: Fuse operations for faster execution
- Batched Inference: Process multiple requests efficiently
Evaluation on 1,000 instructions from the Alpaca dataset measuring:
- Tokens per second
- Latency (p50, p95, p99)
- Memory usage
- Output quality (ROUGE scores, exact match rate)
- Cost per 1M tokens
| Metric | Non-Cached | Cached + JIT | Improvement |
|---|---|---|---|
| Tokens/sec | 1.50 | 24.45 | 16.32x faster |
| Time (15 tokens) | 10.03s | 0.63s | 16.32x faster |
| Model Memory (INT8) | 163MB | 163MB | Same (quantized) |
| KV-Cache Overhead | - | ~38MB | Minimal |
| Output Match | Identical | Identical | Perfect |
| Quality | Correct text | Correct text | 100% |
Memory Note: INT8 quantization reduces GPT-2 from 326MB (FP32) to 163MB (INT8) = 2.00x reduction. KV-cache adds ~38MB overhead for storing attention keys/values.
Optimization Breakdown:
- KV-Cache + JIT combined: 16.32x speedup (measured)
- Note: Cannot measure KV-cache vs JIT separately with decorator approach
Status: Phase 4 (KV-Cache + JIT) COMPLETE for GPT-2!
- Achieved 16.32x speedup on GPT-2 (far exceeds 2-3x target!)
- Fixed critical bug in attention output projection
- Applied JIT compilation to 8 core functions with
@jax.jitdecorators - Text generation quality verified (identical to PyTorch GPT-2)
- All benchmarks passing on GPT-2
- Performance measured after proper JIT warmup (all sequence lengths pre-compiled)
Important: JAX JIT compiles separately for each input shape. Warmup must generate at least as many tokens as the actual test to ensure all shapes are pre-compiled for accurate benchmarking.
Note: All optimizations tested and validated on GPT-2. Mistral-7B implementation pending.
Phase 4 Complete! Ready for Phase 5: Mistral-7B Implementation & Benchmarking.
| Metric | Baseline (PyTorch FP16) | Target (JAX INT8) | Status |
|---|---|---|---|
| Tokens/sec | 8-10 | 20-25 (2.5-3x faster) | ⏳ Not Started |
| Latency | 2.5s | <1s | ⏳ Not Started |
| Memory | ~14GB | ~7GB (2x reduction) | ⏳ Not Started |
| Quality | 100% | 98%+ | ⏳ Not Started |
Note: Mistral-7B support not yet implemented. Based on GPT-2 results (16.32x speedup), we expect similar or better performance for Mistral-7B when implemented.
.
├── src/
│ ├── model_conversion.py # PyTorch to JAX conversion
│ ├── quantization.py # INT8 quantization
│ ├── cached_generation.py # Optimized generation with KV-cache + JIT
│ └── kv_cache.py # KV-cache utilities
├── tests/
│ ├── test_generation.py # Generation quality tests
│ └── test_optimization_comparison.py # Performance benchmarks
├── OPTIMIZATION_GUIDE.md # Detailed optimization documentation
├── BENCHMARKING.md # Benchmarking best practices
├── Execution_Plan.md # Project roadmap and progress
└── README.md
# Create virtual environment
python -m venv venv
source venv/bin/activate
# Install dependencies
pip install jax[cuda12] flax transformers datasets rouge-score matplotlib seaborn