JAX/Flax Pallas kernels for Gated Linear Attention (GLA).
Install directly from the repository:
pip install git+https://github.com/primatrix/pallas-kernel.gitOr build and install locally:
git clone https://github.com/primatrix/pallas-kernel.git
cd pallas-kernel
pip install .For GPU support (CUDA 12):
pip install "tops[gpu] @ git+https://github.com/primatrix/pallas-kernel.git"For TPU support:
pip install "tops[tpu] @ git+https://github.com/primatrix/pallas-kernel.git"For development:
pip install -e ".[dev]"Use the provided build script to create distributable packages:
./scripts/build.sh # Build sdist and wheel into dist/
./scripts/build.sh clean # Remove build artifactsfrom tops.ops.gla import chunk_gla, fused_recurrent_gla, fused_chunk_gla
from tops.layers.gla import GatedLinearAttention
from tops.modules.layernorm import RMSNormApache License 2.0