Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions src/encoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,4 +198,88 @@ void Encoder::forward_batch(const MelBatch& mels,
}
}

void Encoder::run_post_subsampling_batch(const std::vector<float>& x0_host,
int Tp, int B, const std::vector<int>& vout,
std::vector<std::vector<float>>& enc_outs, int& d_model, int& Tout,
std::vector<int>& valid_Tout) const {
// Mirror of forward_batch's post-subsampling body (steps 2-4 + per-item
// split), but the subsampled features arrive pre-computed in x0_host and are
// injected as a graph input instead of built via sub.build_graph_batched.
GraphInputPool pool;
std::vector<float> flat; // receives [d_model, Tp, B] (ne0=d_model fastest)
bool ok = pk::run_graph(/*mem_bytes*/0, /*n_threads*/0,
[&](ggml_context* ctx) -> ggml_tensor* {
// ---- 1. Inject pre-subsampled features x [d_model, Tp, B]. ----
int64_t x_ne[3] = { d_model_, Tp, B };
ggml_tensor* x = pk::graph_input_tensor(ctx, GGML_TYPE_F32, 3, x_ne,
x0_host.data(), x0_host.size() * sizeof(float));

// ---- 2. xscaling (gated; off for this model). ----
if (xscaling_) x = ggml_scale(ctx, x, std::sqrt((float)d_model_));

// ---- 3. Positional encoding (local for long audio; see B=1 path). ----
const int att_w = local_attn_window(Tp);
const bool local = att_w > 0;
const int pos_len = local ? (2 * att_w + 1) : (2 * Tp - 1);
std::vector<float>& pe_host = pool.alloc_f32();
if (local) local_rel_pos_encoding(att_w, att_w, d_model_, pe_host);
else rel_pos_encoding(Tp, d_model_, pe_host); // [pos_len, d_model]
int64_t pe_ne[2] = {d_model_, pos_len};
ggml_tensor* pe = pk::graph_input_tensor(ctx, GGML_TYPE_F32, 2, pe_ne,
pe_host.data(), pe_host.size() * sizeof(float));

// ---- 4. Conformer layer stack (all in-graph, shared pe). ----
for (int i = 0; i < n_layers_; ++i) {
ConformerLayer layer(ml_, i);
x = layer.build_graph_batched(ctx, x, Tp, B, pe, pos_len, vout, pool,
local ? att_w : -1, local ? att_w : -1);
}
return x; // [d_model, Tp, B]
}, flat);

assert(ok && "tiled post-subsampling encoder graph failed");
(void)ok;

d_model = d_model_;
Tout = Tp;
valid_Tout = vout;
enc_outs.assign(B, std::vector<float>());
for (int b = 0; b < B; ++b) {
const int tv = vout[b];
enc_outs[b].resize((size_t)d_model_ * tv);
for (int t = 0; t < tv; ++t)
for (int c = 0; c < d_model_; ++c)
enc_outs[b][(size_t)c * tv + t] =
flat[(((size_t)b * Tp) + t) * d_model_ + c];
}
}

void Encoder::forward_batch_tiled(const MelBatch& mels,
std::vector<std::vector<float>>& enc_outs, int& d_model, int& Tout,
std::vector<int>& valid_Tout, int tile_out_frames) const {
Subsampling sub(ml_);
const int B = mels.B;
std::vector<std::vector<float>> sub_b(B); // each [Tp_b, d_model] (t*dm+c)
std::vector<int> Tp_b(B), vout(B);
int Tp_max = 0, dm = 0;
for (int b = 0; b < B; ++b) {
// slice item b's real mel [n_mels, valid_T[b]] out of the padded batch buffer
std::vector<float> mel((size_t)mels.n_mels * mels.valid_T[b]);
for (int m = 0; m < mels.n_mels; ++m)
for (int t = 0; t < mels.valid_T[b]; ++t)
mel[(size_t)m*mels.valid_T[b]+t] =
mels.data[((size_t)b*mels.n_mels+m)*mels.T_max + t];
int Tp=0, vl=0;
sub.forward_tiled(mel, mels.n_mels, mels.valid_T[b], tile_out_frames,
sub_b[b], Tp, dm, vl);
Tp_b[b]=Tp; vout[b]=vl; if (Tp>Tp_max) Tp_max=Tp;
}
// assemble x0 [d_model, Tp_max, B], zero-padded per item.
// sub_b[b] is [Tp_b, d_model] (t*dm+c); x0 wants (c,t,b) at (b*Tp_max+t)*dm+c.
std::vector<float> x0((size_t)dm*Tp_max*B, 0.0f);
for (int b=0;b<B;++b) for (int t=0;t<Tp_b[b];++t) for (int c=0;c<dm;++c)
x0[((size_t)b*Tp_max+t)*dm+c] = sub_b[b][(size_t)t*dm+c];
run_post_subsampling_batch(x0, Tp_max, B, vout, enc_outs, d_model, Tout, valid_Tout);
}

} // namespace pk
24 changes: 24 additions & 0 deletions src/encoder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,18 @@ class Encoder {
int& d_model, int& Tout,
std::vector<int>& valid_Tout) const;

