Skip to content

feat(Part2): Fused Add + RMSNorm Triton kernel with MASE transform pass#313

Closed
Dorijan9 wants to merge 1 commit intoDeepWok:mainfrom
aahaidar01:feature/fused-rmsnorm-residual
Closed

feat(Part2): Fused Add + RMSNorm Triton kernel with MASE transform pass#313
Dorijan9 wants to merge 1 commit intoDeepWok:mainfrom
aahaidar01:feature/fused-rmsnorm-residual

Conversation

@Dorijan9
Copy link

Fused Add + RMSNorm Triton Kernel with MASE Transform Pass

Part 2 of the ADLS kernel-fusion-aware optimisation pipeline.

Problem

Every transformer decoder layer runs residual addition and RMSNorm as two separate CUDA kernels, writing the intermediate tensor to HBM and reading it back. For a 32-layer Llama model, that's 64 wasted memory round-trips per forward pass.

Solution

A hand-written Triton kernel fuses both operations into a single GPU kernel launch, plus an FX graph transform pass that automatically pattern-matches add → RMSNorm and swaps in the fused module.

Files

File What
src/chop/passes/graph/transforms/fused_rmsnorm/triton_fused_add_rmsnorm.py Triton fwd/bwd kernels, autograd function, nn.Module wrapper
src/chop/passes/graph/transforms/fused_rmsnorm/fused_rmsnorm_transform.py FX graph transform pass
src/chop/passes/graph/transforms/fused_rmsnorm/init.py Package exports
test/passes/graph/transforms/test_fused_add_rmsnorm.py Correctness tests + benchmarks

Tests

  • Forward correctness: 144/144 (8 shapes × 3 dtypes × 3 casting modes × 2 offsets)
  • Backward correctness: 36/36
  • nn.Module wrapper: ✅

Supports llama, gemma, and none casting modes. Recognises all major HuggingFace RMSNorm variants (Llama, Mistral, Gemma, Qwen2).

- Triton forward/backward kernels fusing residual addition + RMSNorm
- torch.autograd.Function and nn.Module wrappers
- FX graph transform pass: pattern-matches add→RMSNorm, swaps in fused module
- 3 casting modes: llama, gemma, none (matches Liger-Kernel conventions)
- 144/144 forward + 36/36 backward correctness tests
- Benchmarks: up to 4.9x speedup (bf16 batch), 60% memory reduction

Part 2 of ADLS kernel-fusion-aware optimisation pipeline.
@Dorijan9 Dorijan9 closed this Mar 22, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant