From db4a19c8d26152f8ab8153a8f6ebb665077f8c41 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Sun, 7 Jun 2026 08:00:23 +0000 Subject: [PATCH 1/7] feat(subsampling): add subsample_len spatial-length helper --- src/subsampling.cpp | 11 +++++++++++ src/subsampling.hpp | 5 +++++ tests/CMakeLists.txt | 1 + tests/test_subsampling_tiling.cpp | 18 ++++++++++++++++++ 4 files changed, 35 insertions(+) create mode 100644 tests/test_subsampling_tiling.cpp diff --git a/src/subsampling.cpp b/src/subsampling.cpp index 7af909a..2e35858 100644 --- a/src/subsampling.cpp +++ b/src/subsampling.cpp @@ -45,6 +45,17 @@ int Subsampling::valid_out_len(int T, int in_valid_frames) const { return valid; } +int Subsampling::subsample_len(int T) const { + // Spatial output length after the three stride-2, k=3 conv stages, using + // ggml conv2d's OH = floor((in + 2p - k)/s) + 1. Non-causal uses symmetric + // pad p=1 (all_paddings=2); causal uses all_paddings=3. This mirrors the + // valid_out_len recurrence but tracks the full (padded) spatial extent. + const int all_paddings = causal_ ? 3 : 2; + int x = T; + for (int s = 0; s < 3; ++s) x = (x + all_paddings - 3) / 2 + 1; + return x; +} + ggml_tensor* Subsampling::build_graph_batched(ggml_context* ctx, const float* mel, int n_mels, int T, int B, GraphInputPool& pool, diff --git a/src/subsampling.hpp b/src/subsampling.hpp index b2ae256..e936296 100644 --- a/src/subsampling.hpp +++ b/src/subsampling.hpp @@ -67,6 +67,11 @@ class Subsampling { // `in_valid_frames` (>=0) overrides the offline T-1 entry valid length // with an explicit count (streaming); <0 keeps the offline convention. int valid_out_len(int T, int in_valid_frames = -1) const; + + // Number of (subsampled) output spatial frames an input of T mel frames + // produces, applying ggml conv2d's per-stage OH = floor((in+2p-k)/s)+1 for + // all three stride-2, k=3 stages. Used for subsampling-tile bookkeeping. + int subsample_len(int T) const; private: const ModelLoader& ml_; int conv_channels_; // C diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 9d71015..6144e6f 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -13,6 +13,7 @@ pk_add_test(test_fft) pk_add_test(test_mel) pk_add_test(test_mel_gpu) pk_add_test(test_subsampling) +pk_add_test(test_subsampling_tiling) pk_add_test(test_subsampling_batch) pk_add_test(test_subsampling_batch_causal) pk_add_test(test_relpos_attention) diff --git a/tests/test_subsampling_tiling.cpp b/tests/test_subsampling_tiling.cpp new file mode 100644 index 0000000..0fefe47 --- /dev/null +++ b/tests/test_subsampling_tiling.cpp @@ -0,0 +1,18 @@ +#include "subsampling.hpp" +#include "model_loader.hpp" +#include +#include +int main(){ + const char* gguf = std::getenv("PARAKEET_TEST_GGUF"); + if(!gguf){ std::fprintf(stderr,"env not set; skip\n"); return 77; } + pk::ModelLoader ml; if(!ml.load(gguf)) return 1; + pk::Subsampling sub(ml); + // f(x)=(x-1)/2+1 applied 3x for non-causal; spot-check. + struct { int T, Tp; } cases[] = {{100, 13}, {1000, 125}, {307848, 38481}}; + for (auto& c : cases) { + int got = sub.subsample_len(c.T); + if (got != c.Tp){ std::fprintf(stderr,"T=%d got=%d want=%d\n",c.T,got,c.Tp); return 1; } + } + std::printf("subsample_len ok\n"); + return 0; +} From 11ebf466ae4e90acac50bbedf8bc4fff965080a8 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Sun, 7 Jun 2026 08:13:14 +0000 Subject: [PATCH 2/7] feat(subsampling): tiled long-audio path (parity vs forward) Co-Authored-By: Claude Opus 4.8 (1M context) --- src/subsampling.cpp | 61 +++++++++++++++++++++++++++++++ src/subsampling.hpp | 8 ++++ tests/test_subsampling_tiling.cpp | 55 ++++++++++++++++++++++++++++ 3 files changed, 124 insertions(+) diff --git a/src/subsampling.cpp b/src/subsampling.cpp index 2e35858..ce8a56b 100644 --- a/src/subsampling.cpp +++ b/src/subsampling.cpp @@ -429,4 +429,65 @@ void Subsampling::forward(const std::vector& mel, int n_mels, int T, valid_len = valid; } +void Subsampling::forward_tiled(const std::vector& mel, int n_mels, int T, + int tile_out_frames, std::vector& out, + int& Tout, int& d_model, int& valid_len) const { + const int Tp = subsample_len(T); + d_model = d_model_; + Tout = Tp; + valid_len = (valid_out_len(T, -1) > Tp) ? Tp : valid_out_len(T, -1); + + // TODO: causal tiling needs the causal phase mapping; offline causal long-audio + // is not a current target. Fall back to the single, untiled graph (== forward()). + if (causal_ || tile_out_frames <= 0) { + int t_unused = 0, dm_unused = 0, vl_unused = 0; + forward(mel, n_mels, T, out, t_unused, dm_unused, vl_unused, -1); + return; + } + + // Non-causal symmetric-pad tiling. Receptive field is +-7 mel frames; output + // frame o has mel center 8*o. Window start ws is a multiple of 8, so the + // window-output frame j maps to global o = j + ws/8. A generous halo H=64 mel + // frames (>> RF) keeps every emitted frame's RF inside the fed window (or, at + // true utterance edges, inside build_graph's own pad which equals the boundary), + // so every kept frame equals the full-utterance result. + const int H = 64; // multiple of 8, >> receptive field (7) + out.assign((size_t)Tp * d_model_, 0.0f); + + for (int os = 0; os < Tp; os += tile_out_frames) { + const int oe = (os + tile_out_frames < Tp) ? (os + tile_out_frames) : Tp; + int ws = 8 * os - H; if (ws < 0) ws = 0; // multiple of 8 (clamp keeps it) + int we = 8 * oe + H; if (we > T) we = T; + const int Lw = we - ws; + + // Slice window mel, feat-major [n_mels, Lw]. + std::vector win((size_t)n_mels * Lw); + for (int f = 0; f < n_mels; ++f) + for (int t = ws; t < we; ++t) + win[(size_t)f * Lw + (t - ws)] = mel[(size_t)f * T + t]; + + // Run the single-item graph on the window. The window is all-real, so pass + // in_valid_frames = Lw (no trailing mask inside the tile). + std::vector win_out; + int Tpw = 0, valid_w = 0; + GraphInputPool pool; + bool ok = pk::run_graph(/*mem_bytes*/0, /*n_threads*/4, + [&](ggml_context* ctx) -> ggml_tensor* { + return build_graph(ctx, win, n_mels, Lw, pool, Tpw, valid_w, Lw); + }, win_out); + assert(ok && "subsampling tile graph failed"); + (void)ok; + + // Window-output frame j has mel center ws+8*j -> global o = j + ws/8. + const int j0 = ws / 8; // exact: ws is a multiple of 8 + for (int o = os; o < oe; ++o) { + const int j = o - j0; + assert(j >= 0 && j < Tpw && "tile frame out of window range"); + std::memcpy(&out[(size_t)o * d_model_], + &win_out[(size_t)j * d_model_], + (size_t)d_model_ * sizeof(float)); + } + } +} + } // namespace pk diff --git a/src/subsampling.hpp b/src/subsampling.hpp index e936296..e26290f 100644 --- a/src/subsampling.hpp +++ b/src/subsampling.hpp @@ -61,6 +61,14 @@ class Subsampling { std::vector& out, int& Tout, int& d_model, int& valid_len, int in_valid_frames) const; + // Tiled subsampling for long audio: result is identical to forward() within the + // valid region, but no intermediate tensor exceeds a tile_out_frames-bounded size. + // tile_out_frames = number of OUTPUT (subsampled) frames per tile. Non-causal only; + // causal falls back to the single-graph path. + void forward_tiled(const std::vector& mel, int n_mels, int T, + int tile_out_frames, std::vector& out, + int& Tout, int& d_model, int& valid_len) const; + // Number of valid (non-pad) output frames for an input of T mel frames, // applying the same per-stage `calc_length` reductions NeMo uses. Pure // arithmetic, no graph; exposed so the encoder can derive valid_len. diff --git a/tests/test_subsampling_tiling.cpp b/tests/test_subsampling_tiling.cpp index 0fefe47..b83b802 100644 --- a/tests/test_subsampling_tiling.cpp +++ b/tests/test_subsampling_tiling.cpp @@ -2,6 +2,8 @@ #include "model_loader.hpp" #include #include +#include +#include int main(){ const char* gguf = std::getenv("PARAKEET_TEST_GGUF"); if(!gguf){ std::fprintf(stderr,"env not set; skip\n"); return 77; } @@ -14,5 +16,58 @@ int main(){ if (got != c.Tp){ std::fprintf(stderr,"T=%d got=%d want=%d\n",c.T,got,c.Tp); return 1; } } std::printf("subsample_len ok\n"); + + // ---- forward_tiled parity: tiled == untiled within the valid region ---- + { + const int n_mels = (int)ml.config().n_mels; // 80 for pk110m + const int T = 4000; // -> ~30 tiles at tile=17 + std::vector mel((size_t)n_mels*T); + // deterministic pseudo-random, mean-subtracted-ish + unsigned s2 = 1234567u; + for (size_t i=0;i>9)&0xFFFF)/32768.0f - 1.0f; } + std::vector full; int Tp=0, dm=0, vl=0; + sub.forward(mel, n_mels, T, full, Tp, dm, vl); + std::vector tiled; int Tp2=0, dm2=0, vl2=0; + sub.forward_tiled(mel, n_mels, T, /*tile_out_frames*/17, tiled, Tp2, dm2, vl2); + if (Tp2!=Tp || dm2!=dm || vl2!=vl){ std::fprintf(stderr,"tiled shape/valid mismatch Tp=%d/%d dm=%d/%d vl=%d/%d\n",Tp,Tp2,dm,dm2,vl,vl2); return 1; } + + // STRUCTURAL guard (float-noise-immune): a single tile whose halo spans the + // whole utterance is literally build_graph over the full T, so it MUST be + // bit-exact vs forward(). This validates the window->global frame mapping + // (j0 = ws/8, halo alignment) with ZERO tolerance: any off-by-one in the + // tiling arithmetic breaks it. (tile_out_frames > Tp => one window.) + { + std::vector one; int a=0,b=0,c=0; + sub.forward_tiled(mel, n_mels, T, Tp+1, one, a, b, c); + double smax=0; for (size_t i=0;ismax)smax=d; } + if (smax != 0.0){ std::fprintf(stderr,"single-tile NOT bit-exact vs forward: maxabs=%.3e\n", smax); return 1; } + } + + // PARITY (multi-tile, tile=17): compare only the valid region [0, vl). + // We report the raw absolute maxabs (the synthetic full-band-noise mel + // drives activations to magnitude ~2e3 and the out Linear sums mixed-sign + // terms, so cancellation leaves an absolute graph-splitting float-noise + // floor of ~3e-2 here, INDEPENDENT of input scale). The pass/fail gate uses + // a per-frame relative metric (max|diff| / max|ref| over the frame): real + // graph-splitting float noise is ~1e-5 relative, while a wrong receptive + // field / frame misalignment reads uncorrelated data and diverges by + // O(|ref|) (~1.0 relative), so 1e-2 cleanly separates the two without + // hiding a boundary error. + double maxabs=0, worstrel=0; for (int t=0;tmaxabs) maxabs=d; + if (d>dmax) dmax=d; + if (std::fabs(ref)>fmax) fmax=std::fabs(ref); + } + double rel = dmax/(fmax+1e-6); + if (rel>worstrel) worstrel=rel; + } + std::printf("forward_tiled maxabs=%.3e (vl=%d dm=%d)\n", maxabs, vl, dm); + std::printf("forward_tiled worst per-frame rel=%.3e\n", worstrel); + if (worstrel > 1e-2){ std::fprintf(stderr,"forward_tiled parity FAIL rel=%.3e maxabs=%.3e\n", worstrel, maxabs); return 1; } + } return 0; } From 5f37b9f82c61a5f8898331544dbb2d24a56723e0 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Sun, 7 Jun 2026 08:21:32 +0000 Subject: [PATCH 3/7] refactor(subsampling): cache valid_out_len in forward_tiled; document tiling test invariant --- src/subsampling.cpp | 3 ++- tests/test_subsampling_tiling.cpp | 3 +++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/subsampling.cpp b/src/subsampling.cpp index ce8a56b..b3272d3 100644 --- a/src/subsampling.cpp +++ b/src/subsampling.cpp @@ -435,7 +435,8 @@ void Subsampling::forward_tiled(const std::vector& mel, int n_mels, int T const int Tp = subsample_len(T); d_model = d_model_; Tout = Tp; - valid_len = (valid_out_len(T, -1) > Tp) ? Tp : valid_out_len(T, -1); + const int vo = valid_out_len(T, -1); + valid_len = (vo > Tp) ? Tp : vo; // TODO: causal tiling needs the causal phase mapping; offline causal long-audio // is not a current target. Fall back to the single, untiled graph (== forward()). diff --git a/tests/test_subsampling_tiling.cpp b/tests/test_subsampling_tiling.cpp index b83b802..01c9ff8 100644 --- a/tests/test_subsampling_tiling.cpp +++ b/tests/test_subsampling_tiling.cpp @@ -36,6 +36,9 @@ int main(){ // bit-exact vs forward(). This validates the window->global frame mapping // (j0 = ws/8, halo alignment) with ZERO tolerance: any off-by-one in the // tiling arithmetic breaks it. (tile_out_frames > Tp => one window.) + // NOTE: keep T such that valid_out_len(T,-1) == Tp (no trailing mask), + // else forward() masks tail frames to bias and this all-element guard + // would false-fail. T=4000 satisfies this (vl==Tp). { std::vector one; int a=0,b=0,c=0; sub.forward_tiled(mel, n_mels, T, Tp+1, one, a, b, c); From 17fc6fb8ba1162dd8fc6da0f347913c3b4726552 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Sun, 7 Jun 2026 08:29:19 +0000 Subject: [PATCH 4/7] feat(encoder): forward_batch_tiled from pre-subsampled features Co-Authored-By: Claude Opus 4.8 (1M context) --- src/encoder.cpp | 84 +++++++++++++++++++++++++++++++++++++ src/encoder.hpp | 24 +++++++++++ tests/CMakeLists.txt | 1 + tests/test_encoder_long.cpp | 51 ++++++++++++++++++++++ 4 files changed, 160 insertions(+) create mode 100644 tests/test_encoder_long.cpp diff --git a/src/encoder.cpp b/src/encoder.cpp index f295989..0f77caa 100644 --- a/src/encoder.cpp +++ b/src/encoder.cpp @@ -198,4 +198,88 @@ void Encoder::forward_batch(const MelBatch& mels, } } +void Encoder::run_post_subsampling_batch(const std::vector& x0_host, + int Tp, int B, const std::vector& vout, + std::vector>& enc_outs, int& d_model, int& Tout, + std::vector& valid_Tout) const { + // Mirror of forward_batch's post-subsampling body (steps 2-4 + per-item + // split), but the subsampled features arrive pre-computed in x0_host and are + // injected as a graph input instead of built via sub.build_graph_batched. + GraphInputPool pool; + std::vector flat; // receives [d_model, Tp, B] (ne0=d_model fastest) + bool ok = pk::run_graph(/*mem_bytes*/0, /*n_threads*/0, + [&](ggml_context* ctx) -> ggml_tensor* { + // ---- 1. Inject pre-subsampled features x [d_model, Tp, B]. ---- + int64_t x_ne[3] = { d_model_, Tp, B }; + ggml_tensor* x = pk::graph_input_tensor(ctx, GGML_TYPE_F32, 3, x_ne, + x0_host.data(), x0_host.size() * sizeof(float)); + + // ---- 2. xscaling (gated; off for this model). ---- + if (xscaling_) x = ggml_scale(ctx, x, std::sqrt((float)d_model_)); + + // ---- 3. Positional encoding (local for long audio; see B=1 path). ---- + const int att_w = local_attn_window(Tp); + const bool local = att_w > 0; + const int pos_len = local ? (2 * att_w + 1) : (2 * Tp - 1); + std::vector& pe_host = pool.alloc_f32(); + if (local) local_rel_pos_encoding(att_w, att_w, d_model_, pe_host); + else rel_pos_encoding(Tp, d_model_, pe_host); // [pos_len, d_model] + int64_t pe_ne[2] = {d_model_, pos_len}; + ggml_tensor* pe = pk::graph_input_tensor(ctx, GGML_TYPE_F32, 2, pe_ne, + pe_host.data(), pe_host.size() * sizeof(float)); + + // ---- 4. Conformer layer stack (all in-graph, shared pe). ---- + for (int i = 0; i < n_layers_; ++i) { + ConformerLayer layer(ml_, i); + x = layer.build_graph_batched(ctx, x, Tp, B, pe, pos_len, vout, pool, + local ? att_w : -1, local ? att_w : -1); + } + return x; // [d_model, Tp, B] + }, flat); + + assert(ok && "tiled post-subsampling encoder graph failed"); + (void)ok; + + d_model = d_model_; + Tout = Tp; + valid_Tout = vout; + enc_outs.assign(B, std::vector()); + for (int b = 0; b < B; ++b) { + const int tv = vout[b]; + enc_outs[b].resize((size_t)d_model_ * tv); + for (int t = 0; t < tv; ++t) + for (int c = 0; c < d_model_; ++c) + enc_outs[b][(size_t)c * tv + t] = + flat[(((size_t)b * Tp) + t) * d_model_ + c]; + } +} + +void Encoder::forward_batch_tiled(const MelBatch& mels, + std::vector>& enc_outs, int& d_model, int& Tout, + std::vector& valid_Tout, int tile_out_frames) const { + Subsampling sub(ml_); + const int B = mels.B; + std::vector> sub_b(B); // each [Tp_b, d_model] (t*dm+c) + std::vector Tp_b(B), vout(B); + int Tp_max = 0, dm = 0; + for (int b = 0; b < B; ++b) { + // slice item b's real mel [n_mels, valid_T[b]] out of the padded batch buffer + std::vector mel((size_t)mels.n_mels * mels.valid_T[b]); + for (int m = 0; m < mels.n_mels; ++m) + for (int t = 0; t < mels.valid_T[b]; ++t) + mel[(size_t)m*mels.valid_T[b]+t] = + mels.data[((size_t)b*mels.n_mels+m)*mels.T_max + t]; + int Tp=0, vl=0; + sub.forward_tiled(mel, mels.n_mels, mels.valid_T[b], tile_out_frames, + sub_b[b], Tp, dm, vl); + Tp_b[b]=Tp; vout[b]=vl; if (Tp>Tp_max) Tp_max=Tp; + } + // assemble x0 [d_model, Tp_max, B], zero-padded per item. + // sub_b[b] is [Tp_b, d_model] (t*dm+c); x0 wants (c,t,b) at (b*Tp_max+t)*dm+c. + std::vector x0((size_t)dm*Tp_max*B, 0.0f); + for (int b=0;b& valid_Tout) const; + // Long-audio tiled variant of forward_batch: subsampling is done per item via + // Subsampling::forward_tiled (which bounds intermediate tensor size so the + // >2^31-element conv tensors that crash long clips never materialise), then + // the pre-subsampled features are fed into the existing post-subsampling graph + // (xscaling + positional encoding + conformer stack). Same output contract as + // forward_batch. `tile_out_frames` = subsampled output frames per tile. + void forward_batch_tiled(const MelBatch& mels, + std::vector>& enc_outs, + int& d_model, int& Tout, + std::vector& valid_Tout, + int tile_out_frames) const; + // Same as forward(), but also captures the per-layer outputs at indices // `capture_layers` (each row-major [T', d_model]) into `layer_outs` (parallel // to capture_layers). Used by the parity test to localize divergence. @@ -54,6 +66,18 @@ class Encoder { std::vector>& layer_outs) const; private: + // Mirrors forward_batch's post-subsampling body (xscaling + positional + // encoding + conformer stack + per-item channels-first split) for the tiled + // path, but takes the already-subsampled features `x0_host` ([d_model, Tp, B] + // in ggml order, element (c,t,b) at ((size_t)b*Tp+t)*d_model+c) injected as a + // graph input instead of building them from sub.build_graph_batched. `vout` + // holds the per-item valid output-frame counts. `x0_host` must outlive the + // call (run_graph is synchronous), which it does as a by-reference parameter. + void run_post_subsampling_batch(const std::vector& x0_host, + int Tp, int B, const std::vector& vout, + std::vector>& enc_outs, int& d_model, int& Tout, + std::vector& valid_Tout) const; + const ModelLoader& ml_; int d_model_; int n_layers_; diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 6144e6f..b85fb25 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -25,6 +25,7 @@ pk_add_test(test_conformer) pk_add_test(test_conformer_batch) pk_add_test(test_conv_eou) pk_add_test(test_encoder) +pk_add_test(test_encoder_long) pk_add_test(test_encoder_batch) pk_add_test(test_encoder_batch_local) pk_add_test(test_encoder_eou) diff --git a/tests/test_encoder_long.cpp b/tests/test_encoder_long.cpp new file mode 100644 index 0000000..e01ddf1 --- /dev/null +++ b/tests/test_encoder_long.cpp @@ -0,0 +1,51 @@ +#include "encoder.hpp" +#include "model_loader.hpp" +#include +#include +#include +#include +// Self-consistency: forward_batch_tiled (subsampling done per-item via the tiled +// path, then the post-subsampling graph) must match forward_batch (one fused +// graph) for B=1, within the valid region. The metric is per-frame RELATIVE +// (benign float reorder across tiled subsampling + 24 conformer layers is O(1e-3); +// a layout/injection/valid-length bug is O(1)). +// +// NOTE on input: we use a SMOOTH, log-mel-shaped signal (sinusoids), not uniform +// random noise. A uniform-random [-1,1] "mel" is wildly out-of-distribution for +// the encoder: full self-attention + 24 layers amplify even the inherent +// batched-vs-single graph float reorder to ~5e-2 at a few edge frames. That +// amplification is NOT a tiling bug -- the pre-existing forward_batch(B=1) vs +// forward() differ by the same ~5e-2 at the same frame on random input. A +// realistic (smooth) input keeps the amplification benign (~1e-3), so the gate +// below cleanly separates a correct refactor from an O(1) layout bug. +int main(){ + const char* gguf = std::getenv("PARAKEET_TEST_GGUF"); + if(!gguf){ std::fprintf(stderr,"env not set; skip\n"); return 77; } + pk::ModelLoader ml; if(!ml.load(gguf)) return 1; + pk::Encoder enc(ml); + const int n_mels = (int)ml.config().n_mels; + const int T = 4000; + // build a 1-item MelBatch with a deterministic, smooth log-mel-shaped signal + pk::MelBatch mb; mb.B=1; mb.n_mels=n_mels; mb.T_max=T; mb.valid_T={T}; + mb.data.assign((size_t)n_mels*T, 0.f); + for (int m=0;m> e_ref, e_tiled; int dmr=0,Tr=0,dmt=0,Tt=0; + std::vector vr, vt; + enc.forward_batch(mb, e_ref, dmr, Tr, vr); + enc.forward_batch_tiled(mb, e_tiled, dmt, Tt, vt, /*tile_out_frames*/17); + if (vr.size()!=vt.size() || vr[0]!=vt[0] || dmr!=dmt){ std::fprintf(stderr,"shape/valid mismatch v=%d/%d dm=%d/%d\n",vr.empty()?-1:vr[0],vt.empty()?-1:vt[0],dmr,dmt); return 1; } + const int tv = vr[0]; + // enc_outs[0] is channels-first [d_model, tv]: index c*tv + t + double maxabs=0, worstrel=0; + for (int t=0;tmaxabs)maxabs=d; if(d>dmax)dmax=d; if(std::fabs(ref)>fmax)fmax=std::fabs(ref);} + double rel=dmax/(fmax+1e-6); if(rel>worstrel)worstrel=rel; } + std::printf("forward_batch_tiled maxabs=%.3e worstrel=%.3e (tv=%d dm=%d)\n",maxabs,worstrel,tv,dmr); + if (worstrel > 2e-2){ std::fprintf(stderr,"encoder tiled parity FAIL rel=%.3e\n",worstrel); return 1; } + return 0; +} From 162742b1c34016dfb463bc33306f3c669ec50eb4 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Sun, 7 Jun 2026 08:37:29 +0000 Subject: [PATCH 5/7] feat(model): tile subsampling for long audio above safe threshold Co-Authored-By: Claude Opus 4.8 (1M context) --- src/model.cpp | 47 ++++++++++++++++++++- tests/CMakeLists.txt | 5 ++- tests/test_transcribe_tiled.cpp | 75 +++++++++++++++++++++++++++++++++ 3 files changed, 123 insertions(+), 4 deletions(-) create mode 100644 tests/test_transcribe_tiled.cpp diff --git a/src/model.cpp b/src/model.cpp index c212418..162505c 100644 --- a/src/model.cpp +++ b/src/model.cpp @@ -4,6 +4,7 @@ #include "mel.hpp" #include "mel_gpu.hpp" #include "encoder.hpp" +#include "subsampling.hpp" #include "ctc_decoder.hpp" #include "search.hpp" #include "tokenizer.hpp" @@ -19,6 +20,7 @@ #include "ggml_graph.hpp" #include +#include #include #include @@ -136,6 +138,21 @@ std::string Model::transcribe_16k(const std::vector& pcm16k, return decode_enc_out(loader_, enc_out, d_model, Tout, use_tdt); } +// Max mel frames per encoder pass before the first subsampling conv output +// (n_mels/2 * T/2 * conv_channels) approaches INT_MAX. ggml's CUDA unary (relu) +// kernel indexes elements with int32, so a tensor > 2^31 elements crashes +// ("invalid configuration argument"). Bound the per-pass first-conv tensor to a +// safe 1.5e9 elements: (n_mels/2)*(T/2)*C < 1.5e9 => T < 2*1.5e9 / ((n_mels/2)*C). +static int safe_mel_window(const pk::ParakeetConfig& cfg) { + const long long per_t = (long long)((int)cfg.n_mels / 2) * (int)cfg.subsampling_conv_channels; // first-conv elems per output mel-row pair + if (per_t <= 0) return 1 << 30; // unknown config -> effectively no cap + const long long bound = 1500000000LL; // 1.5e9 elements, safe margin under 2^31 + long long win = (2 * bound) / per_t; // T such that (n_mels/2)*(T/2)*C ~= bound + if (win < 16384) win = 16384; // floor for tiny/odd configs + if (win > (1LL<<30)) win = (1LL<<30); + return (int)win; +} + // Stage a batch of 16 kHz mono clips into a MelBatch: per-clip log-mel // (GpuMel on a non-CPU backend, else the byte-identical FFT MelFrontend), // zero-padded and stacked to the batch's longest clip (T_max). data layout is @@ -201,7 +218,20 @@ std::vector Model::transcribe_16k_batch( Encoder encoder(loader_); std::vector> enc_outs; int d_model = 0, Tout = 0; std::vector valid_Tout; - encoder.forward_batch(mb, enc_outs, d_model, Tout, valid_Tout); + // Long audio: the first subsampling conv would exceed ggml's 2^31 element limit. + // Tile the subsampling stage (faithful; see Encoder::forward_batch_tiled). An env + // override forces the tiled path for testing on short clips. + int sub_tile = 0; // 0 => auto + if (const char* e = std::getenv("PARAKEET_SUBSAMPLING_TILE")) sub_tile = std::atoi(e); + const int win = safe_mel_window(cfg); + if (sub_tile > 0) { + encoder.forward_batch_tiled(mb, enc_outs, d_model, Tout, valid_Tout, sub_tile); + } else if (mb.T_max > win) { + const int tile_out = pk::Subsampling(loader_).subsample_len(win); + encoder.forward_batch_tiled(mb, enc_outs, d_model, Tout, valid_Tout, tile_out); + } else { + encoder.forward_batch(mb, enc_outs, d_model, Tout, valid_Tout); + } // 2b. Prompt conditioning per item (one language for the whole batch). No-op // for non-prompt models. @@ -348,7 +378,20 @@ std::vector Model::transcribe_16k_batch_with_timestamps( Encoder encoder(loader_); std::vector> enc_outs; int d_model = 0, Tout = 0; std::vector valid_Tout; - encoder.forward_batch(mb, enc_outs, d_model, Tout, valid_Tout); + // Long audio: the first subsampling conv would exceed ggml's 2^31 element limit. + // Tile the subsampling stage (faithful; see Encoder::forward_batch_tiled). An env + // override forces the tiled path for testing on short clips. + int sub_tile = 0; // 0 => auto + if (const char* e = std::getenv("PARAKEET_SUBSAMPLING_TILE")) sub_tile = std::atoi(e); + const int win = safe_mel_window(cfg); + if (sub_tile > 0) { + encoder.forward_batch_tiled(mb, enc_outs, d_model, Tout, valid_Tout, sub_tile); + } else if (mb.T_max > win) { + const int tile_out = pk::Subsampling(loader_).subsample_len(win); + encoder.forward_batch_tiled(mb, enc_outs, d_model, Tout, valid_Tout, tile_out); + } else { + encoder.forward_batch(mb, enc_outs, d_model, Tout, valid_Tout); + } // Prompt conditioning per item (one language for the whole batch). No-op // for non-prompt models. diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index b85fb25..af7243f 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -47,6 +47,7 @@ pk_add_test(test_transcribe_batch_ts) pk_add_test(test_tokenizer) pk_add_test(test_transcribe) pk_add_test(test_transcribe_speech) +pk_add_test(test_transcribe_tiled) pk_add_test(test_transcribe_tdt) pk_add_test(test_transcribe_0_6b) pk_add_test(test_transcribe_ctc) @@ -70,7 +71,7 @@ set_tests_properties(test_model_loader test_mel test_mel_gpu test_subsampling te test_joint test_joint_step_batch test_prompt_kernel test_transducer_core test_tdt_greedy test_transducer_greedy_batch test_transducer_greedy_batch_rnnt test_timestamps_tokens test_timestamps test_transcribe_batch_ts test_tokenizer test_transcribe - test_transcribe_speech test_transcribe_tdt test_transcribe_0_6b + test_transcribe_speech test_transcribe_tiled test_transcribe_tdt test_transcribe_0_6b test_transcribe_ctc test_transcribe_rnnt test_transcribe_eou test_transcribe_nemotron test_streaming_decode test_streaming_eou_reset test_streaming_nemotron test_streaming_mel test_capi test_capi_batch test_capi_stream test_capi_stream_json test_capi_timestamps test_capi_batch_json @@ -85,7 +86,7 @@ set_tests_properties(test_mel test_mel_gpu test_subsampling test_subsampling_bat test_timestamps_tokens test_timestamps test_transcribe_batch_ts test_tokenizer test_transcribe - test_transcribe_speech test_transcribe_tdt test_transcribe_0_6b + test_transcribe_speech test_transcribe_tiled test_transcribe_tdt test_transcribe_0_6b test_transcribe_ctc test_transcribe_rnnt test_transcribe_eou test_transcribe_nemotron test_streaming_decode test_streaming_eou_reset test_streaming_nemotron test_streaming_mel test_capi test_capi_batch test_capi_stream test_capi_stream_json test_capi_timestamps test_capi_batch_json diff --git a/tests/test_transcribe_tiled.cpp b/tests/test_transcribe_tiled.cpp new file mode 100644 index 0000000..94101e8 --- /dev/null +++ b/tests/test_transcribe_tiled.cpp @@ -0,0 +1,75 @@ +#include "model.hpp" +#include "audio_io.hpp" + +#include +#include +#include +#include +#include + +// End-to-end faithfulness test for the long-audio subsampling tiling path. +// +// Tiling the subsampling stage (Encoder::forward_batch_tiled) is engineered to be +// numerically faithful to the fused encoder (Encoder::forward_batch). This test +// runs a real speech clip through the FULL batched pipeline twice: +// +// 1. fused (PARAKEET_SUBSAMPLING_TILE unset -> auto, short clip stays fused) +// 2. tiled (PARAKEET_SUBSAMPLING_TILE=17 -> forces many small subsampling tiles) +// +// and asserts the transcripts are IDENTICAL. A faithful tiling must decode a clean +// clip to exactly the same tokens. Self-contained: only needs PARAKEET_TEST_GGUF +// plus the committed tests/fixtures/speech.wav. +int main() { + const char* gguf = std::getenv("PARAKEET_TEST_GGUF"); + if (!gguf) { + std::fprintf(stderr, "test_transcribe_tiled: PARAKEET_TEST_GGUF not set; skip\n"); + return 77; + } + + std::unique_ptr model = pk::Model::load(gguf); + if (!model) { + std::fprintf(stderr, "test_transcribe_tiled: failed to load model %s\n", gguf); + return 1; + } + + pk::Audio audio; + if (!pk::load_audio_16k_mono("tests/fixtures/speech.wav", audio)) { + std::fprintf(stderr, "test_transcribe_tiled: failed to load tests/fixtures/speech.wav\n"); + return 1; + } + + auto run = [&]() -> std::string { + std::vector> batch{ audio.samples }; + return model->transcribe_pcm_batch(batch, 16000).at(0); + }; + + std::string fused, tiled; + try { + // 1) fused (env unset -> auto; short clip below threshold stays fused) + unsetenv("PARAKEET_SUBSAMPLING_TILE"); + fused = run(); + + // 2) forced tiled path (small tile -> many tiles even on a short clip) + setenv("PARAKEET_SUBSAMPLING_TILE", "17", 1); + tiled = run(); + unsetenv("PARAKEET_SUBSAMPLING_TILE"); + } catch (const std::exception& e) { + std::fprintf(stderr, "test_transcribe_tiled: transcribe threw: %s\n", e.what()); + return 1; + } + + if (fused.empty()) { + std::fprintf(stderr, + "test_transcribe_tiled: fused transcript is EMPTY (vacuous test)\n"); + return 1; + } + + if (fused != tiled) { + std::fprintf(stderr, "tiled transcript differs:\n fused=[%s]\n tiled=[%s]\n", + fused.c_str(), tiled.c_str()); + return 1; + } + + std::printf("transcribe tiled==fused: [%s]\n", fused.c_str()); + return 0; +} From ed1bb9bfb4a6933716e74923c0ab748cc1ccf3ce Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Sun, 7 Jun 2026 08:43:06 +0000 Subject: [PATCH 6/7] feat(model): tile single-clip transcribe for long audio (CLI/path C-API) --- src/model.cpp | 71 ++++++++++++++++++++++++++------- tests/test_transcribe_tiled.cpp | 23 ++++++++++- 2 files changed, 79 insertions(+), 15 deletions(-) diff --git a/src/model.cpp b/src/model.cpp index 162505c..7cd9093 100644 --- a/src/model.cpp +++ b/src/model.cpp @@ -52,6 +52,12 @@ std::unique_ptr Model::load(const std::string& gguf_path) { return m; } +// Forward declarations: subsampling-tiling helpers are defined below (after the +// batched staging helpers) but used by the single-clip transcribe entry points. +static int safe_mel_window(const pk::ParakeetConfig& cfg); +static int subsampling_tile_for(const pk::ParakeetConfig& cfg, + const pk::ModelLoader& ml, int T_max); + int Model::resolve_prompt_index(const std::string& target_lang) const { const ParakeetConfig& cfg = loader_.config(); if (!cfg.prompt.present) return -1; @@ -124,7 +130,23 @@ std::string Model::transcribe_16k(const std::vector& pcm16k, Encoder encoder(loader_); std::vector enc_out; int d_model = 0, Tout = 0; - encoder.forward(feats, n_mels, T, enc_out, d_model, Tout); + // Long audio: the single-clip forward() would overflow ggml's 2^31 subsampling + // limit. Delegate to the tiled batched path with a 1-item batch (faithful, + // reuses forward_batch_tiled). Short audio keeps the fused single-clip path. + const int sub_tile = subsampling_tile_for(cfg, loader_, T); + if (sub_tile > 0) { + MelBatch mb1; + mb1.B = 1; mb1.n_mels = n_mels; mb1.T_max = T; mb1.valid_T = { T }; + mb1.data = feats; // feats is [n_mels,T] = the B=1 batch buffer + std::vector> eo; std::vector vT; + int dm1 = 0, To1 = 0; + encoder.forward_batch_tiled(mb1, eo, dm1, To1, vT, sub_tile); + enc_out = std::move(eo[0]); // channels-first [d_model, vT[0]] + d_model = dm1; + Tout = vT[0]; + } else { + encoder.forward(feats, n_mels, T, enc_out, d_model, Tout); + } // 2b. Prompt conditioning (multilingual nemotron): project the encoder // output with the selected language one-hot before decoding. No-op for @@ -153,6 +175,21 @@ static int safe_mel_window(const pk::ParakeetConfig& cfg) { return (int)win; } +// Decide whether to tile the subsampling stage for a mel of T_max frames. +// Returns the tile size (output frames per tile, >0) to tile, or 0 to use the +// fused path. Tiling engages for long audio (first subsampling conv would exceed +// ggml's 2^31 element limit) or when PARAKEET_SUBSAMPLING_TILE forces it (testing). +static int subsampling_tile_for(const pk::ParakeetConfig& cfg, + const pk::ModelLoader& ml, int T_max) { + if (const char* e = std::getenv("PARAKEET_SUBSAMPLING_TILE")) { + const int t = std::atoi(e); + if (t > 0) return t; + } + const int win = safe_mel_window(cfg); + if (T_max > win) return pk::Subsampling(ml).subsample_len(win); + return 0; +} + // Stage a batch of 16 kHz mono clips into a MelBatch: per-clip log-mel // (GpuMel on a non-CPU backend, else the byte-identical FFT MelFrontend), // zero-padded and stacked to the batch's longest clip (T_max). data layout is @@ -221,14 +258,9 @@ std::vector Model::transcribe_16k_batch( // Long audio: the first subsampling conv would exceed ggml's 2^31 element limit. // Tile the subsampling stage (faithful; see Encoder::forward_batch_tiled). An env // override forces the tiled path for testing on short clips. - int sub_tile = 0; // 0 => auto - if (const char* e = std::getenv("PARAKEET_SUBSAMPLING_TILE")) sub_tile = std::atoi(e); - const int win = safe_mel_window(cfg); + const int sub_tile = subsampling_tile_for(cfg, loader_, mb.T_max); if (sub_tile > 0) { encoder.forward_batch_tiled(mb, enc_outs, d_model, Tout, valid_Tout, sub_tile); - } else if (mb.T_max > win) { - const int tile_out = pk::Subsampling(loader_).subsample_len(win); - encoder.forward_batch_tiled(mb, enc_outs, d_model, Tout, valid_Tout, tile_out); } else { encoder.forward_batch(mb, enc_outs, d_model, Tout, valid_Tout); } @@ -350,7 +382,23 @@ Transcription Model::transcribe_16k_with_timestamps( Encoder encoder(loader_); std::vector enc_out; int d_model = 0, Tout = 0; - encoder.forward(feats, n_mels, T, enc_out, d_model, Tout); + // Long audio: the single-clip forward() would overflow ggml's 2^31 subsampling + // limit. Delegate to the tiled batched path with a 1-item batch (faithful, + // reuses forward_batch_tiled). Short audio keeps the fused single-clip path. + const int sub_tile = subsampling_tile_for(cfg, loader_, T); + if (sub_tile > 0) { + MelBatch mb1; + mb1.B = 1; mb1.n_mels = n_mels; mb1.T_max = T; mb1.valid_T = { T }; + mb1.data = feats; // feats is [n_mels,T] = the B=1 batch buffer + std::vector> eo; std::vector vT; + int dm1 = 0, To1 = 0; + encoder.forward_batch_tiled(mb1, eo, dm1, To1, vT, sub_tile); + enc_out = std::move(eo[0]); // channels-first [d_model, vT[0]] + d_model = dm1; + Tout = vT[0]; + } else { + encoder.forward(feats, n_mels, T, enc_out, d_model, Tout); + } // 2b. Prompt conditioning (nemotron): project before decode. No-op otherwise. maybe_apply_prompt(loader_, enc_out, d_model, Tout, prompt_index); @@ -381,14 +429,9 @@ std::vector Model::transcribe_16k_batch_with_timestamps( // Long audio: the first subsampling conv would exceed ggml's 2^31 element limit. // Tile the subsampling stage (faithful; see Encoder::forward_batch_tiled). An env // override forces the tiled path for testing on short clips. - int sub_tile = 0; // 0 => auto - if (const char* e = std::getenv("PARAKEET_SUBSAMPLING_TILE")) sub_tile = std::atoi(e); - const int win = safe_mel_window(cfg); + const int sub_tile = subsampling_tile_for(cfg, loader_, mb.T_max); if (sub_tile > 0) { encoder.forward_batch_tiled(mb, enc_outs, d_model, Tout, valid_Tout, sub_tile); - } else if (mb.T_max > win) { - const int tile_out = pk::Subsampling(loader_).subsample_len(win); - encoder.forward_batch_tiled(mb, enc_outs, d_model, Tout, valid_Tout, tile_out); } else { encoder.forward_batch(mb, enc_outs, d_model, Tout, valid_Tout); } diff --git a/tests/test_transcribe_tiled.cpp b/tests/test_transcribe_tiled.cpp index 94101e8..b0c8f5d 100644 --- a/tests/test_transcribe_tiled.cpp +++ b/tests/test_transcribe_tiled.cpp @@ -42,16 +42,23 @@ int main() { std::vector> batch{ audio.samples }; return model->transcribe_pcm_batch(batch, 16000).at(0); }; + // Single-clip path (CLI / path-based C-API route through transcribe_path -> + // transcribe_16k). Same clip, same fused-vs-tiled comparison. + auto run_single = [&]() -> std::string { + return model->transcribe_path("tests/fixtures/speech.wav"); + }; - std::string fused, tiled; + std::string fused, tiled, sfused, stiled; try { // 1) fused (env unset -> auto; short clip below threshold stays fused) unsetenv("PARAKEET_SUBSAMPLING_TILE"); fused = run(); + sfused = run_single(); // 2) forced tiled path (small tile -> many tiles even on a short clip) setenv("PARAKEET_SUBSAMPLING_TILE", "17", 1); tiled = run(); + stiled = run_single(); unsetenv("PARAKEET_SUBSAMPLING_TILE"); } catch (const std::exception& e) { std::fprintf(stderr, "test_transcribe_tiled: transcribe threw: %s\n", e.what()); @@ -70,6 +77,20 @@ int main() { return 1; } + if (sfused.empty()) { + std::fprintf(stderr, + "test_transcribe_tiled: single-clip fused transcript is EMPTY (vacuous)\n"); + return 1; + } + + if (sfused != stiled) { + std::fprintf(stderr, + "single-clip tiled transcript differs:\n fused=[%s]\n tiled=[%s]\n", + sfused.c_str(), stiled.c_str()); + return 1; + } + std::printf("transcribe tiled==fused: [%s]\n", fused.c_str()); + std::printf("single-clip tiled==fused: [%s]\n", sfused.c_str()); return 0; } From 236c6889b6e615e35f2f3a02184a8a736ff8024e Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Sun, 7 Jun 2026 21:06:12 +0000 Subject: [PATCH 7/7] fix(ggml-cuda): grid-stride pad kernel for dims > 65535 (long-audio attention) --- .../0004-cuda-pad-grid-stride.patch | 110 ++++++++++++++++++ 1 file changed, 110 insertions(+) create mode 100644 third_party/ggml-patches/0004-cuda-pad-grid-stride.patch diff --git a/third_party/ggml-patches/0004-cuda-pad-grid-stride.patch b/third_party/ggml-patches/0004-cuda-pad-grid-stride.patch new file mode 100644 index 0000000..260fdd0 --- /dev/null +++ b/third_party/ggml-patches/0004-cuda-pad-grid-stride.patch @@ -0,0 +1,110 @@ +diff --git a/src/ggml-cuda/pad.cu b/src/ggml-cuda/pad.cu +index 31cd00f7..2aab7a82 100644 +--- a/src/ggml-cuda/pad.cu ++++ b/src/ggml-cuda/pad.cu +@@ -16,48 +16,56 @@ static __global__ void pad_f32(const float * src, size_t s00, size_t s01, size_t + // blockIdx.y: i1 + // blockIDx.x: i0 / CUDA_PAD_BLOCK_SIZE + // gridDim.y: ne1 +- int i0 = threadIdx.x + blockIdx.x * blockDim.x; +- int i1 = blockIdx.y; +- int i2 = blockIdx.z % ne2; +- int i3 = blockIdx.z / ne2; +- +- if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) { ++ const int i0 = threadIdx.x + blockIdx.x * blockDim.x; ++ if (i0 >= ne0) { + return; + } + +- const int64_t dst_idx = i3 * (ne0 * ne1 * ne2) + i2 * (ne0 * ne1) + i1 * ne0 + i0; +- +- if (!circular) { +- if ((i0 >= lp0 && i0 < ne0 - rp0) && (i1 >= lp1 && i1 < ne1 - rp1) && (i2 >= lp2 && i2 < ne2 - rp2) && +- (i3 >= lp3 && i3 < ne3 - rp3)) { +- const int64_t i00 = i0 - lp0; +- const int64_t i01 = i1 - lp1; +- const int64_t i02 = i2 - lp2; +- const int64_t i03 = i3 - lp3; +- +- const int64_t src_idx = i03 * s03 + i02 * s02 + i01 * s01 + i00 * s00; +- +- dst[dst_idx] = src[src_idx]; +- } else { +- dst[dst_idx] = 0.0f; ++ // Grid-stride over ne1 (gridDim.y) and the packed ne2*ne3 (gridDim.z) so this ++ // kernel handles dimensions larger than CUDA's 65535 cap on gridDim.y/z. When ++ // ne1 <= 65535 and ne2*ne3 <= 65535 (the common case) gridDim.y == ne1 and ++ // gridDim.z == ne2*ne3, so each loop runs exactly once and the work is ++ // identical to the unstrided version (no perf impact). ++ for (int i1 = blockIdx.y; i1 < ne1; i1 += gridDim.y) { ++ for (int iz = blockIdx.z; iz < ne2 * ne3; iz += gridDim.z) { ++ const int i2 = iz % ne2; ++ const int i3 = iz / ne2; ++ ++ const int64_t dst_idx = i3 * (ne0 * ne1 * ne2) + i2 * (ne0 * ne1) + i1 * ne0 + i0; ++ ++ if (!circular) { ++ if ((i0 >= lp0 && i0 < ne0 - rp0) && (i1 >= lp1 && i1 < ne1 - rp1) && (i2 >= lp2 && i2 < ne2 - rp2) && ++ (i3 >= lp3 && i3 < ne3 - rp3)) { ++ const int64_t i00 = i0 - lp0; ++ const int64_t i01 = i1 - lp1; ++ const int64_t i02 = i2 - lp2; ++ const int64_t i03 = i3 - lp3; ++ ++ const int64_t src_idx = i03 * s03 + i02 * s02 + i01 * s01 + i00 * s00; ++ ++ dst[dst_idx] = src[src_idx]; ++ } else { ++ dst[dst_idx] = 0.0f; ++ } ++ } ++ // circular means on a torus, so x and y wrap around ++ else { ++ const int64_t ne00 = ne0 - lp0 - rp0; ++ const int64_t ne01 = ne1 - lp1 - rp1; ++ const int64_t ne02 = ne2 - lp2 - rp2; ++ const int64_t ne03 = ne3 - lp3 - rp3; ++ ++ const int64_t i00 = wrap_around(i0 - lp0, ne00); ++ const int64_t i01 = wrap_around(i1 - lp1, ne01); ++ const int64_t i02 = wrap_around(i2 - lp2, ne02); ++ const int64_t i03 = wrap_around(i3 - lp3, ne03); ++ ++ const int64_t src_idx = i03 * s03 + i02 * s02 + i01 * s01 + i00 * s00; ++ ++ dst[dst_idx] = src[src_idx]; ++ } + } + } +- // circular means on a torus, so x and y wrap around +- else { +- const int64_t ne00 = ne0 - lp0 - rp0; +- const int64_t ne01 = ne1 - lp1 - rp1; +- const int64_t ne02 = ne2 - lp2 - rp2; +- const int64_t ne03 = ne3 - lp3 - rp3; +- +- const int64_t i00 = wrap_around(i0 - lp0, ne00); +- const int64_t i01 = wrap_around(i1 - lp1, ne01); +- const int64_t i02 = wrap_around(i2 - lp2, ne02); +- const int64_t i03 = wrap_around(i3 - lp3, ne03); +- +- const int64_t src_idx = i03 * s03 + i02 * s02 + i01 * s01 + i00 * s00; +- +- dst[dst_idx] = src[src_idx]; +- } + } + + +@@ -67,7 +75,10 @@ static void pad_f32_cuda(const float * src, size_t s00, size_t s01, size_t s02, + const int ne0, const int ne1, const int ne2, const int ne3, + const bool circular, cudaStream_t stream) { + int num_blocks = (ne0 + CUDA_PAD_BLOCK_SIZE - 1) / CUDA_PAD_BLOCK_SIZE; +- dim3 gridDim(num_blocks, ne1, ne2 * ne3); ++ // gridDim.y/z are capped at 65535 by CUDA; the kernel grid-strides past that. ++ const int gy = ne1 > 65535 ? 65535 : ne1; ++ const int gz = (ne2 * ne3) > 65535 ? 65535 : (ne2 * ne3); ++ dim3 gridDim(num_blocks, gy, gz); + pad_f32<<>>(src, s00, s01, s02, s03, dst, + lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, + ne0, ne1, ne2, ne3, circular);