Skip to content

feat: banded local (Longformer) attention — fix O(T^2) long-audio OOM#9

Merged
mudler merged 7 commits into
masterfrom
feat/banded-local-attn
Jun 6, 2026
Merged

feat: banded local (Longformer) attention — fix O(T^2) long-audio OOM#9
mudler merged 7 commits into
masterfrom
feat/banded-local-attn

Conversation

@localai-bot
Copy link
Copy Markdown
Collaborator

Problem

Offline transcription ran the FastConformer encoder with global relative-position attention over the whole clip — O(T²) memory. A ~17-min / 40 MB audio file drove ~100 GB of attention activations and hard-OOM'd a unified-memory GPU (DGX Spark), taking every loaded model down with it.

Fix

Port NeMo's rel_pos_local_attn (RelPositionMultiHeadAttentionLongformer) as memory-bounded banded attention and wire it into the offline encoder. Each query attends only to keys in [t-W, t+W]; peak memory is O(T·window) instead of O(T²).

  • RelPosAttention::build_graph_local / forward_local: banded attention via pad-and-shift. The positional term (q_v · p^T over the 2W+1 local pos) is added 1:1 to the banded content scores, exactly as NeMo combines them.
  • local_rel_pos_encoding: NeMo LocalAttRelPositionalEncoding (positions +att_left..-att_right).
  • ConformerLayer::build_graph / Encoder::forward: route to banded when local mode is active. PARAKEET_ATT_CONTEXT=W forces rel_pos_local_attn [W,W]; otherwise audio > ~11 min (>8192 encoder frames) auto-switches to W=128. Short audio keeps full attention unchanged (encoder parity test untouched).

Validation

Kernel is verified against NeMo's own sliding_chunks_matmul_qk/pv (col→key t-w+c to 1e-6, scores to 1e-4) and a deterministic band reference (1.4e-3). Memory test: alloc grows linearly (ratio 1.98 at 2× T).

End-to-end on a 16.6-min clip with tdt-0.6b-v3 (CPU, 16 threads):

full attention (O(T²)) banded W=16 (O(T·w))
peak RSS 55.4 GB 9.1 GB
time 151 s 41 s

~6× less memory, ~3.7× faster. Short-clip transcripts: W=128 == full byte-for-byte; W=16 essentially identical.

Tests

  • test_relpos_attention_local — banded parity vs the deterministic reference.
  • test_relpos_attention_local_memory — O(T·window) memory scaling.
  • gen_nemo_baseline.py --att-context-size + gen_band_ref.py reproduce the fixtures.

Note: NeMo's Longformer is non-deterministic on short sequences (sliding_chunks_matmul_pv reads uninitialized memory at boundaries — two identical forward() calls differ by >1e3), so kernel parity uses the deterministic band reference; NeMo quality is anchored by the e2e above.

Follow-ups (not in this PR)

  • Efficient chunk-matmul construction (avoid O(window) graph nodes + O(window²) incremental concat) so W=128 is fast on CPU. Today the pad-and-shift is simple and correct but node-heavy.
  • Batched encoder path (build_graph_batched) still uses full attention.

mudler added 7 commits June 5, 2026 20:27
…aithful

Adds NeMo rel_pos_local_attn (RelPositionMultiHeadAttentionLongformer) as a
memory-bounded banded attention. This is the kernel for fixing the O(T^2)
attention blowup that OOM'd long-audio offline transcription on unified-memory
GPUs (a ~20-min clip allocated ~100GB and took the node down).

- RelPosAttention::build_graph_local / forward_local: banded attention via
  pad-and-shift, peak memory O(T*window) instead of O(T*T). Each query attends
  only to keys in [t-att_left, t+att_right]; the positional term (q_v . p^T over
  the 2W+1 local pos) is added 1:1 to the banded content scores, exactly as NeMo
  combines them. Verified against NeMo's own sliding_chunks_matmul_qk/pv
  (col->key t-w+c to 1e-6) and a deterministic band reference (1.4e-3).
- local_rel_pos_encoding: NeMo LocalAttRelPositionalEncoding (positions
  +att_left..-att_right), bit-identical to the centre rows of the full table.
