Skip to content

datthinh1801/CKA-pytorch

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

37 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Centered Kernel Alignment (CKA) - PyTorch Implementation

A PyTorch implementation of Centered Kernel Alignment (CKA) with GPU support for fast and efficient computation.

Warning

This project is for educational and academic purposes (and for fun 🤷🏻).

Features

  • GPU Accelerated: Leverages the power of GPUs for significantly faster CKA calculations compared to NumPy-based implementations.
  • On-the-Fly Calculation: Computes CKA on-the-fly using mini-batches, avoiding the need to cache large intermediate feature representations.
  • Easy to Use: Simple and intuitive API for calculating the CKA matrix between two models.
  • Flexible: Can be used with any PyTorch models and dataloaders.

Installation

pip install cka-pytorch

Usage

import torch

from torchvision.models import resnet18
from torch.utils.data import DataLoader

from cka_pytorch.cka import CKACalculator


# 1. Define your models and dataloader
model1 = resnet18(pretrained=True).cuda()
model2 = resnet18(pretrained=True).cuda() # Or a different model

# Create a dummy dataloader for demonstration
dummy_data = torch.randn(100, 3, 224, 224)
dummy_labels = torch.randint(0, 10, (100,))
dummy_dataset = torch.utils.data.TensorDataset(dummy_data, dummy_labels)
dataloader = DataLoader(dummy_dataset, batch_size=32)

# 2. Initialize the CKACalculator
# By default, we will calculate CKA across all layers of the two models
calculator = CKACalculator(
    model1=model1,
    model2=model2,
    model1_name="ResNet18",
    model2_name="ResNet18",
    batched_feature_size=256,
    verbose=True,
)

# 3. Calculate the CKA matrix
cka_matrix = calculator.calculate_cka_matrix(dataloader)

# 4. Plot the CKA Matrix as heatmap
calculator.plot_cka_matrix(title="CKA between ResNet18 and ResNet18")

Contributing

  • If you find this repository helpful, please give it a ⭐.
  • If you encounter any bugs or have suggestions for improvements, feel free to open an issue.
  • This implementation has been primarily tested with ResNet architectures.

Acknowledgement

This project is based on:

About

A Centered Kernel Alignment implementation with distributed GPU support.

Topics

Resources

License

Stars

Watchers

Forks

Packages

 
 
 

Contributors