Skip to content

Latest commit

 

History

History
556 lines (428 loc) · 20.1 KB

File metadata and controls

556 lines (428 loc) · 20.1 KB

Efficient Attention in Vision Transformers

This project implements Vision Transformers from scratch and provides a framework for comparing different attention mechanisms. The goal is to empirically evaluate the trade-offs between computational efficiency and model performance on image classification tasks, specifically comparing standard attention with efficient alternatives like Linformer, Performer, and Nyströmformer.

Current Implementation Status

Completed Features

  • Complete ViT Implementation: Full Vision Transformer implementation with patch embeddings, multi-head attention, and classification head
  • Optimized Attention: Faster multi-head attention with merged QKV projections
  • CIFAR-10 Training Pipeline: Complete training and evaluation system
  • Comprehensive Benchmarking: Detailed performance metrics (FLOPs, memory usage, inference latency)
  • Attention Visualization: Tools for visualizing attention maps and model behavior
  • Modular Architecture: Easy to extend with new attention mechanisms

🚧 In Progress / To Be Implemented

  • Linformer Attention: Linear attention with low-rank projections (O(n) complexity)
  • Performer Attention: Kernel-based linear attention using random features
  • Nyströmformer Attention: Matrix approximation for efficient attention
  • Hybrid Attention: Combining atrous (dilated) attention with efficient mechanisms
  • ImageNet-100 Dataset: Scaling up from CIFAR-10 to more complex dataset
  • Comparative Analysis: Head-to-head efficiency vs. accuracy trade-offson in Vision Transformers

A comprehensive implementation and comparative study of efficient attention mechanisms for Vision Transformers (ViTs). This project explores various attention optimization techniques including standard multi-head attention, optimized implementations, and future extensions for linear attention methods like Linformer, Performer, and Nyströmformer.

CS 5787 – Deep Learning
Authors: Pranav Dhingra, Shashank Ramachandran
NetIDs: pd453, sr2433

Project Overview

This project implements Vision Transformers from scratch and provides a framework for comparing different attention mechanisms. The goal is to empirically evaluate the trade-offs between computational efficiency and model performance on image classification tasks.

Key Features

  • Complete ViT Implementation: Full Vision Transformer implementation with patch embeddings, multi-head attention, and classification head
  • Optimized Attention: Faster multi-head attention with merged QKV projections
  • Comprehensive Benchmarking: Training pipeline with detailed performance metrics (FLOPs, memory usage, inference latency)
  • Attention Visualization: Tools for visualizing attention maps and model behavior
  • Modular Architecture: Easy to extend with new attention mechanisms

📁 Project Structure

efficient-attention-vit/
├── VIT/code/                    # Core implementation
│   ├── vit.py                  # Vision Transformer models
│   ├── train.py                # Training pipeline and trainer class
│   ├── data.py                 # CIFAR-10 data loading and preprocessing
│   └── utils.py                # Utility functions and evaluation metrics
├── Literature-Review/           # Research papers and documentation
│   ├── How the code works.pdf
│   └── Image-is-worth-16words.pdf
├── data/                       # Dataset storage (created automatically)
├── experiments/                # Saved models and training logs
├── results/                    # Experiment results and summaries
├── proposal.md                 # Project proposal
├── plan.md                     # Implementation plan
├── requirements.txt            # Python dependencies
└── README.md                   # This file

🚀 Quick Start

Installation

  1. Clone the repository

    git clone https://github.com/pcatattacks/efficient-attention-vit.git
    cd efficient-attention-vit
  2. Install dependencies

    pip install -r requirements.txt
  3. Install optional dependencies (for FLOPs computation)

    pip install ptflops pandas

Basic Usage

Train a Vision Transformer on CIFAR-10:

cd VIT/code
python train.py --exp-name "vit_baseline" --batch-size 256 --epochs 100 --lr 1e-2

Train with different configurations:

# Quick test run
python train.py --exp-name "quick_test" --batch-size 64 --epochs 10 --lr 1e-3

# High-performance run
python train.py --exp-name "vit_large" --batch-size 512 --epochs 200 --lr 5e-3

Advanced Usage

Using the models programmatically:

from VIT.code.vit import ViTForClassfication
from VIT.code.data import prepare_data
import torch

# Configure the model
config = {
    "patch_size": 4,
    "hidden_size": 48,
    "num_hidden_layers": 4,
    "num_attention_heads": 4,
    "intermediate_size": 192,
    "hidden_dropout_prob": 0.0,
    "attention_probs_dropout_prob": 0.0,
    "initializer_range": 0.02,
    "image_size": 32,
    "num_classes": 10,
    "num_channels": 3,
    "qkv_bias": True,
    "use_faster_attention": True  # Enable optimized attention
}