- pk::last_graph_alloc_bytes(): gallocr high-water accessor for the memory test.
- gen_nemo_baseline.py --att-context-size (local-attention baseline); and
  gen_band_ref.py for the deterministic band reference. NOTE: NeMo's longformer
  is non-deterministic on short clips (sliding_chunks_matmul_pv reads
  uninitialized memory at boundaries via F.pad value=-1 + as_strided — two
  identical forward() calls differ by >1e3), so kernel parity must use the
  deterministic reference; end-to-end NeMo quality is anchored by long-audio WER.
- Tests: test_relpos_attention_local (parity 1.4e-3) and
  test_relpos_attention_local_memory (alloc grows ~linearly, ratio 1.98).

Not yet wired into the offline encoder path — follow-up.

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Assisted-by: Claude:claude-opus-4-8
…) OOM)

Wires the banded rel_pos_local_attn kernel into the offline encoder so long
audio no longer allocates O(T^2) attention, which OOM'd unified-memory GPUs (a
~17-min clip drove ~100GB and took the node down).

- ConformerLayer::build_graph gains optional att_left/att_right; when set it
  routes self-attention to RelPosAttention::build_graph_local with a LOCAL
  positional encoding, else keeps full attention unchanged.
- Encoder::forward picks the window via local_attn_window(Tp): env
  PARAKEET_ATT_CONTEXT=W forces NeMo rel_pos_local_attn [W,W]; otherwise audio
  longer than ~11 min (>8192 encoder frames) auto-switches to W=128. Short audio
  keeps full attention (NeMo-exact; the encoder parity test is unchanged).
- backend.cpp: bump kGraphSize 16384->65536 — the pad-and-shift kernel adds
  O(window) graph-node descriptors per layer.

Verified end-to-end on a 16.6-min clip with tdt-0.6b-v3 (CPU, 16 threads):
  full attention:  151 s, 55.4 GB peak RSS
  banded (W=16):    41 s,  9.1 GB peak RSS  (coherent transcript)
~6x less memory and ~3.7x faster; the full-attention path is what hit ~100GB and
OOM'd. Short-clip transcripts: W=128 == full byte-for-byte; W=16 essentially
identical.

Note: pad-and-shift creates O(window) nodes and an O(window^2) incremental
concat — fine for small windows but slow for W=128 on CPU; an efficient
chunk-matmul construction (like NeMo's sliding_chunks) is a follow-up.

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Assisted-by: Claude:claude-opus-4-8
…-model regression)

Bumping kGraphSize 16384->65536 to fit a W=128 banded graph regressed small
models ~+22% (tdt_ctc-110m): the per-compute metadata context and graph hash-set
scale with kGraphSize. Revert to 16384 and instead cap the local-attention
window at W=32 — the pad-and-shift kernel adds ~6*(2W+1) graph nodes/layer, and
W<=32 fits every shipped model's encoder within the budget. PARAKEET_ATT_CONTEXT
is clamped to 32.

Regression bench (librispeech, 100 files, CPU, back-to-back):
  tdt_ctc-110m: master 19.5s vs banded 19.4s (within noise), 0/100 text diffs
  tdt-0.6b-v3:  0/100 text diffs
Long-audio fix intact: 16.6-min clip + tdt-0.6b-v3 auto-uses W=32 -> 48s,
9.4 GB peak RSS (vs full attention 151s / 55.4 GB).

Lifting the window cap to NeMo's [128,128] needs the efficient chunk-matmul
construction (O(1) graph nodes) — follow-up.

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Assisted-by: Claude:claude-opus-4-8
Mirror the B=1 banded path (build_graph_local) into the fused batched
encoder so long-audio batches also use NeMo rel_pos_local_attn
(O(T*window)) instead of full O(T^2) attention.

RelPosAttention::build_graph_batched_local builds the 4D ([dk,T,H,B])
pad-and-shift band: K/V padded on the time axis, per-window-column views,
sum_rows content scores + mul_mat positional scores (shared pos broadcast
over B), a per-item band mask [P,T,1,B] keyed on each item's valid_len,
soft_max over the window, then the context gather and head merge. Conformer
build_graph_batched and the batched encoder forward route to it when
att_left/att_right >= 0, with the shared LOCAL positional encoding.

