-
Notifications
You must be signed in to change notification settings - Fork 1
[WIP] Add NeMo Conformer TDT ASR support #15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
d017602
964bc8f
fa9bc25
63aeee8
f6835ad
2dd36a1
10977df
39d9be4
3d984e5
9f3a220
3bac1dc
c75ebd2
493a588
5b4cdab
7690227
1f065c3
ec09a09
8a90a7c
ce0a3eb
5d91d39
dfc2c13
a5bd2cf
abada62
62d8bc0
03fb8bd
426061e
d7476a6
0989f7a
49a4af8
ee819a1
b44f7f3
bfa97e6
a85dff2
8dfccdd
816f581
f59ba06
00b3d93
07118c3
341df3d
29f2baa
39e5cb1
495bab5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,266 @@ | ||
| import { FeatureExtractor, validate_audio_inputs } from '../../feature_extraction_utils.js'; | ||
| import { Tensor } from '../../utils/tensor.js'; | ||
| import { mel_filter_bank, spectrogram, window_function } from '../../utils/audio.js'; | ||
| import { logger } from '../../utils/logger.js'; | ||
| import { FeatureLRUCache, createAudioCacheKey } from './transducer_cache.js'; | ||
| import { computeTemporalDeltas } from './transducer_deltas.js'; | ||
|
|
||
| const EPSILON = 1e-5; | ||
| export const NEMO_FEATURE_OUTPUT_OWNERSHIP = Symbol('NemoConformerTDTFeatureOutputOwnership'); | ||
| export const NEMO_FEATURE_OUTPUT_RELEASE = Symbol('NemoConformerTDTFeatureOutputRelease'); | ||
|
|
||
| function tagNemoFeatureOutputOwnership(value, cacheOwnsTensors, release = null) { | ||
| Object.defineProperty(value, NEMO_FEATURE_OUTPUT_OWNERSHIP, { | ||
| value: cacheOwnsTensors, | ||
| enumerable: false, | ||
| configurable: true, | ||
| }); | ||
| if (release) { | ||
| Object.defineProperty(value, NEMO_FEATURE_OUTPUT_RELEASE, { | ||
| value: release, | ||
| enumerable: false, | ||
| configurable: true, | ||
| }); | ||
| } | ||
| return value; | ||
| } | ||
|
|
||
| /** | ||
| * Feature extractor for Nemo Conformer TDT models. | ||
| * | ||
| * Mirrors NeMo-style log-mel extraction used by Parakeet with configurable | ||
| * `feature_size` (e.g. 80 or 128 mel bins via `preprocessor_config.json`). | ||
| */ | ||
| export class NemoConformerTDTFeatureExtractor extends FeatureExtractor { | ||
| constructor(config) { | ||
| super(config); | ||
|
|
||
| if (!Number.isInteger(this.config.n_fft) || this.config.n_fft <= 0) { | ||
| throw new Error( | ||
| `NemoConformerTDTFeatureExtractor expected \`n_fft\` as a positive integer, got ${this.config.n_fft}.`, | ||
| ); | ||
| } | ||
| if ( | ||
| !Number.isInteger(this.config.win_length) || | ||
| this.config.win_length <= 0 || | ||
| this.config.win_length > this.config.n_fft | ||
| ) { | ||
| throw new Error( | ||
| `NemoConformerTDTFeatureExtractor expected \`win_length\` in [1, n_fft], got win_length=${this.config.win_length}, n_fft=${this.config.n_fft}.`, | ||
| ); | ||
| } | ||
|
|
||
| // Prefer given `mel_filters` from preprocessor_config.json, or calculate them if they don't exist. | ||
| this.config.mel_filters ??= mel_filter_bank( | ||
| Math.floor(1 + this.config.n_fft / 2), // num_frequency_bins | ||
| this.config.feature_size, // num_mel_filters | ||
| 0.0, // min_frequency | ||
| this.config.sampling_rate / 2, // max_frequency | ||
| this.config.sampling_rate, // sampling_rate | ||
| 'slaney', // norm | ||
| 'slaney', // mel_scale | ||
| ); | ||
|
|
||
| const window = window_function(this.config.win_length, 'hann', { | ||
| periodic: false, | ||
| }); | ||
|
|
||
| this.window = new Float64Array(this.config.n_fft); | ||
| const offset = Math.floor((this.config.n_fft - this.config.win_length) / 2); | ||
| this.window.set(window, offset); | ||
|
|
||
| // Optional feature-level cache and delta/delta-delta post-processing. | ||
| this.use_feature_cache = this.config.use_feature_cache ?? false; | ||
| this.delta_order = this.config.delta_order ?? 0; | ||
| this.delta_window = this.config.delta_window ?? 2; | ||
| this.delta_concatenate = this.config.delta_concatenate ?? true; | ||
|
|
||
| if (![0, 1, 2].includes(this.delta_order)) { | ||
| throw new Error( | ||
| `NemoConformerTDTFeatureExtractor expected delta_order in {0,1,2}, got ${this.delta_order}.`, | ||
| ); | ||
| } | ||
| if (!Number.isInteger(this.delta_window) || this.delta_window < 1) { | ||
| throw new Error( | ||
| `NemoConformerTDTFeatureExtractor expected \`delta_window\` as a positive integer, got ${this.delta_window}.`, | ||
| ); | ||
| } | ||
| if (this.delta_order > 0 && !this.delta_concatenate) { | ||
| logger.warn( | ||
| 'NemoConformerTDTFeatureExtractor: `delta_concatenate=false` is set. ' + | ||
| '`input_features` will remain base features and deltas are returned in extra fields.', | ||
| ); | ||
| } | ||
|
|
||
| this.feature_cache = this.use_feature_cache | ||
| ? new FeatureLRUCache({ | ||
| max_entries: this.config.feature_cache_max_entries ?? 128, | ||
| max_size_mb: this.config.feature_cache_max_size_mb ?? 64, | ||
| }) | ||
| : null; | ||
| } | ||
|
|
||
| /** | ||
| * Computes the log-Mel spectrogram of the provided audio waveform. | ||
| * @param {Float32Array|Float64Array} waveform The audio waveform to process. | ||
| * @returns {Promise<Tensor>} An object containing the log-Mel spectrogram data as a Float32Array and its dimensions as an array of numbers. | ||
| */ | ||
| async _extract_fbank_features(waveform) { | ||
| // Parakeet uses a custom preemphasis strategy: Apply preemphasis to entire waveform at once | ||
| const preemphasis = this.config.preemphasis ?? 0; | ||
| if (!Number.isFinite(preemphasis) || preemphasis < 0 || preemphasis >= 1) { | ||
| throw new Error( | ||
| `NemoConformerTDTFeatureExtractor expected \`preemphasis\` in [0, 1), got ${this.config.preemphasis}.`, | ||
| ); | ||
| } | ||
| waveform = new Float64Array(waveform); // Clone to avoid destructive changes | ||
| if (preemphasis !== 0) { | ||
| for (let j = waveform.length - 1; j >= 1; --j) { | ||
| waveform[j] -= preemphasis * waveform[j - 1]; | ||
| } | ||
| } | ||
|
|
||
| const features = await spectrogram( | ||
| waveform, | ||
| this.window, // window | ||
| this.window.length, // frame_length | ||
| this.config.hop_length, // hop_length | ||
| { | ||
| fft_length: this.config.n_fft, | ||
| power: 2.0, | ||
| mel_filters: this.config.mel_filters, | ||
| log_mel: 'log', | ||
| mel_floor: -Infinity, | ||
| pad_mode: 'constant', | ||
| center: true, | ||
|
|
||
| // Custom | ||
| transpose: true, | ||
| mel_offset: 2 ** -24, | ||
| }, | ||
| ); | ||
|
|
||
| return features; | ||
| } | ||
|
|
||
| /** | ||
| * Asynchronously extracts features from a given audio using the provided configuration. | ||
| * @param {Float32Array|Float64Array} audio The audio data as a Float32Array/Float64Array. | ||
| * @returns {Promise<{ | ||
| * input_features: Tensor; | ||
| * attention_mask: Tensor; | ||
| * delta_features?: Tensor; | ||
| * delta_delta_features?: Tensor; | ||
| * }>} A Promise resolving to an object containing extracted model inputs. | ||
| * When cache is enabled, tensor instances are shared and owned by the cache. | ||
| * Do not mutate or dispose returned tensors unless cache is disabled/cleared. | ||
| */ | ||
| async _call(audio) { | ||
| validate_audio_inputs(audio, 'NemoConformerTDTFeatureExtractor'); | ||
|
|
||
| if (this.feature_cache) { | ||
| const key = `${createAudioCacheKey(audio, this.config.sampling_rate)}:${this.delta_order}:${this.delta_window}:${this.delta_concatenate}`; | ||
| const cached = this.feature_cache.acquire(key); | ||
| if (cached) { | ||
| return tagNemoFeatureOutputOwnership({ ...cached.value }, true, cached.release); | ||
| } | ||
|
|
||
| const extracted = await this._extract(audio); | ||
| const cacheOwnsTensors = this.feature_cache.set(key, extracted); | ||
| if (!cacheOwnsTensors) { | ||
| return tagNemoFeatureOutputOwnership({ ...extracted }, false); | ||
| } | ||
|
Comment on lines
+161
to
+172
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 1. Disposed tensors returned NemoConformerTDTFeatureExtractor._call() returns extracted as caller-owned when FeatureLRUCache.set() returns false, but FeatureLRUCache.set() may already have evicted and disposed that same value, yielding disposed tensors to downstream code. Agent Prompt
|
||
|
|
||
| const borrowed = this.feature_cache.acquire(key); | ||
| if (!borrowed) { | ||
| return tagNemoFeatureOutputOwnership({ ...extracted }, false); | ||
| } | ||
| return tagNemoFeatureOutputOwnership({ ...borrowed.value }, true, borrowed.release); | ||
| } | ||
|
|
||
| return tagNemoFeatureOutputOwnership(await this._extract(audio), false); | ||
| } | ||
|
|
||
| async _extract(audio) { | ||
| const features = await this._extract_fbank_features(audio); | ||
|
|
||
| const [num_frames, num_features] = features.dims; | ||
| const raw_features_length = Math.floor( | ||
| (audio.length + Math.floor(this.config.n_fft / 2) * 2 - this.config.n_fft) / this.config.hop_length, | ||
| ); | ||
| // Clamp to [0, num_frames] to avoid a negative fill offset for very short clips. | ||
| const features_length = Math.max(0, Math.min(num_frames, raw_features_length)); | ||
|
|
||
| const features_data = /** @type {Float32Array} */ (features.data); | ||
| features_data.fill(0, features_length * num_features); | ||
|
|
||
| // normalize mel features, ignoring padding | ||
| const sum = new Float64Array(num_features); | ||
| const sum_sq = new Float64Array(num_features); | ||
|
|
||
| for (let i = 0; i < features_length; ++i) { | ||
| const offset = i * num_features; | ||
| for (let j = 0; j < num_features; ++j) { | ||
| const val = features_data[offset + j]; | ||
| sum[j] += val; | ||
| sum_sq[j] += val * val; | ||
| } | ||
| } | ||
|
|
||
| // Skip normalization for empty/very short audio to avoid NaN from divide-by-zero. | ||
| if (features_length > 0) { | ||
| // Calculate mean and standard deviation, then normalize | ||
| const divisor = features_length > 1 ? features_length - 1 : 1; | ||
| for (let j = 0; j < num_features; ++j) { | ||
| const mean = sum[j] / features_length; | ||
| const variance = (sum_sq[j] - features_length * mean * mean) / divisor; | ||
| const std = Math.sqrt(Math.max(variance, 0)) + EPSILON; | ||
| const inv_std = 1 / std; | ||
|
|
||
| for (let i = 0; i < features_length; ++i) { | ||
| const index = i * num_features + j; | ||
| features_data[index] = (features_data[index] - mean) * inv_std; | ||
| } | ||
| } | ||
| } | ||
|
Comment on lines
+197
to
+225
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To improve readability and separation of concerns, consider extracting this feature normalization logic into a separate private helper method. The current implementation is nested within the |
||
|
|
||
| const mask_data = new BigInt64Array(num_frames); | ||
| mask_data.fill(1n, 0, features_length); | ||
|
|
||
| let input_features = features.unsqueeze_(0); | ||
| const attention_mask = new Tensor('int64', mask_data, [1, num_frames]); | ||
|
|
||
| const result = { | ||
| input_features, | ||
| attention_mask, | ||
| }; | ||
|
|
||
| if (this.delta_order > 0) { | ||
| const delta_result = computeTemporalDeltas(input_features, { | ||
| order: this.delta_order, | ||
| window: this.delta_window, | ||
| concatenate: this.delta_concatenate, | ||
| }); | ||
| if (delta_result instanceof Tensor) { | ||
| input_features.dispose(); | ||
| input_features = delta_result; | ||
| result.input_features = input_features; | ||
| } else { | ||
| result.delta_features = delta_result.delta; | ||
| if (delta_result.delta_delta) { | ||
| result.delta_delta_features = delta_result.delta_delta; | ||
| } | ||
| } | ||
| } | ||
|
|
||
| return result; | ||
| } | ||
|
|
||
| clear_cache() { | ||
| this.feature_cache?.clear(); | ||
| } | ||
|
|
||
| get_cache_stats() { | ||
| return this.feature_cache?.stats() ?? null; | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reject invalid
hop_length,feature_size, andsampling_ratein the constructor.The constructor only guards
n_fftandwin_length. Invalid values for these other core fields still flow intomel_filter_bank(),spectrogram(), and frame-length math, so a bad model config fails later and less predictably.Proposed fix
if ( !Number.isInteger(this.config.win_length) || this.config.win_length <= 0 || this.config.win_length > this.config.n_fft ) { throw new Error( `NemoConformerTDTFeatureExtractor expected \`win_length\` in [1, n_fft], got win_length=${this.config.win_length}, n_fft=${this.config.n_fft}.`, ); } + if (!Number.isInteger(this.config.hop_length) || this.config.hop_length <= 0) { + throw new Error( + `NemoConformerTDTFeatureExtractor expected \`hop_length\` as a positive integer, got ${this.config.hop_length}.`, + ); + } + if (!Number.isInteger(this.config.feature_size) || this.config.feature_size <= 0) { + throw new Error( + `NemoConformerTDTFeatureExtractor expected \`feature_size\` as a positive integer, got ${this.config.feature_size}.`, + ); + } + if (!Number.isFinite(this.config.sampling_rate) || this.config.sampling_rate <= 0) { + throw new Error( + `NemoConformerTDTFeatureExtractor expected \`sampling_rate\` as a positive number, got ${this.config.sampling_rate}.`, + ); + } // Prefer given `mel_filters` from preprocessor_config.json, or calculate them if they don't exist. this.config.mel_filters ??= mel_filter_bank(