# Create model
model = ViTForClassfication(config)

# Load data
trainloader, testloader, classes = prepare_data(batch_size=256)

# Forward pass
for batch in trainloader:
    images, labels = batch
    logits, attention_maps = model(images, output_attentions=True)
    break

🏗️ Architecture Details

Vision Transformer Components

  1. Patch Embeddings: Converts 32×32 images into 8×8 patches (with patch_size=4)
  2. Position Embeddings: Learnable position encodings for spatial awareness
  3. Multi-Head Attention: Standard or optimized attention mechanisms
  4. Feed-Forward Network: MLP blocks with GELU activation
  5. Classification Head: Linear layer for CIFAR-10 classification

Model Configurations

Component Standard Optimized
Attention Separate Q, K, V projections Merged QKV projection
Memory Usage Higher Lower
Speed Slower Faster
Accuracy Baseline Comparable

Default Configuration (CIFAR-10)

config = {
    "patch_size": 4,           # 32×32 → 8×8 patches
    "hidden_size": 48,         # Model dimension
    "num_hidden_layers": 4,    # Transformer blocks
    "num_attention_heads": 4,  # Attention heads
    "intermediate_size": 192,  # FFN dimension (4×hidden_size)
    "image_size": 32,          # CIFAR-10 image size
    "num_classes": 10,         # CIFAR-10 classes
    "num_channels": 3,         # RGB channels
    "qkv_bias": True,          # Bias in attention projections
    "use_faster_attention": True  # Enable optimization
}

📊 Evaluation Metrics

The framework automatically tracks comprehensive performance metrics:

Accuracy Metrics

  • Top-1 Accuracy: Primary classification accuracy
  • Top-5 Accuracy: Top-5 classification accuracy

Efficiency Metrics

  • Parameter Count: Total trainable parameters
  • FLOPs/MACs: Floating-point operations (requires ptflops)
  • Peak Memory Usage: GPU memory consumption during training
  • Inference Latency: Average forward pass time per image
  • Training Time: Time per epoch and total training time

Example Output

Final metrics for vit_baseline:
  Params: 42,826
  FLOPs (MACs): 1.234e+07
  Inference latency: 2.145 ± 0.123 ms / image
  Final Top-1 Accuracy: 0.8234
  Final Top-5 Accuracy: 0.9567

🎨 Visualization Features

Attention Map Visualization

from VIT.code.utils import visualize_attention

# Load trained model
model = load_trained_model("experiments/vit_baseline/model_final.pt")

# Visualize attention patterns
visualize_attention(model, output="attention_maps.png", device="cuda")

Dataset Visualization

from VIT.code.utils import visualize_images

# Display sample CIFAR-10 images
visualize_images()

🔧 Training Pipeline

Command Line Interface

python train.py [OPTIONS]

Options:
  --exp-name TEXT          Experiment name (required)
  --batch-size INTEGER     Batch size [default: 256]
  --epochs INTEGER         Number of epochs [default: 100]
  --lr FLOAT              Learning rate [default: 0.01]
  --device TEXT           Device (cuda/cpu) [default: auto-detect]
  --save-model-every INT  Save checkpoints every N epochs [default: 0]
  --output-dir TEXT       Output directory [default: outputs]

Trainer Class

The Trainer class provides a clean interface for model training:

from VIT.code.train import Trainer

trainer = Trainer(model, optimizer, loss_fn, exp_name, device)
trainer.train(trainloader, testloader, epochs=100)

Automatic Experiment Tracking

  • Model Checkpoints: Saved in experiments/{exp_name}/
  • Training Logs: JSON format with all metrics
  • Configuration: Model config saved for reproducibility
  • Summary DataFrames: CSV summaries for easy comparison

📈 Results and Analysis

Performance Benchmarks

Model Variant Params FLOPs Top-1 Acc Inference (ms)
Standard ViT 42.8K 12.3M 82.3% 2.15 ± 0.12
Optimized ViT 42.8K 12.3M 82.1% 1.87 ± 0.08

Results on CIFAR-10 with 100 epochs of training

Attention Pattern Analysis

The visualization tools reveal that the model learns to:

  • Focus on object boundaries and distinctive features
  • Develop hierarchical attention patterns across layers
  • Adapt attention based on object complexity

🔬 Research Implementation Plan

Based on our project proposal, the following efficient attention mechanisms need to be implemented and compared:

🎯 Core Research Objectives

  1. Empirical Comparison: Compare standard ViT attention with efficient variants on computational cost vs. accuracy
  2. Scalability Analysis: Test how each mechanism scales with input resolution (CIFAR-10 → ImageNet-100)
  3. Hybrid Innovation: Develop novel hybrid attention combining dilated/sparse patterns with linear attention