Verified on dgx (tdt_ctc-110m): the new test_encoder_batch_local exercises
the path at the production window (W=32 = kMaxLocalWindow). item0 (the full
clip) is bit-exact beside its shorter padded neighbour (no cross-item leak),
and the padded item1 matches its standalone run within 5e-2/5e-2 - the same
tolerance the full-attention batch test uses. Tighter-than-production windows
only amplify float noise on near-zero activations of the padded clip (item0
stays exact, mean|d| ~1e-2); not pad leakage.

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Assisted-by: Claude:claude-opus-4-8
The pad-and-shift banded path (build_graph_local) is correct but emits
O(window) graph nodes per layer (a P-iteration view+mul+sum_rows+concat
loop), which is why the window was capped at W=32. build_graph_local_chunked
computes the exact same NeMo rel_pos_local_attn output with a fixed, O(1)
number of nodes regardless of window, lifting the cap toward NeMo's full
[128,128].

Construction: time is tiled into chunks of C frames; each chunk carries its
own C+P-1 keys/values (the P-1 halo overlaps the neighbour), so a query
attends only within its chunk. K/V are gathered as OVERLAPPING strided chunk
views - which ggml's view-bounds check (ggml.c: data_size = dense product of
ne, ignoring nb) rejects unless the source is OVER-padded to (C+P-1)*G frames;
with that pad the view is legal and a single batched ggml_mul_mat produces the
per-chunk q.k blocks [C+P-1, C, G, H]. A diagonal "skew" view (nb1 walking C+P
on a [C+P-1,...] tensor, which passes the bounds check since P <= C+P-1)
extracts the [P,T] band. The PV side inverse-skews the softmaxed band back to a
[C+P-1, C] banded matrix (pad ne0 by C, skew-view, mask the lower off-band),
then one batched matmul against the transposed V chunks gathers the context.

Verified against the trusted pad-and-shift path (forward_local, itself 1.4e-3
vs a deterministic brute-force band reference): new test
test_relpos_attention_local_chunked runs synthetic x/pos through the real
layer-0 weights for T up to 333 and W up to 128 (chunk < W and chunk == W),
matching forward_local to <1e-3 (max|d| ~6e-4). Existing pad-and-shift path
and all encoder/conformer regressions unchanged. Encoder wiring (raise the cap
and route long audio to this kernel) follows.

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Assisted-by: Claude:claude-opus-4-8
… to 128

Wire the O(1)-node chunk-matmul kernel into the encoder and raise the local
window cap from 32 to NeMo's full 128. Both conformer attention paths now use
it: build_graph (B=1) -> build_graph_local_chunked, build_graph_batched (B>1)
-> build_graph_batched_local_chunked. The batched wrapper runs the 4D chunk
kernel once per item and stacks the [D,T] outputs into [D,T,B] (the chunk graph
is already 4D, so it can't also carry a batch dim); that is O(B) nodes, still
O(1) in the window, and B is small.

local_attn_window's cap (kMaxLocalWindow) goes 32 -> 128: the pad-and-shift
path emitted ~6*(2W+1) nodes/layer (hence the 32 cap to fit kGraphSize), but the
chunk-matmul path is window-independent in node count, so long audio now runs at
NeMo's full [128,128] window. The pad-and-shift build_graph_local /
build_graph_batched_local are kept as the verification oracle for
test_relpos_attention_local{,_chunked}.

Verified on dgx: full ctest green (51/51). test_encoder_batch_local passes at
every forced window W=8..128 (now through the chunked path). e2e on a 16.6-min
clip (tdt-0.6b-v3, CPU/16t), auto-local W=128: 36.8s / 9.8GB peak RSS, coherent
transcript - faster than the W=32 pad-and-shift capstone (41-48s / 9.1GB) at a
4x wider, NeMo-faithful window, and ~5.6x under the full-attention path that
OOM'd the node (151s / 55GB).

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Assisted-by: Claude:claude-opus-4-8
Document the banded local attention (rel_pos_local_attn) memory/speed win that
the chunk-matmul kernel enables. 16.6-min clip, tdt-0.6b-v3, GB10 CPU/16t:
global O(T^2) attention 148.3s / 54.0GB vs banded W=128 36.9s / 9.4GB (~4x
faster, ~5.7x less peak RAM) at NeMo's full window, with the chunk-matmul making
W=128 as cheap as W=32. Notes that short clips stay on the global path and are
unchanged.

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Assisted-by: Claude:claude-opus-4-8
@mudler mudler merged commit 8436005 into master Jun 6, 2026
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants