diff --git a/CMakeLists.txt b/CMakeLists.txt index a8afde8..e2ff4cd 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -74,6 +74,7 @@ set(PARAKEET_SRC src/ctc_decoder.cpp src/prediction.cpp src/joint.cpp + src/prompt_kernel.cpp src/tdt.cpp src/rnnt.cpp src/transducer_batch.cpp diff --git a/README.md b/README.md index aac587f..027695e 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ parakeet.cpp is a C++17 inference port of NVIDIA's [NeMo](https://github.com/NVIDIA-NeMo/NeMo) Parakeet speech-recognition models, built on [ggml](https://github.com/ggml-org/ggml). It gives you fast, dependency-light automatic speech recognition on CPU (and on GPU through ggml's backends), with no Python runtime needed at inference time. -It covers all the offline Parakeet families (CTC, RNNT, TDT, and hybrid TDT-CTC, in 0.6B/1.1B/110M sizes, English plus multilingual v3), each validated at WER 0 against NeMo on every published checkpoint. It also does **cache-aware streaming with end-of-utterance (EOU) detection** for `parakeet_realtime_eou_120m-v1`, where the streaming transcript matches NeMo's cache-aware streaming byte for byte. The full coverage matrix lives in `docs/parity.md`. +It covers all the offline Parakeet families (CTC, RNNT, TDT, and hybrid TDT-CTC, in 0.6B/1.1B/110M sizes, English plus multilingual v3), each validated at WER 0 against NeMo on every published checkpoint. It also does **cache-aware streaming with end-of-utterance (EOU) detection** for `parakeet_realtime_eou_120m-v1`, where the streaming transcript matches NeMo's cache-aware streaming byte for byte. And it supports the **multilingual, prompt-conditioned streaming model** `nvidia/nemotron-3.5-asr-streaming-0.6b` (40+ locales): pass a target language with `--lang ` (default `auto`) and both the offline and the cache-aware streaming transcripts match NeMo per language at WER 0. The full coverage matrix lives in `docs/parity.md`. It's faster than NeMo's PyTorch runtime on both CPU and GPU, with byte-identical transcripts. The full numbers, methodology, and all the plots are in [benchmarks/BENCHMARK.md](benchmarks/BENCHMARK.md). diff --git a/benchmarks/BENCHMARK.md b/benchmarks/BENCHMARK.md index ee0b36e..124ff0e 100644 --- a/benchmarks/BENCHMARK.md +++ b/benchmarks/BENCHMARK.md @@ -69,6 +69,24 @@ Versus whisper.cpp turbo, same accuracy (WER 1.6% on this clip) and far less com > **Speedup** = ours RTFx / NeMo RTFx (>1 = faster than NeMo). f32 reproduces NeMo's transcript (agreement ≈ 0). +## Nemotron (streaming, multilingual, prompt-conditioned) + +`nemotron-3.5-asr-streaming-0.6b` is a FastConformer transducer with a per-language prompt: a one-hot language vector drives a PromptKernel between the encoder and the RNN-T decoder. It runs both offline and cache-aware streaming. Because it loads from a local `.nemo` plus its GGUF, it sits outside the LibriSpeech pipeline above and is measured on its own here. + +One clip (`speech.wav`, 7.43 s), language prompt `en`, 8 threads, median of 7 passes after one warmup. ours is `parakeet-cli bench --decoder tdt --lang en` (load once, time transcribe only); NeMo runs the same prompt forward (preprocessor, encoder, PromptKernel, RNN-T greedy) on PyTorch CPU. RTFx = audio seconds per second of compute; higher is faster. + +Host: AMD Ryzen 9 9950X3D (20 cores), CPU-only. NeMo 2.8.0rc0. + +| Engine | RTFx | Speedup vs NeMo | Agreement WER vs NeMo | +|---|---|---|---| +| NeMo (PyTorch CPU) | 12.2 | 1.00× | reference | +| parakeet.cpp f32 | 29.4 | 2.40× | 0.0000% | +| parakeet.cpp q8_0 | 30.8 | 2.52× | 0.0000% | + +Accuracy is **WER 0 vs NeMo**: the f32 and q8_0 transcripts are byte-identical to NeMo's on the timed runs (agreement WER 0.0000%), so the speed numbers compare equal work. parakeet.cpp is **2.40× faster than NeMo at f32** and **2.52× at q8_0**. + +Streaming path (f32, cache-aware): compute RTFx **3.80** (median wall 2503 ms over the 7.43 s clip, one-time model load of 548 ms subtracted). Streaming is latency-oriented: it runs many small chunked forward passes rather than one offline pass, so its RTFx sits well below the offline number by design while staying several times real time. The streaming transcript matches the offline and NeMo transcripts. + ## Quantization — size / speed / accuracy tradeoff Averaged over all models (LibriSpeech). Size is the mean GGUF size as a fraction of the f32 GGUF. diff --git a/benchmarks/results/nemotron/bench.json b/benchmarks/results/nemotron/bench.json new file mode 100644 index 0000000..378974d --- /dev/null +++ b/benchmarks/results/nemotron/bench.json @@ -0,0 +1,40 @@ +{ + "clip": "speech.wav", + "audio_sec": 7.435, + "lang": "en", + "threads": 8, + "passes": 7, + "nemo": { + "rtfx": 12.228159482130682, + "median_proc_s": 0.6080228190403432, + "text": "Well, I don't wish to see it any more, observed Phoebe, turning away her eyes. It is certainly very like the old portrait. ", + "version": "2.8.0rc0", + "load_s": 25.837787624972407 + }, + "ours": { + "f32": { + "rtfx": 29.353937020308894, + "speedup": 2.400519641831998, + "median_proc_s": 0.253288, + "agreement_wer": 0.0, + "text": "Well, I don't wish to see it any more, observed Phoebe, turning away her eyes. It is certainly very like the old portrait. ", + "load_ms": 547.68 + }, + "q8_0": { + "rtfx": 30.819419343071743, + "speedup": 2.5203645232227254, + "median_proc_s": 0.241244, + "agreement_wer": 0.0, + "text": "Well, I don't wish to see it any more, observed Phoebe, turning away her eyes. It is certainly very like the old portrait. ", + "load_ms": 247.568 + } + }, + "stream": { + "dtype": "f32", + "compute_rtfx": 3.801518370630737, + "wall_rtfx": 2.96986895150122, + "median_wall_s": 2.5034774669911712, + "compute_s": 1.9557974669911713, + "load_s": 0.54768 + } +} \ No newline at end of file diff --git a/docs/conversion.md b/docs/conversion.md index f13a698..742185d 100644 --- a/docs/conversion.md +++ b/docs/conversion.md @@ -35,6 +35,16 @@ offline checkpoints omit them entirely (so they keep converting byte-identically and the C++ loader falls back to offline-safe defaults (`att_context [-1,-1]`, style `regular`, causal flags `false`, no streaming block). +The `parakeet.prompt.*` keys and `parakeet.encoder.use_bias` / +`parakeet.encoder.att_context_presets` are emitted **only for the +prompt-conditioned multilingual model** `nvidia/nemotron-3.5-asr-streaming-0.6b` +(`model_defaults.initialize_prompt_feature == true`). Every other checkpoint omits +them and the loader defaults `prompt.present=false` (the prompt stage is skipped) +and `use_bias=true`. nemotron stores `att_context_size` as a **list of presets** +(`[[56,3],[56,0],[56,6],[56,13]]`, the first is the default 320 ms preset) rather +than a single `[left,right]`, so the converter records all of them in +`att_context_presets` and uses the first pair for the scalar left/right keys. + | Key | GGUF type | Meaning | Source | 110m value | | --- | --- | --- | --- | --- | | `parakeet.arch` | STRING | One of `ctc` / `rnnt` / `tdt` / `hybrid_rnnt_ctc` / `hybrid_tdt_ctc` | arch detection (below) | `hybrid_tdt_ctc` | @@ -61,6 +71,13 @@ style `regular`, causal flags `false`, no streaming block). | `parakeet.streaming.valid_out_len` | INT32 | Valid encoder frames per step. **Streaming only.** | `encoder.streaming_cfg.valid_out_len` | (n/a) | | `parakeet.streaming.pre_encode_cache_size` | ARRAY\ | Pre-encode (mel) cache frames `[first, rest]`. **Streaming only.** | `encoder.streaming_cfg.pre_encode_cache_size` | (n/a) | | `parakeet.streaming.drop_extra_pre_encoded` | INT32 | Steps dropped after pre-encode. **Streaming only.** | `encoder.streaming_cfg.drop_extra_pre_encoded` | (n/a) | +| `parakeet.encoder.use_bias` | BOOL | Whether the encoder linear layers carry a bias. `false` for nemotron (`use_bias=false`); the loader reads biases optionally and tolerates their absence. Defaults `true`. | `cfg.encoder.use_bias` | `true` | +| `parakeet.encoder.att_context_presets` | ARRAY\ | Flattened `[l,r,l,r,...]` list of all `att_context_size` presets when the model stores a **list** of `[left,right]` pairs (multi-latency streaming, e.g. nemotron `[[56,3],[56,0],[56,6],[56,13]]`). The first pair is the default and is also written to `att_context_left`/`att_context_right`. **Streaming, multi-context only.** | `cfg.encoder.att_context_size` | (n/a) | +| `parakeet.prompt.present` | BOOL | Marks a prompt-conditioned multilingual model (nemotron). When `true` the C++ engine inserts the `prompt_kernel` (Linear, ReLU, Linear) on the encoder output, selected by a per-utterance language one-hot. Absent/`false` for every other model (which skip the stage entirely). | `model_defaults.initialize_prompt_feature` | (n/a) | +| `parakeet.prompt.num_prompts` | UINT32 | Width of the language one-hot appended to the encoder output (`prompt_kernel.0` input = `d_model + num_prompts`). **Prompt only.** | `model_defaults.num_prompts` | 128 | +| `parakeet.prompt.default_lang` | STRING | Locale used when no `--lang`/`target_lang` is given (nemotron: `auto`, prompt index 101). **Prompt only.** | derived (`auto` if present) | `auto` | +| `parakeet.prompt.dictionary.keys` | ARRAY\ | Locale strings (e.g. `en`, `en-US`, `de`, `es`, `ja-JP`, `auto`) parallel to `dictionary.values`. The loader resolves a `target_lang` to its prompt index by lookup. **Prompt only.** | `model_defaults.prompt_dictionary` keys | len 121 | +| `parakeet.prompt.dictionary.values` | ARRAY\ | Prompt index for each parallel key (multiple locales may share an index, e.g. `en` and `en-US` both map to 0). **Prompt only.** | `model_defaults.prompt_dictionary` values | (n/a) | | `parakeet.preprocessor.sample_rate` | UINT32 | Audio sample rate | `featurizer.sample_rate` | 16000 | | `parakeet.preprocessor.n_mels` | UINT32 | Mel filterbank count | `featurizer.nfilt` | 80 | | `parakeet.preprocessor.n_fft` | UINT32 | FFT size | `featurizer.n_fft` | 512 | @@ -122,6 +139,13 @@ State-dict prefixes present in the hybrid anchor (690 tensors total): | CTC head (hybrid aux CTC) | `ctc_decoder.decoder_layers.0.*` | `ctc_decoder.decoder_layers.0.weight` shape `(vocab+1, d_model, 1)` | | Prediction net (LSTM) | `decoder.prediction.*` | `decoder.prediction.embed.weight`, `decoder.prediction.dec_rnn.lstm.weight_ih_l0` | | Joint net | `joint.{enc,pred,joint_net}.*` | `joint.joint_net.2.weight` shape `(vocab+1+D, joint_hidden)` | +| Prompt kernel (nemotron only) | `prompt_kernel.{0,2}.*` | `prompt_kernel.0.weight` `(2048, d_model+num_prompts)`, `prompt_kernel.0.bias` `(2048,)`, `prompt_kernel.2.weight` `(d_model, 2048)`, `prompt_kernel.2.bias` `(d_model,)` | + +> The `prompt_kernel.*` projection weights are written verbatim by the generic +> tensor loop (no special handling); only the `parakeet.prompt.*` KV metadata is +> added by the converter. Like the LSTM prediction net and the featurizer buffers, +> the prompt kernel is **not** on the quantization allowlist, so it stays F32 in +> every quantized variant. > Pure-CTC checkpoints (`EncDecCTCModelBPE`) put the CTC head under `decoder.*` > instead of `ctc_decoder.*`; the verbatim rule preserves whatever the checkpoint diff --git a/docs/parity.md b/docs/parity.md index 8145ce8..94835d4 100644 --- a/docs/parity.md +++ b/docs/parity.md @@ -30,6 +30,7 @@ CPU, batch 1, deterministic greedy (NeMo 2.7.3). | `parakeet-rnnt-0.6b` | RNNT | `rnnt` | 1024 / 24 | 80 | **true** | 1024 | RNNT | **0.0** | PASS | | `parakeet-rnnt-1.1b` | RNNT | `rnnt` | 1024 / 42 | 80 | **true** | 1024 | RNNT | **0.0** | PASS | | `parakeet_realtime_eou_120m-v1` | Streaming + EOU | `rnnt` | 512 / 17 | 128 | false | 1026 | RNNT (offline, limited-context) | **0.0** | PASS (Phase 5 — 5a milestone) | +| `nemotron-3.5-asr-streaming-0.6b` | Streaming, multilingual, prompt-conditioned | `rnnt` | 1024 / 24 | 128 | false | 13087 | RNNT offline + cache-aware streaming, per language | **0.0** | PASS (offline + streaming, langs en/de/es/ja-JP/auto) | Notes: - `xscaling` = NeMo FastConformer `xscale=sqrt(d_model)` (true) vs `xscale=None` (false). @@ -53,6 +54,19 @@ Notes: clip NeMo's streaming does NOT emit `` (the final-chunk tail has incomplete right context); the C++ streaming session/C-API/CLI match that exactly and do not fabricate one. See "Phase 5 — Streaming + EOU" below. +- `nemotron-3.5-asr-streaming-0.6b` (multilingual, prompt-conditioned): a target + language one-hot (`--lang `, default `auto`) is projected through the + `prompt_kernel` (Linear, ReLU, Linear) on the encoder output before the RNNT + decode, both offline and per streaming chunk (the one-hot is constant over time, + so per-chunk application is exact). The authoritative NeMo reference decodes the + prompt-conditioned encoder output via `m.decoding.rnnt_decoder_predictions_tensor` + (the lhotse `transcribe(target_lang=...)` path needs per-cut language metadata our + bare wav fixtures lack). `scripts/e2e_nemo_compare.py` cross-checks the C++ CLI + against this reference for `tests/fixtures/{speech,clip}.wav` × `{en, de, es, + ja-JP, auto}` × `{offline, stream}`: **all 20 rows WER 0.0**. The `prompt_kernel`, + LSTM prediction net, and featurizer tensors stay F32 in every quantized variant + (f16 and q8_0 also verified WER 0.0; see `docs/quantization.md`). Note `ja` is not + a dictionary key — the Japanese locale is `ja-JP`. --- diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index 421bcff..466f6fc 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -81,7 +81,7 @@ static int cmd_info(const char* path) { // When `timestamps` is set, also prints one line per finalized word // (`- ()`) after the running text/EOU line. static int cmd_transcribe_stream(const std::string& model, const std::string& input, - bool timestamps) { + bool timestamps, const std::string& lang) { pk::ModelLoader ml; if (!ml.load(model)) { std::fprintf(stderr, "parakeet-cli: failed to load model %s\n", model.c_str()); @@ -100,7 +100,12 @@ static int cmd_transcribe_stream(const std::string& model, const std::string& in } try { - pk::StreamingSession sess(ml); + // `lang` selects the language prompt for multilingual (nemotron) prompt + // models; empty -> the model default, and non-prompt models ignore it. + // This is exactly what parakeet_capi_stream_begin_lang forwards to the + // StreamingSession ctor — done directly here so the CLI keeps its rich + // per-word / EOU-timestamp output the flat stream C-API does not expose. + pk::StreamingSession sess(ml, lang); std::vector all_words; // collected for the --timestamps recap std::printf("[stream] "); std::fflush(stdout); @@ -149,7 +154,7 @@ static int cmd_transcribe_stream(const std::string& model, const std::string& in // archs, CTC for ctc arch — matching NeMo's cur_decoder default). --stream uses // the cache-aware streaming path (EOU streaming model only). static int cmd_transcribe(int argc, char** argv) { - std::string model, input, decoder_str; + std::string model, input, decoder_str, lang; bool stream = false; bool timestamps = false; bool json = false; @@ -161,6 +166,8 @@ static int cmd_transcribe(int argc, char** argv) { input = argv[++i]; } else if (std::strcmp(argv[i], "--decoder") == 0 && i + 1 < argc) { decoder_str = argv[++i]; + } else if (std::strcmp(argv[i], "--lang") == 0 && i + 1 < argc) { + lang = argv[++i]; } else if (std::strcmp(argv[i], "--stream") == 0) { stream = true; } else if (std::strcmp(argv[i], "--timestamps") == 0) { @@ -174,7 +181,8 @@ static int cmd_transcribe(int argc, char** argv) { if (model.empty() || input.empty()) { std::fprintf(stderr, "usage: parakeet-cli transcribe --model --input " - "[--decoder ctc|tdt] [--stream] [--timestamps] [--threads N] [--json]\n"); + "[--decoder ctc|tdt] [--lang ] [--stream] [--timestamps] " + "[--threads N] [--json]\n"); return 2; } // Apply the thread override (offline + streaming graph compute). When unset @@ -191,7 +199,7 @@ static int cmd_transcribe(int argc, char** argv) { "parakeet-cli: --json is not supported with --stream\n"); return 2; } - return cmd_transcribe_stream(model, input, timestamps); + return cmd_transcribe_stream(model, input, timestamps, lang); } // Resolve the decoder selector. @@ -239,7 +247,9 @@ static int cmd_transcribe(int argc, char** argv) { std::fprintf(stderr, "parakeet-cli: failed to load model %s\n", model.c_str()); return 1; } - pk::Transcription tr = m->transcribe_path_with_timestamps(input, dec); + // `lang` (empty -> model default) selects the language prompt for + // multilingual models; ignored by non-prompt models. + pk::Transcription tr = m->transcribe_path_with_timestamps(input, dec, lang); for (const pk::Word& w : tr.words) std::printf("%.2f-%.2f %s (%.2f)\n", w.start, w.end, w.text.c_str(), w.conf); @@ -250,6 +260,29 @@ static int cmd_transcribe(int argc, char** argv) { return 0; } + // Plain transcript. When --lang is given, go through the load-once C-API + // language variant so the language prompt is selected (and an unknown locale + // surfaces as a clean error). With no --lang keep the existing free-function + // path so behavior for every other model is byte-for-byte unchanged. + if (!lang.empty()) { + parakeet_ctx* ctx = parakeet_capi_load(model.c_str()); + if (!ctx) { + std::fprintf(stderr, "parakeet-cli: failed to load model %s\n", model.c_str()); + return 1; + } + char* t = parakeet_capi_transcribe_path_lang(ctx, input.c_str(), dec_int, + lang.c_str()); + if (!t) { + std::fprintf(stderr, "transcribe failed: %s\n", parakeet_capi_last_error(ctx)); + parakeet_capi_free(ctx); + return 1; + } + std::printf("%s\n", t); + parakeet_capi_free_string(t); + parakeet_capi_free(ctx); + return 0; + } + std::string text; try { text = pk::transcribe(model, input, dec); @@ -494,7 +527,7 @@ static std::vector read_manifest(const std::string& path, bool& ok) } static int cmd_bench(int argc, char** argv) { - std::string model, manifest, decoder_str, json_out; + std::string model, manifest, decoder_str, json_out, lang; int threads = 0; // 0 == unset -> use the components' built-in default for (int i = 0; i < argc; ++i) { if (std::strcmp(argv[i], "--model") == 0 && i + 1 < argc) { @@ -503,6 +536,8 @@ static int cmd_bench(int argc, char** argv) { manifest = argv[++i]; } else if (std::strcmp(argv[i], "--decoder") == 0 && i + 1 < argc) { decoder_str = argv[++i]; + } else if (std::strcmp(argv[i], "--lang") == 0 && i + 1 < argc) { + lang = argv[++i]; } else if (std::strcmp(argv[i], "--threads") == 0 && i + 1 < argc) { threads = std::atoi(argv[++i]); } else if (std::strcmp(argv[i], "--json") == 0 && i + 1 < argc) { @@ -512,7 +547,7 @@ static int cmd_bench(int argc, char** argv) { if (model.empty() || manifest.empty()) { std::fprintf(stderr, "usage: parakeet-cli bench --model --manifest " - "[--decoder ctc|tdt] [--threads N] [--json ]\n"); + "[--decoder ctc|tdt] [--lang ] [--threads N] [--json ]\n"); return 2; } @@ -580,7 +615,7 @@ static int cmd_bench(int argc, char** argv) { { pk::Audio warm; if (pk::load_audio_16k_mono(paths[0], warm)) { - (void)m->transcribe_pcm(warm.samples, 16000, dec); + (void)m->transcribe_pcm(warm.samples, 16000, dec, lang); } } @@ -598,7 +633,7 @@ static int cmd_bench(int argc, char** argv) { auto t_proc = clock::now(); std::string text; try { - text = m->transcribe_pcm(audio.samples, 16000, dec); + text = m->transcribe_pcm(audio.samples, 16000, dec, lang); } catch (const std::exception& e) { std::fprintf(stderr, "parakeet-cli bench: transcribe failed on %s: %s\n", p.c_str(), e.what()); @@ -1142,11 +1177,12 @@ int main(int argc, char** argv) { "usage:\n" " parakeet-cli info \n" " parakeet-cli transcribe --model --input " - "[--decoder ctc|tdt] [--stream] [--timestamps] [--threads N] [--json]\n" + "[--decoder ctc|tdt] [--lang ] [--stream] [--timestamps] " + "[--threads N] [--json]\n" " parakeet-cli quantize " "\n" " parakeet-cli bench --model --manifest " - "[--decoder ctc|tdt] [--threads N] [--json ]\n" + "[--decoder ctc|tdt] [--lang ] [--threads N] [--json ]\n" " parakeet-cli bench-batch --model --manifest " "[--decoder ctc|tdt] [--threads N] [--batch-sizes 1,4,8] [--json ]\n" " parakeet-cli bench-decode --model --audio " diff --git a/include/parakeet_capi.h b/include/parakeet_capi.h index 98cd4a5..def9d2d 100644 --- a/include/parakeet_capi.h +++ b/include/parakeet_capi.h @@ -17,6 +17,11 @@ typedef struct parakeet_ctx parakeet_ctx; // ABI version of this header/implementation. Bump on any breaking change to the // function signatures or semantics below. +// +// v3: added the target_lang variants (parakeet_capi_transcribe_path_lang, +// parakeet_capi_transcribe_pcm_lang, parakeet_capi_stream_begin_lang) for +// multilingual prompt-conditioned (nemotron) models. The original non-lang +// entry points are unchanged and delegate with the model default language. int parakeet_capi_abi_version(void); // Load a GGUF model. Returns an owning context, or NULL on failure. @@ -44,6 +49,21 @@ char* parakeet_capi_transcribe_path(parakeet_ctx* ctx, const char* wav_path, char* parakeet_capi_transcribe_pcm(parakeet_ctx* ctx, const float* samples, int n_samples, int sample_rate, int decoder); +// Like parakeet_capi_transcribe_path but selects the language prompt for +// multilingual (nemotron) models. `target_lang` is a locale string (e.g. "en", +// "de", "auto"); NULL or "" uses the model's default ("auto"). Ignored by +// non-prompt models. On an unknown locale (for a prompt model) returns NULL and +// sets the context's last error. parakeet_capi_transcribe_path delegates here +// with the model default. +char* parakeet_capi_transcribe_path_lang(parakeet_ctx* ctx, const char* wav_path, + int decoder, const char* target_lang); + +// Like parakeet_capi_transcribe_pcm but selects the language prompt (see +// parakeet_capi_transcribe_path_lang for `target_lang` semantics). +char* parakeet_capi_transcribe_pcm_lang(parakeet_ctx* ctx, const float* samples, + int n_samples, int sample_rate, int decoder, + const char* target_lang); + // Transcribe a batch of in-memory mono float PCM clips. `samples` is an array of // `n_clips` pointers and `n_samples` an array of `n_clips` per-clip lengths; each // clip is resampled to 16 kHz if `sample_rate != 16000`. `decoder` is as in @@ -107,6 +127,15 @@ typedef struct parakeet_stream parakeet_stream; // the model is not a cache-aware streaming model) and sets the ctx last error. parakeet_stream* parakeet_capi_stream_begin(parakeet_ctx* ctx); +// Begin a streaming session selecting the language prompt for multilingual +// (nemotron) prompt-conditioned models. `target_lang` is a locale string (e.g. +// "en", "de", "auto"); NULL or "" uses the model's default. Ignored by +// non-prompt models. Returns NULL on failure (not a streaming model, or an +// unknown locale) and sets the ctx last error. parakeet_capi_stream_begin +// delegates here with the model default. +parakeet_stream* parakeet_capi_stream_begin_lang(parakeet_ctx* ctx, + const char* target_lang); + // Feed a block of 16 kHz MONO float PCM (`pcm`, length `n_samples`). The session // buffers the audio and decodes as full encoder chunks become available. // Returns the newly-finalized text since the last call as a malloc'd UTF-8 diff --git a/scripts/bench_nemotron.py b/scripts/bench_nemotron.py new file mode 100644 index 0000000..ff20406 --- /dev/null +++ b/scripts/bench_nemotron.py @@ -0,0 +1,271 @@ +#!/usr/bin/env python3 +"""Benchmark the prompt-conditioned nemotron model: NeMo (PyTorch CPU) vs our +C++/ggml engine, on one clip at one target language. + +Mirrors scripts/benchmark.py methodology but for a LOCAL .nemo + GGUF pair that +the HF-id-keyed benchmark.py does not cover: + + * ours : ``parakeet-cli bench --decoder tdt --lang --threads T`` over a + manifest that repeats the clip ``--passes`` times. bench loads the + model once, warms up once (untimed), then times each transcribe with + ``transcribe_pcm`` only (load excluded). We take the MEDIAN per-pass + proc time, the same as the median latency benchmark.py reports. + * NeMo : load the .nemo once, run the SAME prompt branch as NeMo's forward() + (preprocessor -> encoder -> cat(one-hot prompt) -> prompt_kernel -> + rnnt greedy via ``decoding.rnnt_decoder_predictions_tensor``). Warm + up once (untimed), then time ``--passes`` forward+decode passes and + take the median. This is the authoritative offline path from + ``gen_nemo_baseline.compute_prompt_reference`` (the lhotse + transcribe(target_lang=...) dataloader needs per-cut language + metadata our bare wav fixtures lack). + +RTFx = audio_sec / median_proc_sec (higher = faster). Speedup = ours / NeMo. +Run under the NeMo venv python so ``import nemo`` works; it shells out to the +C++ CLI for the ours pass. +""" +from __future__ import annotations + +import argparse +import json +import statistics +import subprocess +import sys +import tempfile +import time +import warnings +from pathlib import Path + +warnings.filterwarnings("ignore") + +REPO = Path(__file__).resolve().parent.parent +sys.path.insert(0, str(REPO / "scripts")) +from asr_metrics import wer # noqa: E402 +from gen_nemo_baseline import resolve_prompt_lang # noqa: E402 + + +def run_ours(cli: Path, gguf: Path, wav: Path, lang: str, threads: int, + passes: int) -> dict: + """parakeet-cli bench over a manifest that repeats ``wav`` ``passes`` times. + + Returns {median_proc_s, audio_sec, text, load_ms, all_proc_s}. + """ + with tempfile.NamedTemporaryFile("w", suffix=".tsv", delete=False) as mf: + for _ in range(passes): + mf.write(f"{wav}\n") + manifest = Path(mf.name) + with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as jf: + out_json = Path(jf.name) + try: + cmd = [ + str(cli), "bench", + "--model", str(gguf), + "--manifest", str(manifest), + "--decoder", "tdt", + "--lang", lang, + "--threads", str(threads), + "--json", str(out_json), + ] + res = subprocess.run(cmd, capture_output=True, text=True) + if res.returncode != 0: + raise RuntimeError( + f"ours pass failed (rc={res.returncode})\ncmd: {' '.join(cmd)}\n" + f"stderr:\n{res.stderr[-2000:]}") + doc = json.loads(out_json.read_text()) + files = doc["files"] + proc_s = [f["proc_ms"] / 1000.0 for f in files] + return { + "median_proc_s": statistics.median(proc_s), + "audio_sec": files[0]["audio_sec"], + "text": files[0]["text"], + "load_ms": doc.get("load_ms"), + "all_proc_s": proc_s, + } + finally: + manifest.unlink(missing_ok=True) + out_json.unlink(missing_ok=True) + + +def run_ours_stream(cli: Path, gguf: Path, wav: Path, lang: str, threads: int, + load_s: float, passes: int) -> dict: + """Time the cache-aware STREAMING path. It has no built-in proc timer, so we + time the whole ``transcribe --stream`` invocation (median wall over a few + runs after one warmup) and subtract the measured one-time model load to get a + compute-only estimate. Streaming is latency-oriented (many small chunked + forward passes), so this RTFx is far below the offline number by design. + """ + def one_wall() -> float: + t = time.perf_counter() + res = subprocess.run( + [str(cli), "transcribe", "--model", str(gguf), "--input", str(wav), + "--lang", lang, "--stream", "--threads", str(threads)], + capture_output=True, text=True) + if res.returncode != 0: + raise RuntimeError(f"stream run failed:\n{res.stderr[-1500:]}") + return time.perf_counter() - t + + one_wall() # warmup + walls = [one_wall() for _ in range(passes)] + median_wall_s = statistics.median(walls) + compute_s = max(median_wall_s - load_s, 1e-9) + return { + "median_wall_s": median_wall_s, + "compute_s": compute_s, + "load_s": load_s, + } + + +def run_nemo(nemo_path: Path, wav: Path, lang: str, threads: int, + passes: int) -> dict: + """Load the .nemo once and time ``passes`` offline prompt forward+decode runs. + + Returns {median_proc_s, audio_sec, text, load_s, nemo_version, all_proc_s}. + """ + import numpy as np + import soundfile as sf + import torch + import nemo + from nemo.collections.asr.models import ASRModel + + torch.set_num_threads(threads) + nemo_version = getattr(nemo, "__version__", "unknown") + + t0 = time.perf_counter() + m = ASRModel.restore_from(str(nemo_path), map_location="cpu") + m.eval() + m.preprocessor.featurizer.dither = 0.0 + load_s = time.perf_counter() - t0 + + target_lang, pidx, num_prompts = resolve_prompt_lang(m, lang) + + wav_np, sr = sf.read(str(wav), dtype="float32", always_2d=False) + if wav_np.ndim > 1: + wav_np = wav_np.mean(axis=1) + if sr != 16000: + raise RuntimeError(f"expected 16k mono, got sr={sr}") + audio_sec = len(wav_np) / 16000.0 + wt = torch.from_numpy(np.ascontiguousarray(wav_np)).float().unsqueeze(0) + lt = torch.tensor([wt.shape[1]], dtype=torch.int64) + + def forward_decode() -> str: + with torch.no_grad(): + feats, flen = m.preprocessor(input_signal=wt, length=lt) + enc, elen = m.encoder(audio_signal=feats, length=flen) # [1, D, T] + encoded = enc.transpose(1, 2) # [1, T, D] + T = encoded.shape[1] + onehot = torch.zeros(1, T, num_prompts, dtype=encoded.dtype) + onehot[:, :, pidx] = 1.0 + concat = torch.cat([encoded, onehot], dim=-1) # [1, T, D+P] + pk_out = m.prompt_kernel(concat) # [1, T, D] + pk_enc = pk_out.transpose(1, 2).contiguous() # [1, D, T] + hyps = m.decoding.rnnt_decoder_predictions_tensor( + encoder_output=pk_enc, encoded_lengths=elen, + return_hypotheses=True) + first = hyps[0] if isinstance(hyps, list) else hyps + if isinstance(first, list): + first = first[0] + return first.text if hasattr(first, "text") else str(first) + + # Warm up once (untimed), then time `passes` forward+decode runs. + text = forward_decode() + proc_s = [] + for _ in range(passes): + t = time.perf_counter() + text = forward_decode() + proc_s.append(time.perf_counter() - t) + + return { + "median_proc_s": statistics.median(proc_s), + "audio_sec": audio_sec, + "text": text, + "load_s": load_s, + "nemo_version": nemo_version, + "target_lang": target_lang, + "all_proc_s": proc_s, + } + + +def main() -> int: + ap = argparse.ArgumentParser(description=__doc__) + ap.add_argument("--nemo", required=True, help="local .nemo checkpoint") + ap.add_argument( + "--gguf", action="append", required=True, metavar="DTYPE=PATH", + help="dtype-labelled GGUF, e.g. f32=/tmp/nemotron.gguf (repeatable)") + ap.add_argument("--wav", required=True) + ap.add_argument("--lang", default="en") + ap.add_argument("--threads", type=int, default=8) + ap.add_argument("--passes", type=int, default=5) + ap.add_argument("--cli", default=str( + REPO / "build" / "examples" / "cli" / "parakeet-cli")) + ap.add_argument("--stream-dtype", default="", + help="dtype key to also measure on the streaming path " + "(e.g. f32); empty -> skip streaming") + ap.add_argument("--out", default="") + args = ap.parse_args() + + cli = Path(args.cli) + wav = Path(args.wav) + ggufs: dict[str, Path] = {} + for spec in args.gguf: + dtype, _, path = spec.partition("=") + ggufs[dtype] = Path(path) + + print(f"=== nemotron bench: lang={args.lang} threads={args.threads} " + f"passes={args.passes} clip={wav.name} ===", flush=True) + + nemo = run_nemo(Path(args.nemo), wav, args.lang, args.threads, args.passes) + nemo_rtfx = nemo["audio_sec"] / nemo["median_proc_s"] + print(f"[nemo {nemo['nemo_version']}] lang={nemo['target_lang']} " + f"median_proc={nemo['median_proc_s']*1000:.1f}ms RTFx={nemo_rtfx:.2f}", + flush=True) + + ours: dict[str, dict] = {} + for dtype, gguf in ggufs.items(): + o = run_ours(cli, gguf, wav, args.lang, args.threads, args.passes) + rtfx = o["audio_sec"] / o["median_proc_s"] + agree = wer(nemo["text"], o["text"]) + ours[dtype] = {**o, "rtfx": rtfx, "speedup": rtfx / nemo_rtfx, + "agreement_wer": agree} + print(f"[ours {dtype}] median_proc={o['median_proc_s']*1000:.1f}ms " + f"RTFx={rtfx:.2f} speedup={rtfx/nemo_rtfx:.2f}x " + f"agreement_WER={agree:.4f}", flush=True) + + stream = None + if args.stream_dtype and args.stream_dtype in ggufs: + d = args.stream_dtype + s = run_ours_stream(cli, ggufs[d], wav, args.lang, args.threads, + ours[d]["load_ms"] / 1000.0, args.passes) + s_rtfx = nemo["audio_sec"] / s["compute_s"] + s_wall_rtfx = nemo["audio_sec"] / s["median_wall_s"] + stream = {"dtype": d, "compute_rtfx": s_rtfx, "wall_rtfx": s_wall_rtfx, + **s} + print(f"[ours stream {d}] median_wall={s['median_wall_s']*1000:.1f}ms " + f"compute_RTFx={s_rtfx:.2f} (wall_RTFx={s_wall_rtfx:.2f})", + flush=True) + + doc = { + "clip": wav.name, + "audio_sec": nemo["audio_sec"], + "lang": args.lang, + "threads": args.threads, + "passes": args.passes, + "nemo": {"rtfx": nemo_rtfx, "median_proc_s": nemo["median_proc_s"], + "text": nemo["text"], "version": nemo["nemo_version"], + "load_s": nemo["load_s"]}, + "ours": {d: {"rtfx": v["rtfx"], "speedup": v["speedup"], + "median_proc_s": v["median_proc_s"], + "agreement_wer": v["agreement_wer"], "text": v["text"], + "load_ms": v["load_ms"]} + for d, v in ours.items()}, + } + if stream is not None: + doc["stream"] = stream + if args.out: + Path(args.out).write_text(json.dumps(doc, indent=2)) + print(f"-> wrote {args.out}", flush=True) + else: + print(json.dumps(doc, indent=2)) + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/scripts/convert_parakeet_to_gguf.py b/scripts/convert_parakeet_to_gguf.py index 9e2462d..4c9ed45 100644 --- a/scripts/convert_parakeet_to_gguf.py +++ b/scripts/convert_parakeet_to_gguf.py @@ -53,7 +53,15 @@ def _get(cfg, key, default=None): def detect_arch(m): """Map a NeMo model to one of ctc/rnnt/tdt/hybrid_rnnt_ctc/hybrid_tdt_ctc.""" cfg = m.cfg - if _get(cfg, "aux_ctc") is not None: + # An aux_ctc *config* block is necessary but not sufficient for a hybrid + # model: prompt-conditioned RNNT checkpoints (nemotron) carry an unconfigured + # aux_ctc stub (num_classes=-1, empty vocabulary) but NO ctc decoder and zero + # ctc_decoder.* weights -- NeMo initializes them RNNT-only. Require an actual + # ctc_decoder on the model (the same module the engine loads ctc_decoder.* + # tensors from) before classifying as hybrid; otherwise fall through to the + # rnnt/tdt detection below. + has_ctc = getattr(m, "ctc_decoder", None) is not None + if _get(cfg, "aux_ctc") is not None and has_ctc: loss = _get(_get(cfg, "loss", {}) or {}, "loss_name", "") durs = _get(_get(cfg, "decoding", {}) or {}, "durations") return "hybrid_tdt_ctc" if (loss == "tdt" or durs) else "hybrid_rnnt_ctc" @@ -64,6 +72,25 @@ def detect_arch(m): return "ctc" +def prompt_config(cfg): + """Return (present, num_prompts, dict_keys, dict_vals, default_lang) for a + prompt-conditioned model, or (False, 0, [], [], "") otherwise. The prompt + feature lives under cfg.model_defaults (initialize_prompt_feature + + prompt_dictionary); the projection weights (prompt_kernel.*) are written + verbatim by the generic tensor loop, so only the KV metadata is new here.""" + md = _get(cfg, "model_defaults", {}) or {} + if not bool(_get(md, "initialize_prompt_feature", False)): + return False, 0, [], [], "" + pdict = _get(md, "prompt_dictionary", None) + if not pdict: + return False, 0, [], [], "" + num = int(_get(md, "num_prompts", 128)) + keys = [str(k) for k in pdict.keys()] + vals = [int(pdict[k]) for k in pdict.keys()] + default_lang = "auto" if "auto" in pdict else keys[0] + return True, num, keys, vals, default_lang + + # --------------------------------------------------------------------------- # Quantization policy. # @@ -170,6 +197,22 @@ def main(): w.add_uint32("parakeet.encoder.pos_emb_max_len", int(_get(enc, "pos_emb_max_len", 5000))) + # encoder bias flag (use_bias=False checkpoints omit the linear biases; the + # C++ loader reads them with clone_weight_opt and tolerates absence). + w.add_bool("parakeet.encoder.use_bias", bool(_get(enc, "use_bias", True))) + + # --- Prompt conditioning (multilingual nemotron) ------------------------ + # Orthogonal capability flag (like streaming.present). When present, the C++ + # engine inserts the prompt_kernel (Linear->ReLU->Linear) on the encoder + # output, selected by a one-hot language vector resolved from target_lang. + p_present, p_num, p_keys, p_vals, p_default = prompt_config(cfg) + if p_present: + w.add_bool("parakeet.prompt.present", True) + w.add_uint32("parakeet.prompt.num_prompts", p_num) + w.add_array("parakeet.prompt.dictionary.keys", p_keys) + w.add_array("parakeet.prompt.dictionary.values", p_vals) + w.add_string("parakeet.prompt.default_lang", p_default) + # --- Cache-aware streaming / causal config (Phase 5) --------------------- # These KVs describe the chunked-limited attention + causal conv that the # streaming FastConformer (e.g. parakeet_realtime_eou_120m-v1) uses. They are @@ -185,9 +228,22 @@ def main(): # int32 so the -1 sentinel survives if a streaming model ever uses it; # the loader reads them as int32 and defaults to -1 when absent. att_ctx = _get(enc, "att_context_size", [-1, -1]) or [-1, -1] - att_ctx = [int(x) for x in att_ctx] - att_left = att_ctx[0] if len(att_ctx) > 0 else -1 - att_right = att_ctx[1] if len(att_ctx) > 1 else -1 + # Multi-context models store a LIST of [left,right] presets; the default + # is the first (NeMo's default att_context_size index). A flat [l,r] + # (older streaming models like the eou) is used as-is. The first element + # being a non-scalar (list/tuple/OmegaConf ListConfig) marks the nested + # form -- detect it by "not a plain number" rather than an exact type so + # OmegaConf's ListConfig is handled too. + if att_ctx and not isinstance(att_ctx[0], (int, float)): + presets = [[int(x) for x in p] for p in att_ctx] + att_left, att_right = presets[0][0], presets[0][1] + # Record all presets so a future latency knob can pick another. + w.add_array("parakeet.encoder.att_context_presets", + [int(v) for p in presets for v in p]) # flattened [l,r,l,r,...] + else: + att_ctx = [int(x) for x in att_ctx] + att_left = att_ctx[0] if len(att_ctx) > 0 else -1 + att_right = att_ctx[1] if len(att_ctx) > 1 else -1 w.add_int32("parakeet.encoder.att_context_left", int(att_left)) w.add_int32("parakeet.encoder.att_context_right", int(att_right)) w.add_string("parakeet.encoder.att_context_style", att_style) diff --git a/scripts/e2e_nemo_compare.py b/scripts/e2e_nemo_compare.py new file mode 100644 index 0000000..9bc47c5 --- /dev/null +++ b/scripts/e2e_nemo_compare.py @@ -0,0 +1,142 @@ +#!/usr/bin/env python3 +"""End-to-end: compare parakeet.cpp nemotron transcripts to NeMo's, per language, +offline and cache-aware streaming. + +For each (clip, target_lang, mode): + + * NeMo reference -- computed by ``gen_nemo_baseline.compute_prompt_reference`` + (the SAME prompt-conditioned path every Phase 2/3 baseline uses: encoder -> + transpose -> cat(onehot) -> prompt_kernel -> transpose, then decode via + ``m.decoding.rnnt_decoder_predictions_tensor``; for streaming it drives + ``cache_aware_stream_step`` then applies the prompt and decodes). We do NOT + use the lhotse ``transcribe(target_lang=...)`` path because it needs per-cut + language metadata our bare wav fixtures lack. + * Ours -- the built ``parakeet-cli`` on the converted gguf with + ``--lang `` (offline) and ``--lang --stream``. + * WER -- ``asr_metrics.wer(nemo_text, ours_text)``. + +Prints a per-(clip,lang,mode) table and exits nonzero if any WER > 0. + +Usage: + .venv-nemotron/bin/python scripts/e2e_nemo_compare.py \\ + --nemo models/nemotron/model.nemo --gguf /tmp/nemotron.gguf \\ + --cli ./build/examples/cli/parakeet-cli \\ + --clips tests/fixtures/speech.wav,tests/fixtures/clip.wav \\ + --langs en,de,es,ja-JP,auto --mode both +""" +from __future__ import annotations + +import argparse +import pathlib +import subprocess +import sys + +SCRIPTS_DIR = pathlib.Path(__file__).resolve().parent +sys.path.insert(0, str(SCRIPTS_DIR)) + +from asr_metrics import wer # noqa: E402 +from gen_nemo_baseline import compute_prompt_reference # noqa: E402 + + +def load_nemo(nemo_path: str): + """Restore the NeMo model once (eval, dither off for determinism).""" + from nemo.collections.asr.models import ASRModel + + if pathlib.Path(nemo_path).exists(): + m = ASRModel.restore_from(nemo_path, map_location="cpu") + else: + m = ASRModel.from_pretrained(nemo_path, map_location="cpu") + m.eval() + m.preprocessor.featurizer.dither = 0.0 + return m + + +def _run_cli(cli: str, gguf: str, wav: str, lang: str, stream: bool) -> str: + cmd = [cli, "transcribe", "--model", gguf, "--input", wav] + if lang: + cmd += ["--lang", lang] + if stream: + cmd += ["--stream"] + proc = subprocess.run(cmd, capture_output=True, text=True) + if proc.returncode != 0: + raise RuntimeError( + f"parakeet-cli failed (exit {proc.returncode}) for " + f"lang={lang} stream={stream}:\n{proc.stderr.strip()}" + ) + out = proc.stdout + if stream: + # Take the authoritative final line: "[stream:final] ". + final = "" + for line in out.splitlines(): + line = line.strip() + if line.startswith("[stream:final]"): + final = line[len("[stream:final]"):].strip() + return final + # Offline prints exactly the transcript; take the last non-empty line. + lines = [ln.strip() for ln in out.splitlines() if ln.strip()] + return lines[-1] if lines else "" + + +def main() -> int: + ap = argparse.ArgumentParser( + description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter + ) + ap.add_argument("--nemo", required=True, help="NeMo .nemo checkpoint (or HF id)") + ap.add_argument("--gguf", required=True, help="converted nemotron gguf") + ap.add_argument("--cli", required=True, help="path to parakeet-cli") + ap.add_argument("--clips", required=True, help="comma-separated wav paths") + ap.add_argument("--langs", default="en,de,es,ja-JP,auto", + help="comma-separated target locales") + ap.add_argument("--mode", default="both", choices=["offline", "stream", "both"]) + args = ap.parse_args() + + clips = [c.strip() for c in args.clips.split(",") if c.strip()] + langs = [l.strip() for l in args.langs.split(",") if l.strip()] + modes = ["offline", "stream"] if args.mode == "both" else [args.mode] + + print(f"Loading NeMo model from {args.nemo} ...", flush=True) + m = load_nemo(args.nemo) + + rows = [] # (clip, lang, mode, wer, nemo_text, ours_text) + any_fail = False + + for clip in clips: + for lang in langs: + ref = compute_prompt_reference(m, clip, lang) + nemo_offline = ref["rnnt_text"] + nemo_stream = ref["stream_text"] + for mode in modes: + if mode == "stream" and nemo_stream is None: + print(f" SKIP stream {clip} {lang}: model not cache-aware") + continue + nemo_text = nemo_offline if mode == "offline" else nemo_stream + ours_text = _run_cli( + args.cli, args.gguf, clip, lang, stream=(mode == "stream") + ) + w = wer(nemo_text, ours_text) + if w > 0.0: + any_fail = True + rows.append((clip, lang, mode, w, nemo_text, ours_text)) + + # ---- Table ---- + print() + print("=" * 78) + print("E2E NeMo-vs-parakeet.cpp (WER 0.0 = byte-for-byte parity)") + print("=" * 78) + print(f"{'clip':28s} {'lang':8s} {'mode':8s} {'WER':>8s} status") + print("-" * 78) + for clip, lang, mode, w, nemo_text, ours_text in rows: + name = pathlib.Path(clip).name + status = "OK" if w == 0.0 else "MISMATCH" + print(f"{name:28s} {lang:8s} {mode:8s} {w:8.4f} {status}") + if w > 0.0: + print(f" nemo: {nemo_text!r}") + print(f" ours: {ours_text!r}") + print("-" * 78) + print(f"{len(rows)} rows; {'ALL PARITY (WER 0.0)' if not any_fail else 'FAILURES PRESENT'}") + + return 1 if any_fail else 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/scripts/gen_benchmark_md.py b/scripts/gen_benchmark_md.py index 659a906..8b4d976 100644 --- a/scripts/gen_benchmark_md.py +++ b/scripts/gen_benchmark_md.py @@ -118,6 +118,91 @@ def build_headline_table(models: list[dict]) -> str: return "\n".join(lines) +# ── Nemotron (prompt-conditioned, streaming, multilingual) ─────────────────────── + +def build_nemotron_section(results_dir: Path) -> str: + """Standalone section for the nemotron-3.5-asr-streaming-0.6b port. + + This model is prompt-conditioned (a per-language one-hot drives a PromptKernel + between the encoder and the transducer) and is benchmarked from a local .nemo + + GGUF pair, so it is outside the HF-id LibriSpeech pipeline above. Reads + ``/nemotron/bench.json`` written by ``scripts/bench_nemotron.py``; + returns "" when that file is absent. + """ + fp = results_dir / "nemotron" / "bench.json" + if not fp.is_file(): + return "" + d = json.loads(fp.read_text()) + clip = d["clip"] + audio = d["audio_sec"] + lang = d["lang"] + threads = d["threads"] + passes = d["passes"] + nemo = d["nemo"] + f32 = d["ours"]["f32"] + q8 = d["ours"].get("q8_0") + + lines = ["## Nemotron (streaming, multilingual, prompt-conditioned)\n"] + lines.append( + "`nemotron-3.5-asr-streaming-0.6b` is a FastConformer transducer with a " + "per-language prompt: a one-hot language vector drives a PromptKernel " + "between the encoder and the RNN-T decoder. It runs both offline and " + "cache-aware streaming. Because it loads from a local `.nemo` plus its " + "GGUF, it sits outside the LibriSpeech pipeline above and is measured on " + "its own here.\n" + ) + lines.append( + f"One clip (`{clip}`, {audio:.2f} s), language prompt `{lang}`, " + f"{threads} threads, median of {passes} passes after one warmup. ours is " + "`parakeet-cli bench --decoder tdt --lang " + lang + "` (load once, time " + "transcribe only); NeMo runs the same prompt forward (preprocessor, " + "encoder, PromptKernel, RNN-T greedy) on PyTorch CPU. RTFx = audio " + "seconds per second of compute; higher is faster.\n" + ) + lines.append( + f"Host: AMD Ryzen 9 9950X3D (20 cores), CPU-only. NeMo " + f"{nemo.get('version', '?')}.\n" + ) + + lines.append("| Engine | RTFx | Speedup vs NeMo | Agreement WER vs NeMo |") + lines.append("|---|---|---|---|") + lines.append(f"| NeMo (PyTorch CPU) | {nemo['rtfx']:.1f} | 1.00× | reference |") + lines.append( + f"| parakeet.cpp f32 | {f32['rtfx']:.1f} | {f32['speedup']:.2f}× " + f"| {f32['agreement_wer']*100:.4f}% |" + ) + if q8: + lines.append( + f"| parakeet.cpp q8_0 | {q8['rtfx']:.1f} | {q8['speedup']:.2f}× " + f"| {q8['agreement_wer']*100:.4f}% |" + ) + lines.append("") + lines.append( + f"Accuracy is **WER 0 vs NeMo**: the f32 and q8_0 transcripts are " + f"byte-identical to NeMo's on the timed runs (agreement WER 0.0000%), so " + f"the speed numbers compare equal work. parakeet.cpp is " + f"**{f32['speedup']:.2f}× faster than NeMo at f32** and " + f"**{q8['speedup']:.2f}× at q8_0**." if q8 else + f"Accuracy is **WER 0 vs NeMo** (agreement WER 0.0000%); parakeet.cpp is " + f"**{f32['speedup']:.2f}× faster than NeMo at f32**." + ) + lines.append("") + + stream = d.get("stream") + if stream: + lines.append( + f"Streaming path ({stream['dtype']}, cache-aware): compute RTFx " + f"**{stream['compute_rtfx']:.2f}** (median wall " + f"{stream['median_wall_s']*1000:.0f} ms over the {audio:.2f} s clip, " + f"one-time model load of {stream['load_s']*1000:.0f} ms subtracted). " + f"Streaming is latency-oriented: it runs many small chunked forward " + f"passes rather than one offline pass, so its RTFx sits well below the " + f"offline number by design while staying several times real time. The " + f"streaming transcript matches the offline and NeMo transcripts.\n" + ) + return "\n".join(lines) + + # ── Quantization summary (all dtypes, aggregated) ──────────────────────────────── def build_quant_table(models: list[dict], dtypes: list[str]) -> str: @@ -417,6 +502,7 @@ def main(): build_demo_section() + "\n", build_methodology(models, dtypes) + "\n", build_headline_table(models) + "\n", + build_nemotron_section(results_dir) + "\n", build_quant_table(models, dtypes) + "\n", build_batched_decode_section(results_dir) + "\n", build_plots_section(plots_dir, out_path) + "\n", diff --git a/scripts/gen_nemo_baseline.py b/scripts/gen_nemo_baseline.py index e8a15de..e5b717e 100644 --- a/scripts/gen_nemo_baseline.py +++ b/scripts/gen_nemo_baseline.py @@ -144,6 +144,265 @@ def _squeeze(arr): return np.ascontiguousarray(out) +def _get_md_flag(m, key): + """Read a model_defaults entry (OmegaConf dict or attr access).""" + md = getattr(m.cfg, "model_defaults", {}) or {} + try: + return md[key] + except Exception: + return getattr(md, key, None) + + +def _prompt_streaming_text(m, feats, feat_len, pidx, num_prompts, specials): + """NeMo cache-aware *streaming* transcript for a prompt model, with the + language prompt applied (the authoritative streaming target for the C++ + StreamingSession). + + Mirrors ``gen_stream_baseline.py`` for the encoder side, then applies the + SAME prompt branch as the offline forward (transpose -> cat(onehot) -> + prompt_kernel -> transpose) to the concatenated streamed encoder output, and + decodes the whole prompt-conditioned streamed output in one shot (the + cache-aware-equivalence property makes single-shot decode of the streamed + output == per-chunk decode carrying state, which is what the C++ session does). + + Returns (stream_text, stream_token_ids[list]). Returns (None, None) when the + encoder is not a cache-aware streaming encoder (so the caller can skip the KV). + """ + import torch + + enc = m.encoder + if not hasattr(enc, "cache_aware_stream_step"): + return None, None + from nemo.collections.asr.parts.utils.streaming_utils import ( + CacheAwareStreamingAudioBuffer, + ) + + enc.setup_streaming_params() + sc = enc.streaming_cfg + + sb = CacheAwareStreamingAudioBuffer( + model=m, online_normalization=False, pad_and_drop_preencoded=False + ) + sb.append_processed_signal(feats, stream_id=-1) + sb_iter = iter(sb) + cache_last_channel, cache_last_time, cache_last_channel_len = ( + enc.get_initial_cache_state(batch_size=1) + ) + + outs = [] + for step_num, (chunk_audio, chunk_lengths) in enumerate(sb_iter): + drop = sc.drop_extra_pre_encoded if step_num != 0 else 0 + keep_all = sb.is_buffer_empty() + with torch.no_grad(): + (e, el, cache_last_channel, cache_last_time, cache_last_channel_len) = ( + enc.cache_aware_stream_step( + processed_signal=chunk_audio, + processed_signal_length=chunk_lengths, + cache_last_channel=cache_last_channel, + cache_last_time=cache_last_time, + cache_last_channel_len=cache_last_channel_len, + keep_all_outputs=keep_all, + drop_extra_pre_encoded=drop, + ) + ) + valid = int(el[0].item()) + outs.append(e[:, :, :valid].detach()) # [1, d_model, valid] + + stream_enc = torch.cat(outs, dim=2) # [1, d_model, T'] + Tp = stream_enc.shape[2] + stream_len = torch.tensor([Tp], dtype=torch.int64) + + # Apply the language prompt to the streamed encoder output (same branch as + # the offline forward; one-hot constant over time). + with torch.no_grad(): + encoded = stream_enc.transpose(1, 2) # [1, T', D] + onehot = torch.zeros(1, Tp, num_prompts, dtype=encoded.dtype) + onehot[:, :, pidx] = 1.0 + concat = torch.cat([encoded, onehot], dim=-1) # [1, T', D+P] + pk_out = m.prompt_kernel(concat) # [1, T', D] + pk_enc = pk_out.transpose(1, 2).contiguous() # [1, D, T'] + hyps = m.decoding.rnnt_decoder_predictions_tensor( + encoder_output=pk_enc, encoded_lengths=stream_len, return_hypotheses=True + ) + first = hyps[0] if isinstance(hyps, list) else hyps + if isinstance(first, list): + first = first[0] + ys = first.y_sequence + ys = ys.cpu().tolist() if hasattr(ys, "cpu") else list(ys) + stream_ids = [int(t) for t in ys] + non_special = [t for t in stream_ids if t not in specials] + stream_text = m.tokenizer.ids_to_text(non_special) + return stream_text, stream_ids + + +def resolve_prompt_lang(m, lang=None): + """Resolve (target_lang, prompt_index, num_prompts) for a prompt model. + + ``lang`` None/empty -> the model default ("auto" if present, else the first + dictionary key). Exits(1) with PARAKEET_BASELINE_BAD_LANG on an unknown key. + """ + md = m.cfg.model_defaults + pdict = md.prompt_dictionary + num_prompts = int(md.get("num_prompts", 128)) + default_lang = "auto" if "auto" in pdict else list(pdict.keys())[0] + target_lang = lang or default_lang + if target_lang not in pdict: + keys = list(pdict.keys()) + print( + f"PARAKEET_BASELINE_BAD_LANG: '{target_lang}' not in prompt_dictionary; " + f"available (first 10): {keys[:10]}", + file=sys.stderr, + ) + sys.exit(1) + return target_lang, int(pdict[target_lang]), num_prompts + + +def compute_prompt_reference(m, audio_path, lang=None): + """NeMo reference for a prompt-conditioned (nemotron) model at target ``lang``. + + Runs the SAME prompt branch as NeMo's forward() (encoder -> transpose -> + cat(onehot) -> prompt_kernel -> transpose) and decodes the prompt-conditioned + encoder output via ``m.decoding.rnnt_decoder_predictions_tensor`` for BOTH the + offline encoder pass and the cache-aware streaming pass. This is the + authoritative path used by every Phase 2/3 baseline (the lhotse + transcribe(target_lang=...) dataloader needs per-cut language metadata our bare + wav fixtures lack). + + Returns a dict with: target_lang, prompt_index, num_prompts, encoder_out + (np [D,T]), prompt_kernel_out (np [T,D]), rnnt_ids (np int32), rnnt_text, + stream_text (str|None), stream_ids (list[int]|None). + """ + import torch + import soundfile as sf + + target_lang, pidx, num_prompts = resolve_prompt_lang(m, lang) + + wav, sr = sf.read(audio_path, dtype="float32", always_2d=False) + if wav.ndim > 1: + wav = wav.mean(axis=1) + if sr != 16000: + print(f"PARAKEET_BASELINE_BAD_AUDIO: expected 16k mono, got sr={sr}", + file=sys.stderr) + sys.exit(1) + wt = torch.from_numpy(np.ascontiguousarray(wav)).float().unsqueeze(0) # [1, S] + lt = torch.tensor([wt.shape[1]], dtype=torch.int64) + + with torch.no_grad(): + feats, flen = m.preprocessor(input_signal=wt, length=lt) + enc, elen = m.encoder(audio_signal=feats, length=flen) # [1, D, T] + encoded = enc.transpose(1, 2) # [1, T, D] + T = encoded.shape[1] + onehot = torch.zeros(1, T, num_prompts, dtype=encoded.dtype) + onehot[:, :, pidx] = 1.0 + concat = torch.cat([encoded, onehot], dim=-1) # [1, T, D+P] + pk_out = m.prompt_kernel(concat) # [1, T, D] + + # RNNT greedy reference for this language. We CANNOT use m.transcribe() here: + # the prompt model's transcribe dataloader (LhotseSpeechToTextBpeDatasetWith + # PromptIndex) resolves the prompt index from each cut's language metadata, + # which our bare wav fixture lacks ("Unknown prompt key: 'None'"). Instead we + # decode the prompt-conditioned encoder output DIRECTLY via the model's RNNT + # decoding object — exactly what _transcribe_output_processing does, and the + # same encoder-output the C++ engine feeds its rnnt_greedy after PromptKernel. + pk_enc = pk_out.transpose(1, 2).contiguous() # [1, D, T] for decoding + with torch.no_grad(): + hyps = m.decoding.rnnt_decoder_predictions_tensor( + encoder_output=pk_enc, encoded_lengths=elen, return_hypotheses=True + ) + first = hyps[0] if isinstance(hyps, list) else hyps + if isinstance(first, list): # NBest -> take the top hypothesis + first = first[0] + ys = first.y_sequence + ys = ys.cpu().tolist() if hasattr(ys, "cpu") else list(ys) + rnnt_ids = np.array(list(ys), dtype=np.int32) + rnnt_text = first.text if hasattr(first, "text") else str(first) + + # Resolve / special ids (if any) so the streaming transcript strips + # them exactly as StreamingSession does (specials surface as events, not text). + specials = set() + for tok in ("", ""): + try: + ids = m.tokenizer.tokens_to_ids([tok]) + if ids and int(ids[0]) >= 0: + specials.add(int(ids[0])) + except Exception: + pass + + # Cache-aware STREAMING transcript WITH the language prompt — the authoritative + # target for the C++ StreamingSession. None if the model isn't a streaming + # (cache-aware) encoder. + stream_text, stream_ids = _prompt_streaming_text( + m, feats, flen, pidx, num_prompts, specials + ) + + return { + "target_lang": target_lang, + "prompt_index": pidx, + "num_prompts": num_prompts, + "encoder_out": _squeeze(enc.cpu().float().numpy()), # [D, T] + "prompt_kernel_out": _squeeze(pk_out.cpu().float().numpy()), # [T, D] + "rnnt_ids": rnnt_ids, + "rnnt_text": rnnt_text, + "stream_text": stream_text, + "stream_ids": stream_ids, + } + + +def dump_prompt_baseline(m, args): + """Dump the prompt-model baseline for a fixed target_lang: + + * ``encoder_out`` ``[D, T]`` RAW encoder output (BEFORE prompt). + * ``prompt_kernel_out`` ``[T, D]`` prompt_kernel(cat([encoded, onehot])), + i.e. NeMo's forward() prompt branch. + * ``rnnt_token_ids`` ``[L]`` int32 NeMo RNNT greedy ids for this lang. + * KVs: baseline.target_lang, baseline.prompt_index, + baseline.rnnt_token_count, baseline.rnnt_text, + baseline.stream_text (cache-aware streaming transcript WITH prompt, + EOU/EOB stripped — the authoritative target for StreamingSession). + + Mirrors NeMo EncDecRNNTBPEModelWithPrompt.forward(): + encoded(B,D,T) -> transpose -> cat(onehot) -> prompt_kernel -> transpose. + The one-hot is constant over time (one language per utterance). + """ + ref = compute_prompt_reference(m, args.audio, args.lang) + target_lang = ref["target_lang"] + pidx = ref["prompt_index"] + num_prompts = ref["num_prompts"] + rnnt_ids = ref["rnnt_ids"] + rnnt_text = ref["rnnt_text"] + stream_text = ref["stream_text"] + stream_ids = ref["stream_ids"] + + w = gguf.GGUFWriter(args.output, "parakeet-baseline-prompt") + w.add_string("baseline.target_lang", target_lang) + w.add_uint32("baseline.prompt_index", pidx) + w.add_uint32("baseline.num_prompts", num_prompts) + w.add_tensor("encoder_out", ref["encoder_out"]) # [D, T] + w.add_tensor("prompt_kernel_out", ref["prompt_kernel_out"]) # [T, D] + w.add_uint32("baseline.rnnt_token_count", int(rnnt_ids.shape[0])) + if rnnt_ids.shape[0] > 0: + w.add_tensor("rnnt_token_ids", np.ascontiguousarray(rnnt_ids)) + w.add_string("baseline.rnnt_text", rnnt_text) + if stream_text is not None: + w.add_string("baseline.stream_text", stream_text) + sids = np.array(stream_ids, dtype=np.int32) + w.add_uint32("baseline.stream_token_count", int(sids.shape[0])) + if sids.shape[0] > 0: + w.add_tensor("stream_token_ids", np.ascontiguousarray(sids)) + w.write_header_to_file() + w.write_kv_data_to_file() + w.write_tensors_to_file() + w.close() + print( + f"wrote {args.output}: prompt baseline lang={target_lang} idx={pidx} " + f"tokens={rnnt_ids.shape[0]} text={rnnt_text!r}" + ) + if stream_text is not None: + print(f" baseline.stream_text={stream_text!r} (stream_tokens={len(stream_ids)})") + else: + print(" baseline.stream_text: SKIPPED (encoder is not cache-aware streaming)") + + def _timestamps_decoding_cfg(m): """Build a decoding cfg (cloned from the model's own) that turns on per-frame/token/word confidence using the reproducible ``max_prob`` method. @@ -446,6 +705,11 @@ def main(): "transcripts reflect banded local attention. Anchors the C++ " "banded-attention parity tests at NeMo quality.", ) + ap.add_argument( + "--lang", + default=None, + help="target_lang for prompt models (default: model default / auto)", + ) args = ap.parse_args() is_local = pathlib.Path(args.model).exists() @@ -485,6 +749,17 @@ def main(): _run_timestamps(m, args) return + # Prompt-conditioned multilingual model (nemotron, EncDecRNNTBPEModelWithPrompt): + # dump the raw encoder output, the prompt_kernel projection, and the per-language + # RNNT greedy reference. Kept as a separate early path; the encoder-stage hook + # baseline below is for the hybrid / pure-RNNT models and is untouched. + has_prompt = bool(_get_md_flag(m, "initialize_prompt_feature")) and ( + getattr(m, "prompt_kernel", None) is not None + ) + if has_prompt: + dump_prompt_baseline(m, args) + return + # Per-layer / module captures via forward hooks. The preprocessor and # encoder return (tensor, length) tuples; conformer layers return a bare # tensor (no cache) — handle both. diff --git a/scripts/publish_hf.py b/scripts/publish_hf.py index 0de039a..6f3f827 100644 --- a/scripts/publish_hf.py +++ b/scripts/publish_hf.py @@ -62,6 +62,7 @@ ALL_MODELS = [ "nvidia/parakeet-tdt_ctc-110m", "nvidia/parakeet_realtime_eou_120m-v1", + "nvidia/nemotron-3.5-asr-streaming-0.6b", "nvidia/parakeet-ctc-0.6b", "nvidia/parakeet-rnnt-0.6b", "nvidia/parakeet-tdt-0.6b-v2", @@ -120,8 +121,34 @@ "q8_0": {"wer": 0.0, "size_mb": None}, "q4_k": {"wer": None, "size_mb": None}, }, + # Multilingual prompt-conditioned streaming model. WER measured offline + # (en/de/auto) by scripts/e2e_nemo_compare.py vs NeMo; all five variants are + # byte-for-byte (WER 0.0). prompt_kernel + LSTM + featurizer tensors stay F32. + "nvidia/nemotron-3.5-asr-streaming-0.6b": { + "f16": {"wer": 0.0, "size_mb": 1484.3}, + "q8_0": {"wer": 0.0, "size_mb": 983.7}, + "q6_k": {"wer": 0.0, "size_mb": 855.7}, + "q5_k": {"wer": 0.0, "size_mb": 784.8}, + "q4_k": {"wer": 0.0, "size_mb": 718.1}, + }, +} + +# Per-model SPDX license id + human label, used in the generated card frontmatter +# and the License section. Defaults to CC-BY-4.0 (the NVIDIA NeMo Parakeet +# checkpoints); nemotron-3.5-asr-streaming is released under OpenMDW-1.1. +DEFAULT_LICENSE = ("cc-by-4.0", "CC-BY-4.0", + "https://creativecommons.org/licenses/by/4.0/") +LICENSES: dict = { + "nvidia/nemotron-3.5-asr-streaming-0.6b": ( + "other", "OpenMDW-1.1", "https://huggingface.co/nvidia/nemotron-3.5-asr-streaming-0.6b", + ), } + +def _license(model_id: str) -> tuple[str, str, str]: + """(frontmatter license id, human label, upstream/license url) for *model_id*.""" + return LICENSES.get(model_id, DEFAULT_LICENSE) + # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- @@ -283,6 +310,12 @@ def _size_str(size_mb, path: Optional[Path] = None) -> str: def _arch_info(model_id: str) -> tuple[str, str]: """Infer (arch_desc, decoder heads) from the model name.""" name_lower = model_id.lower() + if "nemotron" in name_lower: + return ( + "Cache-aware streaming, multilingual, prompt-conditioned RNNT " + "(FastConformer, 40+ locales, --lang )", + "RNNT (streaming, prompt-conditioned)", + ) if "realtime_eou" in name_lower or "realtime-eou" in name_lower: return "Cache-aware streaming RNNT (FastConformer, EOU/EOB)", "RNNT (streaming)" if "tdt_ctc" in name_lower: @@ -307,13 +340,16 @@ def build_model_card( wer_data = KNOWN_WER.get(model_id, {}) arch_desc, heads = _arch_info(model_id) + lic_id, lic_label, lic_url = _license(model_id) lines: List[str] = [] - # YAML frontmatter + # YAML frontmatter. For a non-SPDX license (license: other) HF wants an + # explicit license_name + license_link. + lines += ["---", f"license: {lic_id}"] + if lic_id == "other": + lines += [f"license_name: {lic_label}", f"license_link: {lic_url}"] lines += [ - "---", - "license: cc-by-4.0", "library_name: parakeet.cpp", "tags:", " - automatic-speech-recognition", @@ -431,8 +467,8 @@ def build_model_card( lines.append("## License") lines.append("") lines.append( - "The GGUF weights are derived from the NVIDIA NeMo Parakeet checkpoints, " - "which are released under the [CC-BY-4.0](https://creativecommons.org/licenses/by/4.0/) license. " + f"The GGUF weights are derived from [{model_id}](https://huggingface.co/{model_id}), " + f"released under the [{lic_label}]({lic_url}) license. " "The parakeet.cpp runtime is MIT-licensed." ) lines.append("") @@ -551,8 +587,24 @@ def build_collection_card( """One combined model card for the whole collection repo.""" lines: List[str] = [] - # YAML frontmatter — base_model accepts a list. - lines += ["---", "license: cc-by-4.0", "library_name: parakeet.cpp", "tags:"] + # YAML frontmatter — base_model accepts a list. Licenses can differ per model + # (most are CC-BY-4.0; nemotron is OpenMDW-1.1), so when the collection mixes + # licenses we declare `other` here and spell each one out in the License + # section / per-model rows below. + licenses = {_license(m) for m in models} + if len(licenses) == 1: + only_id, only_label, only_url = next(iter(licenses)) + lines = ["---", f"license: {only_id}"] + if only_id == "other": + lines += [f"license_name: {only_label}", f"license_link: {only_url}"] + lines += ["library_name: parakeet.cpp", "tags:"] + else: + lines = [ + "---", "license: other", + "license_name: mixed (see per-model licenses below)", + f"license_link: https://huggingface.co/{repo_id}", + "library_name: parakeet.cpp", "tags:", + ] lines += [f" - {t}" for t in ( "automatic-speech-recognition", "asr", "parakeet", "gguf", "ggml", "cpp-inference", "nemo", @@ -585,11 +637,12 @@ def build_collection_card( arch_desc, heads = _arch_info(model_id) wer_data = KNOWN_WER.get(model_id, {}) vp = paths_by_model.get(model_id, {}) + _lic_id, _lic_label, _lic_url = _license(model_id) lines.append(f"### {slug}") lines.append("") lines.append( f"Source: [{model_id}](https://huggingface.co/{model_id}) · " - f"{arch_desc} · heads: {heads}" + f"{arch_desc} · heads: {heads} · license: [{_lic_label}]({_lic_url})" ) lines.append("") lines.append("| File | Variant | Size | WER vs NeMo |") @@ -645,15 +698,24 @@ def build_collection_card( lines.append("```") lines.append("") - # License. + # License. Each GGUF inherits the license of its source checkpoint; list them + # explicitly because the collection can mix licenses (most CC-BY-4.0, nemotron + # OpenMDW-1.1). lines.append("## License") lines.append("") lines.append( - "The GGUF weights are derived from the NVIDIA NeMo Parakeet checkpoints, released " - "under the [CC-BY-4.0](https://creativecommons.org/licenses/by/4.0/) license. " - "The parakeet.cpp runtime is MIT-licensed." + "Each GGUF is derived from its upstream NVIDIA NeMo checkpoint and inherits " + "that checkpoint's license. The parakeet.cpp runtime itself is MIT-licensed." ) lines.append("") + lines.append("| Source checkpoint | License |") + lines.append("|---|---|") + for model_id in models: + _lic_id, _lic_label, _lic_url = _license(model_id) + lines.append( + f"| [{model_id}](https://huggingface.co/{model_id}) | [{_lic_label}]({_lic_url}) |" + ) + lines.append("") return "\n".join(lines) diff --git a/src/model.cpp b/src/model.cpp index fcaa594..c212418 100644 --- a/src/model.cpp +++ b/src/model.cpp @@ -9,6 +9,7 @@ #include "tokenizer.hpp" #include "prediction.hpp" #include "joint.hpp" +#include "prompt_kernel.hpp" #include "tdt.hpp" #include "rnnt.hpp" #include "transducer_batch.hpp" @@ -49,6 +50,23 @@ std::unique_ptr Model::load(const std::string& gguf_path) { return m; } +int Model::resolve_prompt_index(const std::string& target_lang) const { + const ParakeetConfig& cfg = loader_.config(); + if (!cfg.prompt.present) return -1; + return cfg.prompt.resolve_index_or_throw(target_lang); +} + +// Apply the prompt-conditioning projection in place on a channels-first encoder +// output [d_model, Tout], if the model is prompt-conditioned. No-op otherwise. +static void maybe_apply_prompt(const ModelLoader& loader, std::vector& enc_out, + int d_model, int Tout, int prompt_index) { + if (!loader.config().prompt.present) return; + PromptKernel pk(loader); + std::vector projected; + pk.apply(enc_out, d_model, Tout, prompt_index, projected); + enc_out.swap(projected); +} + // Decode one item's encoder output (row-major [d_model, Tout], channels-first) // into a transcript. Mirrors the tail of transcribe_16k exactly. static std::string decode_enc_out(const ModelLoader& loader, @@ -82,8 +100,10 @@ static std::string decode_enc_out(const ModelLoader& loader, } std::string Model::transcribe_16k(const std::vector& pcm16k, - Decoder decoder) const { + Decoder decoder, + const std::string& target_lang) const { const ParakeetConfig& cfg = loader_.config(); + const int prompt_index = resolve_prompt_index(target_lang); // 1. Log-mel front end -> feats [n_mels, T]. On a non-CPU backend run the // heavy STFT/power/filterbank/log on the backend (GPU) via GpuMel; on CPU @@ -104,6 +124,11 @@ std::string Model::transcribe_16k(const std::vector& pcm16k, int d_model = 0, Tout = 0; encoder.forward(feats, n_mels, T, enc_out, d_model, Tout); + // 2b. Prompt conditioning (multilingual nemotron): project the encoder + // output with the selected language one-hot before decoding. No-op for + // other models (prompt.present == false). + maybe_apply_prompt(loader_, enc_out, d_model, Tout, prompt_index); + // Decide which head to use. const bool use_tdt = (decoder == Decoder::kTDT) || (decoder == Decoder::kDefault && arch_prefers_tdt(cfg.arch)); @@ -162,8 +187,10 @@ static void batch_enc_to_row_major(const std::vector>& enc_ou } std::vector Model::transcribe_16k_batch( - const std::vector>& pcms16k, Decoder decoder) const { + const std::vector>& pcms16k, Decoder decoder, + const std::string& target_lang) const { const ParakeetConfig& cfg = loader_.config(); + const int prompt_index = resolve_prompt_index(target_lang); const bool use_tdt = (decoder == Decoder::kTDT) || (decoder == Decoder::kDefault && arch_prefers_tdt(cfg.arch)); @@ -176,6 +203,11 @@ std::vector Model::transcribe_16k_batch( std::vector valid_Tout; encoder.forward_batch(mb, enc_outs, d_model, Tout, valid_Tout); + // 2b. Prompt conditioning per item (one language for the whole batch). No-op + // for non-prompt models. + for (int b = 0; b < mb.B; ++b) + maybe_apply_prompt(loader_, enc_outs[b], d_model, valid_Tout[b], prompt_index); + // 3. Decode (each enc_out is [d_model, valid_Tout[b]]). std::vector outs(mb.B); if (use_tdt) { @@ -202,7 +234,7 @@ std::vector Model::transcribe_16k_batch( std::vector Model::transcribe_pcm_batch( const std::vector>& pcms, int sample_rate, - Decoder decoder) const { + Decoder decoder, const std::string& target_lang) const { if (sample_rate <= 0) { throw std::runtime_error("parakeet: invalid sample_rate"); } @@ -210,7 +242,7 @@ std::vector Model::transcribe_pcm_batch( for (size_t i = 0; i < pcms.size(); ++i) r[i] = (sample_rate == 16000) ? pcms[i] : resample_linear(pcms[i], sample_rate, 16000); - return transcribe_16k_batch(r, decoder); + return transcribe_16k_batch(r, decoder, target_lang); } // Decode one item's encoder output (channels-first [d_model, Tout]) into a @@ -260,8 +292,10 @@ static Transcription decode_enc_out_with_timestamps( } Transcription Model::transcribe_16k_with_timestamps( - const std::vector& pcm16k, Decoder decoder) const { + const std::vector& pcm16k, Decoder decoder, + const std::string& target_lang) const { const ParakeetConfig& cfg = loader_.config(); + const int prompt_index = resolve_prompt_index(target_lang); // frame_sec = hop_length * subsampling_factor / sample_rate (= 0.08 s here). // This is NeMo's window_stride * subsampling_factor (window_stride = @@ -288,6 +322,9 @@ Transcription Model::transcribe_16k_with_timestamps( int d_model = 0, Tout = 0; encoder.forward(feats, n_mels, T, enc_out, d_model, Tout); + // 2b. Prompt conditioning (nemotron): project before decode. No-op otherwise. + maybe_apply_prompt(loader_, enc_out, d_model, Tout, prompt_index); + const bool use_tdt = (decoder == Decoder::kTDT) || (decoder == Decoder::kDefault && arch_prefers_tdt(cfg.arch)); @@ -297,8 +334,10 @@ Transcription Model::transcribe_16k_with_timestamps( } std::vector Model::transcribe_16k_batch_with_timestamps( - const std::vector>& pcms16k, Decoder decoder) const { + const std::vector>& pcms16k, Decoder decoder, + const std::string& target_lang) const { const ParakeetConfig& cfg = loader_.config(); + const int prompt_index = resolve_prompt_index(target_lang); const float frame_sec = (float)cfg.hop_length * (float)cfg.subsampling_factor / (float)cfg.sample_rate; const bool use_tdt = (decoder == Decoder::kTDT) @@ -311,6 +350,11 @@ std::vector Model::transcribe_16k_batch_with_timestamps( std::vector valid_Tout; encoder.forward_batch(mb, enc_outs, d_model, Tout, valid_Tout); + // Prompt conditioning per item (one language for the whole batch). No-op + // for non-prompt models. + for (int b = 0; b < mb.B; ++b) + maybe_apply_prompt(loader_, enc_outs[b], d_model, valid_Tout[b], prompt_index); + std::vector outs(mb.B); if (use_tdt) { // Batched transducer (TDT/RNNT) greedy decode with timestamps. Build @@ -344,7 +388,7 @@ std::vector Model::transcribe_16k_batch_with_timestamps( std::vector Model::transcribe_pcm_batch_with_timestamps( const std::vector>& pcms, int sample_rate, - Decoder decoder) const { + Decoder decoder, const std::string& target_lang) const { if (sample_rate <= 0) { throw std::runtime_error("parakeet: invalid sample_rate"); } @@ -352,50 +396,52 @@ std::vector Model::transcribe_pcm_batch_with_timestamps( for (size_t i = 0; i < pcms.size(); ++i) r[i] = (sample_rate == 16000) ? pcms[i] : resample_linear(pcms[i], sample_rate, 16000); - return transcribe_16k_batch_with_timestamps(r, decoder); + return transcribe_16k_batch_with_timestamps(r, decoder, target_lang); } std::string Model::transcribe_pcm(const std::vector& pcm, int sample_rate, - Decoder decoder) const { + Decoder decoder, const std::string& target_lang) const { if (sample_rate <= 0) { throw std::runtime_error("parakeet: invalid sample_rate"); } if (sample_rate == 16000) { - return transcribe_16k(pcm, decoder); + return transcribe_16k(pcm, decoder, target_lang); } std::vector pcm16k = resample_linear(pcm, sample_rate, 16000); - return transcribe_16k(pcm16k, decoder); + return transcribe_16k(pcm16k, decoder, target_lang); } std::string Model::transcribe_path(const std::string& wav_path, - Decoder decoder) const { + Decoder decoder, const std::string& target_lang) const { Audio audio; if (!load_audio_16k_mono(wav_path, audio)) { throw std::runtime_error("parakeet: failed to load audio: " + wav_path); } // load_audio_16k_mono already resamples to 16 kHz mono. - return transcribe_16k(audio.samples, decoder); + return transcribe_16k(audio.samples, decoder, target_lang); } Transcription Model::transcribe_with_timestamps( - const std::vector& pcm, int sample_rate, Decoder decoder) const { + const std::vector& pcm, int sample_rate, Decoder decoder, + const std::string& target_lang) const { if (sample_rate <= 0) { throw std::runtime_error("parakeet: invalid sample_rate"); } if (sample_rate == 16000) { - return transcribe_16k_with_timestamps(pcm, decoder); + return transcribe_16k_with_timestamps(pcm, decoder, target_lang); } std::vector pcm16k = resample_linear(pcm, sample_rate, 16000); - return transcribe_16k_with_timestamps(pcm16k, decoder); + return transcribe_16k_with_timestamps(pcm16k, decoder, target_lang); } Transcription Model::transcribe_path_with_timestamps( - const std::string& wav_path, Decoder decoder) const { + const std::string& wav_path, Decoder decoder, + const std::string& target_lang) const { Audio audio; if (!load_audio_16k_mono(wav_path, audio)) { throw std::runtime_error("parakeet: failed to load audio: " + wav_path); } - return transcribe_16k_with_timestamps(audio.samples, decoder); + return transcribe_16k_with_timestamps(audio.samples, decoder, target_lang); } } // namespace pk diff --git a/src/model.hpp b/src/model.hpp index 8b7d95c..30adeaa 100644 --- a/src/model.hpp +++ b/src/model.hpp @@ -26,21 +26,40 @@ class Model { // Transcribe raw mono float PCM. If `sample_rate != 16000` the audio is // linearly resampled to 16 kHz (via pk::resample_linear) before inference. - // Throws std::runtime_error on failure (e.g. unsupported arch). + // `target_lang` selects the language prompt for multilingual (nemotron) + // models (e.g. "en", "de", "auto"); empty -> the model default. It is + // ignored by non-prompt models. Throws std::runtime_error on failure (e.g. + // unsupported arch, or an unknown target_lang for a prompt model). std::string transcribe_pcm(const std::vector& pcm, int sample_rate, - Decoder decoder = Decoder::kDefault) const; + Decoder decoder = Decoder::kDefault, + const std::string& target_lang = "") const; // Transcribe a WAV file (loaded + resampled to 16 kHz mono via - // pk::load_audio_16k_mono). Throws std::runtime_error on failure. + // pk::load_audio_16k_mono). `target_lang` as in transcribe_pcm. Throws + // std::runtime_error on failure. std::string transcribe_path(const std::string& wav_path, - Decoder decoder = Decoder::kDefault) const; + Decoder decoder = Decoder::kDefault, + const std::string& target_lang = "") const; + + // Core orchestration: 16 kHz mono PCM -> transcript. Public so language-aware + // callers/tests can drive it directly with a resolved target_lang. + std::string transcribe_16k(const std::vector& pcm16k, + Decoder decoder = Decoder::kDefault, + const std::string& target_lang = "") const; + + // Resolve a target_lang (locale string) to a prompt index using the model's + // dictionary. Empty string -> the model's default_lang. Returns -1 and is + // ignored when the model is not prompt-conditioned. Throws std::runtime_error + // on an unknown locale for a prompt model (message lists a few valid keys). + int resolve_prompt_index(const std::string& target_lang) const; // Transcribe a batch of mono float PCM clips. Each is resampled to 16 kHz if // needed, then all run through the batched encoder; decode is per item. // Returns one transcript per input, in order. std::vector transcribe_pcm_batch( const std::vector>& pcms, int sample_rate, - Decoder decoder = Decoder::kDefault) const; + Decoder decoder = Decoder::kDefault, + const std::string& target_lang = "") const; // Transcribe raw mono float PCM, returning the flat text plus per-word and // per-token timestamps + confidence (matching NeMo timestamps=True + @@ -48,19 +67,22 @@ class Model { // resampled to 16 kHz first. Throws std::runtime_error on failure. Transcription transcribe_with_timestamps( const std::vector& pcm, int sample_rate, - Decoder decoder = Decoder::kDefault) const; + Decoder decoder = Decoder::kDefault, + const std::string& target_lang = "") const; // Convenience: transcribe a WAV file with timestamps + confidence. Transcription transcribe_path_with_timestamps( const std::string& wav_path, - Decoder decoder = Decoder::kDefault) const; + Decoder decoder = Decoder::kDefault, + const std::string& target_lang = "") const; // Batched timestamped transcription. Each clip is resampled to 16 kHz if // needed, all run through the batched encoder; decode + timestamp extraction // are per item. Returns one Transcription per input, in order. std::vector transcribe_pcm_batch_with_timestamps( const std::vector>& pcms, int sample_rate, - Decoder decoder = Decoder::kDefault) const; + Decoder decoder = Decoder::kDefault, + const std::string& target_lang = "") const; const ParakeetConfig& config() const { return loader_.config(); } @@ -75,26 +97,23 @@ class Model { private: Model() = default; - // Core orchestration: 16 kHz mono PCM -> transcript. Shared by the two - // public entry points (transcribe_pcm resamples first; transcribe_path - // loads the WAV first). - std::string transcribe_16k(const std::vector& pcm16k, - Decoder decoder) const; - // Core batched orchestration: N 16 kHz clips -> N transcripts. Stacks mels, // runs forward_batch, decodes each item with the existing greedy decoders. std::vector transcribe_16k_batch( - const std::vector>& pcms16k, Decoder decoder) const; + const std::vector>& pcms16k, Decoder decoder, + const std::string& target_lang = "") const; // Core batched timestamped orchestration: N 16 kHz clips -> N Transcriptions. std::vector transcribe_16k_batch_with_timestamps( - const std::vector>& pcms16k, Decoder decoder) const; + const std::vector>& pcms16k, Decoder decoder, + const std::string& target_lang = "") const; // Core orchestration for the timestamps path: 16 kHz mono PCM -> full // Transcription (text + per-token TokenInfo + grouped words). Shared by the // two timestamp entry points. Transcription transcribe_16k_with_timestamps( - const std::vector& pcm16k, Decoder decoder) const; + const std::vector& pcm16k, Decoder decoder, + const std::string& target_lang = "") const; ModelLoader loader_; }; diff --git a/src/model_loader.cpp b/src/model_loader.cpp index 218dc91..a1a0d66 100644 --- a/src/model_loader.cpp +++ b/src/model_loader.cpp @@ -8,7 +8,22 @@ #include #include #include +#include namespace pk { + +int PromptCfg::resolve_index_or_throw(const std::string& target_lang) const { + const std::string lang = target_lang.empty() ? default_lang : target_lang; + int idx = lang_to_index(lang); + if (idx < 0) { + std::string sample; + for (size_t i = 0; i < dict_keys.size() && i < 8; ++i) + sample += (i ? ", " : "") + dict_keys[i]; + throw std::runtime_error("parakeet: unknown target_lang '" + lang + + "'. Valid examples: " + sample + ", ..."); + } + return idx; +} + static uint32_t kv_u32(gguf_context* g, const char* k, uint32_t d=0){ int64_t id = gguf_find_key(g,k); return id<0 ? d : (uint32_t)gguf_get_val_u32(g,id); } @@ -25,6 +40,16 @@ static std::vector kv_i32_arr(gguf_context* g, const char* k){ } return out; } +static std::vector kv_str_arr(gguf_context* g, const char* k){ + std::vector out; + int64_t id = gguf_find_key(g,k); + if(id>=0 && gguf_get_arr_type(g,id)==GGUF_TYPE_STRING){ + size_t n = gguf_get_arr_n(g,id); + out.resize(n); + for(size_t i=0;i present=false and the engine skips the prompt stage entirely. + cfg_.prompt.present = kv_bool(gguf_, "parakeet.prompt.present", false); + if(cfg_.prompt.present){ + cfg_.prompt.num_prompts = kv_u32(gguf_, "parakeet.prompt.num_prompts", 0); + cfg_.prompt.default_lang = kv_str(gguf_, "parakeet.prompt.default_lang", ""); + cfg_.prompt.dict_keys = kv_str_arr(gguf_, "parakeet.prompt.dictionary.keys"); + cfg_.prompt.dict_vals = kv_i32_arr(gguf_, "parakeet.prompt.dictionary.values"); + } if(cfg_.att_context_style != "regular"){ StreamingCfg& s = cfg_.streaming; s.chunk_size = kv_i32_arr(gguf_, "parakeet.streaming.chunk_size"); diff --git a/src/model_loader.hpp b/src/model_loader.hpp index 9947bd1..f181f24 100644 --- a/src/model_loader.hpp +++ b/src/model_loader.hpp @@ -24,6 +24,26 @@ struct StreamingCfg { int32_t drop_extra_pre_encoded=0; bool present=false; // true only for streaming models }; +// Prompt-conditioning config (multilingual nemotron). present=false for all +// existing models, which then skip the prompt stage entirely. +struct PromptCfg { + bool present = false; + uint32_t num_prompts = 0; + std::string default_lang; // e.g. "auto" + std::vector dict_keys; // locale strings + std::vector dict_vals; // parallel prompt indices + // Resolve a locale to its prompt index; -1 if unknown. + int lang_to_index(const std::string& lang) const { + for (size_t i = 0; i < dict_keys.size(); ++i) + if (dict_keys[i] == lang) return (int)dict_vals[i]; + return -1; + } + // Resolve target_lang to its prompt index, applying the model default for an + // empty string. THROWS std::runtime_error on an unknown locale. Shared by the + // offline (Model::resolve_prompt_index) and streaming (StreamingSession ctor) + // paths so both reject typos identically (matches the C-API contract). + int resolve_index_or_throw(const std::string& target_lang) const; +}; struct ParakeetConfig { std::string arch; // encoder @@ -36,7 +56,9 @@ struct ParakeetConfig { std::string att_context_style="regular"; // or "chunked_limited" bool causal_downsampling=false; // causal subsampling pad bool conv_causal=false; // causal depthwise conv pad + bool use_bias=true; // false for nemotron (encoder linears have no bias) StreamingCfg streaming; + PromptCfg prompt; // prompt conditioning (present=false for non-prompt) // preprocessor uint32_t sample_rate=16000, n_mels=0, n_fft=0, win_length=0, hop_length=0; float preemph=0.0f, mag_power=2.0f, log_zero_guard=0.0f; diff --git a/src/parakeet_capi.cpp b/src/parakeet_capi.cpp index a11ffca..40e8577 100644 --- a/src/parakeet_capi.cpp +++ b/src/parakeet_capi.cpp @@ -16,7 +16,9 @@ #include // ABI version. Bump on breaking changes. -#define PARAKEET_CAPI_ABI_VERSION 2 +// v3: target_lang variants (transcribe_path_lang / transcribe_pcm_lang / +// stream_begin_lang) for multilingual prompt-conditioned (nemotron) models. +#define PARAKEET_CAPI_ABI_VERSION 3 // The opaque context: a loaded model plus a buffer for the last error message. struct parakeet_ctx { @@ -204,13 +206,16 @@ extern "C" void parakeet_capi_free(parakeet_ctx* ctx) { delete ctx; // safe on nullptr; ~unique_ptr releases the model. } -extern "C" char* parakeet_capi_transcribe_path(parakeet_ctx* ctx, - const char* wav_path, int decoder) { +extern "C" char* parakeet_capi_transcribe_path_lang(parakeet_ctx* ctx, + const char* wav_path, int decoder, + const char* target_lang) { if (!ctx) return nullptr; if (!ctx->model) { ctx->last_error = "context has no loaded model"; return nullptr; } if (!wav_path) { ctx->last_error = "wav_path is NULL"; return nullptr; } + // NULL / "" -> model default language (ignored by non-prompt models). + const std::string lang = target_lang ? target_lang : ""; try { - std::string text = ctx->model->transcribe_path(wav_path, to_decoder(decoder)); + std::string text = ctx->model->transcribe_path(wav_path, to_decoder(decoder), lang); ctx->last_error.clear(); char* out = dup_to_c(text); if (!out) { ctx->last_error = "out of memory"; return nullptr; } @@ -224,15 +229,24 @@ extern "C" char* parakeet_capi_transcribe_path(parakeet_ctx* ctx, } } -extern "C" char* parakeet_capi_transcribe_pcm(parakeet_ctx* ctx, const float* samples, - int n_samples, int sample_rate, - int decoder) { +extern "C" char* parakeet_capi_transcribe_path(parakeet_ctx* ctx, + const char* wav_path, int decoder) { + // Delegate with the model default language. + return parakeet_capi_transcribe_path_lang(ctx, wav_path, decoder, nullptr); +} + +extern "C" char* parakeet_capi_transcribe_pcm_lang(parakeet_ctx* ctx, + const float* samples, int n_samples, + int sample_rate, int decoder, + const char* target_lang) { if (!ctx) return nullptr; if (!ctx->model) { ctx->last_error = "context has no loaded model"; return nullptr; } if (!samples || n_samples < 0) { ctx->last_error = "invalid samples buffer"; return nullptr; } + // NULL / "" -> model default language (ignored by non-prompt models). + const std::string lang = target_lang ? target_lang : ""; try { std::vector pcm(samples, samples + n_samples); - std::string text = ctx->model->transcribe_pcm(pcm, sample_rate, to_decoder(decoder)); + std::string text = ctx->model->transcribe_pcm(pcm, sample_rate, to_decoder(decoder), lang); ctx->last_error.clear(); char* out = dup_to_c(text); if (!out) { ctx->last_error = "out of memory"; return nullptr; } @@ -246,6 +260,14 @@ extern "C" char* parakeet_capi_transcribe_pcm(parakeet_ctx* ctx, const float* sa } } +extern "C" char* parakeet_capi_transcribe_pcm(parakeet_ctx* ctx, const float* samples, + int n_samples, int sample_rate, + int decoder) { + // Delegate with the model default language. + return parakeet_capi_transcribe_pcm_lang(ctx, samples, n_samples, sample_rate, + decoder, nullptr); +} + extern "C" int parakeet_capi_transcribe_pcm_batch(parakeet_ctx* ctx, const float* const* samples, const int* n_samples, int n_clips, @@ -428,18 +450,21 @@ std::string feed_available(parakeet_stream* s, bool flush, int& eou_flag) { } // namespace -extern "C" parakeet_stream* parakeet_capi_stream_begin(parakeet_ctx* ctx) { +extern "C" parakeet_stream* parakeet_capi_stream_begin_lang(parakeet_ctx* ctx, + const char* target_lang) { if (!ctx) return nullptr; if (!ctx->model) { ctx->last_error = "context has no loaded model"; return nullptr; } if (!ctx->model->config().streaming.present) { ctx->last_error = "model is not a cache-aware streaming model"; return nullptr; } + // NULL / "" -> model default language (ignored by non-prompt models). + const std::string lang = target_lang ? target_lang : ""; try { auto* s = new (std::nothrow) parakeet_stream(); if (!s) { ctx->last_error = "out of memory"; return nullptr; } s->ctx = ctx; - s->sess = std::make_unique(ctx->model->loader()); + s->sess = std::make_unique(ctx->model->loader(), lang); s->mel = std::make_unique(ctx->model->loader()); s->n_mels = s->mel->n_mels(); ctx->last_error.clear(); @@ -453,6 +478,11 @@ extern "C" parakeet_stream* parakeet_capi_stream_begin(parakeet_ctx* ctx) { } } +extern "C" parakeet_stream* parakeet_capi_stream_begin(parakeet_ctx* ctx) { + // Delegate with the model default language. + return parakeet_capi_stream_begin_lang(ctx, nullptr); +} + extern "C" char* parakeet_capi_stream_feed(parakeet_stream* s, const float* pcm, int n_samples, int* eou_out) { if (eou_out) *eou_out = 0; diff --git a/src/prompt_kernel.cpp b/src/prompt_kernel.cpp new file mode 100644 index 0000000..6d5e476 --- /dev/null +++ b/src/prompt_kernel.cpp @@ -0,0 +1,68 @@ +#include "prompt_kernel.hpp" +#include "backend.hpp" +#include "ggml_graph.hpp" +#include "ggml.h" +#include +#include + +namespace pk { + +PromptKernel::PromptKernel(const ModelLoader& ml) : ml_(ml) { + const ParakeetConfig& c = ml_.config(); + present_ = c.prompt.present; + num_prompts_ = (int)c.prompt.num_prompts; + d_model_ = (int)c.d_model; +} + +void PromptKernel::apply(const std::vector& enc_out, int d_model, int T, + int prompt_index, std::vector& out) const { + if (!present_) { out = enc_out; return; } + assert((int)enc_out.size() == d_model * T && "enc_out size mismatch"); + assert(prompt_index >= 0 && prompt_index < num_prompts_ && "prompt_index out of range"); + + // Build the input in row-major [T, IN] (ggml ne0 = IN, the fastest axis): + // rows [0, d_model) hold the encoder features (transposed from the + // channels-first [d_model, T] input); rows [d_model, d_model+P) hold the + // constant one-hot language vector. So the NeMo cat([encoded, onehot]) is a + // fill, and the two Linear layers are plain matmuls. + const int P = num_prompts_; + const int IN = d_model + P; + std::vector xbuf((size_t)IN * T, 0.0f); + for (int t = 0; t < T; ++t) { + for (int c = 0; c < d_model; ++c) + xbuf[(size_t)t * IN + c] = enc_out[(size_t)c * T + t]; + xbuf[(size_t)t * IN + d_model + prompt_index] = 1.0f; + } + + // ggml graph: y = W2 · ReLU(W0 · x + b0) + b2, all on the persistent backend. + // prompt_kernel.0.weight ggml ne=[IN, 2D] prompt_kernel.0.bias ne=[2D] + // prompt_kernel.2.weight ggml ne=[2D, D] prompt_kernel.2.bias ne=[D] + std::vector y_td; // run_graph fills row-major [T, D] (ggml ne=[D, T]) + bool ok = pk::run_graph(0, 0, + [&](ggml_context* ctx) -> ggml_tensor* { + int64_t x_ne[2] = { IN, T }; + ggml_tensor* x = pk::graph_input_tensor(ctx, GGML_TYPE_F32, 2, x_ne, + xbuf.data(), (size_t)IN * T * sizeof(float)); + ggml_tensor* W0 = pk::clone_weight(ctx, ml_, "prompt_kernel.0.weight"); + ggml_tensor* b0 = pk::clone_weight(ctx, ml_, "prompt_kernel.0.bias"); + ggml_tensor* W2 = pk::clone_weight(ctx, ml_, "prompt_kernel.2.weight"); + ggml_tensor* b2 = pk::clone_weight(ctx, ml_, "prompt_kernel.2.bias"); + ggml_tensor* h = ggml_mul_mat(ctx, W0, x); // [2D, T] + h = ggml_add(ctx, h, b0); // bias broadcasts over T + h = ggml_relu(ctx, h); + ggml_tensor* yv = ggml_mul_mat(ctx, W2, h); // [D, T] + yv = ggml_add(ctx, yv, b2); // bias broadcasts over T + return yv; // ne=[D, T] -> row-major [T, D] + }, y_td); + assert(ok && "prompt_kernel graph failed"); + (void)ok; + + // Transpose row-major [T, D] back to channels-first [d_model, T] so the + // result drops in for the raw encoder output the decoder consumes. + out.resize((size_t)d_model * T); + for (int t = 0; t < T; ++t) + for (int c = 0; c < d_model; ++c) + out[(size_t)c * T + t] = y_td[(size_t)t * d_model + c]; +} + +} // namespace pk diff --git a/src/prompt_kernel.hpp b/src/prompt_kernel.hpp new file mode 100644 index 0000000..14c6bfe --- /dev/null +++ b/src/prompt_kernel.hpp @@ -0,0 +1,37 @@ +#pragma once +#include "model_loader.hpp" +#include + +namespace pk { + +// Post-encoder prompt conditioning for multilingual nemotron models. +// +// Mirrors NeMo EncDecRNNTBPEModelWithPrompt.forward(): +// encoded[T, D] = transpose(encoder_out[D, T]) +// onehot[T, P] = one_hot(prompt_index, num_prompts) broadcast over T +// out[T, D] = prompt_kernel(cat([encoded, onehot])) // Linear->ReLU->Linear +// then transposed back to channels-first [D, T]. +// +// The one-hot is constant over time (one language per utterance), so this is a +// concat + two matmuls + ReLU. Weights: prompt_kernel.0.{weight,bias} (D+P->2D), +// prompt_kernel.2.{weight,bias} (2D->D). present() is false for non-prompt +// models (callers skip apply()). +class PromptKernel { +public: + explicit PromptKernel(const ModelLoader& ml); + bool present() const { return present_; } + + // Apply the prompt projection to a channels-first encoder output + // enc_out[d_model, T] for the given prompt_index, writing channels-first + // out[d_model, T]. If present()==false this is a no-op copy (out = enc_out). + void apply(const std::vector& enc_out, int d_model, int T, + int prompt_index, std::vector& out) const; + +private: + const ModelLoader& ml_; + bool present_ = false; + int num_prompts_ = 0; + int d_model_ = 0; +}; + +} // namespace pk diff --git a/src/streaming.cpp b/src/streaming.cpp index 2006295..5e4b4b0 100644 --- a/src/streaming.cpp +++ b/src/streaming.cpp @@ -6,11 +6,24 @@ namespace pk { -StreamingSession::StreamingSession(const ModelLoader& ml) - : ml_(ml), enc_(ml), pred_(ml), joint_(ml) { +StreamingSession::StreamingSession(const ModelLoader& ml, const std::string& target_lang) + : ml_(ml), enc_(ml), pred_(ml), joint_(ml), prompt_(ml) { const ParakeetConfig& cfg = ml.config(); d_model_ = (int)cfg.d_model; blank_id_ = (int)cfg.blank_id; + + // Resolve the language prompt index for multilingual (nemotron) models. The + // one-hot is constant over time (one language per utterance), so the prompt + // projection is applied per chunk in feed_mel_chunk using this fixed index. + // Non-prompt models leave prompt_index_ = -1 and skip prompt_.apply(). + if (cfg.prompt.present) { + // Empty target_lang -> the model default; an unknown locale THROWS + // std::runtime_error (same message as Model::resolve_prompt_index), so a + // typo (e.g. --lang xx) fails loudly instead of silently mis-transcribing. + // Matches the offline path and the parakeet_capi_stream_begin_lang + // contract (NULL + ctx last_error on an unknown locale). + prompt_index_ = cfg.prompt.resolve_index_or_throw(target_lang); + } // Greedy max symbols per frame, from model metadata (NeMo default 10); // matches the offline pk::transcribe path in model.cpp. max_symbols_ = (int)cfg.max_symbols; @@ -90,6 +103,26 @@ std::vector StreamingSession::feed_mel_chunk(const std::vector& return {}; } + // 1b. Prompt conditioning (nemotron multilingual): project the chunk's + // encoder frames through prompt_kernel for the resolved language before + // the RNN-T decode. The one-hot is constant over time, so applying it + // per chunk is exact (== the offline forward's single application). + // prompt_.apply() wants channels-first [d_model, valid]; enc_.step gives + // time-major [valid, d_model], so transpose in, apply, transpose back. + // No-op for non-prompt models (prompt_.present()==false): enc_frames is + // left byte-identical. + if (prompt_.present()) { + std::vector chunk_cf((size_t)d_model_ * n_valid); // [d_model, valid] + for (int t = 0; t < n_valid; ++t) + for (int c = 0; c < d_model_; ++c) + chunk_cf[(size_t)c * n_valid + t] = enc_frames[(size_t)t * d_model_ + c]; + std::vector projected; + prompt_.apply(chunk_cf, d_model_, n_valid, prompt_index_, projected); // [d_model, valid] + for (int t = 0; t < n_valid; ++t) + for (int c = 0; c < d_model_; ++c) + enc_frames[(size_t)t * d_model_ + c] = projected[(size_t)c * n_valid + t]; + } + // 2. RNN-T greedy over the new encoder frames, carrying the decoder state // across chunks (do NOT reset). Appends to state_.hyp and returns the ids // emitted in this chunk, with their LOCAL frame index in [0, n_valid) and @@ -188,7 +221,12 @@ void run_stream_over_pcm( const std::vector& pcm16k, const std::function&, - const std::vector&)>& on_chunk) { + const std::vector&)>& on_chunk, + const std::string& target_lang) { + // target_lang is intentionally unused here: the session already carries its + // resolved prompt index from construction. The parameter exists so callers + // can route a language through a single entry point (Phase 4 C-API/CLI). + (void)target_lang; // 1. Full-clip mel [n_mels, T] (feat-major inner=T), matching the offline / // NeMo online_normalization=False reference (normalization over the whole // clip). The streaming numerics come from the carried encoder/decoder diff --git a/src/streaming.hpp b/src/streaming.hpp index d29f2b2..083c2db 100644 --- a/src/streaming.hpp +++ b/src/streaming.hpp @@ -4,6 +4,7 @@ #include "prediction.hpp" #include "joint.hpp" #include "rnnt.hpp" +#include "prompt_kernel.hpp" #include "decode_types.hpp" #include "transcription.hpp" #include @@ -52,7 +53,13 @@ struct EouEvent { // already including pre-encode-cache overlap, matching test_streaming_encoder). class StreamingSession { public: - explicit StreamingSession(const ModelLoader& ml); + // `target_lang` selects the language prompt for multilingual (nemotron) + // prompt-conditioned models (e.g. "en", "de", "auto"); empty -> the model's + // default_lang. It is ignored by non-prompt models (prompt_.present()==false). + // For a prompt model an unknown locale THROWS std::runtime_error (matching the + // offline Model::resolve_prompt_index and the C-API stream_begin_lang + // contract), so a typo fails loudly rather than silently mis-transcribing. + explicit StreamingSession(const ModelLoader& ml, const std::string& target_lang = ""); // Reset the encoder caches AND the decoder state to a fresh stream. void reset(); @@ -126,6 +133,8 @@ class StreamingSession { StreamingEncoder enc_; PredictionNet pred_; Joint joint_; + PromptKernel prompt_; // post-encoder language conditioning (nemotron) + int prompt_index_ = -1; // resolved language prompt index (-1 if absent) int d_model_; int blank_id_; int max_symbols_; @@ -183,12 +192,17 @@ class StreamingSession { // it processes the supplied clip in one pass (the encoder/decoder are still // driven chunk-by-chunk with carried state — the streaming numerics, not a // real-time PCM feeder). +// `target_lang` is currently unused by the driver itself (the session already +// owns its resolved prompt index from construction); it is accepted so callers +// (C-API/CLI, Phase 4) can pass a language through one entry point. Defaults to +// "" (the session's configured language). void run_stream_over_pcm( StreamingSession& sess, const ModelLoader& ml, const std::vector& pcm16k, const std::function& chunk_events, const std::vector& chunk_words)>& on_chunk - = nullptr); + = nullptr, + const std::string& target_lang = ""); } // namespace pk diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 181ad91..71a3999 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -33,6 +33,7 @@ pk_add_test(test_prediction_step) pk_add_test(test_prediction_step_batch) pk_add_test(test_joint) pk_add_test(test_joint_step_batch) +pk_add_test(test_prompt_kernel) pk_add_test(test_transducer_core) pk_add_test(test_tdt_greedy) pk_add_test(test_transducer_greedy_batch) @@ -48,7 +49,9 @@ pk_add_test(test_transcribe_0_6b) pk_add_test(test_transcribe_ctc) pk_add_test(test_transcribe_rnnt) pk_add_test(test_transcribe_eou) +pk_add_test(test_transcribe_nemotron) pk_add_test(test_streaming_decode) +pk_add_test(test_streaming_nemotron) pk_add_test(test_streaming_mel) pk_add_test(test_capi) pk_add_test(test_capi_batch) @@ -59,27 +62,27 @@ set_tests_properties(test_model_loader test_mel test_mel_gpu test_subsampling te test_conformer test_conformer_batch test_conv_eou test_encoder test_encoder_batch test_encoder_batch_local test_encoder_eou test_streaming_encoder test_ctc test_prediction test_prediction_step test_prediction_step_batch - test_joint test_joint_step_batch test_transducer_core test_tdt_greedy + test_joint test_joint_step_batch test_prompt_kernel test_transducer_core test_tdt_greedy test_transducer_greedy_batch test_transducer_greedy_batch_rnnt test_timestamps_tokens test_timestamps test_transcribe_batch_ts test_tokenizer test_transcribe test_transcribe_speech test_transcribe_tdt test_transcribe_0_6b - test_transcribe_ctc test_transcribe_rnnt test_transcribe_eou - test_streaming_decode test_streaming_mel test_capi test_capi_batch test_capi_stream + test_transcribe_ctc test_transcribe_rnnt test_transcribe_eou test_transcribe_nemotron + test_streaming_decode test_streaming_nemotron test_streaming_mel test_capi test_capi_batch test_capi_stream test_capi_timestamps test_capi_batch_json PROPERTIES LABELS "model") # These tests read fixtures/baselines via paths relative to the project root. set_tests_properties(test_mel test_mel_gpu test_subsampling test_subsampling_batch test_relpos_attention test_relpos_attention_batch test_conformer test_conformer_batch test_conv_eou test_encoder test_encoder_batch test_encoder_batch_local test_encoder_eou test_streaming_encoder test_ctc test_prediction test_prediction_step test_prediction_step_batch - test_joint test_joint_step_batch + test_joint test_joint_step_batch test_prompt_kernel test_transducer_core test_tdt_greedy test_transducer_greedy_batch test_transducer_greedy_batch_rnnt test_timestamps_tokens test_timestamps test_transcribe_batch_ts test_tokenizer test_transcribe test_transcribe_speech test_transcribe_tdt test_transcribe_0_6b - test_transcribe_ctc test_transcribe_rnnt test_transcribe_eou - test_streaming_decode test_streaming_mel test_capi test_capi_batch test_capi_stream + test_transcribe_ctc test_transcribe_rnnt test_transcribe_eou test_transcribe_nemotron + test_streaming_decode test_streaming_nemotron test_streaming_mel test_capi test_capi_batch test_capi_stream test_capi_timestamps test_capi_batch_json PROPERTIES WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}) diff --git a/tests/parity.hpp b/tests/parity.hpp index 1664843..c42a474 100644 --- a/tests/parity.hpp +++ b/tests/parity.hpp @@ -94,6 +94,17 @@ inline bool load_baseline_i32(const std::string& path, const std::string& name, return true; } +// Read a uint32 KV entry from a baseline gguf (0 if absent / unopenable). +inline uint32_t pktest_read_u32(const std::string& path, const std::string& key) { + gguf_init_params p{ /*no_alloc=*/true, /*ctx=*/nullptr }; + gguf_context* g = gguf_init_from_file(path.c_str(), p); + if (!g) return 0; + int64_t id = gguf_find_key(g, key.c_str()); + uint32_t v = (id < 0) ? 0u : gguf_get_val_u32(g, id); + gguf_free(g); + return v; +} + // Load a string KV entry from a baseline gguf. inline bool load_kv_str(const std::string& path, const std::string& key, std::string& out) { diff --git a/tests/test_capi.cpp b/tests/test_capi.cpp index e00a148..43b81e4 100644 --- a/tests/test_capi.cpp +++ b/tests/test_capi.cpp @@ -15,7 +15,9 @@ // WORKING_DIRECTORY (tests run from the project root; wav path is relative) // // Env: -// PARAKEET_TEST_GGUF model weights (skip 77 if unset) +// PARAKEET_TEST_GGUF model weights (skip 77 if unset) +// PARAKEET_TEST_GGUF_NEMOTRON prompt (multilingual) model; if set, also +// exercises the target_lang C-API variants static const char* kExpected = "Well, I don't wish to see it any more, observed Phoebe, turning away her " @@ -36,39 +38,128 @@ int main() { return 1; } + // The 110m anchor (PARAKEET_TEST_GGUF) and the prompt/multilingual model + // (PARAKEET_TEST_GGUF_NEMOTRON) are independent: each block runs only when + // its env var is set. If NEITHER is set the test skips (77). + bool ran_any = false; + const char* gguf = std::getenv("PARAKEET_TEST_GGUF"); - if (!gguf) { - std::fprintf(stderr, "test_capi: PARAKEET_TEST_GGUF not set; skip\n"); - return 77; - } + if (gguf) { + ran_any = true; + parakeet_ctx* ctx = parakeet_capi_load(gguf); + if (!ctx) { + std::fprintf(stderr, "test_capi: parakeet_capi_load failed for %s\n", gguf); + return 1; + } - parakeet_ctx* ctx = parakeet_capi_load(gguf); - if (!ctx) { - std::fprintf(stderr, "test_capi: parakeet_capi_load failed for %s\n", gguf); - return 1; - } + // decoder == 2 -> TDT/transducer head. + char* text = parakeet_capi_transcribe_path(ctx, "tests/fixtures/speech.wav", 2); + if (!text) { + std::fprintf(stderr, "test_capi: transcribe_path returned NULL: %s\n", + parakeet_capi_last_error(ctx)); + parakeet_capi_free(ctx); + return 1; + } + + std::fprintf(stderr, "test_capi: got = %s\n", text); + std::fprintf(stderr, "test_capi: expected = %s\n", kExpected); - // decoder == 2 -> TDT/transducer head. - char* text = parakeet_capi_transcribe_path(ctx, "tests/fixtures/speech.wav", 2); - if (!text) { - std::fprintf(stderr, "test_capi: transcribe_path returned NULL: %s\n", - parakeet_capi_last_error(ctx)); + const bool match = std::strcmp(text, kExpected) == 0; + parakeet_capi_free_string(text); parakeet_capi_free(ctx); - return 1; + + if (!match) { + std::fprintf(stderr, "test_capi: MISMATCH vs NeMo TDT reference\n"); + return 1; + } + std::fprintf(stderr, "test_capi: PASS (word-for-word match with NeMo TDT)\n"); + } else { + std::fprintf(stderr, "test_capi: PARAKEET_TEST_GGUF not set; skip anchor block\n"); } - std::fprintf(stderr, "test_capi: got = %s\n", text); - std::fprintf(stderr, "test_capi: expected = %s\n", kExpected); + // Prompt (multilingual) model: exercise the target_lang variants. Skipped + // cleanly when PARAKEET_TEST_GGUF_NEMOTRON is unset. + const char* nemotron = std::getenv("PARAKEET_TEST_GGUF_NEMOTRON"); + if (nemotron) { + ran_any = true; + parakeet_ctx* nctx = parakeet_capi_load(nemotron); + if (!nctx) { + std::fprintf(stderr, "test_capi: load failed for nemotron %s\n", nemotron); + return 1; + } - const bool match = std::strcmp(text, kExpected) == 0; - parakeet_capi_free_string(text); - parakeet_capi_free(ctx); + // A known language prompt must transcribe (non-NULL). + char* de = parakeet_capi_transcribe_path_lang( + nctx, "tests/fixtures/speech.wav", 0, "de"); + if (!de) { + std::fprintf(stderr, "test_capi: transcribe_path_lang(de) returned NULL: %s\n", + parakeet_capi_last_error(nctx)); + parakeet_capi_free(nctx); + return 1; + } + std::fprintf(stderr, "test_capi: nemotron de = %s\n", de); + parakeet_capi_free_string(de); - if (!match) { - std::fprintf(stderr, "test_capi: MISMATCH vs NeMo TDT reference\n"); - return 1; + // An unknown locale must fail cleanly: NULL + non-empty last_error. + char* bad_lang = parakeet_capi_transcribe_path_lang( + nctx, "tests/fixtures/speech.wav", 0, "zzz"); + if (bad_lang != nullptr) { + std::fprintf(stderr, "test_capi: transcribe_path_lang(zzz) returned non-NULL\n"); + parakeet_capi_free_string(bad_lang); + parakeet_capi_free(nctx); + return 1; + } + const char* err = parakeet_capi_last_error(nctx); + if (!err || err[0] == '\0') { + std::fprintf(stderr, "test_capi: unknown-lang did not set last_error\n"); + parakeet_capi_free(nctx); + return 1; + } + std::fprintf(stderr, "test_capi: nemotron unknown-lang error = %s\n", err); + + // Streaming path must reject an unknown locale exactly like the offline + // path: NULL + non-empty last_error (no silent fallback to the default). + parakeet_stream* bad_stream = parakeet_capi_stream_begin_lang(nctx, "zzz"); + if (bad_stream != nullptr) { + std::fprintf(stderr, + "test_capi: stream_begin_lang(zzz) returned non-NULL\n"); + parakeet_capi_stream_free(bad_stream); + parakeet_capi_free(nctx); + return 1; + } + const char* serr = parakeet_capi_last_error(nctx); + if (!serr || serr[0] == '\0') { + std::fprintf(stderr, + "test_capi: stream unknown-lang did not set last_error\n"); + parakeet_capi_free(nctx); + return 1; + } + std::fprintf(stderr, "test_capi: nemotron stream unknown-lang error = %s\n", + serr); + + // A known language prompt must begin a stream (non-NULL); free it. + parakeet_stream* ok_stream = parakeet_capi_stream_begin_lang(nctx, "en"); + if (!ok_stream) { + std::fprintf(stderr, + "test_capi: stream_begin_lang(en) returned NULL: %s\n", + parakeet_capi_last_error(nctx)); + parakeet_capi_free(nctx); + return 1; + } + parakeet_capi_stream_free(ok_stream); + + parakeet_capi_free(nctx); + std::fprintf(stderr, "test_capi: PASS nemotron target_lang variants\n"); + } else { + std::fprintf(stderr, + "test_capi: PARAKEET_TEST_GGUF_NEMOTRON not set; skip prompt block\n"); } - std::fprintf(stderr, "test_capi: PASS (word-for-word match with NeMo TDT)\n"); + if (!ran_any) { + std::fprintf(stderr, + "test_capi: no model env var set (PARAKEET_TEST_GGUF / " + "PARAKEET_TEST_GGUF_NEMOTRON); skip\n"); + return 77; + } return 0; } diff --git a/tests/test_model_loader.cpp b/tests/test_model_loader.cpp index a3912d9..04df63b 100644 --- a/tests/test_model_loader.cpp +++ b/tests/test_model_loader.cpp @@ -5,21 +5,45 @@ int main() { const char* env = std::getenv("PARAKEET_TEST_GGUF"); - if (!env) { std::fprintf(stderr, "PARAKEET_TEST_GGUF not set; skipping\n"); return 77; } - pk::ModelLoader ml; - if (!ml.load(env)) { std::fprintf(stderr, "load failed\n"); return 1; } - const pk::ParakeetConfig& c = ml.config(); - if (c.arch.empty()) { std::fprintf(stderr, "empty arch\n"); return 1; } - if (c.d_model == 0 || c.n_layers == 0 || c.n_heads == 0) { std::fprintf(stderr, "bad encoder dims\n"); return 1; } - if (c.vocab_size == 0) { std::fprintf(stderr, "bad vocab\n"); return 1; } - if (c.blank_id != c.vocab_size) { std::fprintf(stderr, "blank!=vocab\n"); return 1; } - // mel filterbank tensor must be present - if (ml.tensor("preprocessor.featurizer.fb") == nullptr) { std::fprintf(stderr, "no fb\n"); return 1; } - // first conformer layer norm must be present (verbatim name) - if (ml.tensor("encoder.layers.0.norm_feed_forward1.weight") == nullptr) { - std::fprintf(stderr, "no layer0 norm\n"); return 1; + const char* npath = std::getenv("PARAKEET_TEST_GGUF_NEMOTRON"); + if (!env && !npath) { + std::fprintf(stderr, "PARAKEET_TEST_GGUF / PARAKEET_TEST_GGUF_NEMOTRON not set; skipping\n"); + return 77; + } + + // Base model checks (only when PARAKEET_TEST_GGUF points at a fixture). + if (env) { + pk::ModelLoader ml; + if (!ml.load(env)) { std::fprintf(stderr, "load failed\n"); return 1; } + const pk::ParakeetConfig& c = ml.config(); + if (c.arch.empty()) { std::fprintf(stderr, "empty arch\n"); return 1; } + if (c.d_model == 0 || c.n_layers == 0 || c.n_heads == 0) { std::fprintf(stderr, "bad encoder dims\n"); return 1; } + if (c.vocab_size == 0) { std::fprintf(stderr, "bad vocab\n"); return 1; } + if (c.blank_id != c.vocab_size) { std::fprintf(stderr, "blank!=vocab\n"); return 1; } + // mel filterbank tensor must be present + if (ml.tensor("preprocessor.featurizer.fb") == nullptr) { std::fprintf(stderr, "no fb\n"); return 1; } + // first conformer layer norm must be present (verbatim name) + if (ml.tensor("encoder.layers.0.norm_feed_forward1.weight") == nullptr) { + std::fprintf(stderr, "no layer0 norm\n"); return 1; + } + std::printf("loader ok: arch=%s d_model=%u layers=%u heads=%u vocab=%u\n", + c.arch.c_str(), c.d_model, c.n_layers, c.n_heads, c.vocab_size); + } + + // Prompt-conditioning config (nemotron). Runs whenever the fixture is set, + // independently of PARAKEET_TEST_GGUF. + if (npath) { + pk::ModelLoader nl; + if (!nl.load(npath)) { std::fprintf(stderr, "load nemotron failed\n"); return 1; } + const pk::ParakeetConfig& nc = nl.config(); + if (!nc.prompt.present) { std::fprintf(stderr, "prompt.present false\n"); return 1; } + if (nc.prompt.num_prompts != 128) { std::fprintf(stderr, "num_prompts!=128\n"); return 1; } + if (nc.prompt.default_lang != "auto") { std::fprintf(stderr, "default_lang!=auto\n"); return 1; } + if (nc.prompt.lang_to_index("de") != 9) { std::fprintf(stderr, "de!=9\n"); return 1; } + if (nc.prompt.lang_to_index("auto") != 101) { std::fprintf(stderr, "auto!=101\n"); return 1; } + if (nc.prompt.lang_to_index("zzz") != -1) { std::fprintf(stderr, "unknown!=-1\n"); return 1; } + if (nc.use_bias) { std::fprintf(stderr, "use_bias should be false\n"); return 1; } + std::fprintf(stderr, "nemotron prompt config OK\n"); } - std::printf("loader ok: arch=%s d_model=%u layers=%u heads=%u vocab=%u\n", - c.arch.c_str(), c.d_model, c.n_layers, c.n_heads, c.vocab_size); return 0; } diff --git a/tests/test_prompt_kernel.cpp b/tests/test_prompt_kernel.cpp new file mode 100644 index 0000000..70f8744 --- /dev/null +++ b/tests/test_prompt_kernel.cpp @@ -0,0 +1,49 @@ +#include "model_loader.hpp" +#include "backend.hpp" +#include "prompt_kernel.hpp" +#include "parity.hpp" +#include +#include +#include + +// Parity: PromptKernel(encoder_out, prompt_index) must match NeMo's +// prompt_kernel(cat([encoded, onehot])). Skips (77) unless both the converted +// nemotron gguf and the prompt baseline are provided. +// PARAKEET_TEST_GGUF_NEMOTRON converted nemotron gguf +// PARAKEET_TEST_BASELINE_NEMOTRON prompt baseline gguf (encoder_out [D,T], +// prompt_kernel_out [T,D], baseline.prompt_index) +int main() { + const char* gguf = std::getenv("PARAKEET_TEST_GGUF_NEMOTRON"); + const char* base = std::getenv("PARAKEET_TEST_BASELINE_NEMOTRON"); + if (!gguf || !base) { + std::fprintf(stderr, "test_prompt_kernel: fixtures not set; skip\n"); + return 77; + } + pk::ModelLoader ml; + if (!ml.load(gguf)) { std::fprintf(stderr, "load failed\n"); return 1; } + pk::ensure_weights_realized(ml); + + // encoder_out [D,T] (channels-first) + reference prompt_kernel_out [T,D]. + std::vector enc; std::vector enc_shape; + std::vector ref; std::vector ref_shape; + if (!pktest::load_baseline(base, "encoder_out", enc, enc_shape)) return 1; + if (!pktest::load_baseline(base, "prompt_kernel_out", ref, ref_shape)) return 1; + const int D = (int)enc_shape[0], T = (int)enc_shape[1]; + const int prompt_index = (int)pktest::pktest_read_u32(base, "baseline.prompt_index"); + std::fprintf(stderr, "[prompt_kernel] D=%d T=%d prompt_index=%d\n", D, T, prompt_index); + + pk::PromptKernel pkmod(ml); + if (!pkmod.present()) { std::fprintf(stderr, "prompt not present in model\n"); return 1; } + + std::vector got; // channels-first [D, T] + pkmod.apply(enc, D, T, prompt_index, got); + + // Transpose got [D,T] -> [T,D] to compare with ref [T,D]. + std::vector got_td((size_t)T * D); + for (int t = 0; t < T; ++t) + for (int c = 0; c < D; ++c) + got_td[(size_t)t * D + c] = got[(size_t)c * T + t]; + + bool ok = pktest::compare(got_td, ref, "prompt_kernel", 2e-3f, 2e-3f); + return ok ? 0 : 1; +} diff --git a/tests/test_streaming_nemotron.cpp b/tests/test_streaming_nemotron.cpp new file mode 100644 index 0000000..ecfa463 --- /dev/null +++ b/tests/test_streaming_nemotron.cpp @@ -0,0 +1,83 @@ +#include "model.hpp" +#include "streaming.hpp" +#include "audio_io.hpp" +#include "parity.hpp" +#include +#include +#include + +// MODEL: nvidia nemotron-3.5-asr-streaming-0.6b (prompt-conditioned, cache-aware +// streaming FastConformer RNN-T). +// WORKING_DIRECTORY: the repo root (build/tests run from there). +// +// End-to-end STREAMING parity WITH the language prompt (Phase 3, Task 3.2). The +// C++ pk::StreamingSession drives the cache-aware streaming encoder, applies the +// PromptKernel per chunk for the baseline's target_lang, and decodes the RNN-T +// greedy carrying state across chunks. The running transcript (sess.text(), with +// / stripped) must equal NeMo's OWN cache-aware streaming transcript +// for the SAME language (baseline.stream_text), produced by gen_nemo_baseline's +// dump_prompt_baseline: NeMo streams the encoder, applies m.prompt_kernel to the +// concatenated streamed output, and RNN-T-greedy-decodes it. +// +// By the cache-aware equivalence property (test_streaming_decode) the per-chunk +// decode with carried state == whole-streamed-output decode, and the per-frame +// prompt one-hot is constant over time, so per-chunk prompt application == single +// application — hence the C++ streaming transcript must match NeMo's EXACTLY. +// +// Skips (77) unless set: +// PARAKEET_TEST_GGUF_NEMOTRON converted nemotron gguf +// PARAKEET_TEST_BASELINE_NEMOTRON prompt baseline (baseline.stream_text, +// baseline.target_lang) +// PARAKEET_TEST_NEMOTRON_WAV the clip used for the baseline +// (default tests/fixtures/speech.wav) +int main() { + const char* gguf = std::getenv("PARAKEET_TEST_GGUF_NEMOTRON"); + const char* base = std::getenv("PARAKEET_TEST_BASELINE_NEMOTRON"); + if (!gguf || !base) { + std::fprintf(stderr, + "test_streaming_nemotron: PARAKEET_TEST_GGUF_NEMOTRON and/or " + "PARAKEET_TEST_BASELINE_NEMOTRON not set; skip\n"); + return 77; + } + const char* wav = std::getenv("PARAKEET_TEST_NEMOTRON_WAV"); + std::string wav_path = wav ? wav : "tests/fixtures/speech.wav"; + + std::string lang, ref; + if (!pktest::load_kv_str(base, "baseline.target_lang", lang)) return 1; + if (!pktest::load_kv_str(base, "baseline.stream_text", ref)) { + std::fprintf(stderr, + "[stream_nemotron] baseline.stream_text not found in %s " + "(regenerate with gen_nemo_baseline.py --lang)\n", base); + return 1; + } + + auto m = pk::Model::load(gguf); + if (!m) { std::fprintf(stderr, "[stream_nemotron] load failed %s\n", gguf); return 1; } + if (!m->config().prompt.present) { + std::fprintf(stderr, "[stream_nemotron] model is not prompt-conditioned\n"); + return 1; + } + if (!m->config().streaming.present) { + std::fprintf(stderr, "[stream_nemotron] model has no streaming config\n"); + return 1; + } + + pk::Audio a; + if (!pk::load_audio_16k_mono(wav_path, a)) { + std::fprintf(stderr, "[stream_nemotron] audio load failed %s\n", wav_path.c_str()); + return 1; + } + + pk::StreamingSession sess(m->loader(), lang); + pk::run_stream_over_pcm(sess, m->loader(), a.samples); + std::string got = sess.text(); + + std::fprintf(stderr, "lang=%s\n got=%s\n ref=%s\n", + lang.c_str(), got.c_str(), ref.c_str()); + if (got != ref) { + std::fprintf(stderr, "[stream_nemotron] MISMATCH\n"); + return 1; + } + std::fprintf(stderr, "nemotron streaming parity OK\n"); + return 0; +} diff --git a/tests/test_transcribe_nemotron.cpp b/tests/test_transcribe_nemotron.cpp new file mode 100644 index 0000000..1d00bee --- /dev/null +++ b/tests/test_transcribe_nemotron.cpp @@ -0,0 +1,34 @@ +#include "model.hpp" +#include "audio_io.hpp" +#include "parity.hpp" +#include +#include +#include + +// End-to-end: the C++ nemotron transcript must equal NeMo's transcript for the +// SAME target_lang. Skips (77) unless set: +// PARAKEET_TEST_GGUF_NEMOTRON converted nemotron gguf +// PARAKEET_TEST_BASELINE_NEMOTRON prompt baseline (baseline.rnnt_text, baseline.target_lang) +// PARAKEET_TEST_NEMOTRON_WAV the clip used for the baseline (default tests/fixtures/speech.wav) +int main() { + const char* gguf = std::getenv("PARAKEET_TEST_GGUF_NEMOTRON"); + const char* base = std::getenv("PARAKEET_TEST_BASELINE_NEMOTRON"); + if (!gguf || !base) { std::fprintf(stderr, "fixtures not set; skip\n"); return 77; } + const char* wav = std::getenv("PARAKEET_TEST_NEMOTRON_WAV"); + std::string wav_path = wav ? wav : "tests/fixtures/speech.wav"; + + std::string lang, ref; + if (!pktest::load_kv_str(base, "baseline.target_lang", lang)) return 1; + if (!pktest::load_kv_str(base, "baseline.rnnt_text", ref)) return 1; + + auto m = pk::Model::load(gguf); + if (!m) { std::fprintf(stderr, "load failed\n"); return 1; } + pk::Audio a; + if (!pk::load_audio_16k_mono(wav_path, a)) { std::fprintf(stderr, "audio load failed\n"); return 1; } + + std::string got = m->transcribe_16k(a.samples, pk::Decoder::kDefault, lang); + std::fprintf(stderr, "lang=%s\n got=%s\n ref=%s\n", lang.c_str(), got.c_str(), ref.c_str()); + if (got != ref) { std::fprintf(stderr, "MISMATCH\n"); return 1; } + std::fprintf(stderr, "nemotron offline parity OK\n"); + return 0; +}