Skip to content

primatrix/pallas-kernel

Repository files navigation

tops

JAX/Flax Pallas kernels for Gated Linear Attention (GLA).

Installation

Install directly from the repository:

pip install git+https://github.com/primatrix/pallas-kernel.git

Or build and install locally:

git clone https://github.com/primatrix/pallas-kernel.git
cd pallas-kernel
pip install .

Optional dependencies

For GPU support (CUDA 12):

pip install "tops[gpu] @ git+https://github.com/primatrix/pallas-kernel.git"

For TPU support:

pip install "tops[tpu] @ git+https://github.com/primatrix/pallas-kernel.git"

For development:

pip install -e ".[dev]"

Building packages

Use the provided build script to create distributable packages:

./scripts/build.sh        # Build sdist and wheel into dist/
./scripts/build.sh clean  # Remove build artifacts

Usage

from tops.ops.gla import chunk_gla, fused_recurrent_gla, fused_chunk_gla
from tops.layers.gla import GatedLinearAttention
from tops.modules.layernorm import RMSNorm

License

Apache License 2.0

About

A set of pallas kernels for learning and tutorials

Resources

License

Stars

Watchers

Forks

Packages

 
 
 

Contributors