// Long-audio tiled variant of forward_batch: subsampling is done per item via
// Subsampling::forward_tiled (which bounds intermediate tensor size so the
// >2^31-element conv tensors that crash long clips never materialise), then
// the pre-subsampled features are fed into the existing post-subsampling graph
// (xscaling + positional encoding + conformer stack). Same output contract as
// forward_batch. `tile_out_frames` = subsampled output frames per tile.
void forward_batch_tiled(const MelBatch& mels,
std::vector<std::vector<float>>& enc_outs,
int& d_model, int& Tout,
std::vector<int>& valid_Tout,
int tile_out_frames) const;

// Same as forward(), but also captures the per-layer outputs at indices
// `capture_layers` (each row-major [T', d_model]) into `layer_outs` (parallel
// to capture_layers). Used by the parity test to localize divergence.
Expand All @@ -54,6 +66,18 @@ class Encoder {
std::vector<std::vector<float>>& layer_outs) const;

private:
// Mirrors forward_batch's post-subsampling body (xscaling + positional
// encoding + conformer stack + per-item channels-first split) for the tiled
// path, but takes the already-subsampled features `x0_host` ([d_model, Tp, B]
// in ggml order, element (c,t,b) at ((size_t)b*Tp+t)*d_model+c) injected as a
// graph input instead of building them from sub.build_graph_batched. `vout`
// holds the per-item valid output-frame counts. `x0_host` must outlive the
// call (run_graph is synchronous), which it does as a by-reference parameter.
void run_post_subsampling_batch(const std::vector<float>& x0_host,
int Tp, int B, const std::vector<int>& vout,
std::vector<std::vector<float>>& enc_outs, int& d_model, int& Tout,
std::vector<int>& valid_Tout) const;

const ModelLoader& ml_;
int d_model_;
int n_layers_;
Expand Down
94 changes: 90 additions & 4 deletions src/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "mel.hpp"
#include "mel_gpu.hpp"
#include "encoder.hpp"
#include "subsampling.hpp"
#include "ctc_decoder.hpp"
#include "search.hpp"
#include "tokenizer.hpp"
Expand All @@ -19,6 +20,7 @@
#include "ggml_graph.hpp"

#include <algorithm>
#include <cstdlib>
#include <stdexcept>
#include <vector>

Expand Down Expand Up @@ -50,6 +52,12 @@ std::unique_ptr<Model> Model::load(const std::string& gguf_path) {
return m;
}

// Forward declarations: subsampling-tiling helpers are defined below (after the
// batched staging helpers) but used by the single-clip transcribe entry points.
static int safe_mel_window(const pk::ParakeetConfig& cfg);
static int subsampling_tile_for(const pk::ParakeetConfig& cfg,
const pk::ModelLoader& ml, int T_max);