📋 Implementation Roadmap

Phase 1: Efficient Attention Mechanisms ⏳

# Target implementations needed in vit.py:

class LinformerAttention(nn.Module):
    """Linear attention with low-rank projections - O(n) complexity"""
    # Projects K,V to lower dimensional space
    # Reduces quadratic attention to linear

class PerformerAttention(nn.Module):
    """Kernel-based linear attention using FAVOR+ algorithm"""
    # Uses random feature approximation
    # Maintains accuracy while achieving linear complexity

class NystromformerAttention(nn.Module):
    """Nyström method for attention matrix approximation"""
    # Approximates attention matrix using landmark points
    # Balances efficiency and approximation quality

class HybridAttention(nn.Module):
    """Custom hybrid combining dilated attention with linear methods"""
    # Integrates atrous (dilated) patterns for local efficiency
    # Combines with global linear attention mechanisms

Phase 2: Dataset Scaling 📈

  • Current: CIFAR-10 (32×32, 8×8 patches)
  • Target: ImageNet-100 (224×224, 14×14 patches)
  • Challenge: Where efficiency gains become meaningful

Phase 3: Comprehensive Evaluation 📊

  • Metrics: Training time, inference latency, memory usage, FLOPs`
  • Analysis: Trade-off curves between efficiency and accuracy
  • Visualization: Attention pattern analysis across mechanisms

🔍 Research Questions to Answer

  1. Efficiency vs. Accuracy: Which method provides the best trade-off?
  2. Scalability: How do efficiency gains change with input resolution?
  3. Attention Patterns: Do efficient methods learn different visual representations?
  4. Hybrid Benefits: Can dilated attention improve upon linear methods?
  5. Practical Deployment: Which methods are viable for resource-constrained scenarios?

🛠️ Development & Implementation Guide

Current Implementation Status

What's Working

  • Standard Vision Transformer with multi-head attention
  • Faster attention with merged QKV projections
  • CIFAR-10 training and evaluation pipeline
  • Comprehensive metrics collection and visualization

🔧 Next Development Steps

1. Implementing Efficient Attention Mechanisms

Each attention mechanism should follow this pattern in vit.py:

class LinformerAttention(nn.Module):
    """
    Linformer: Self-Attention with Linear Complexity
    Key insight: Project K,V to lower dimensional space (n×k instead of n×n)
    """
    def __init__(self, config):
        super().__init__()
        self.seq_len = (config["image_size"] // config["patch_size"]) ** 2 + 1  # +1 for CLS
        self.k = config.get("linformer_k", 64)  # Projection dimension
        # Standard Q projection
        self.query = nn.Linear(config["hidden_size"], config["hidden_size"])
        # Low-rank K,V projections
        self.key_proj = nn.Linear(self.seq_len, self.k)
        self.value_proj = nn.Linear(self.seq_len, self.k) 
        self.key = nn.Linear(config["hidden_size"], config["hidden_size"])
        self.value = nn.Linear(config["hidden_size"], config["hidden_size"])
        
    def forward(self, x, output_attentions=False):
        # Q: (batch, seq_len, hidden) -> (batch, seq_len, hidden)
        # K,V: (batch, seq_len, hidden) -> (batch, k, hidden) via projection
        # Attention: (batch, seq_len, hidden) @ (batch, hidden, k) = (batch, seq_len, k)
        pass  # Implementation needed

class PerformerAttention(nn.Module):
    """
    Performer: Rethinking Attention with Performers
    Key insight: Approximate softmax attention using random features
    """
    def __init__(self, config):
        super().__init__()
        self.num_features = config.get("performer_features", 64)
        # Random feature matrix for kernel approximation
        self.register_buffer("random_features", 
                           torch.randn(config["hidden_size"], self.num_features))
        
    def forward(self, x, output_attentions=False):
        # Use FAVOR+ algorithm for kernel approximation
        # φ(q)^T φ(k) ≈ exp(q^T k / √d) via random features
        pass  # Implementation needed

class NystromformerAttention(nn.Module):
    """
    Nyströmformer: Nyström method for approximating attention
    Key insight: Use landmark points to approximate full attention matrix
    """
    def __init__(self, config):
        super().__init__()
        self.num_landmarks = config.get("nystrom_landmarks", 32)
        
    def forward(self, x, output_attentions=False):
        # Select landmark points and approximate attention matrix
        # A ≈ A[:,L] @ pinv(A[L,L]) @ A[L,:]
        pass  # Implementation needed

2. Update Block Class for Attention Selection

# In Block.__init__(), add mechanism selection:
class Block(nn.Module):
    def __init__(self, config):
        super().__init__()
        attention_type = config.get("attention_type", "standard")
        
        if attention_type == "linformer":
            self.attention = LinformerAttention(config)
        elif attention_type == "performer":
            self.attention = PerformerAttention(config)
        elif attention_type == "nystromformer":
            self.attention = NystromformerAttention(config)
        elif attention_type == "hybrid":
            self.attention = HybridAttention(config)  # To be implemented
        elif config.get("use_faster_attention", False):
            self.attention = FasterMultiHeadAttention(config)
        else:
            self.attention = MultiHeadAttention(config)

3. Configuration Updates

Add to config dictionary:

config = {
    # Existing parameters...
    "attention_type": "standard",  # Options: standard, linformer, performer, nystromformer, hybrid
    "linformer_k": 64,            # Linformer projection dimension
    "performer_features": 64,      # Performer random features
    "nystrom_landmarks": 32,       # Nyströmformer landmark points
}

4. Testing Framework

# Test each attention mechanism:
python train.py --exp-name "test_linformer" --epochs 5 --batch-size 64
# Modify config in train.py to set attention_type = "linformer"

# Compare all mechanisms:
python scripts/compare_attention.py  # To be created

📚 References

Core Papers (From Proposal)

  1. An Image Is Worth 16×16 Words: Transformers for Image Recognition at Scale
    Dosovitskiy et al., ICLR 2021
    Status: Implemented as baseline ViT architecture

  2. Linformer: Self-Attention with Linear Complexity
    Wang et al., NeurIPS 2020
    🔄 Status: To be implemented - linear attention via low-rank projections

  3. Performer: Rethinking Attention with Performers
    Choromanski et al., ICLR 2021
    🔄 Status: To be implemented - FAVOR+ algorithm for kernel approximation

  4. Nyströmformer: A Nyström-Based Algorithm for Approximating Self-Attention
    Xiong et al., AAAI 2021
    🔄 Status: To be implemented - landmark-based matrix approximation

  5. Fusion of Regional and Sparse Attention in Vision Transformers
    Ibtehaz et al., 2024
    🔄 Status: To be implemented - inspiration for hybrid attention mechanism

Implementation Resources

  • Original ViT Paper: Foundation for our baseline implementation
  • Efficient Attention Survey: Tay et al., "Efficient Transformers: A Survey" (2020)
  • Linear Attention Methods: Katharopoulos et al., "Transformers are RNNs" (2020)

🤝 Contributing

  1. Fork the repository
  2. Create a feature branch (git checkout -b feature/amazing-attention)
  3. Commit your changes (git commit -m 'Add amazing attention mechanism')
  4. Push to the branch (git push origin feature/amazing-attention)
  5. Open a Pull Request

📄 License

This project is part of an academic research study. Please cite our work if you use this code in your research.

🐛 Troubleshooting

Common Issues

  1. CUDA out of memory: Reduce batch size or model dimensions
  2. Slow training: Enable use_faster_attention=True in config
  3. Import errors: Ensure all dependencies are installed
  4. Dataset download fails: Check internet connection and disk space

Performance Tips

  • Use use_faster_attention=True for better performance
  • Adjust batch size based on available GPU memory
  • Enable mixed precision training for faster convergence
  • Use multiple workers for data loading (num_workers > 0)

🎯 Immediate Next Steps (Implementation Priority)

Based on the project proposal, here's the implementation roadmap:

1. Implement Core Efficient Attention Mechanisms (High Priority)

# Files to modify:
- VIT/code/vit.py: Add LinformerAttention, PerformerAttention, NystromformerAttention
- VIT/code/train.py: Update config to support attention_type parameter

2. Scale to ImageNet-100 Dataset (Medium Priority)

# Files to create/modify:
- VIT/code/data.py: Add ImageNet-100 data loading
- Update image_size from 32 to 224, patch_size from 4 to 16

3. Implement Hybrid Attention Mechanism (Medium Priority)

class HybridAttention(nn.Module):
    """
    Combines dilated/atrous attention patterns with linear attention
    Inspired by "Fusion of Regional and Sparse Attention"
    """
    # Dilated convolution-like attention patterns
    # Combined with linear attention for global context

4. Comparative Evaluation Pipeline (High Priority)

# Files to create:
- scripts/compare_all_attention.py: Train all variants and compare
- scripts/generate_efficiency_plots.py: Create trade-off visualizations

5. Research Analysis (Final Phase)

  • Efficiency vs. accuracy trade-off curves
  • Attention pattern visualization comparisons
  • Scalability analysis (CIFAR-10 vs ImageNet-100)
  • Memory and computational cost analysis

📞 Contact

For questions about the implementation or research directions, feel free to open an issue or contact the authors directly.