int Model::resolve_prompt_index(const std::string& target_lang) const {
const ParakeetConfig& cfg = loader_.config();
if (!cfg.prompt.present) return -1;
Expand Down Expand Up @@ -122,7 +130,23 @@ std::string Model::transcribe_16k(const std::vector<float>& pcm16k,
Encoder encoder(loader_);
std::vector<float> enc_out;
int d_model = 0, Tout = 0;
encoder.forward(feats, n_mels, T, enc_out, d_model, Tout);
// Long audio: the single-clip forward() would overflow ggml's 2^31 subsampling
// limit. Delegate to the tiled batched path with a 1-item batch (faithful,
// reuses forward_batch_tiled). Short audio keeps the fused single-clip path.
const int sub_tile = subsampling_tile_for(cfg, loader_, T);
if (sub_tile > 0) {
MelBatch mb1;
mb1.B = 1; mb1.n_mels = n_mels; mb1.T_max = T; mb1.valid_T = { T };
mb1.data = feats; // feats is [n_mels,T] = the B=1 batch buffer
std::vector<std::vector<float>> eo; std::vector<int> vT;
int dm1 = 0, To1 = 0;
encoder.forward_batch_tiled(mb1, eo, dm1, To1, vT, sub_tile);
enc_out = std::move(eo[0]); // channels-first [d_model, vT[0]]
d_model = dm1;
Tout = vT[0];
} else {
encoder.forward(feats, n_mels, T, enc_out, d_model, Tout);
}

// 2b. Prompt conditioning (multilingual nemotron): project the encoder
// output with the selected language one-hot before decoding. No-op for
Expand All @@ -136,6 +160,36 @@ std::string Model::transcribe_16k(const std::vector<float>& pcm16k,
return decode_enc_out(loader_, enc_out, d_model, Tout, use_tdt);
}

// Max mel frames per encoder pass before the first subsampling conv output
// (n_mels/2 * T/2 * conv_channels) approaches INT_MAX. ggml's CUDA unary (relu)
// kernel indexes elements with int32, so a tensor > 2^31 elements crashes
// ("invalid configuration argument"). Bound the per-pass first-conv tensor to a
// safe 1.5e9 elements: (n_mels/2)*(T/2)*C < 1.5e9 => T < 2*1.5e9 / ((n_mels/2)*C).
static int safe_mel_window(const pk::ParakeetConfig& cfg) {
const long long per_t = (long long)((int)cfg.n_mels / 2) * (int)cfg.subsampling_conv_channels; // first-conv elems per output mel-row pair
if (per_t <= 0) return 1 << 30; // unknown config -> effectively no cap
const long long bound = 1500000000LL; // 1.5e9 elements, safe margin under 2^31
long long win = (2 * bound) / per_t; // T such that (n_mels/2)*(T/2)*C ~= bound
if (win < 16384) win = 16384; // floor for tiny/odd configs
if (win > (1LL<<30)) win = (1LL<<30);
return (int)win;
}

// Decide whether to tile the subsampling stage for a mel of T_max frames.
// Returns the tile size (output frames per tile, >0) to tile, or 0 to use the
// fused path. Tiling engages for long audio (first subsampling conv would exceed
// ggml's 2^31 element limit) or when PARAKEET_SUBSAMPLING_TILE forces it (testing).
static int subsampling_tile_for(const pk::ParakeetConfig& cfg,
const pk::ModelLoader& ml, int T_max) {
if (const char* e = std::getenv("PARAKEET_SUBSAMPLING_TILE")) {
const int t = std::atoi(e);
if (t > 0) return t;
}
const int win = safe_mel_window(cfg);
if (T_max > win) return pk::Subsampling(ml).subsample_len(win);
return 0;
}

// Stage a batch of 16 kHz mono clips into a MelBatch: per-clip log-mel
// (GpuMel on a non-CPU backend, else the byte-identical FFT MelFrontend),
// zero-padded and stacked to the batch's longest clip (T_max). data layout is
Expand Down Expand Up @@ -201,7 +255,15 @@ std::vector<std::string> Model::transcribe_16k_batch(
Encoder encoder(loader_);
std::vector<std::vector<float>> enc_outs; int d_model = 0, Tout = 0;
std::vector<int> valid_Tout;
encoder.forward_batch(mb, enc_outs, d_model, Tout, valid_Tout);
// Long audio: the first subsampling conv would exceed ggml's 2^31 element limit.
// Tile the subsampling stage (faithful; see Encoder::forward_batch_tiled). An env
// override forces the tiled path for testing on short clips.
const int sub_tile = subsampling_tile_for(cfg, loader_, mb.T_max);
if (sub_tile > 0) {
encoder.forward_batch_tiled(mb, enc_outs, d_model, Tout, valid_Tout, sub_tile);
} else {
encoder.forward_batch(mb, enc_outs, d_model, Tout, valid_Tout);
}

// 2b. Prompt conditioning per item (one language for the whole batch). No-op
// for non-prompt models.
Expand Down Expand Up @@ -320,7 +382,23 @@ Transcription Model::transcribe_16k_with_timestamps(
Encoder encoder(loader_);
std::vector<float> enc_out;
int d_model = 0, Tout = 0;
encoder.forward(feats, n_mels, T, enc_out, d_model, Tout);
// Long audio: the single-clip forward() would overflow ggml's 2^31 subsampling
// limit. Delegate to the tiled batched path with a 1-item batch (faithful,
// reuses forward_batch_tiled). Short audio keeps the fused single-clip path.
const int sub_tile = subsampling_tile_for(cfg, loader_, T);
if (sub_tile > 0) {
MelBatch mb1;
mb1.B = 1; mb1.n_mels = n_mels; mb1.T_max = T; mb1.valid_T = { T };
mb1.data = feats; // feats is [n_mels,T] = the B=1 batch buffer
std::vector<std::vector<float>> eo; std::vector<int> vT;
int dm1 = 0, To1 = 0;
encoder.forward_batch_tiled(mb1, eo, dm1, To1, vT, sub_tile);
enc_out = std::move(eo[0]); // channels-first [d_model, vT[0]]
d_model = dm1;
Tout = vT[0];
} else {
encoder.forward(feats, n_mels, T, enc_out, d_model, Tout);
}

// 2b. Prompt conditioning (nemotron): project before decode. No-op otherwise.
maybe_apply_prompt(loader_, enc_out, d_model, Tout, prompt_index);
Expand Down Expand Up @@ -348,7 +426,15 @@ std::vector<Transcription> Model::transcribe_16k_batch_with_timestamps(
Encoder encoder(loader_);
std::vector<std::vector<float>> enc_outs; int d_model = 0, Tout = 0;
std::vector<int> valid_Tout;
encoder.forward_batch(mb, enc_outs, d_model, Tout, valid_Tout);
// Long audio: the first subsampling conv would exceed ggml's 2^31 element limit.
// Tile the subsampling stage (faithful; see Encoder::forward_batch_tiled). An env
// override forces the tiled path for testing on short clips.
const int sub_tile = subsampling_tile_for(cfg, loader_, mb.T_max);
if (sub_tile > 0) {
encoder.forward_batch_tiled(mb, enc_outs, d_model, Tout, valid_Tout, sub_tile);
} else {
encoder.forward_batch(mb, enc_outs, d_model, Tout, valid_Tout);
}

// Prompt conditioning per item (one language for the whole batch). No-op
// for non-prompt models.
Expand Down
73 changes: 73 additions & 0 deletions src/subsampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,17 @@ int Subsampling::valid_out_len(int T, int in_valid_frames) const {
return valid;
}

int Subsampling::subsample_len(int T) const {
// Spatial output length after the three stride-2, k=3 conv stages, using
// ggml conv2d's OH = floor((in + 2p - k)/s) + 1. Non-causal uses symmetric
// pad p=1 (all_paddings=2); causal uses all_paddings=3. This mirrors the
// valid_out_len recurrence but tracks the full (padded) spatial extent.
const int all_paddings = causal_ ? 3 : 2;
int x = T;
for (int s = 0; s < 3; ++s) x = (x + all_paddings - 3) / 2 + 1;
return x;
}

ggml_tensor* Subsampling::build_graph_batched(ggml_context* ctx,
const float* mel,
int n_mels, int T, int B, GraphInputPool& pool,
Expand Down Expand Up @@ -418,4 +429,66 @@ void Subsampling::forward(const std::vector<float>& mel, int n_mels, int T,
valid_len = valid;
}

void Subsampling::forward_tiled(const std::vector<float>& mel, int n_mels, int T,
int tile_out_frames, std::vector<float>& out,
int& Tout, int& d_model, int& valid_len) const {
const int Tp = subsample_len(T);
d_model = d_model_;
Tout = Tp;
const int vo = valid_out_len(T, -1);
valid_len = (vo > Tp) ? Tp : vo;

// TODO: causal tiling needs the causal phase mapping; offline causal long-audio
// is not a current target. Fall back to the single, untiled graph (== forward()).
if (causal_ || tile_out_frames <= 0) {
int t_unused = 0, dm_unused = 0, vl_unused = 0;
forward(mel, n_mels, T, out, t_unused, dm_unused, vl_unused, -1);
return;
}

// Non-causal symmetric-pad tiling. Receptive field is +-7 mel frames; output
// frame o has mel center 8*o. Window start ws is a multiple of 8, so the
// window-output frame j maps to global o = j + ws/8. A generous halo H=64 mel
// frames (>> RF) keeps every emitted frame's RF inside the fed window (or, at
// true utterance edges, inside build_graph's own pad which equals the boundary),
// so every kept frame equals the full-utterance result.
const int H = 64; // multiple of 8, >> receptive field (7)
out.assign((size_t)Tp * d_model_, 0.0f);

for (int os = 0; os < Tp; os += tile_out_frames) {
const int oe = (os + tile_out_frames < Tp) ? (os + tile_out_frames) : Tp;
int ws = 8 * os - H; if (ws < 0) ws = 0; // multiple of 8 (clamp keeps it)
int we = 8 * oe + H; if (we > T) we = T;
const int Lw = we - ws;

// Slice window mel, feat-major [n_mels, Lw].
std::vector<float> win((size_t)n_mels * Lw);
for (int f = 0; f < n_mels; ++f)
for (int t = ws; t < we; ++t)
win[(size_t)f * Lw + (t - ws)] = mel[(size_t)f * T + t];

// Run the single-item graph on the window. The window is all-real, so pass
// in_valid_frames = Lw (no trailing mask inside the tile).
std::vector<float> win_out;
int Tpw = 0, valid_w = 0;
GraphInputPool pool;
bool ok = pk::run_graph(/*mem_bytes*/0, /*n_threads*/4,
[&](ggml_context* ctx) -> ggml_tensor* {
return build_graph(ctx, win, n_mels, Lw, pool, Tpw, valid_w, Lw);
}, win_out);
assert(ok && "subsampling tile graph failed");
(void)ok;

// Window-output frame j has mel center ws+8*j -> global o = j + ws/8.
const int j0 = ws / 8; // exact: ws is a multiple of 8
for (int o = os; o < oe; ++o) {
const int j = o - j0;
assert(j >= 0 && j < Tpw && "tile frame out of window range");
std::memcpy(&out[(size_t)o * d_model_],
&win_out[(size_t)j * d_model_],
(size_t)d_model_ * sizeof(float));
}
}
}

} // namespace pk
13 changes: 13 additions & 0 deletions src/subsampling.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,25 @@ class Subsampling {
std::vector<float>& out, int& Tout, int& d_model,
int& valid_len, int in_valid_frames) const;

// Tiled subsampling for long audio: result is identical to forward() within the
// valid region, but no intermediate tensor exceeds a tile_out_frames-bounded size.
// tile_out_frames = number of OUTPUT (subsampled) frames per tile. Non-causal only;
// causal falls back to the single-graph path.
void forward_tiled(const std::vector<float>& mel, int n_mels, int T,
int tile_out_frames, std::vector<float>& out,
int& Tout, int& d_model, int& valid_len) const;

// Number of valid (non-pad) output frames for an input of T mel frames,
// applying the same per-stage `calc_length` reductions NeMo uses. Pure
// arithmetic, no graph; exposed so the encoder can derive valid_len.
// `in_valid_frames` (>=0) overrides the offline T-1 entry valid length
// with an explicit count (streaming); <0 keeps the offline convention.
int valid_out_len(int T, int in_valid_frames = -1) const;

// Number of (subsampled) output spatial frames an input of T mel frames
// produces, applying ggml conv2d's per-stage OH = floor((in+2p-k)/s)+1 for
// all three stride-2, k=3 stages. Used for subsampling-tile bookkeeping.
int subsample_len(int T) const;
private:
const ModelLoader& ml_;
int conv_channels_; // C
Expand Down
Loading
Loading