diff --git a/packages/transformers/src/models/feature_extractors.js b/packages/transformers/src/models/feature_extractors.js index 589b96cd7..48ecc6ff6 100644 --- a/packages/transformers/src/models/feature_extractors.js +++ b/packages/transformers/src/models/feature_extractors.js @@ -5,6 +5,7 @@ export * from './clap/feature_extraction_clap.js'; export * from './dac/feature_extraction_dac.js'; export * from './gemma3n/feature_extraction_gemma3n.js'; export * from './moonshine/feature_extraction_moonshine.js'; +export * from './nemo_conformer_tdt/feature_extraction_nemo_conformer_tdt.js'; export * from './parakeet/feature_extraction_parakeet.js'; export * from './pyannote/feature_extraction_pyannote.js'; export * from './seamless_m4t/feature_extraction_seamless_m4t.js'; diff --git a/packages/transformers/src/models/modeling_utils.js b/packages/transformers/src/models/modeling_utils.js index 96c0861b9..475de6be0 100644 --- a/packages/transformers/src/models/modeling_utils.js +++ b/packages/transformers/src/models/modeling_utils.js @@ -118,6 +118,7 @@ export const MODEL_TYPES = { ImageAudioTextToText: 13, Supertonic: 14, Chatterbox: 15, + NemoConformerTDT: 16, }; const MODEL_TYPE_CONFIG = { diff --git a/packages/transformers/src/models/models.js b/packages/transformers/src/models/models.js index 2d4994892..8b0c0371c 100644 --- a/packages/transformers/src/models/models.js +++ b/packages/transformers/src/models/models.js @@ -103,6 +103,7 @@ export * from './mt5/modeling_mt5.js'; export * from './multi_modality/modeling_multi_modality.js'; export * from './musicgen/modeling_musicgen.js'; export * from './nanochat/modeling_nanochat.js'; +export * from './nemo_conformer_tdt/modeling_nemo_conformer_tdt.js'; export * from './neobert/modeling_neobert.js'; export * from './nomic_bert/modeling_nomic_bert.js'; export * from './olmo/modeling_olmo.js'; diff --git a/packages/transformers/src/models/nemo_conformer_tdt/feature_extraction_nemo_conformer_tdt.js b/packages/transformers/src/models/nemo_conformer_tdt/feature_extraction_nemo_conformer_tdt.js new file mode 100644 index 000000000..bca3bc021 --- /dev/null +++ b/packages/transformers/src/models/nemo_conformer_tdt/feature_extraction_nemo_conformer_tdt.js @@ -0,0 +1,266 @@ +import { FeatureExtractor, validate_audio_inputs } from '../../feature_extraction_utils.js'; +import { Tensor } from '../../utils/tensor.js'; +import { mel_filter_bank, spectrogram, window_function } from '../../utils/audio.js'; +import { logger } from '../../utils/logger.js'; +import { FeatureLRUCache, createAudioCacheKey } from './transducer_cache.js'; +import { computeTemporalDeltas } from './transducer_deltas.js'; + +const EPSILON = 1e-5; +export const NEMO_FEATURE_OUTPUT_OWNERSHIP = Symbol('NemoConformerTDTFeatureOutputOwnership'); +export const NEMO_FEATURE_OUTPUT_RELEASE = Symbol('NemoConformerTDTFeatureOutputRelease'); + +function tagNemoFeatureOutputOwnership(value, cacheOwnsTensors, release = null) { + Object.defineProperty(value, NEMO_FEATURE_OUTPUT_OWNERSHIP, { + value: cacheOwnsTensors, + enumerable: false, + configurable: true, + }); + if (release) { + Object.defineProperty(value, NEMO_FEATURE_OUTPUT_RELEASE, { + value: release, + enumerable: false, + configurable: true, + }); + } + return value; +} + +/** + * Feature extractor for Nemo Conformer TDT models. + * + * Mirrors NeMo-style log-mel extraction used by Parakeet with configurable + * `feature_size` (e.g. 80 or 128 mel bins via `preprocessor_config.json`). + */ +export class NemoConformerTDTFeatureExtractor extends FeatureExtractor { + constructor(config) { + super(config); + + if (!Number.isInteger(this.config.n_fft) || this.config.n_fft <= 0) { + throw new Error( + `NemoConformerTDTFeatureExtractor expected \`n_fft\` as a positive integer, got ${this.config.n_fft}.`, + ); + } + if ( + !Number.isInteger(this.config.win_length) || + this.config.win_length <= 0 || + this.config.win_length > this.config.n_fft + ) { + throw new Error( + `NemoConformerTDTFeatureExtractor expected \`win_length\` in [1, n_fft], got win_length=${this.config.win_length}, n_fft=${this.config.n_fft}.`, + ); + } + + // Prefer given `mel_filters` from preprocessor_config.json, or calculate them if they don't exist. + this.config.mel_filters ??= mel_filter_bank( + Math.floor(1 + this.config.n_fft / 2), // num_frequency_bins + this.config.feature_size, // num_mel_filters + 0.0, // min_frequency + this.config.sampling_rate / 2, // max_frequency + this.config.sampling_rate, // sampling_rate + 'slaney', // norm + 'slaney', // mel_scale + ); + + const window = window_function(this.config.win_length, 'hann', { + periodic: false, + }); + + this.window = new Float64Array(this.config.n_fft); + const offset = Math.floor((this.config.n_fft - this.config.win_length) / 2); + this.window.set(window, offset); + + // Optional feature-level cache and delta/delta-delta post-processing. + this.use_feature_cache = this.config.use_feature_cache ?? false; + this.delta_order = this.config.delta_order ?? 0; + this.delta_window = this.config.delta_window ?? 2; + this.delta_concatenate = this.config.delta_concatenate ?? true; + + if (![0, 1, 2].includes(this.delta_order)) { + throw new Error( + `NemoConformerTDTFeatureExtractor expected delta_order in {0,1,2}, got ${this.delta_order}.`, + ); + } + if (!Number.isInteger(this.delta_window) || this.delta_window < 1) { + throw new Error( + `NemoConformerTDTFeatureExtractor expected \`delta_window\` as a positive integer, got ${this.delta_window}.`, + ); + } + if (this.delta_order > 0 && !this.delta_concatenate) { + logger.warn( + 'NemoConformerTDTFeatureExtractor: `delta_concatenate=false` is set. ' + + '`input_features` will remain base features and deltas are returned in extra fields.', + ); + } + + this.feature_cache = this.use_feature_cache + ? new FeatureLRUCache({ + max_entries: this.config.feature_cache_max_entries ?? 128, + max_size_mb: this.config.feature_cache_max_size_mb ?? 64, + }) + : null; + } + + /** + * Computes the log-Mel spectrogram of the provided audio waveform. + * @param {Float32Array|Float64Array} waveform The audio waveform to process. + * @returns {Promise} An object containing the log-Mel spectrogram data as a Float32Array and its dimensions as an array of numbers. + */ + async _extract_fbank_features(waveform) { + // Parakeet uses a custom preemphasis strategy: Apply preemphasis to entire waveform at once + const preemphasis = this.config.preemphasis ?? 0; + if (!Number.isFinite(preemphasis) || preemphasis < 0 || preemphasis >= 1) { + throw new Error( + `NemoConformerTDTFeatureExtractor expected \`preemphasis\` in [0, 1), got ${this.config.preemphasis}.`, + ); + } + waveform = new Float64Array(waveform); // Clone to avoid destructive changes + if (preemphasis !== 0) { + for (let j = waveform.length - 1; j >= 1; --j) { + waveform[j] -= preemphasis * waveform[j - 1]; + } + } + + const features = await spectrogram( + waveform, + this.window, // window + this.window.length, // frame_length + this.config.hop_length, // hop_length + { + fft_length: this.config.n_fft, + power: 2.0, + mel_filters: this.config.mel_filters, + log_mel: 'log', + mel_floor: -Infinity, + pad_mode: 'constant', + center: true, + + // Custom + transpose: true, + mel_offset: 2 ** -24, + }, + ); + + return features; + } + + /** + * Asynchronously extracts features from a given audio using the provided configuration. + * @param {Float32Array|Float64Array} audio The audio data as a Float32Array/Float64Array. + * @returns {Promise<{ + * input_features: Tensor; + * attention_mask: Tensor; + * delta_features?: Tensor; + * delta_delta_features?: Tensor; + * }>} A Promise resolving to an object containing extracted model inputs. + * When cache is enabled, tensor instances are shared and owned by the cache. + * Do not mutate or dispose returned tensors unless cache is disabled/cleared. + */ + async _call(audio) { + validate_audio_inputs(audio, 'NemoConformerTDTFeatureExtractor'); + + if (this.feature_cache) { + const key = `${createAudioCacheKey(audio, this.config.sampling_rate)}:${this.delta_order}:${this.delta_window}:${this.delta_concatenate}`; + const cached = this.feature_cache.acquire(key); + if (cached) { + return tagNemoFeatureOutputOwnership({ ...cached.value }, true, cached.release); + } + + const extracted = await this._extract(audio); + const cacheOwnsTensors = this.feature_cache.set(key, extracted); + if (!cacheOwnsTensors) { + return tagNemoFeatureOutputOwnership({ ...extracted }, false); + } + + const borrowed = this.feature_cache.acquire(key); + if (!borrowed) { + return tagNemoFeatureOutputOwnership({ ...extracted }, false); + } + return tagNemoFeatureOutputOwnership({ ...borrowed.value }, true, borrowed.release); + } + + return tagNemoFeatureOutputOwnership(await this._extract(audio), false); + } + + async _extract(audio) { + const features = await this._extract_fbank_features(audio); + + const [num_frames, num_features] = features.dims; + const raw_features_length = Math.floor( + (audio.length + Math.floor(this.config.n_fft / 2) * 2 - this.config.n_fft) / this.config.hop_length, + ); + // Clamp to [0, num_frames] to avoid a negative fill offset for very short clips. + const features_length = Math.max(0, Math.min(num_frames, raw_features_length)); + + const features_data = /** @type {Float32Array} */ (features.data); + features_data.fill(0, features_length * num_features); + + // normalize mel features, ignoring padding + const sum = new Float64Array(num_features); + const sum_sq = new Float64Array(num_features); + + for (let i = 0; i < features_length; ++i) { + const offset = i * num_features; + for (let j = 0; j < num_features; ++j) { + const val = features_data[offset + j]; + sum[j] += val; + sum_sq[j] += val * val; + } + } + + // Skip normalization for empty/very short audio to avoid NaN from divide-by-zero. + if (features_length > 0) { + // Calculate mean and standard deviation, then normalize + const divisor = features_length > 1 ? features_length - 1 : 1; + for (let j = 0; j < num_features; ++j) { + const mean = sum[j] / features_length; + const variance = (sum_sq[j] - features_length * mean * mean) / divisor; + const std = Math.sqrt(Math.max(variance, 0)) + EPSILON; + const inv_std = 1 / std; + + for (let i = 0; i < features_length; ++i) { + const index = i * num_features + j; + features_data[index] = (features_data[index] - mean) * inv_std; + } + } + } + + const mask_data = new BigInt64Array(num_frames); + mask_data.fill(1n, 0, features_length); + + let input_features = features.unsqueeze_(0); + const attention_mask = new Tensor('int64', mask_data, [1, num_frames]); + + const result = { + input_features, + attention_mask, + }; + + if (this.delta_order > 0) { + const delta_result = computeTemporalDeltas(input_features, { + order: this.delta_order, + window: this.delta_window, + concatenate: this.delta_concatenate, + }); + if (delta_result instanceof Tensor) { + input_features.dispose(); + input_features = delta_result; + result.input_features = input_features; + } else { + result.delta_features = delta_result.delta; + if (delta_result.delta_delta) { + result.delta_delta_features = delta_result.delta_delta; + } + } + } + + return result; + } + + clear_cache() { + this.feature_cache?.clear(); + } + + get_cache_stats() { + return this.feature_cache?.stats() ?? null; + } +} diff --git a/packages/transformers/src/models/nemo_conformer_tdt/modeling_nemo_conformer_tdt.js b/packages/transformers/src/models/nemo_conformer_tdt/modeling_nemo_conformer_tdt.js new file mode 100644 index 000000000..b32332583 --- /dev/null +++ b/packages/transformers/src/models/nemo_conformer_tdt/modeling_nemo_conformer_tdt.js @@ -0,0 +1,958 @@ +import { AutoConfig } from '../../configs.js'; +import { Tensor } from '../../utils/tensor.js'; +import { + PreTrainedModel, + MODEL_TYPES, + MODEL_TYPE_MAPPING, + MODEL_NAME_TO_CLASS_MAPPING, + MODEL_CLASS_TO_NAME_MAPPING, +} from '../modeling_utils.js'; +import { constructSessions, sessionRun } from '../session.js'; +import { buildTransducerDetailedOutputs, decodeTransducerText } from './transducer_text.js'; + +const NEMO_CONFORMER_TDT_MODEL_TYPE = 'nemo-conformer-tdt'; + +const DEFAULT_TRANSDUCER_IO = Object.freeze({ + encoder_output: 'outputs', + decoder_encoder: 'encoder_outputs', + decoder_token: 'targets', + decoder_token_length: 'target_length', + decoder_state_1: 'input_states_1', + decoder_state_2: 'input_states_2', + decoder_output: 'outputs', + decoder_output_state_1: 'output_states_1', + decoder_output_state_2: 'output_states_2', +}); + +function argmax(values, offset = 0, length = values.length - offset) { + let maxIndex = offset; + let maxValue = Number.NEGATIVE_INFINITY; + const end = offset + length; + for (let i = offset; i < end; ++i) { + const v = values[i]; + if (v > maxValue) { + maxValue = v; + maxIndex = i; + } + } + return maxIndex; +} + +function toInt(value) { + return typeof value === 'bigint' ? Number(value) : value; +} + +function nowMs() { + return typeof performance !== 'undefined' && typeof performance.now === 'function' ? performance.now() : Date.now(); +} + +function roundMetric(value, digits = 2) { + if (!Number.isFinite(value)) return 0; + const factor = 10 ** digits; + return Math.round(value * factor) / factor; +} + +function roundTs(value) { + return Math.round(value * 1000) / 1000; +} + +/** + * @param {import('../../utils/tensor.js').Tensor['data']} logits + * @param {number} tokenId + * @param {number} vocabSize + * @returns {{ confidence: number, logProb: number }} + */ +function confidenceFromLogits(logits, tokenId, vocabSize) { + let maxLogit = Number.NEGATIVE_INFINITY; + for (let i = 0; i < vocabSize; ++i) { + if (logits[i] > maxLogit) { + maxLogit = logits[i]; + } + } + + let expSum = 0; + for (let i = 0; i < vocabSize; ++i) { + expSum += Math.exp(logits[i] - maxLogit); + } + const logSumExp = maxLogit + Math.log(expSum); + const logProb = logits[tokenId] - logSumExp; + return { + confidence: Math.exp(logProb), + logProb, + }; +} + +function resolveTransducerConfig(config, sessions) { + const transducerConfig = config['transformers.js_config']?.transducer; + if (!transducerConfig) { + throw new Error( + 'Missing `transformers.js_config.transducer` in config.json for nemo-conformer-tdt. See external model repo contract.', + ); + } + + const decoderConfig = transducerConfig.decoder ?? {}; + const numLayers = decoderConfig.num_layers; + const hiddenSize = decoderConfig.hidden_size; + + if (!Number.isInteger(numLayers) || numLayers <= 0) { + throw new Error('Invalid `transformers.js_config.transducer.decoder.num_layers`: expected a positive integer.'); + } + if (!Number.isInteger(hiddenSize) || hiddenSize <= 0) { + throw new Error( + 'Invalid `transformers.js_config.transducer.decoder.hidden_size`: expected a positive integer.', + ); + } + + const io = { + ...DEFAULT_TRANSDUCER_IO, + ...(transducerConfig.io ?? {}), + }; + const requiredDecoderInputs = [ + io.decoder_encoder, + io.decoder_token, + io.decoder_token_length, + io.decoder_state_1, + io.decoder_state_2, + ]; + if (new Set(requiredDecoderInputs).size !== requiredDecoderInputs.length) { + throw new Error( + 'Invalid `transformers.js_config.transducer.io`: decoder input names must be distinct ' + + '(decoder_encoder, decoder_token, decoder_token_length, decoder_state_1, decoder_state_2).', + ); + } + const requiredDecoderOutputs = [io.decoder_output, io.decoder_output_state_1, io.decoder_output_state_2]; + if (new Set(requiredDecoderOutputs).size !== requiredDecoderOutputs.length) { + throw new Error( + 'Invalid `transformers.js_config.transducer.io`: decoder output names must be distinct ' + + '(decoder_output, decoder_output_state_1, decoder_output_state_2).', + ); + } + + const decoderSession = sessions?.decoder_model_merged; + if (!decoderSession) { + throw new Error('Missing required session `decoder_model_merged` for Nemo Conformer TDT.'); + } + + const decoderInputNames = decoderSession.inputNames ?? []; + const decoderOutputNames = decoderSession.outputNames ?? []; + const missingDecoderInputs = [ + io.decoder_encoder, + io.decoder_token, + io.decoder_token_length, + io.decoder_state_1, + io.decoder_state_2, + ].filter((name) => !decoderInputNames.includes(name)); + + if (missingDecoderInputs.length > 0) { + throw new Error( + `Nemo Conformer TDT decoder session is missing expected inputs: ${missingDecoderInputs.join(', ')}. ` + + 'Override I/O names via `transformers.js_config.transducer.io` if your export uses different names.', + ); + } + const missingDecoderOutputs = [io.decoder_output, io.decoder_output_state_1, io.decoder_output_state_2].filter( + (name) => !decoderOutputNames.includes(name), + ); + if (missingDecoderOutputs.length > 0) { + throw new Error( + `Nemo Conformer TDT decoder session is missing expected outputs: ${missingDecoderOutputs.join(', ')}. ` + + 'Override I/O names via `transformers.js_config.transducer.io` if your export uses different names.', + ); + } + + const encoderSession = sessions?.encoder_model; + if (!encoderSession) { + throw new Error('Missing required session `encoder_model` for Nemo Conformer TDT.'); + } + if (!(encoderSession.outputNames ?? []).includes(io.encoder_output)) { + throw new Error( + `Nemo Conformer TDT encoder session is missing expected output: ${io.encoder_output}. ` + + 'Override `transformers.js_config.transducer.io.encoder_output` if your export uses a different name.', + ); + } + + const maxSymbolsPerStep = transducerConfig.max_symbols_per_step ?? 10; + const subsamplingFactor = transducerConfig.subsampling_factor ?? 8; + const frameShiftS = transducerConfig.frame_shift_s ?? 0.01; + const blankTokenId = transducerConfig.blank_token_id ?? 0; + const encoderOutputLayout = transducerConfig.encoder_output_layout; + const encoderInputLayout = transducerConfig.encoder_input_layout ?? 'BTF'; + const encoderFrameLayout = transducerConfig.encoder_frame_layout ?? 'BD1'; + const decoderTokenDType = transducerConfig.decoder_token_dtype ?? 'int32'; + const decoderTokenLengthDType = transducerConfig.decoder_token_length_dtype ?? 'int32'; + + if (!Number.isInteger(blankTokenId) || blankTokenId < 0) { + throw new Error('Invalid `transformers.js_config.transducer.blank_token_id`: expected a non-negative integer.'); + } + if (!Number.isInteger(maxSymbolsPerStep) || maxSymbolsPerStep <= 0) { + throw new Error( + 'Invalid `transformers.js_config.transducer.max_symbols_per_step`: expected a positive integer.', + ); + } + if (!Number.isFinite(subsamplingFactor) || subsamplingFactor <= 0) { + throw new Error('Invalid `transformers.js_config.transducer.subsampling_factor`: expected a positive number.'); + } + if (!Number.isFinite(frameShiftS) || frameShiftS <= 0) { + throw new Error('Invalid `transformers.js_config.transducer.frame_shift_s`: expected a positive number.'); + } + if (encoderOutputLayout !== 'BDT' && encoderOutputLayout !== 'BTD') { + throw new Error('Invalid `transformers.js_config.transducer.encoder_output_layout`: expected "BDT" or "BTD".'); + } + if (encoderInputLayout !== 'BTF' && encoderInputLayout !== 'BFT') { + throw new Error('Invalid `transformers.js_config.transducer.encoder_input_layout`: expected "BTF" or "BFT".'); + } + if (encoderFrameLayout !== 'BD1' && encoderFrameLayout !== 'B1D') { + throw new Error('Invalid `transformers.js_config.transducer.encoder_frame_layout`: expected "BD1" or "B1D".'); + } + if (!['int32', 'int64'].includes(decoderTokenDType)) { + throw new Error( + 'Invalid `transformers.js_config.transducer.decoder_token_dtype`: expected "int32" or "int64".', + ); + } + if (!['int32', 'int64'].includes(decoderTokenLengthDType)) { + throw new Error( + 'Invalid `transformers.js_config.transducer.decoder_token_length_dtype`: expected "int32" or "int64".', + ); + } + + return { + blank_token_id: blankTokenId, + max_symbols_per_step: maxSymbolsPerStep, + subsampling_factor: subsamplingFactor, + frame_shift_s: frameShiftS, + vocab_size: transducerConfig.vocab_size ?? config.vocab_size ?? null, + duration_start_index: transducerConfig.duration_start_index ?? null, + encoder_input_layout: encoderInputLayout, + encoder_output_layout: encoderOutputLayout, + encoder_frame_layout: encoderFrameLayout, + decoder_token_dtype: decoderTokenDType, + decoder_token_length_dtype: decoderTokenLengthDType, + decoder: { + num_layers: numLayers, + hidden_size: hiddenSize, + }, + io, + }; +} + +export class NemoConformerTDTPreTrainedModel extends PreTrainedModel { + main_input_name = 'input_features'; + forward_params = ['input_features', 'attention_mask']; + + constructor(config, sessions, configs) { + super(config, sessions, configs); + this.transducer = resolveTransducerConfig(config, sessions); + } + + static supports(model_type) { + return model_type === NEMO_CONFORMER_TDT_MODEL_TYPE; + } + + /** + * Load Nemo Conformer TDT sessions using v4 canonical ONNX filenames. + * @type {typeof PreTrainedModel.from_pretrained} + */ + static async from_pretrained( + pretrained_model_name_or_path, + { + progress_callback = null, + config = null, + cache_dir = null, + local_files_only = false, + revision = 'main', + model_file_name = null, + subfolder = 'onnx', + device = null, + dtype = null, + use_external_data_format = null, + session_options = {}, + } = {}, + ) { + const options = { + progress_callback, + config, + cache_dir, + local_files_only, + revision, + model_file_name, + subfolder, + device, + dtype, + use_external_data_format, + session_options, + }; + + config = options.config = await AutoConfig.from_pretrained(pretrained_model_name_or_path, options); + if (config.model_type !== NEMO_CONFORMER_TDT_MODEL_TYPE) { + throw new Error(`Unsupported model type: ${config.model_type}`); + } + + if (options.model_file_name && options.model_file_name !== 'encoder_model') { + throw new Error( + 'NemoConformerForTDT does not support `model_file_name` override. ' + + 'Expected canonical files: `encoder_model{suffix}.onnx` and `decoder_model_merged{suffix}.onnx`.', + ); + } + + let sessions; + try { + sessions = await constructSessions( + pretrained_model_name_or_path, + { + encoder_model: 'encoder_model', + decoder_model_merged: 'decoder_model_merged', + }, + options, + 'decoder_model_merged', + ); + } catch (error) { + const reason = error?.message ?? String(error); + throw new Error( + 'Failed to load Nemo Conformer TDT sessions. Expected canonical v4 files under `onnx/`: ' + + '`encoder_model{suffix}.onnx` and `decoder_model_merged{suffix}.onnx`. ' + + `Original error: ${reason}`, + ); + } + + return new this(config, sessions, {}); + } +} + +export class NemoConformerForTDT extends NemoConformerTDTPreTrainedModel { + async _runEncoder(feeds) { + return await sessionRun(this.sessions.encoder_model, feeds); + } + + async _runDecoder(feeds) { + return await sessionRun(this.sessions.decoder_model_merged, feeds); + } + + _disposeDecoderState(state, keepState = null) { + if (!state) return; + if (state.state1 instanceof Tensor && state.state1 !== keepState?.state1) { + state.state1.dispose(); + } + if (state.state2 instanceof Tensor && state.state2 !== keepState?.state2) { + state.state2.dispose(); + } + } + + _getEncoderOutput(outputs) { + const name = this.transducer.io.encoder_output; + const out = outputs?.[name]; + if (!(out instanceof Tensor)) { + const available = outputs && typeof outputs === 'object' ? Object.keys(outputs).join(', ') : '(none)'; + throw new Error( + `Nemo Conformer TDT encoder output "${name}" was not returned by the session. ` + + `Available outputs: ${available}.`, + ); + } + return out; + } + + _getEncoderFrameCount(encoderOutput) { + if (encoderOutput.dims.length !== 3 || encoderOutput.dims[0] !== 1) { + throw new Error( + `Nemo Conformer TDT expected encoder output dims [1, D, T] or [1, T, D], got [${encoderOutput.dims.join(', ')}].`, + ); + } + const layout = this.transducer.encoder_output_layout; + if (layout === 'BDT') { + return encoderOutput.dims[2]; + } + if (layout === 'BTD') { + return encoderOutput.dims[1]; + } + throw new Error( + `Unsupported encoder output layout "${layout}". Use 'BDT' or 'BTD' in transformers.js_config.transducer.`, + ); + } + + _getFrameData(encoderOutput, frameIndex, reusableFrame) { + const layout = this.transducer.encoder_output_layout; + if (encoderOutput.type !== 'float32') { + throw new Error(`Nemo Conformer TDT expected encoder output type "float32", got "${encoderOutput.type}".`); + } + const data = /** @type {Float32Array} */ (encoderOutput.data); + + if (layout === 'BDT') { + const D = encoderOutput.dims[1]; + const T = encoderOutput.dims[2]; + const frame = reusableFrame && reusableFrame.length === D ? reusableFrame : new Float32Array(D); + for (let d = 0; d < D; ++d) { + frame[d] = data[d * T + frameIndex]; + } + return frame; + } + + if (layout === 'BTD') { + const D = encoderOutput.dims[2]; + const offset = frameIndex * D; + return data.subarray(offset, offset + D); + } + + throw new Error( + `Unsupported encoder output layout "${layout}". Use 'BDT' or 'BTD' in transformers.js_config.transducer.`, + ); + } + + _createFrameTensor(frameData) { + const layout = this.transducer.encoder_frame_layout; + if (layout === 'BD1') { + return new Tensor('float32', frameData, [1, frameData.length, 1]); + } else if (layout === 'B1D') { + return new Tensor('float32', frameData, [1, 1, frameData.length]); + } + throw new Error( + `Unsupported encoder frame layout "${layout}". Use 'BD1' or 'B1D' in transformers.js_config.transducer.`, + ); + } + + _buildEncoderFeeds(model_inputs) { + const encoderSession = this.sessions.encoder_model; + const feeds = {}; + const disposables = []; + const inputFeatures = model_inputs.input_features; + + if (!(inputFeatures instanceof Tensor)) { + throw new Error( + 'NemoConformerForTDT.transcribe expected `model_inputs.input_features` as a Tensor from the processor.', + ); + } + + const missingInputs = []; + let preparedEncoderInput = null; + const getPreparedEncoderInput = () => { + if (preparedEncoderInput) { + return preparedEncoderInput; + } + + const layout = this.transducer.encoder_input_layout; + if (layout === 'BTF') { + preparedEncoderInput = inputFeatures; + } else if (layout === 'BFT') { + preparedEncoderInput = inputFeatures.transpose(0, 2, 1); + disposables.push(preparedEncoderInput); + } else { + throw new Error( + `Unsupported encoder input layout "${layout}". Use 'BTF' or 'BFT' in transformers.js_config.transducer.`, + ); + } + return preparedEncoderInput; + }; + for (const name of encoderSession.inputNames) { + if (name === 'input_features' || name === 'audio_signal') { + feeds[name] = getPreparedEncoderInput(); + continue; + } + + if (model_inputs[name] instanceof Tensor) { + feeds[name] = model_inputs[name]; + continue; + } + + if (name === 'length') { + let length = null; + const attentionMask = model_inputs.attention_mask; + if (attentionMask instanceof Tensor) { + const maskData = attentionMask.data; + let sum = 0; + for (let i = 0; i < maskData.length; ++i) { + sum += toInt(maskData[i]); + } + length = sum; + } else { + length = inputFeatures.dims[1]; + } + if (!Number.isInteger(length) || length < 0) { + throw new Error( + `Nemo Conformer TDT expected a non-negative integer encoder length, got: ${length}.`, + ); + } + const lengthTensor = new Tensor('int64', BigInt64Array.from([BigInt(length)]), [1]); + disposables.push(lengthTensor); + feeds[name] = lengthTensor; + continue; + } + + missingInputs.push(name); + } + + if (missingInputs.length > 0) { + for (const tensor of disposables) { + tensor.dispose(); + } + throw new Error( + `Nemo Conformer TDT encoder session expects additional inputs that are not available: ${missingInputs.join(', ')}.`, + ); + } + + return { feeds, disposables }; + } + + _resolveVocabSize(tokenizer) { + if (Number.isInteger(this.transducer.vocab_size)) { + return this.transducer.vocab_size; + } + + if (tokenizer?.get_vocab) { + const vocab = tokenizer.get_vocab(); + if (vocab instanceof Map) { + let maxId = -1; + for (const id of vocab.values()) { + const numericId = Number(id); + if (Number.isInteger(numericId) && numericId >= 0) { + maxId = Math.max(maxId, numericId); + } + } + if (maxId >= 0) { + return maxId + 1; + } + } else if (Array.isArray(vocab)) { + if (vocab.length > 0) { + return vocab.length; + } + } else if (vocab && typeof vocab === 'object') { + let maxId = -1; + for (const id of Object.values(vocab)) { + const numericId = Number(id); + if (Number.isInteger(numericId) && numericId >= 0) { + maxId = Math.max(maxId, numericId); + } + } + if (maxId >= 0) { + return maxId + 1; + } + } + } + + throw new Error( + 'Unable to resolve vocabulary size for Nemo Conformer TDT. Set `vocab_size` in config.json or provide tokenizer with a vocab.', + ); + } + + _validateRuntimeConfig(vocabSize) { + if (!Number.isInteger(vocabSize) || vocabSize <= 0) { + throw new Error(`Invalid Nemo Conformer TDT config: vocab_size=${vocabSize} must be a positive integer.`); + } + if (this.transducer.blank_token_id >= vocabSize) { + throw new Error( + `Invalid Nemo Conformer TDT config: blank_token_id=${this.transducer.blank_token_id} must be < vocab_size=${vocabSize}.`, + ); + } + const durationStart = this.transducer.duration_start_index ?? vocabSize; + if (!Number.isInteger(durationStart) || durationStart < vocabSize) { + throw new Error( + `Invalid Nemo Conformer TDT config: duration_start_index=${durationStart} must be an integer >= vocab_size=${vocabSize}.`, + ); + } + } + + /** + * Transcribe model-ready features using TDT decoding. + * + * - `returnTimestamps: false` → `{ text, isFinal }` (+ metrics if `returnMetrics`) + * - `returnTimestamps: true` → adds `utteranceTimestamp` and grouped `confidence` + * - `returnWords: true` (requires `returnTimestamps`) → adds `words` list + * - `returnTokens: true` (requires `returnTimestamps`) → adds `tokens` list + * - `returnMetrics` is independent and can be combined with either level. + * - Debug flags (`returnFrameConfidences`, `returnFrameIndices`, `returnLogProbs`, `returnTdtSteps`) are independent. + * - Legacy snake_case aliases (`return_timestamps`, `return_words`, `return_tokens`, `return_metrics`) are accepted. + * + * @param {Object} model_inputs Processor outputs (must include `input_features`). + * @param {Object} [decode_options] + * @param {any} [decode_options.tokenizer] Tokenizer for text reconstruction and word boundaries. + * @param {boolean} [decode_options.returnTimestamps=true] Include utterance-level timestamps and confidence aggregates. + * @param {boolean} [decode_options.return_timestamps] Legacy alias for `returnTimestamps`. + * @param {boolean} [decode_options.returnWords=false] Include word-level list (requires `returnTimestamps`). + * @param {boolean} [decode_options.return_words] Legacy alias for `returnWords`. + * @param {boolean} [decode_options.returnTokens=false] Include token-level list (requires `returnTimestamps`). + * @param {boolean} [decode_options.return_tokens] Legacy alias for `returnTokens`. + * @param {boolean} [decode_options.returnMetrics=false] Include encoding/decoding timing metrics. + * @param {boolean} [decode_options.return_metrics] Legacy alias for `returnMetrics`. + * @param {boolean} [decode_options.returnFrameConfidences=false] Include per-frame confidence scores in `confidence`. + * @param {boolean} [decode_options.returnFrameIndices=false] Include per-token encoder frame indices. + * @param {boolean} [decode_options.returnLogProbs=false] Include per-token log probabilities. + * @param {boolean} [decode_options.returnTdtSteps=false] Include raw TDT duration steps. + * @param {number} [decode_options.timeOffset=0] Offset added to all timestamps (seconds). + * @returns {Promise<{ + * text: string, + * isFinal: boolean, + * utteranceTimestamp?: [number, number], + * words?: Array<{ text: string, startTime: number, endTime: number, confidence?: number }>, + * tokens?: Array<{ id: number, token: string, rawToken: string, isWordStart: boolean, startTime: number, endTime: number, confidence?: number }>, + * confidence?: { utterance?: number|null, wordAverage?: number|null, frames?: number[]|null, frameAverage?: number|null, averageLogProb?: number|null }, + * metrics?: { preprocessMs: number, encodeMs: number, decodeMs: number, tokenizeMs: number, totalMs: number, rtf: number, rtfX: number }, + * debug?: { frameIndices?: number[] | null, logProbs?: number[] | null, tdtSteps?: number[] | null }, + * }>} + */ + async transcribe(model_inputs, decode_options = {}) { + const { + tokenizer = null, + returnTimestamps: returnTimestampsOption, + return_timestamps: legacyReturnTimestamps, + returnWords: returnWordsOption, + return_words: legacyReturnWords, + returnTokens: returnTokensOption, + return_tokens: legacyReturnTokens, + returnMetrics: returnMetricsOption, + return_metrics: legacyReturnMetrics, + returnFrameConfidences = false, + returnFrameIndices = false, + returnLogProbs = false, + returnTdtSteps = false, + timeOffset = 0, + } = decode_options; + const returnTimestamps = returnTimestampsOption ?? legacyReturnTimestamps ?? true; + const returnWords = returnWordsOption ?? legacyReturnWords ?? false; + const returnTokens = returnTokensOption ?? legacyReturnTokens ?? false; + const returnMetrics = returnMetricsOption ?? legacyReturnMetrics ?? false; + + if (!Number.isFinite(timeOffset)) { + throw new Error('NemoConformerForTDT.transcribe expected `timeOffset` to be a finite number.'); + } + const totalStart = nowMs(); + const io = this.transducer.io; + const vocabSize = this._resolveVocabSize(tokenizer); + this._validateRuntimeConfig(vocabSize); + + const { feeds: encoderFeeds, disposables } = this._buildEncoderFeeds(model_inputs); + let encoderOutputs; + const encodeStart = nowMs(); + try { + encoderOutputs = await this._runEncoder(encoderFeeds); + } finally { + for (const tensor of disposables) { + tensor.dispose(); + } + } + const encodeMs = nowMs() - encodeStart; + + let frameCount = 0; + let encoderOutput = null; + const frameTime = this.transducer.subsampling_factor * this.transducer.frame_shift_s; + + const numLayers = this.transducer.decoder.num_layers; + const hiddenSize = this.transducer.decoder.hidden_size; + const blankId = this.transducer.blank_token_id; + const maxSymbolsPerStep = this.transducer.max_symbols_per_step; + + const needConfidences = !!returnTimestamps; + + /** @type {number[]} */ + const tokenIds = []; + /** @type {[number, number][]} */ + const tokenTimestamps = []; + /** @type {number[] | null} */ + const tokenConfidences = needConfidences ? [] : null; + /** @type {Map | null} */ + const frameConfidenceStats = returnFrameConfidences ? new Map() : null; + /** @type {number[] | null} */ + const frameIndices = returnFrameIndices ? [] : null; + /** @type {number[] | null} */ + const logProbs = returnLogProbs || needConfidences ? [] : null; + /** @type {number[] | null} */ + const tdtSteps = returnTdtSteps ? [] : null; + + let decoderState; + let targetLengthTensor; + let reusableFrame = null; + + let emittedOnFrame = 0; + const decodeStart = nowMs(); + + try { + encoderOutput = this._getEncoderOutput(encoderOutputs); + frameCount = this._getEncoderFrameCount(encoderOutput); + decoderState = { + state1: new Tensor('float32', new Float32Array(numLayers * hiddenSize), [numLayers, 1, hiddenSize]), + state2: new Tensor('float32', new Float32Array(numLayers * hiddenSize), [numLayers, 1, hiddenSize]), + }; + + targetLengthTensor = + this.transducer.decoder_token_length_dtype === 'int64' + ? new Tensor('int64', BigInt64Array.from([1n]), [1]) + : new Tensor('int32', new Int32Array([1]), [1]); + + for (let frameIndex = 0; frameIndex < frameCount; ) { + const frameData = this._getFrameData(encoderOutput, frameIndex, reusableFrame); + if (this.transducer.encoder_output_layout === 'BDT') { + reusableFrame = frameData; + } + const frameTensor = this._createFrameTensor(frameData); + const prevTokenId = tokenIds.length > 0 ? tokenIds[tokenIds.length - 1] : blankId; + const tokenTensor = + this.transducer.decoder_token_dtype === 'int64' + ? new Tensor('int64', BigInt64Array.from([BigInt(prevTokenId)]), [1, 1]) + : new Tensor('int32', new Int32Array([prevTokenId]), [1, 1]); + + const decoderFeeds = { + [io.decoder_encoder]: frameTensor, + [io.decoder_token]: tokenTensor, + [io.decoder_token_length]: targetLengthTensor, + [io.decoder_state_1]: decoderState.state1, + [io.decoder_state_2]: decoderState.state2, + }; + + let decoderOutput; + try { + decoderOutput = await this._runDecoder(decoderFeeds); + } finally { + tokenTensor.dispose(); + frameTensor.dispose(); + } + + const logits = decoderOutput[io.decoder_output]; + const outputState1 = decoderOutput[io.decoder_output_state_1]; + const outputState2 = decoderOutput[io.decoder_output_state_2]; + const seenDecoderTensors = new Set(); + for (const value of Object.values(decoderOutput)) { + if (!(value instanceof Tensor) || seenDecoderTensors.has(value)) continue; + seenDecoderTensors.add(value); + if (value === logits || value === outputState1 || value === outputState2) { + continue; + } + value.dispose(); + } + if (!(logits instanceof Tensor)) { + this._disposeDecoderState( + { + state1: outputState1, + state2: outputState2, + }, + decoderState, + ); + throw new Error( + `Nemo Conformer TDT decoder output "${io.decoder_output}" was not returned by the session.`, + ); + } + if (!(outputState1 instanceof Tensor) || !(outputState2 instanceof Tensor)) { + logits.dispose(); + this._disposeDecoderState( + { + state1: outputState1, + state2: outputState2, + }, + decoderState, + ); + throw new Error( + `Nemo Conformer TDT decoder state outputs "${io.decoder_output_state_1}" and "${io.decoder_output_state_2}" were not returned by the session.`, + ); + } + const logitsData = logits.data; + if (logitsData.length < vocabSize) { + logits.dispose(); + this._disposeDecoderState( + { + state1: outputState1, + state2: outputState2, + }, + decoderState, + ); + throw new Error( + `Nemo Conformer TDT decoder output is too small (${logitsData.length}) for vocab_size=${vocabSize}.`, + ); + } + const tokenId = argmax(logitsData, 0, vocabSize); + const durationStart = this.transducer.duration_start_index ?? vocabSize; + const hasDurationLogits = logitsData.length > durationStart; + if (this.transducer.duration_start_index != null && !hasDurationLogits) { + logits.dispose(); + this._disposeDecoderState( + { + state1: outputState1, + state2: outputState2, + }, + decoderState, + ); + throw new Error( + `Nemo Conformer TDT decoder output is missing duration logits: expected values beyond index ${durationStart - 1}, got length=${logitsData.length}.`, + ); + } + const step = hasDurationLogits + ? argmax(logitsData, durationStart, logitsData.length - durationStart) - durationStart + : 0; + if (tdtSteps) { + tdtSteps.push(step); + } + + const maybeConfidence = + needConfidences || returnLogProbs || returnFrameConfidences + ? confidenceFromLogits(logitsData, tokenId, vocabSize) + : null; + if (frameConfidenceStats && maybeConfidence) { + const stats = frameConfidenceStats.get(frameIndex); + if (stats) { + stats.sum += maybeConfidence.confidence; + stats.count += 1; + } else { + frameConfidenceStats.set(frameIndex, { sum: maybeConfidence.confidence, count: 1 }); + } + } + + const newState = { + state1: outputState1, + state2: outputState2, + }; + + if (tokenId !== blankId) { + this._disposeDecoderState(decoderState, newState); + decoderState = newState; + + tokenIds.push(tokenId); + // TDT duration convention: step=0 means "stay on current frame" (duration index 0 = no advance). + // We still associate the token with this frame, so durationFrames is at least 1. + const durationFrames = Math.max(1, step); + const endFrame = Math.min(frameCount, frameIndex + durationFrames); + tokenTimestamps.push([ + roundTs(frameIndex * frameTime + timeOffset), + roundTs(endFrame * frameTime + timeOffset), + ]); + if (tokenConfidences && maybeConfidence) { + tokenConfidences.push(maybeConfidence.confidence); + } + if (frameIndices) { + frameIndices.push(frameIndex); + } + if (logProbs && maybeConfidence) { + logProbs.push(maybeConfidence.logProb); + } + emittedOnFrame += 1; + } else { + this._disposeDecoderState(newState, decoderState); + } + + logits.dispose(); + + if (step > 0) { + frameIndex += step; + emittedOnFrame = 0; + } else if (tokenId === blankId || emittedOnFrame >= maxSymbolsPerStep) { + frameIndex += 1; + emittedOnFrame = 0; + } + } + } finally { + if (targetLengthTensor) targetLengthTensor.dispose(); + if (decoderState) this._disposeDecoderState(decoderState); + if (encoderOutputs && typeof encoderOutputs === 'object') { + const seen = new Set(); + for (const value of Object.values(encoderOutputs)) { + if (value instanceof Tensor && !seen.has(value)) { + value.dispose(); + seen.add(value); + } + } + } + } + const decodeMs = nowMs() - decodeStart; + + const tokenizeStart = nowMs(); + const text = decodeTransducerText(tokenizer, tokenIds); + const needDetailed = returnTimestamps && (returnWords || returnTokens); + const detailed = needDetailed + ? buildTransducerDetailedOutputs(tokenizer, tokenIds, tokenTimestamps, tokenConfidences) + : null; + const tokenizeMs = nowMs() - tokenizeStart; + + /** @type {any} */ + const result = { text, isFinal: true }; + const utteranceConfidence = + tokenConfidences && tokenConfidences.length > 0 + ? roundMetric(tokenConfidences.reduce((a, b) => a + b, 0) / tokenConfidences.length, 6) + : null; + const utteranceTimestamp = + tokenTimestamps.length > 0 + ? /** @type {[number, number]} */ ([ + tokenTimestamps[0][0], + tokenTimestamps[tokenTimestamps.length - 1][1], + ]) + : /** @type {[number, number]} */ ([roundTs(timeOffset), roundTs(frameCount * frameTime + timeOffset)]); + const averageLogProb = + logProbs && logProbs.length > 0 + ? roundMetric(logProbs.reduce((a, b) => a + b, 0) / logProbs.length, 6) + : null; + + if (returnTimestamps) { + result.utteranceTimestamp = utteranceTimestamp; + + if (detailed) { + if (returnWords) result.words = detailed.words; + if (returnTokens) result.tokens = detailed.tokens; + } + + result.confidence = { + utterance: utteranceConfidence, + wordAverage: detailed?.wordAverage != null ? roundMetric(detailed.wordAverage, 6) : null, + averageLogProb, + }; + } + + // Frame confidences are independent of return_timestamps — emit whenever requested. + if (returnFrameConfidences && frameConfidenceStats && frameConfidenceStats.size > 0) { + const frameConfidences = []; + for (const { sum, count } of frameConfidenceStats.values()) { + frameConfidences.push(sum / count); + } + result.confidence = { + ...(result.confidence ?? {}), + frames: frameConfidences, + frameAverage: roundMetric(frameConfidences.reduce((a, b) => a + b, 0) / frameConfidences.length, 6), + }; + } + + if (!returnTimestamps && averageLogProb != null) { + result.confidence = { + ...(result.confidence ?? {}), + averageLogProb, + }; + } + + const debug = {}; + if (returnFrameIndices) { + debug.frameIndices = frameIndices; + } + if (returnLogProbs) { + debug.logProbs = logProbs; + } + if (returnTdtSteps) { + debug.tdtSteps = tdtSteps; + } + if (Object.keys(debug).length > 0) { + result.debug = debug; + } + + if (returnMetrics) { + const totalMs = nowMs() - totalStart; + const utteranceDuration = result.utteranceTimestamp + ? Math.max(result.utteranceTimestamp[1] - result.utteranceTimestamp[0], 1e-8) + : Math.max(frameCount * frameTime, 1e-8); + const rtf = totalMs / 1000 / utteranceDuration; + result.metrics = { + preprocessMs: 0.0, + encodeMs: roundMetric(encodeMs, 2), + decodeMs: roundMetric(decodeMs, 2), + tokenizeMs: roundMetric(tokenizeMs, 2), + totalMs: roundMetric(totalMs, 2), + rtf: roundMetric(rtf, 4), + rtfX: roundMetric(1 / Math.max(rtf, 1e-8), 2), + }; + } + + return result; + } + + /** + * Runs TDT transcription when called directly. + * @param {Object} model_inputs + */ + async _call(model_inputs) { + return await this.transcribe(model_inputs); + } +} + +MODEL_TYPE_MAPPING.set('nemo-conformer-tdt', MODEL_TYPES.NemoConformerTDT); +MODEL_TYPE_MAPPING.set('NemoConformerForTDT', MODEL_TYPES.NemoConformerTDT); +MODEL_NAME_TO_CLASS_MAPPING.set('NemoConformerTDTPreTrainedModel', NemoConformerTDTPreTrainedModel); +MODEL_NAME_TO_CLASS_MAPPING.set('NemoConformerForTDT', NemoConformerForTDT); +MODEL_CLASS_TO_NAME_MAPPING.set(NemoConformerTDTPreTrainedModel, 'NemoConformerTDTPreTrainedModel'); +MODEL_CLASS_TO_NAME_MAPPING.set(NemoConformerForTDT, 'NemoConformerForTDT'); diff --git a/packages/transformers/src/models/nemo_conformer_tdt/pipeline_nemo_conformer_tdt.js b/packages/transformers/src/models/nemo_conformer_tdt/pipeline_nemo_conformer_tdt.js new file mode 100644 index 000000000..f28fd5611 --- /dev/null +++ b/packages/transformers/src/models/nemo_conformer_tdt/pipeline_nemo_conformer_tdt.js @@ -0,0 +1,345 @@ +import { Tensor } from '../../utils/tensor.js'; +import { NEMO_FEATURE_OUTPUT_OWNERSHIP, NEMO_FEATURE_OUTPUT_RELEASE } from './feature_extraction_nemo_conformer_tdt.js'; +import { + buildWordChunks, + buildNemoSegmentChunks, + joinTimedWords, + partitionNemoWordsIntoSegments, +} from './transducer_segment_offsets.js'; +import { dedupeMergedWords } from './transducer_window_merge.js'; + +const NEMO_AUTO_WINDOW_THRESHOLD_S = 180; +const NEMO_MIN_CHUNK_LENGTH_S = 20; +const NEMO_MAX_CHUNK_LENGTH_S = 180; +const NEMO_AUTO_CHUNK_LENGTH_S = 90; +const NEMO_AUTO_WINDOW_FALLBACK_OVERLAP_S = 10; +const NEMO_AUTO_WINDOW_EPSILON_S = 1e-6; +const NEMO_SEGMENT_DEDUP_TOLERANCE_S = 0.15; +const NEMO_CURSOR_MIN_ADVANCE_S = 1.0; +const NEMO_CURSOR_GAP_THRESHOLD_S = 0.2; +const NEMO_CURSOR_SNAP_WINDOW_S = 0.5; + +function validateNemoAudio(audio, index) { + if (!(audio instanceof Float32Array || audio instanceof Float64Array)) { + throw new TypeError( + `Nemo Conformer TDT pipeline expected audio at index ${index} to be Float32Array or Float64Array.`, + ); + } + if (audio.length === 0) { + throw new Error(`Nemo Conformer TDT pipeline expected non-empty audio at index ${index}.`); + } + for (let i = 0; i < audio.length; ++i) { + if (!Number.isFinite(audio[i])) { + throw new Error( + `Nemo Conformer TDT pipeline expected finite audio samples; found ${audio[i]} at index ${index}:${i}.`, + ); + } + } +} + +function disposeNemoPipelineInputs(inputs) { + const seen = new Set(); + for (const value of Object.values(inputs ?? {})) { + if (value instanceof Tensor && !seen.has(value)) { + value.dispose(); + seen.add(value); + } + } +} + +function releaseNemoPipelineInputs(inputs) { + const release = inputs?.[NEMO_FEATURE_OUTPUT_RELEASE]; + if (typeof release === 'function') { + release(); + } +} + +function normalizeNemoChunkLengthS(value) { + const num = Number(value); + if (!Number.isFinite(num) || num <= 0) { + return 0; + } + return Math.max(NEMO_MIN_CHUNK_LENGTH_S, Math.min(NEMO_MAX_CHUNK_LENGTH_S, num)); +} + +function flattenNemoSegmentWords(segments) { + return segments.flatMap((segment) => segment.words); +} + +function mergePendingAndCurrentNemoWords(pendingWords, currentWords) { + const normalizedPendingWords = Array.isArray(pendingWords) ? pendingWords : []; + const normalizedCurrentWords = Array.isArray(currentWords) ? currentWords : []; + + if (normalizedPendingWords.length === 0) { + return dedupeMergedWords(normalizedCurrentWords); + } + if (normalizedCurrentWords.length === 0) { + return dedupeMergedWords(normalizedPendingWords); + } + + const pendingStart = normalizedPendingWords[0].startTime; + const currentStart = normalizedCurrentWords[0].startTime; + if (currentStart <= pendingStart + NEMO_AUTO_WINDOW_EPSILON_S) { + return dedupeMergedWords(normalizedCurrentWords); + } + + return dedupeMergedWords([...normalizedPendingWords, ...normalizedCurrentWords]); +} + +function normalizeNemoSegmentText(text) { + return String(text ?? '') + .normalize('NFKC') + .replace(/[“”]/g, '"') + .replace(/[‘’]/g, "'") + .replace(/\s+/g, ' ') + .trim() + .toLowerCase(); +} + +function isDuplicateFinalizedNemoSegment(finalizedSegments, segment) { + const normalized = normalizeNemoSegmentText(segment.text); + if (!normalized) { + return false; + } + + return finalizedSegments.some( + (candidate) => + normalizeNemoSegmentText(candidate.text) === normalized && + Math.abs(candidate.timestamp[1] - segment.timestamp[1]) < NEMO_SEGMENT_DEDUP_TOLERANCE_S, + ); +} + +function appendFinalizedNemoSegment(finalizedSegments, segment) { + if (isDuplicateFinalizedNemoSegment(finalizedSegments, segment)) { + return; + } + finalizedSegments.push(segment); +} + +function relocateNemoCursorToNearbyGap(target_s, words) { + let best = target_s; + let bestDist = NEMO_CURSOR_SNAP_WINDOW_S + 1; + + for (let i = 0; i < words.length - 1; ++i) { + const current = words[i]; + const next = words[i + 1]; + const gapStart = current.endTime; + const gapEnd = next.startTime; + const gap = gapEnd - gapStart; + if (gap < NEMO_CURSOR_GAP_THRESHOLD_S) { + continue; + } + + for (const candidate of [gapStart, gapEnd]) { + if (candidate + NEMO_AUTO_WINDOW_EPSILON_S < target_s) { + continue; + } + const dist = candidate - target_s; + if (dist <= NEMO_CURSOR_SNAP_WINDOW_S && dist < bestDist) { + best = candidate; + bestDist = dist; + } + } + } + + return best; +} + +async function runNemoAutoSentenceWindowing({ audio, sampling_rate, chunk_length_s, tokenizer, runNemoTranscribe }) { + const audio_duration_s = audio.length / sampling_rate; + const fallback_overlap_s = Math.min(NEMO_AUTO_WINDOW_FALLBACK_OVERLAP_S, Math.max(0, chunk_length_s - 1)); + const fallback_advance_s = Math.max(1, chunk_length_s - fallback_overlap_s); + const maxWindows = Math.max( + 4, + Math.ceil(Math.max(0, audio_duration_s - chunk_length_s) / NEMO_CURSOR_MIN_ADVANCE_S) + 2, + ); + + /** @type {Array<{ words: Array<{ text: string, startTime: number, endTime: number, confidence?: number }>, text: string, timestamp: [number, number] }>} */ + const finalizedSegments = []; + /** @type {Array<{ text: string, startTime: number, endTime: number, confidence?: number }>} */ + let pendingWords = []; + let lastTextFallback = ''; + let start_s = 0; + let shouldMergePending = false; + + for ( + let windowIndex = 0; + windowIndex < maxWindows && start_s < audio_duration_s - NEMO_AUTO_WINDOW_EPSILON_S; + ++windowIndex + ) { + const end_s = Math.min(audio_duration_s, start_s + chunk_length_s); + const start_sample = Math.max(0, Math.min(audio.length - 1, Math.floor(start_s * sampling_rate))); + const end_sample = Math.max(start_sample + 1, Math.min(audio.length, Math.ceil(end_s * sampling_rate))); + const windowAudio = audio.subarray(start_sample, end_sample); + const is_last_window = end_s >= audio_duration_s - NEMO_AUTO_WINDOW_EPSILON_S; + + const output = await runNemoTranscribe(windowAudio, { + tokenizer, + returnTimestamps: true, + returnWords: true, + returnMetrics: false, + timeOffset: start_s, + }); + lastTextFallback = output.text ?? lastTextFallback; + + const currentWords = Array.isArray(output.words) ? output.words : []; + const windowWords = shouldMergePending + ? mergePendingAndCurrentNemoWords(pendingWords, currentWords) + : dedupeMergedWords(currentWords); + const segments = partitionNemoWordsIntoSegments(windowWords); + + if (is_last_window) { + for (const segment of segments) { + appendFinalizedNemoSegment(finalizedSegments, segment); + } + pendingWords = []; + break; + } + + if (segments.length > 1) { + const pendingSegment = segments[segments.length - 1]; + const pendingStart_s = pendingSegment.timestamp[0]; + if (pendingStart_s >= start_s + NEMO_CURSOR_MIN_ADVANCE_S - NEMO_AUTO_WINDOW_EPSILON_S) { + const readySegments = segments.slice(0, -1); + for (const segment of readySegments) { + appendFinalizedNemoSegment(finalizedSegments, segment); + } + + pendingWords = dedupeMergedWords(pendingSegment.words); + const next_start_s = Math.min( + audio_duration_s, + relocateNemoCursorToNearbyGap(pendingStart_s, windowWords), + ); + shouldMergePending = next_start_s > pendingStart_s + NEMO_AUTO_WINDOW_EPSILON_S; + if (next_start_s > start_s + NEMO_AUTO_WINDOW_EPSILON_S) { + start_s = next_start_s; + continue; + } + } + } + + pendingWords = windowWords; + shouldMergePending = true; + + const fallback_start_s = Math.min(audio_duration_s, start_s + fallback_advance_s); + if (fallback_start_s <= start_s + NEMO_AUTO_WINDOW_EPSILON_S) { + break; + } + start_s = fallback_start_s; + } + + const words = dedupeMergedWords([...flattenNemoSegmentWords(finalizedSegments), ...pendingWords]); + const text = words.length > 0 ? joinTimedWords(words) : String(lastTextFallback ?? '').trim(); + const utteranceTimestamp = + words.length > 0 + ? /** @type {[number, number]} */ ([words[0].startTime, words[words.length - 1].endTime]) + : null; + + return { + text, + words, + utteranceTimestamp, + chunks: buildNemoSegmentChunks(words, utteranceTimestamp, text), + }; +} + +/** + * Run the ASR pipeline adapter for Nemo Conformer TDT models. + * Keeps the public contract task-shaped while delegating rich outputs to `model.transcribe()`. + * + * @param {{ + * model: any, + * processor: any, + * tokenizer: any, + * audio: Float32Array|Float64Array|Array, + * kwargs: Record, + * prepareAudios: (audio: any[], sampling_rate: number) => Promise<(Float32Array|Float64Array)[]>, + * }} options + */ +export async function runNemoConformerTDTPipeline({ model, processor, tokenizer, audio, kwargs, prepareAudios }) { + if (typeof model?.transcribe !== 'function') { + throw new Error('Nemo Conformer TDT model does not expose a `transcribe` method.'); + } + if (!processor) { + throw new Error('Nemo Conformer TDT pipeline requires a processor.'); + } + if (!tokenizer) { + throw new Error('Nemo Conformer TDT pipeline requires a tokenizer.'); + } + if (!processor.feature_extractor?.config?.sampling_rate) { + throw new Error( + 'Nemo Conformer TDT pipeline requires `processor.feature_extractor.config.sampling_rate` to prepare audio.', + ); + } + + const return_timestamps = kwargs.return_timestamps ?? false; + const wantWordTimestamps = return_timestamps === 'word'; + const wantTimestampChunks = return_timestamps === true || wantWordTimestamps; + const requested_chunk_length_s = normalizeNemoChunkLengthS(kwargs.chunk_length_s ?? 0); + + const single = !Array.isArray(audio); + const batchedAudio = single ? [audio] : audio; + const sampling_rate = processor.feature_extractor.config.sampling_rate; + const preparedAudios = await prepareAudios(batchedAudio, sampling_rate); + for (let i = 0; i < preparedAudios.length; ++i) { + validateNemoAudio(preparedAudios[i], i); + } + + const runNemoTranscribe = async (windowAudio, decodeOptions) => { + const inputs = await processor(windowAudio); + const cacheOwnsTensors = inputs?.[NEMO_FEATURE_OUTPUT_OWNERSHIP] === true; + try { + return await model.transcribe(inputs, decodeOptions); + } finally { + if (cacheOwnsTensors) { + releaseNemoPipelineInputs(inputs); + } else { + disposeNemoPipelineInputs(inputs); + } + } + }; + + const toReturn = []; + for (const aud of preparedAudios) { + const audio_duration_s = aud.length / sampling_rate; + const autoWindowing = requested_chunk_length_s <= 0 && audio_duration_s > NEMO_AUTO_WINDOW_THRESHOLD_S; + const chunk_length_s = + requested_chunk_length_s > 0 ? requested_chunk_length_s : autoWindowing ? NEMO_AUTO_CHUNK_LENGTH_S : 0; + const useSentenceWindowing = chunk_length_s > 0; + + if (useSentenceWindowing) { + const merged = await runNemoAutoSentenceWindowing({ + audio: aud, + sampling_rate, + chunk_length_s, + tokenizer, + runNemoTranscribe, + }); + const result = { text: merged.text }; + if (wantWordTimestamps) { + result.chunks = buildWordChunks(merged.words); + } else if (wantTimestampChunks) { + result.chunks = merged.chunks; + } + toReturn.push(result); + continue; + } + + const output = await runNemoTranscribe(aud, { + tokenizer, + returnTimestamps: wantTimestampChunks, + returnWords: wantTimestampChunks, + returnMetrics: false, + }); + + const result = { text: output.text ?? '' }; + if (wantWordTimestamps) { + result.chunks = buildWordChunks(output.words ?? []); + } else if (wantTimestampChunks) { + result.chunks = buildNemoSegmentChunks(output.words ?? [], output.utteranceTimestamp ?? null, result.text); + } + toReturn.push(result); + } + + return single ? toReturn[0] : toReturn; +} diff --git a/packages/transformers/src/models/nemo_conformer_tdt/processing_nemo_conformer_tdt.js b/packages/transformers/src/models/nemo_conformer_tdt/processing_nemo_conformer_tdt.js new file mode 100644 index 000000000..4c2d0a7eb --- /dev/null +++ b/packages/transformers/src/models/nemo_conformer_tdt/processing_nemo_conformer_tdt.js @@ -0,0 +1,19 @@ +import { AutoFeatureExtractor } from '../auto/feature_extraction_auto.js'; +import { AutoTokenizer } from '../auto/tokenization_auto.js'; +import { Processor } from '../../processing_utils.js'; + +/** + * Processor for Nemo Conformer TDT models. + */ +export class NemoConformerTDTProcessor extends Processor { + static tokenizer_class = AutoTokenizer; + static feature_extractor_class = AutoFeatureExtractor; + + /** + * Preprocess raw audio for Nemo Conformer TDT models. + * @param {Float32Array|Float64Array} audio + */ + async _call(audio) { + return await this.feature_extractor(audio); + } +} diff --git a/packages/transformers/src/models/nemo_conformer_tdt/transducer_cache.js b/packages/transformers/src/models/nemo_conformer_tdt/transducer_cache.js new file mode 100644 index 000000000..6c0cfab1f --- /dev/null +++ b/packages/transformers/src/models/nemo_conformer_tdt/transducer_cache.js @@ -0,0 +1,259 @@ +import { Tensor } from '../../utils/tensor.js'; + +/** + * Create a stable hash key for audio samples, used by feature caches. + * @param {Float32Array|Float64Array} audio + * @param {number} [sampling_rate=16000] + * @returns {string} + */ +export function createAudioCacheKey(audio, sampling_rate = 16000) { + // FNV-1a 32-bit over quantized values for deterministic cross-runtime keys. + let hash = 2166136261; + hash ^= audio.length; + hash = Math.imul(hash, 16777619); + hash ^= sampling_rate; + hash = Math.imul(hash, 16777619); + + // Hash all quantized samples to minimize false cache hits across waveforms. + for (let i = 0; i < audio.length; ++i) { + const sample = Number.isFinite(audio[i]) ? audio[i] : 0; + const q = Math.max(-32768, Math.min(32767, Math.round(sample * 32768))); + hash ^= q; + hash = Math.imul(hash, 16777619); + } + return `${sampling_rate}:${audio.length}:${(hash >>> 0).toString(16)}`; +} + +/** + * Lightweight LRU cache for extracted features. + * Stores values as-is, owns cached tensor lifetimes, and tracks approximate memory usage. + */ +export class FeatureLRUCache { + /** + * @param {{max_entries?: number, max_size_mb?: number}} [options] + */ + constructor({ max_entries = 128, max_size_mb = 64 } = {}) { + if (!Number.isInteger(max_entries) || max_entries < 0) { + throw new Error('FeatureLRUCache expected `max_entries` to be a non-negative integer.'); + } + if (!Number.isFinite(max_size_mb) || max_size_mb < 0) { + throw new Error('FeatureLRUCache expected `max_size_mb` to be a non-negative number.'); + } + this.max_entries = max_entries; + this.max_size_mb = max_size_mb; + this.cache = new Map(); + this.current_size_bytes = 0; + } + + /** + * @param {string} key + * @returns {any|null} + */ + get(key) { + const entry = this._touch(key); + if (!entry) return null; + return entry.value; + } + + /** + * @param {string} key + * @returns {{ value: any, release: () => void } | null} + */ + acquire(key) { + const entry = this._touch(key); + if (!entry) return null; + + entry.borrowers += 1; + let released = false; + return { + value: entry.value, + release: () => { + if (released) return; + released = true; + this._releaseEntry(entry); + }, + }; + } + + /** + * @param {string} key + * @param {any} value + * @returns {boolean} Whether the cache retained ownership of the supplied value. + */ + set(key, value) { + // Explicit no-cache mode: keep caller ownership of current values. + if (this.max_entries === 0 || this.max_size_mb === 0) { + if (this.cache.size > 0) { + this.clear(); + } + return false; + } + + const max_bytes = this.max_size_mb * 1024 * 1024; + const existing = this.cache.get(key); + if (existing?.value === value) { + // Refresh recency for unchanged value without invalidating caller-owned references. + if (existing.size_bytes <= max_bytes) { + this.cache.delete(key); + this.cache.set(key, existing); + return true; + } else { + this._deleteEntry(key, existing); + return false; + } + } + + const size_bytes = estimateSizeBytes(value); + if (size_bytes > max_bytes) { + // Cannot fit in cache: keep caller ownership and skip caching. + if (existing) { + this._deleteEntry(key, existing); + } + return false; + } + + if (existing) { + this._deleteEntry(key, existing); + } + + this.cache.set(key, { + value, + size_bytes, + borrowers: 0, + pendingDispose: false, + }); + this.current_size_bytes += size_bytes; + this._evict(); + return this.cache.get(key)?.value === value; + } + + clear() { + for (const [key, entry] of Array.from(this.cache.entries())) { + this._deleteEntry(key, entry); + } + } + + stats() { + return { + entries: this.cache.size, + size_mb: this.current_size_bytes / (1024 * 1024), + max_entries: this.max_entries, + max_size_mb: this.max_size_mb, + }; + } + + _evict() { + const max_bytes = this.max_size_mb * 1024 * 1024; + while (this.cache.size > this.max_entries || this.current_size_bytes > max_bytes) { + const oldest_key = this.cache.keys().next().value; + if (oldest_key === undefined) break; + const oldest = this.cache.get(oldest_key); + if (!oldest) break; + this._deleteEntry(oldest_key, oldest); + } + } + + _touch(key) { + const entry = this.cache.get(key); + if (!entry) return null; + this.cache.delete(key); + this.cache.set(key, entry); + return entry; + } + + _deleteEntry(key, entry) { + const current = this.cache.get(key); + if (current !== entry) { + return; + } + + this.cache.delete(key); + if (entry.borrowers > 0) { + entry.pendingDispose = true; + } else { + this.current_size_bytes -= entry.size_bytes; + disposeCachedValue(entry.value); + } + } + + _releaseEntry(entry) { + if (entry.borrowers > 0) { + entry.borrowers -= 1; + } + if (entry.borrowers === 0 && entry.pendingDispose) { + entry.pendingDispose = false; + this.current_size_bytes -= entry.size_bytes; + disposeCachedValue(entry.value); + } + } +} + +function tensorByteSize(tensor) { + let byteLength = null; + try { + byteLength = /** @type {any} */ (tensor.data)?.byteLength ?? null; + } catch { + byteLength = null; + } + if (typeof byteLength === 'number' && byteLength >= 0) { + return byteLength; + } + + const bytesPerElement = { + bool: 1, + int8: 1, + uint8: 1, + int16: 2, + uint16: 2, + int32: 4, + uint32: 4, + int64: 8, + uint64: 8, + float16: 2, + float32: 4, + float64: 8, + }; + return tensor.size * (bytesPerElement[tensor.type] ?? 4); +} + +function collectCachedTensors(value, out = new Set()) { + if (value instanceof Tensor) { + out.add(value); + return out; + } + if (value?.input_features instanceof Tensor) out.add(value.input_features); + if (value?.attention_mask instanceof Tensor) out.add(value.attention_mask); + if (value?.delta_features instanceof Tensor) out.add(value.delta_features); + if (value?.delta_delta_features instanceof Tensor) out.add(value.delta_delta_features); + return out; +} + +function disposeCachedValue(value) { + for (const tensor of collectCachedTensors(value)) { + tensor.dispose(); + } +} + +function estimateSizeBytes(value) { + if (value instanceof Tensor) { + return tensorByteSize(value); + } + if (value?.input_features instanceof Tensor) { + let bytes = tensorByteSize(value.input_features); + if (value.attention_mask instanceof Tensor) { + bytes += tensorByteSize(value.attention_mask); + } + if (value.delta_features instanceof Tensor) { + bytes += tensorByteSize(value.delta_features); + } + if (value.delta_delta_features instanceof Tensor) { + bytes += tensorByteSize(value.delta_delta_features); + } + return bytes; + } + const byteLength = value?.byteLength; + if (typeof byteLength === 'number' && Number.isFinite(byteLength) && byteLength >= 0) { + return byteLength; + } + return 0; +} diff --git a/packages/transformers/src/models/nemo_conformer_tdt/transducer_deltas.js b/packages/transformers/src/models/nemo_conformer_tdt/transducer_deltas.js new file mode 100644 index 000000000..957fa0776 --- /dev/null +++ b/packages/transformers/src/models/nemo_conformer_tdt/transducer_deltas.js @@ -0,0 +1,91 @@ +import { Tensor } from '../../utils/tensor.js'; + +/** + * Compute temporal deltas (and optionally delta-deltas) for [1, T, F] features. + * @param {Tensor} input_features + * @param {{order?: 1|2, window?: number, concatenate?: boolean}} [options] + * @returns {Tensor|{delta: Tensor, delta_delta?: Tensor}} + */ +export function computeTemporalDeltas(input_features, { order = 1, window = 2, concatenate = false } = {}) { + if (!(input_features instanceof Tensor)) { + throw new Error('computeTemporalDeltas expects `input_features` as a Tensor.'); + } + if (input_features.dims.length !== 3 || input_features.dims[0] !== 1) { + throw new Error(`computeTemporalDeltas expects dims [1, T, F], got [${input_features.dims.join(', ')}].`); + } + if (!Number.isInteger(window) || window <= 0) { + throw new Error('computeTemporalDeltas expects `window` to be a positive integer.'); + } + if (order !== 1 && order !== 2) { + throw new Error('computeTemporalDeltas expects `order` to be 1 or 2.'); + } + if (input_features.type !== 'float32') { + throw new Error(`computeTemporalDeltas expects input tensor type "float32", got "${input_features.type}".`); + } + + const [batch, T, F] = input_features.dims; + const base = /** @type {Float32Array} */ (input_features.data); + const delta = new Float32Array(base.length); + const denom = 2 * Array.from({ length: window }, (_, i) => (i + 1) ** 2).reduce((a, b) => a + b, 0); + + const at = (t, f) => base[t * F + f]; + for (let t = 0; t < T; ++t) { + for (let f = 0; f < F; ++f) { + let num = 0; + for (let n = 1; n <= window; ++n) { + const tp = Math.min(T - 1, t + n); + const tm = Math.max(0, t - n); + num += n * (at(tp, f) - at(tm, f)); + } + delta[t * F + f] = num / denom; + } + } + + const delta_tensor = new Tensor('float32', delta, [batch, T, F]); + if (order === 1) { + if (!concatenate) { + return { delta: delta_tensor }; + } + const result = new Tensor('float32', interleaveByFrame([base, delta], T, F), [batch, T, F * 2]); + delta_tensor.dispose(); + return result; + } + + const recursive_result = /** @type {{delta: Tensor}} */ ( + computeTemporalDeltas(delta_tensor, { order: 1, window, concatenate: false }) + ); + const delta_delta_tensor = recursive_result.delta; + if (!concatenate) { + return { + delta: delta_tensor, + delta_delta: delta_delta_tensor, + }; + } + + const delta_delta = /** @type {Float32Array} */ (delta_delta_tensor.data); + const result = new Tensor('float32', interleaveByFrame([base, delta, delta_delta], T, F), [batch, T, F * 3]); + delta_delta_tensor.dispose(); + delta_tensor.dispose(); + return result; +} + +function interleaveByFrame(items, T, F) { + const chunkSize = T * F; + for (const arr of items) { + if (arr.length !== chunkSize) { + throw new Error( + `computeTemporalDeltas expected concatenation arrays with length ${chunkSize}, got ${arr.length}.`, + ); + } + } + + const output = new Float32Array(chunkSize * items.length); + for (let t = 0; t < T; ++t) { + const srcOffset = t * F; + const dstOffset = t * F * items.length; + for (let i = 0; i < items.length; ++i) { + output.set(items[i].subarray(srcOffset, srcOffset + F), dstOffset + i * F); + } + } + return output; +} diff --git a/packages/transformers/src/models/nemo_conformer_tdt/transducer_segment_offsets.js b/packages/transformers/src/models/nemo_conformer_tdt/transducer_segment_offsets.js new file mode 100644 index 000000000..632b35d77 --- /dev/null +++ b/packages/transformers/src/models/nemo_conformer_tdt/transducer_segment_offsets.js @@ -0,0 +1,178 @@ +const NEMO_STRONG_SENTENCE_END_REGEX = /[!?…](?:["')\]]+)?$/u; +const NEMO_PERIOD_SENTENCE_END_REGEX = /\.(?:["')\]]+)?$/u; +const NEMO_TRAILING_CLOSERS_REGEX = /["')\]]+$/gu; +const NEMO_LEADING_OPENERS_REGEX = /^[("'“‘\[{]+/u; +const NEMO_DOTTED_ACRONYM_REGEX = /^(?:[A-Z]\.){2,}$/; +const NEMO_SINGLE_LETTER_ENUM_REGEX = /^[A-Z]\.$/; +const NEMO_ROMAN_ENUM_REGEX = /^(?:[IVXLCDM]+)\.$/i; +const NEMO_NUMERIC_ENUM_REGEX = /^\d+\.$/; +const NEMO_FALLBACK_SEGMENT_GAP_S = 3.0; +const NEMO_NON_BREAKING_PERIOD_WORDS = new Set([ + 'mr.', + 'mrs.', + 'ms.', + 'dr.', + 'prof.', + 'sr.', + 'jr.', + 'vs.', + 'etc.', + 'e.g.', + 'i.e.', +]); + +/** + * @param {Array<{ text: string, startTime: number, endTime: number }>} words + * @returns {string} + */ +export function joinTimedWords(words) { + let text = ''; + for (const word of words) { + const part = word.text ?? ''; + if (!part) continue; + if (!text) { + text = part; + } else if (/^[,.;:!?)}\]]+$/.test(part)) { + text += part; + } else { + text += ` ${part}`; + } + } + return text; +} + +/** + * @param {Array<{ text: string, startTime: number, endTime: number }>} words + * @returns {Array<{ text: string, timestamp: [number, number] }>} + */ +export function buildWordChunks(words) { + return words.map((word) => ({ + text: word.text, + timestamp: [word.startTime, word.endTime], + })); +} + +/** + * @param {Array<{ text: string, startTime: number, endTime: number }>} words + * @returns {string} + */ +export function buildSegmentText(words) { + return joinTimedWords(words); +} + +function stripTrailingClosers(text) { + return String(text ?? '').replace(NEMO_TRAILING_CLOSERS_REGEX, ''); +} + +function looksLikeSentenceStart(text) { + const cleaned = String(text ?? '').replace(NEMO_LEADING_OPENERS_REGEX, ''); + return /^[A-Z]/.test(cleaned); +} + +/** + * Conservative sentence-boundary heuristic for ASR word timestamps. + * Favors under-segmentation over mid-sentence false positives. + * + * @param {{ text: string }} currentWord + * @param {{ text: string } | null} nextWord + * @param {number} gap_s + * @returns {boolean} + */ +export function shouldEndSentenceAfterWord(currentWord, nextWord, gap_s = 0) { + if (!nextWord) { + return false; + } + + if (gap_s >= NEMO_FALLBACK_SEGMENT_GAP_S) { + return true; + } + + const currentText = String(currentWord?.text ?? ''); + if (!currentText) { + return false; + } + + if (NEMO_STRONG_SENTENCE_END_REGEX.test(currentText)) { + return true; + } + + if (!NEMO_PERIOD_SENTENCE_END_REGEX.test(currentText)) { + return false; + } + + const stripped = stripTrailingClosers(currentText); + const lowered = stripped.toLowerCase(); + if ( + NEMO_NON_BREAKING_PERIOD_WORDS.has(lowered) || + NEMO_DOTTED_ACRONYM_REGEX.test(stripped) || + NEMO_SINGLE_LETTER_ENUM_REGEX.test(stripped) || + NEMO_ROMAN_ENUM_REGEX.test(stripped) || + NEMO_NUMERIC_ENUM_REGEX.test(stripped) + ) { + return false; + } + + return looksLikeSentenceStart(nextWord.text); +} + +/** + * Partition timed words into conservative sentence-like segments. + * + * @param {Array<{ text: string, startTime: number, endTime: number }>} words + * @returns {Array<{ words: Array<{ text: string, startTime: number, endTime: number }>, text: string, timestamp: [number, number] }>} + */ +export function partitionNemoWordsIntoSegments(words) { + if (!Array.isArray(words) || words.length === 0) { + return []; + } + + /** @type {Array<{ words: Array<{ text: string, startTime: number, endTime: number }>, text: string, timestamp: [number, number] }>} */ + const segments = []; + /** @type {typeof words} */ + let current = []; + for (let i = 0; i < words.length; ++i) { + const word = words[i]; + current.push(word); + + const nextWord = words[i + 1] ?? null; + const gap_s = nextWord ? Math.max(0, nextWord.startTime - word.endTime) : 0; + if (shouldEndSentenceAfterWord(word, nextWord, gap_s)) { + segments.push({ + words: current, + text: buildSegmentText(current), + timestamp: [current[0].startTime, current[current.length - 1].endTime], + }); + current = []; + } + } + + if (current.length > 0) { + segments.push({ + words: current, + text: buildSegmentText(current), + timestamp: [current[0].startTime, current[current.length - 1].endTime], + }); + } + + return segments; +} + +/** + * @param {Array<{ text: string, startTime: number, endTime: number }>} words + * @param {[number, number] | null} utteranceTimestamp + * @param {string} text + * @returns {Array<{ text: string, timestamp: [number, number] }>} + */ +export function buildNemoSegmentChunks(words, utteranceTimestamp = null, text = '') { + if (!Array.isArray(words) || words.length === 0) { + if (utteranceTimestamp) { + return [{ text, timestamp: utteranceTimestamp }]; + } + return []; + } + + return partitionNemoWordsIntoSegments(words).map((segment) => ({ + text: segment.text, + timestamp: segment.timestamp, + })); +} diff --git a/packages/transformers/src/models/nemo_conformer_tdt/transducer_text.js b/packages/transformers/src/models/nemo_conformer_tdt/transducer_text.js new file mode 100644 index 000000000..f155b9962 --- /dev/null +++ b/packages/transformers/src/models/nemo_conformer_tdt/transducer_text.js @@ -0,0 +1,30 @@ +import { buildTransducerWordOffsets } from './transducer_word_offsets.js'; + +/** + * Decode token ids into final transcription text. + * @param {any} tokenizer + * @param {number[]} token_ids + * @returns {string} + */ +export function decodeTransducerText(tokenizer, token_ids) { + if (!Array.isArray(token_ids) || token_ids.length === 0) return ''; + if (!tokenizer) return token_ids.join(' '); + return tokenizer.decode(token_ids, { skip_special_tokens: true }).trim(); +} + +/** + * Build detailed word/token outputs with optional confidence aggregation. + * @param {any} tokenizer + * @param {number[]} token_ids + * @param {[number, number][]} token_timestamps + * @param {number[] | null} token_confidences + * @returns {{ + * words: Array<{ text: string, startTime: number, endTime: number, confidence?: number }>, + * tokens: Array<{ id: number, token: string, rawToken: string, isWordStart: boolean, startTime: number, endTime: number, confidence?: number }>, + * wordAverage: number | null, + * }} + */ +export function buildTransducerDetailedOutputs(tokenizer, token_ids, token_timestamps, token_confidences = null) { + const fullText = decodeTransducerText(tokenizer, token_ids); + return buildTransducerWordOffsets(tokenizer, token_ids, token_timestamps, token_confidences, fullText); +} diff --git a/packages/transformers/src/models/nemo_conformer_tdt/transducer_window_merge.js b/packages/transformers/src/models/nemo_conformer_tdt/transducer_window_merge.js new file mode 100644 index 000000000..a1cc3a15c --- /dev/null +++ b/packages/transformers/src/models/nemo_conformer_tdt/transducer_window_merge.js @@ -0,0 +1,43 @@ +function normalizeMergedWordText(text) { + return String(text ?? '') + .normalize('NFKC') + .toLowerCase() + .replace(/^[("'“‘\[{]+/g, '') + .replace(/[.,!?;:)"'”’\]}]+$/g, '') + .trim(); +} + +function normalizeRawMergedWordText(text) { + return String(text ?? '') + .normalize('NFKC') + .toLowerCase() + .trim(); +} + +export function dedupeMergedWords(words) { + /** @type {typeof words} */ + const merged = []; + for (const word of words) { + const prev = merged.at(-1); + const prevText = normalizeMergedWordText(prev?.text); + const wordText = normalizeMergedWordText(word.text); + if ( + prev && + prevText === wordText && + ( + prevText.length > 0 || + normalizeRawMergedWordText(prev.text) === normalizeRawMergedWordText(word.text) + ) && + word.startTime < prev.endTime + ) { + const prevDuration = prev.endTime - prev.startTime; + const nextDuration = word.endTime - word.startTime; + if (nextDuration > prevDuration) { + merged[merged.length - 1] = word; + } + continue; + } + merged.push(word); + } + return merged; +} diff --git a/packages/transformers/src/models/nemo_conformer_tdt/transducer_word_offsets.js b/packages/transformers/src/models/nemo_conformer_tdt/transducer_word_offsets.js new file mode 100644 index 000000000..24a11423b --- /dev/null +++ b/packages/transformers/src/models/nemo_conformer_tdt/transducer_word_offsets.js @@ -0,0 +1,231 @@ +/** + * Cache tokenizer id->token maps for stable and fast boundary detection. + * @type {WeakMap>} + */ +const TOKEN_ID_TO_TEXT_CACHE = new WeakMap(); + +/** + * @param {any} tokenizer + * @returns {Map} + */ +function getIdToTokenMap(tokenizer) { + let cached = TOKEN_ID_TO_TEXT_CACHE.get(tokenizer); + if (cached) return cached; + + cached = new Map(); + if (tokenizer?.get_vocab) { + const vocab = tokenizer.get_vocab(); + if (Array.isArray(vocab)) { + for (let id = 0; id < vocab.length; ++id) { + if (typeof vocab[id] === 'string') { + cached.set(id, vocab[id]); + } + } + } else if (vocab instanceof Map) { + for (const [token, id] of vocab.entries()) { + if (Number.isInteger(id)) { + cached.set(id, token); + } + } + } else if (vocab && typeof vocab === 'object') { + for (const [token, id] of Object.entries(vocab)) { + if (Number.isInteger(id)) { + cached.set(id, token); + } + } + } + } + TOKEN_ID_TO_TEXT_CACHE.set(tokenizer, cached); + return cached; +} + +/** + * Resolve per-token text and word boundary metadata in a tokenizer-agnostic way. + * @param {any} tokenizer + * @param {number} id + * @returns {{ raw: string, clean: string, startsNewWord: boolean }} + */ +function resolveTokenPiece(tokenizer, id) { + const rawToken = getIdToTokenMap(tokenizer).get(id) ?? ''; + const decoded = tokenizer.decode([id], { + skip_special_tokens: true, + clean_up_tokenization_spaces: false, + }); + + const startsWithBoundaryMarker = /^(?:▁|Ġ)+/.test(rawToken); + const startsWithWhitespace = /^\s+/.test(decoded); + const startsNewWord = startsWithBoundaryMarker || startsWithWhitespace; + + let clean = decoded.replace(/^\s+/, ''); + if (!clean) { + clean = rawToken.replace(/^(?:▁|Ġ|Ċ)+/, '').replace(/^ +/, ''); + } + + return { raw: rawToken || decoded, clean, startsNewWord }; +} + +/** + * @param {string} fullText + * @param {number} cursor + * @param {string} tokenText + * @returns {{ cursor: number, text: string, skippedWhitespace: boolean }} + */ +function consumeAlignedTokenText(fullText, cursor, tokenText) { + let skippedWhitespace = false; + while (cursor < fullText.length && /\s/.test(fullText[cursor])) { + skippedWhitespace = true; + cursor += 1; + } + + if (!tokenText) { + return { cursor, text: '', skippedWhitespace }; + } + + if (fullText.startsWith(tokenText, cursor)) { + return { + cursor: cursor + tokenText.length, + text: fullText.slice(cursor, cursor + tokenText.length), + skippedWhitespace, + }; + } + + const next = fullText.indexOf(tokenText, cursor); + if (next !== -1 && /^\s*$/.test(fullText.slice(cursor, next))) { + return { + cursor: next + tokenText.length, + text: fullText.slice(next, next + tokenText.length), + skippedWhitespace: skippedWhitespace || next > cursor, + }; + } + + return { + cursor: cursor + tokenText.length, + text: tokenText, + skippedWhitespace, + }; +} + +/** + * @param {Array<{ text: string, startTime: number, endTime: number, confidence?: number }>} words + * @param {{ text: string, start: number, end: number, confs: number[] } | null} current + */ +function finalizeAndPushWord(words, current) { + if (!current) return; + + const text = current.text.trim(); + if (!text) return; + + /** @type {{ text: string, startTime: number, endTime: number, confidence?: number }} */ + const word = { + text, + startTime: current.start, + endTime: current.end, + }; + if (current.confs.length > 0) { + word.confidence = Math.round((current.confs.reduce((a, b) => a + b, 0) / current.confs.length) * 1e6) / 1e6; + } + words.push(word); +} + +/** + * @param {any} tokenizer + * @param {number[]} token_ids + * @param {[number, number][]} token_timestamps + * @param {number[] | null} token_confidences + * @param {string} fullText + * @returns {{ + * words: Array<{ text: string, startTime: number, endTime: number, confidence?: number }>, + * tokens: Array<{ id: number, token: string, rawToken: string, isWordStart: boolean, startTime: number, endTime: number, confidence?: number }>, + * wordAverage: number | null, + * }} + */ +export function buildTransducerWordOffsets( + tokenizer, + token_ids, + token_timestamps, + token_confidences = null, + fullText = '', +) { + if (token_ids.length !== token_timestamps.length) { + throw new Error( + `buildTransducerWordOffsets expects equal lengths for token_ids (${token_ids.length}) and token_timestamps (${token_timestamps.length}).`, + ); + } + if (token_confidences && token_confidences.length !== token_ids.length) { + throw new Error( + `buildTransducerWordOffsets expects token_confidences length (${token_confidences.length}) to match token_ids length (${token_ids.length}).`, + ); + } + if (token_ids.length === 0) { + return { words: [], tokens: [], wordAverage: null }; + } + if (!tokenizer) { + throw new Error('buildTransducerWordOffsets requires a tokenizer for non-empty token_ids.'); + } + + /** @type {Array<{ id: number, token: string, rawToken: string, isWordStart: boolean, startTime: number, endTime: number, confidence?: number }>} */ + const tokens = []; + /** @type {Array<{ text: string, startTime: number, endTime: number, confidence?: number }>} */ + const words = []; + let textCursor = 0; + + /** @type {{ text: string, start: number, end: number, confs: number[] } | null} */ + let current = null; + + for (let i = 0; i < token_ids.length; ++i) { + const id = token_ids[i]; + const ts = token_timestamps[i]; + const piece = resolveTokenPiece(tokenizer, id); + const raw = piece.raw; + const clean = piece.clean; + if (!clean) continue; + + const aligned = consumeAlignedTokenText(fullText, textCursor, clean); + textCursor = aligned.cursor; + const tokenText = aligned.text || clean; + const startsNewWord = !current || aligned.skippedWhitespace || piece.startsNewWord; + + const tok = { + id, + token: tokenText, + rawToken: raw, + isWordStart: startsNewWord, + startTime: ts[0], + endTime: ts[1], + }; + const conf = token_confidences?.[i]; + if (conf != null && Number.isFinite(conf)) { + tok.confidence = Math.round(conf * 1e6) / 1e6; + } + tokens.push(tok); + + if (!current || startsNewWord) { + finalizeAndPushWord(words, current); + current = { + text: tokenText, + start: ts[0], + end: ts[1], + confs: conf != null && Number.isFinite(conf) ? [conf] : [], + }; + } else { + current.text += tokenText; + current.end = ts[1]; + if (conf != null && Number.isFinite(conf)) { + current.confs.push(conf); + } + } + } + + finalizeAndPushWord(words, current); + + let wordAverage = null; + if (words.some((x) => x.confidence != null)) { + const validConfidences = words.map((x) => x.confidence).filter((x) => x != null); + if (validConfidences.length > 0) { + wordAverage = + Math.round((validConfidences.reduce((a, b) => a + b, 0) / validConfidences.length) * 1e6) / 1e6; + } + } + + return { words, tokens, wordAverage }; +} diff --git a/packages/transformers/src/models/processors.js b/packages/transformers/src/models/processors.js index 4e26d8a78..9efa43abd 100644 --- a/packages/transformers/src/models/processors.js +++ b/packages/transformers/src/models/processors.js @@ -8,6 +8,7 @@ export * from './jina_clip/processing_jina_clip.js'; export * from './llava/processing_llava.js'; export * from './mgp_str/processing_mgp_str.js'; export * from './moonshine/processing_moonshine.js'; +export * from './nemo_conformer_tdt/processing_nemo_conformer_tdt.js'; export * from './owlvit/processing_owlvit.js'; export * from './paligemma/processing_paligemma.js'; export * from './phi3_v/processing_phi3_v.js'; diff --git a/packages/transformers/src/models/registry.js b/packages/transformers/src/models/registry.js index 3b6aa5d8d..a5a7d4582 100644 --- a/packages/transformers/src/models/registry.js +++ b/packages/transformers/src/models/registry.js @@ -41,6 +41,7 @@ const MODEL_MAPPING_NAMES_ENCODER_ONLY = new Map([ ['unispeech-sat', 'UniSpeechSatModel'], ['hubert', 'HubertModel'], ['wavlm', 'WavLMModel'], + ['nemo-conformer-tdt', 'NemoConformerForTDT'], ['audio-spectrogram-transformer', 'ASTModel'], ['vits', 'VitsModel'], ['pyannote', 'PyAnnoteModel'], @@ -580,6 +581,7 @@ const CUSTOM_MAPPING = [ ], ['SupertonicForConditionalGeneration', ALL_MODEL_FILES.SupertonicForConditionalGeneration, MODEL_TYPES.Supertonic], ['ChatterboxModel', ALL_MODEL_FILES.ChatterboxModel, MODEL_TYPES.Chatterbox], + ['NemoConformerForTDT', ALL_MODEL_FILES.NemoConformerForTDT, MODEL_TYPES.NemoConformerTDT], ]; for (const [name, model, type] of CUSTOM_MAPPING) { MODEL_TYPE_MAPPING.set(name, type); diff --git a/packages/transformers/src/pipelines/automatic-speech-recognition.js b/packages/transformers/src/pipelines/automatic-speech-recognition.js index d4ab074a2..1548bfc87 100644 --- a/packages/transformers/src/pipelines/automatic-speech-recognition.js +++ b/packages/transformers/src/pipelines/automatic-speech-recognition.js @@ -3,6 +3,9 @@ import { Pipeline, prepareAudios } from './_base.js'; import { Tensor } from '../utils/tensor.js'; import { max, round } from '../utils/maths.js'; import { logger } from '../utils/logger.js'; +import { + runNemoConformerTDTPipeline, +} from '../models/nemo_conformer_tdt/pipeline_nemo_conformer_tdt.js'; /** * @typedef {import('./_base.js').TextAudioPipelineConstructorArgs} TextAudioPipelineConstructorArgs @@ -152,6 +155,8 @@ export class AutomaticSpeechRecognitionPipeline case 'hubert': case 'parakeet_ctc': return this._call_wav2vec2(audio, kwargs); + case 'nemo-conformer-tdt': + return this._call_nemo_conformer_tdt(audio, kwargs); case 'moonshine': return this._call_moonshine(audio, kwargs); default: @@ -300,6 +305,17 @@ export class AutomaticSpeechRecognitionPipeline return single ? toReturn[0] : toReturn; } + async _call_nemo_conformer_tdt(audio, kwargs) { + return runNemoConformerTDTPipeline({ + model: this.model, + processor: this.processor, + tokenizer: this.tokenizer, + audio, + kwargs, + prepareAudios, + }); + } + async _call_moonshine(audio, kwargs) { const single = !Array.isArray(audio); const batchedAudio = single ? [audio] : audio; diff --git a/packages/transformers/src/pipelines/index.js b/packages/transformers/src/pipelines/index.js index 918fe5154..7330b328c 100644 --- a/packages/transformers/src/pipelines/index.js +++ b/packages/transformers/src/pipelines/index.js @@ -30,6 +30,7 @@ import { AutoModelForDepthEstimation, AutoModelForImageFeatureExtraction, } from '../models/auto/modeling_auto.js'; +import { NemoConformerForTDT } from '../models/nemo_conformer_tdt/modeling_nemo_conformer_tdt.js'; import { TextClassificationPipeline } from './text-classification.js'; import { TokenClassificationPipeline } from './token-classification.js'; @@ -150,7 +151,7 @@ export const SUPPORTED_TASKS = Object.freeze({ }, 'automatic-speech-recognition': { pipeline: AutomaticSpeechRecognitionPipeline, - model: [AutoModelForSpeechSeq2Seq, AutoModelForCTC], + model: [AutoModelForSpeechSeq2Seq, AutoModelForCTC, NemoConformerForTDT], default: { model: 'Xenova/whisper-tiny.en', }, diff --git a/packages/transformers/src/utils/model_registry/get_model_files.js b/packages/transformers/src/utils/model_registry/get_model_files.js index 1e25a4c57..32e3ab82d 100644 --- a/packages/transformers/src/utils/model_registry/get_model_files.js +++ b/packages/transformers/src/utils/model_registry/get_model_files.js @@ -82,8 +82,8 @@ export async function get_model_files( const archList = architectures.length > 0 ? architectures.join(', ') : '(none)'; logger.warn( `[get_model_files] Architecture(s) not found in MODEL_TYPE_MAPPING: [${archList}] ` + - `for model type '${config.model_type}'. Falling back to EncoderOnly (single model.onnx file). ` + - `If you encounter issues, please report at: ${GITHUB_ISSUE_URL}`, + `for model type '${config.model_type}'. Falling back to EncoderOnly (single model.onnx file). ` + + `If you encounter issues, please report at: ${GITHUB_ISSUE_URL}`, ); // Always fallback to EncoderOnly (single model.onnx file) @@ -177,6 +177,9 @@ export async function get_model_files( add_model_file('model', 'language_model'); add_model_file('conditional_decoder'); files.push('generation_config.json'); + } else if (modelType === MODEL_TYPES.NemoConformerTDT) { + add_model_file('encoder_model'); + add_model_file('decoder_model_merged'); } else if (modelType === MODEL_TYPES.AutoEncoder) { add_model_file('encoder_model'); add_model_file('decoder_model'); diff --git a/packages/transformers/tests/models/nemo_conformer_tdt/test_feature_extraction_nemo_conformer_tdt.js b/packages/transformers/tests/models/nemo_conformer_tdt/test_feature_extraction_nemo_conformer_tdt.js new file mode 100644 index 000000000..77489d0e1 --- /dev/null +++ b/packages/transformers/tests/models/nemo_conformer_tdt/test_feature_extraction_nemo_conformer_tdt.js @@ -0,0 +1,311 @@ +import { NemoConformerTDTFeatureExtractor, Tensor } from "../../../src/transformers.js"; +import { NEMO_FEATURE_OUTPUT_OWNERSHIP } from "../../../src/models/nemo_conformer_tdt/feature_extraction_nemo_conformer_tdt.js"; + +import { MAX_TEST_EXECUTION_TIME } from "../../init.js"; + +export default () => { + describe("NemoConformerTDTFeatureExtractor", () => { + const base = { + sampling_rate: 16000, + n_fft: 512, + win_length: 400, + hop_length: 160, + preemphasis: 0.97, + }; + + const audio = Float32Array.from({ length: 16000 }, (_, i) => Math.sin((2 * Math.PI * 220 * i) / 16000)); + + it( + "supports 80 mel bins", + async () => { + const extractor = new NemoConformerTDTFeatureExtractor({ ...base, feature_size: 80 }); + const { input_features, attention_mask } = await extractor(audio); + try { + expect(input_features.dims[0]).toBe(1); + expect(input_features.dims[2]).toBe(80); + expect(attention_mask.dims).toEqual([1, input_features.dims[1]]); + } finally { + input_features.dispose(); + attention_mask.dispose(); + } + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + "supports 128 mel bins", + async () => { + const extractor = new NemoConformerTDTFeatureExtractor({ ...base, feature_size: 128 }); + const { input_features, attention_mask } = await extractor(audio); + try { + expect(input_features.dims[0]).toBe(1); + expect(input_features.dims[2]).toBe(128); + expect(attention_mask.dims).toEqual([1, input_features.dims[1]]); + } finally { + input_features.dispose(); + attention_mask.dispose(); + } + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + "supports concatenated delta and delta-delta features", + async () => { + const extractor = new NemoConformerTDTFeatureExtractor({ + ...base, + feature_size: 128, + delta_order: 2, + delta_window: 2, + delta_concatenate: true, + }); + const { input_features } = await extractor(audio); + try { + expect(input_features.dims[2]).toBe(128 * 3); + } finally { + input_features.dispose(); + } + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + "supports non-concatenated delta and delta-delta features", + async () => { + const extractor = new NemoConformerTDTFeatureExtractor({ + ...base, + feature_size: 80, + delta_order: 2, + delta_window: 2, + delta_concatenate: false, + }); + const { input_features, delta_features, delta_delta_features, attention_mask } = await extractor(audio); + try { + expect(input_features.dims[0]).toBe(1); + expect(input_features.dims[2]).toBe(80); + expect(delta_features).toBeDefined(); + expect(delta_delta_features).toBeDefined(); + expect(delta_features.dims).toEqual(input_features.dims); + expect(delta_delta_features.dims).toEqual(input_features.dims); + expect(attention_mask.dims).toEqual([1, input_features.dims[1]]); + } finally { + input_features.dispose(); + delta_features?.dispose(); + delta_delta_features?.dispose(); + attention_mask.dispose(); + } + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + "disposes replaced base features when concatenated delta output is used", + async () => { + const extractor = new NemoConformerTDTFeatureExtractor({ + ...base, + feature_size: 80, + delta_order: 1, + delta_window: 2, + delta_concatenate: true, + }); + + const originalDispose = Tensor.prototype.dispose; + let disposeCalls = 0; + Tensor.prototype.dispose = function () { + disposeCalls += 1; + return originalDispose.call(this); + }; + + let input_features; + try { + ({ input_features } = await extractor(audio)); + expect(input_features.dims[2]).toBe(80 * 2); + } finally { + Tensor.prototype.dispose = originalDispose; + input_features?.dispose(); + } + + // One dispose from computeTemporalDeltas intermediate tensor, one from replacing base features tensor. + expect(disposeCalls).toBe(2); + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + "uses feature cache when enabled", + async () => { + const extractor = new NemoConformerTDTFeatureExtractor({ + ...base, + feature_size: 80, + use_feature_cache: true, + feature_cache_max_entries: 8, + feature_cache_max_size_mb: 8, + }); + try { + const first = await extractor(audio); + const second = await extractor(audio); + + expect(first).not.toBe(second); + expect(first.input_features).toBe(second.input_features); + expect(first.attention_mask).toBe(second.attention_mask); + expect(extractor.get_cache_stats().entries).toBe(1); + } finally { + extractor.clear_cache(); + } + expect(extractor.get_cache_stats().entries).toBe(0); + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + "marks uncached outputs as caller-owned", + async () => { + const extractor = new NemoConformerTDTFeatureExtractor({ ...base, feature_size: 80 }); + const outputs = await extractor(audio); + try { + expect(outputs[NEMO_FEATURE_OUTPUT_OWNERSHIP]).toBe(false); + } finally { + outputs.input_features.dispose(); + outputs.attention_mask.dispose(); + } + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + "marks cached outputs as cache-owned", + async () => { + const extractor = new NemoConformerTDTFeatureExtractor({ + ...base, + feature_size: 80, + use_feature_cache: true, + feature_cache_max_entries: 8, + feature_cache_max_size_mb: 8, + }); + try { + const first = await extractor(audio); + const second = await extractor(audio); + + expect(first[NEMO_FEATURE_OUTPUT_OWNERSHIP]).toBe(true); + expect(second[NEMO_FEATURE_OUTPUT_OWNERSHIP]).toBe(true); + expect(first.input_features).toBe(second.input_features); + } finally { + extractor.clear_cache(); + } + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + "marks skipped-cache outputs as caller-owned", + async () => { + const extractor = new NemoConformerTDTFeatureExtractor({ + ...base, + feature_size: 80, + use_feature_cache: true, + feature_cache_max_entries: 0, + feature_cache_max_size_mb: 8, + }); + const outputs = await extractor(audio); + try { + expect(outputs[NEMO_FEATURE_OUTPUT_OWNERSHIP]).toBe(false); + expect(extractor.get_cache_stats().entries).toBe(0); + } finally { + outputs.input_features.dispose(); + outputs.attention_mask.dispose(); + } + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + "marks oversized-cache outputs as caller-owned", + async () => { + const extractor = new NemoConformerTDTFeatureExtractor({ + ...base, + feature_size: 80, + use_feature_cache: true, + feature_cache_max_entries: 8, + feature_cache_max_size_mb: 0.000001, + }); + const outputs = await extractor(audio); + try { + expect(outputs[NEMO_FEATURE_OUTPUT_OWNERSHIP]).toBe(false); + expect(extractor.get_cache_stats().entries).toBe(0); + } finally { + outputs.input_features.dispose(); + outputs.attention_mask.dispose(); + } + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + "uses feature cache when enabled for non-concatenated delta outputs", + async () => { + const extractor = new NemoConformerTDTFeatureExtractor({ + ...base, + feature_size: 80, + delta_order: 2, + delta_window: 2, + delta_concatenate: false, + use_feature_cache: true, + feature_cache_max_entries: 8, + feature_cache_max_size_mb: 8, + }); + try { + const first = await extractor(audio); + const second = await extractor(audio); + + expect(first).not.toBe(second); + expect(first.input_features).toBe(second.input_features); + expect(first.attention_mask).toBe(second.attention_mask); + expect(first.delta_features).toBe(second.delta_features); + expect(first.delta_delta_features).toBe(second.delta_delta_features); + expect(extractor.get_cache_stats().entries).toBe(1); + } finally { + extractor.clear_cache(); + } + expect(extractor.get_cache_stats().entries).toBe(0); + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + "validates preemphasis range", + async () => { + const invalidHigh = new NemoConformerTDTFeatureExtractor({ ...base, feature_size: 80, preemphasis: 1 }); + await expect(invalidHigh(audio)).rejects.toThrow("preemphasis"); + + const invalidLow = new NemoConformerTDTFeatureExtractor({ ...base, feature_size: 80, preemphasis: -0.1 }); + await expect(invalidLow(audio)).rejects.toThrow("preemphasis"); + }, + MAX_TEST_EXECUTION_TIME, + ); + + it("validates delta_window at construction time", () => { + expect( + () => new NemoConformerTDTFeatureExtractor({ ...base, feature_size: 80, delta_order: 1, delta_window: 0 }), + ).toThrow("delta_window"); + expect( + () => + new NemoConformerTDTFeatureExtractor({ + ...base, + feature_size: 80, + delta_order: 1, + delta_window: 1.5, + }), + ).toThrow("delta_window"); + }); + + it("validates n_fft and win_length at construction time", () => { + expect(() => new NemoConformerTDTFeatureExtractor({ ...base, feature_size: 80, n_fft: 0 })).toThrow("n_fft"); + expect(() => new NemoConformerTDTFeatureExtractor({ ...base, feature_size: 80, win_length: 0 })).toThrow( + "win_length", + ); + expect(() => new NemoConformerTDTFeatureExtractor({ ...base, feature_size: 80, win_length: 1024 })).toThrow( + "win_length", + ); + }); + }); +}; diff --git a/packages/transformers/tests/models/nemo_conformer_tdt/test_modeling_nemo_conformer_tdt.js b/packages/transformers/tests/models/nemo_conformer_tdt/test_modeling_nemo_conformer_tdt.js new file mode 100644 index 000000000..121ad5b79 --- /dev/null +++ b/packages/transformers/tests/models/nemo_conformer_tdt/test_modeling_nemo_conformer_tdt.js @@ -0,0 +1,1215 @@ +import { NemoConformerForTDT, Tensor } from "../../../src/transformers.js"; +import { createAudioCacheKey, FeatureLRUCache } from "../../../src/models/nemo_conformer_tdt/transducer_cache.js"; +import { computeTemporalDeltas } from "../../../src/models/nemo_conformer_tdt/transducer_deltas.js"; +import { buildNemoSegmentChunks, partitionNemoWordsIntoSegments, shouldEndSentenceAfterWord } from "../../../src/models/nemo_conformer_tdt/transducer_segment_offsets.js"; +import { buildTransducerDetailedOutputs } from "../../../src/models/nemo_conformer_tdt/transducer_text.js"; +import { buildTransducerWordOffsets } from "../../../src/models/nemo_conformer_tdt/transducer_word_offsets.js"; +import { dedupeMergedWords } from "../../../src/models/nemo_conformer_tdt/transducer_window_merge.js"; +import { MODEL_TYPE_MAPPING, MODEL_TYPES } from "../../../src/models/modeling_utils.js"; +import { get_model_files } from "../../../src/utils/model_registry/get_model_files.js"; + +import { MAX_TEST_EXECUTION_TIME } from "../../init.js"; + +class MockNemoConformerForTDT extends NemoConformerForTDT { + constructor(config, sessions, decoderScript) { + super(config, sessions, {}); + this.decoderScript = decoderScript; + this.decoderCalls = 0; + } + + async _runEncoder() { + return { + outputs: new Tensor( + "float32", + new Float32Array([ + // D=2, T=3 (BDT) + 0.1, + 0.2, + 0.3, // d0 over t + 0.4, + 0.5, + 0.6, // d1 over t + ]), + [1, 2, 3], + ), + }; + } + + async _runDecoder() { + const step = this.decoderScript[this.decoderCalls++]; + const stateShape = [1, 1, 2]; + return { + outputs: new Tensor("float32", new Float32Array(step.logits), [1, 1, step.logits.length]), + output_states_1: new Tensor("float32", new Float32Array([this.decoderCalls, 0]), stateShape), + output_states_2: new Tensor("float32", new Float32Array([0, this.decoderCalls]), stateShape), + }; + } +} + +const BASE_SESSIONS = { + encoder_model: { + inputNames: ["input_features"], + outputNames: ["outputs"], + }, + decoder_model_merged: { + inputNames: ["encoder_outputs", "targets", "target_length", "input_states_1", "input_states_2"], + outputNames: ["outputs", "output_states_1", "output_states_2"], + }, +}; + +const BASE_CONFIG = { + model_type: "nemo-conformer-tdt", + "transformers.js_config": { + transducer: { + blank_token_id: 0, + max_symbols_per_step: 2, + subsampling_factor: 4, + frame_shift_s: 0.01, + vocab_size: 3, + duration_start_index: 3, + encoder_output_layout: "BDT", + encoder_frame_layout: "BD1", + decoder: { + num_layers: 1, + hidden_size: 2, + }, + }, + }, +}; + +export default () => { + describe("NemoConformerForTDT", () => { + it("maps NemoConformerForTDT to MODEL_TYPES.NemoConformerTDT", () => { + expect(MODEL_TYPE_MAPPING.get("NemoConformerForTDT")).toBe(MODEL_TYPES.NemoConformerTDT); + expect(MODEL_TYPE_MAPPING.get("nemo-conformer-tdt")).toBe(MODEL_TYPES.NemoConformerTDT); + }); + + it( + "throws on invalid runtime config: vocab_size must be > 0", + async () => { + const invalidConfig = { + ...BASE_CONFIG, + "transformers.js_config": { + ...BASE_CONFIG["transformers.js_config"], + transducer: { + ...BASE_CONFIG["transformers.js_config"].transducer, + vocab_size: 0, + }, + }, + }; + const model = new MockNemoConformerForTDT(invalidConfig, BASE_SESSIONS, []); + const inputs = { + input_features: new Tensor("float32", new Float32Array([0, 0]), [1, 1, 2]), + }; + + await expect( + model.transcribe(inputs, { + tokenizer: { + decode: () => "", + get_vocab: () => new Map([["a", 0]]), + }, + }), + ).rejects.toThrow("vocab_size"); + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + "throws on invalid runtime config: blank_token_id must be < vocab_size", + async () => { + const invalidConfig = { + ...BASE_CONFIG, + "transformers.js_config": { + ...BASE_CONFIG["transformers.js_config"], + transducer: { + ...BASE_CONFIG["transformers.js_config"].transducer, + blank_token_id: 3, + }, + }, + }; + const model = new MockNemoConformerForTDT(invalidConfig, BASE_SESSIONS, []); + const inputs = { + input_features: new Tensor("float32", new Float32Array([0, 0]), [1, 1, 2]), + }; + + await expect( + model.transcribe(inputs, { + tokenizer: { + decode: () => "", + get_vocab: () => + new Map([ + ["a", 0], + ["b", 1], + ["c", 2], + ]), + }, + }), + ).rejects.toThrow("blank_token_id"); + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + "throws on invalid runtime config: duration_start_index must be >= vocab_size", + async () => { + const invalidConfig = { + ...BASE_CONFIG, + "transformers.js_config": { + ...BASE_CONFIG["transformers.js_config"], + transducer: { + ...BASE_CONFIG["transformers.js_config"].transducer, + duration_start_index: 2, + }, + }, + }; + const model = new MockNemoConformerForTDT(invalidConfig, BASE_SESSIONS, []); + const inputs = { + input_features: new Tensor("float32", new Float32Array([0, 0]), [1, 1, 2]), + }; + + await expect( + model.transcribe(inputs, { + tokenizer: { + decode: () => "", + get_vocab: () => + new Map([ + ["a", 0], + ["b", 1], + ["c", 2], + ]), + }, + }), + ).rejects.toThrow("duration_start_index"); + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + "throws explicit vocab resolution error when tokenizer.get_vocab returns a non-object", + async () => { + const configWithoutVocab = { + ...BASE_CONFIG, + "transformers.js_config": { + ...BASE_CONFIG["transformers.js_config"], + transducer: { + ...BASE_CONFIG["transformers.js_config"].transducer, + vocab_size: undefined, + }, + }, + }; + const model = new MockNemoConformerForTDT(configWithoutVocab, BASE_SESSIONS, []); + const inputs = { + input_features: new Tensor("float32", new Float32Array([0, 0]), [1, 1, 2]), + }; + + await expect( + model.transcribe(inputs, { + tokenizer: { + decode: () => "", + get_vocab: () => null, + }, + }), + ).rejects.toThrow("Unable to resolve vocabulary size"); + }, + MAX_TEST_EXECUTION_TIME, + ); + + it("resolves vocab size from array tokenizers when config vocab_size is not set", () => { + const configWithoutVocab = { + ...BASE_CONFIG, + "transformers.js_config": { + ...BASE_CONFIG["transformers.js_config"], + transducer: { + ...BASE_CONFIG["transformers.js_config"].transducer, + vocab_size: undefined, + }, + }, + }; + const model = new MockNemoConformerForTDT(configWithoutVocab, BASE_SESSIONS, []); + expect( + model._resolveVocabSize({ + get_vocab: () => ["", "hello", "world"], + }), + ).toBe(3); + }); + + it("resolves vocab size from the maximum sparse tokenizer id when config vocab_size is not set", () => { + const configWithoutVocab = { + ...BASE_CONFIG, + "transformers.js_config": { + ...BASE_CONFIG["transformers.js_config"], + transducer: { + ...BASE_CONFIG["transformers.js_config"].transducer, + vocab_size: undefined, + }, + }, + }; + const model = new MockNemoConformerForTDT(configWithoutVocab, BASE_SESSIONS, []); + expect( + model._resolveVocabSize({ + get_vocab: () => ({ + "": 0, + hello: 2, + world: 7, + }), + }), + ).toBe(8); + }); + + it( + "greedily decodes scripted token and duration logits", + async () => { + const tokenizer = { + decode(ids) { + const idArray = Array.isArray(ids) ? ids : [ids]; + return idArray + .map((id) => { + if (id === 1 || id === 1n) return " hello"; + if (id === 2 || id === 2n) return " world"; + return ""; + }) + .join(""); + }, + }; + + const model = new MockNemoConformerForTDT(BASE_CONFIG, BASE_SESSIONS, [ + // step 1: emit token=1, duration=0 + { logits: [0.1, 10.0, 0.0, 8.0, 1.0, 0.5] }, + // step 2: emit blank, duration=1 -> move to next frame + { logits: [9.0, 0.0, 0.0, 0.0, 8.0, 0.0] }, + // step 3: emit token=2, duration=2 -> jump to end + { logits: [0.0, 0.0, 10.0, 0.0, 0.0, 9.0] }, + ]); + + const inputs = { + input_features: new Tensor("float32", new Float32Array([0, 0, 0, 0, 0, 0]), [1, 3, 2]), + }; + + const output = await model.transcribe(inputs, { + tokenizer, + returnTimestamps: true, + returnWords: true, + returnTokens: true, + }); + + expect(output.text).toBe("hello world"); + expect(output.isFinal).toBe(true); + expect(output.utteranceTimestamp).toEqual([0, 0.12]); + expect(output.words).toEqual([expect.objectContaining({ text: "hello", startTime: 0, endTime: 0.04 }), expect.objectContaining({ text: "world", startTime: 0.04, endTime: 0.12 })]); + expect(output.tokens).toEqual([expect.objectContaining({ id: 1, startTime: 0, endTime: 0.04 }), expect.objectContaining({ id: 2, startTime: 0.04, endTime: 0.12 })]); + expect(output.confidence).toEqual(expect.objectContaining({ utterance: expect.any(Number), wordAverage: expect.any(Number), averageLogProb: expect.any(Number) })); + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + "clamps token timestamps when step jumps beyond remaining frames", + async () => { + const tokenizer = { + decode(ids) { + const idArray = Array.isArray(ids) ? ids : [ids]; + return idArray.map((id) => (id === 1 || id === 1n ? " token" : "")).join(""); + }, + }; + + const model = new MockNemoConformerForTDT(BASE_CONFIG, BASE_SESSIONS, [ + // Emit token=1 with duration index choosing a large step (argmax at tail). + { logits: [0.1, 10.0, 0.0, 0.0, 0.0, 0.0, 12.0] }, + ]); + + const inputs = { + input_features: new Tensor("float32", new Float32Array([0, 0, 0, 0, 0, 0]), [1, 3, 2]), + }; + + const output = await model.transcribe(inputs, { + tokenizer, + returnTimestamps: true, + returnTokens: true, + }); + + expect(output.tokens).toHaveLength(1); + expect(output.tokens[0]).toEqual(expect.objectContaining({ startTime: 0, endTime: 0.12 })); + expect(output.utteranceTimestamp).toEqual([0, 0.12]); + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + "aggregates frame confidences per encoder frame (not per decode step)", + async () => { + const model = new MockNemoConformerForTDT(BASE_CONFIG, BASE_SESSIONS, [ + // Frame 0: emit token=1, step=0 + { logits: [0.0, 4.0, -2.0, 9.0, 1.0, 0.0] }, + // Frame 0: emit token=2, step=0 (hits max_symbols_per_step and advances frame) + { logits: [0.0, -1.0, 3.0, 9.0, 1.0, 0.0] }, + // Frame 1: emit blank, step=2 -> exits decode loop + { logits: [5.0, 0.0, 0.0, 0.0, 1.0, 9.0] }, + ]); + + const inputs = { + input_features: new Tensor("float32", new Float32Array([0, 0, 0, 0, 0, 0]), [1, 3, 2]), + }; + + const output = await model.transcribe(inputs, { + returnTimestamps: false, + returnFrameConfidences: true, + }); + + expect(output.confidence.frames).toHaveLength(2); + expect(output.confidence.frames[0]).toBeCloseTo(0.9579343795, 6); + expect(output.confidence.frameAverage).toBeCloseTo((output.confidence.frames[0] + output.confidence.frames[1]) / 2, 6); + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + "rejects non-finite timeOffset", + async () => { + const model = new MockNemoConformerForTDT(BASE_CONFIG, BASE_SESSIONS, [{ logits: [9.0, 0.0, 0.0, 1.0] }]); + const inputs = { + input_features: new Tensor("float32", new Float32Array([0, 0, 0, 0, 0, 0]), [1, 3, 2]), + }; + + await expect( + model.transcribe(inputs, { + tokenizer: { decode: () => "" }, + returnTimestamps: true, + timeOffset: Number.NaN, + }), + ).rejects.toThrow("timeOffset"); + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + "fails fast when duration logits are required but missing", + async () => { + const model = new MockNemoConformerForTDT(BASE_CONFIG, BASE_SESSIONS, [ + // Only vocab logits are returned; duration head is missing. + { logits: [0.1, 10.0, 0.0] }, + ]); + + const inputs = { + input_features: new Tensor("float32", new Float32Array([0, 0, 0, 0, 0, 0]), [1, 3, 2]), + }; + + await expect( + model.transcribe(inputs, { + tokenizer: { decode: () => "" }, + returnTimestamps: false, + }), + ).rejects.toThrow("missing duration logits"); + }, + MAX_TEST_EXECUTION_TIME, + ); + + it("fails fast when transducer config is missing", () => { + const invalidConfig = { model_type: "nemo-conformer-tdt" }; + expect(() => new NemoConformerForTDT(invalidConfig, BASE_SESSIONS, {})).toThrow("Missing `transformers.js_config.transducer`"); + }); + + it("requires explicit encoder_output_layout in transducer config", () => { + const invalidConfig = { + ...BASE_CONFIG, + "transformers.js_config": { + ...BASE_CONFIG["transformers.js_config"], + transducer: { + ...BASE_CONFIG["transformers.js_config"].transducer, + encoder_output_layout: undefined, + }, + }, + }; + expect(() => new NemoConformerForTDT(invalidConfig, BASE_SESSIONS, {})).toThrow("encoder_output_layout"); + }); + + it("rejects invalid encoder_input_layout at construction time", () => { + const invalidConfig = { + ...BASE_CONFIG, + "transformers.js_config": { + ...BASE_CONFIG["transformers.js_config"], + transducer: { + ...BASE_CONFIG["transformers.js_config"].transducer, + encoder_input_layout: "BAD", + }, + }, + }; + expect(() => new NemoConformerForTDT(invalidConfig, BASE_SESSIONS, {})).toThrow("encoder_input_layout"); + }); + + it("applies encoder_input_layout to canonical input_features feeds", () => { + const config = { + ...BASE_CONFIG, + "transformers.js_config": { + ...BASE_CONFIG["transformers.js_config"], + transducer: { + ...BASE_CONFIG["transformers.js_config"].transducer, + encoder_input_layout: "BFT", + }, + }, + }; + const model = new NemoConformerForTDT(config, BASE_SESSIONS, {}); + const input_features = new Tensor("float32", new Float32Array([1, 2, 3, 4, 5, 6]), [1, 3, 2]); + + const { feeds, disposables } = model._buildEncoderFeeds({ input_features }); + + try { + expect(disposables).toHaveLength(1); + expect(feeds.input_features).not.toBe(input_features); + expect(feeds.input_features.dims).toEqual([1, 2, 3]); + expect(Array.from(feeds.input_features.data)).toEqual([1, 3, 5, 2, 4, 6]); + } finally { + for (const tensor of disposables) { + tensor.dispose(); + } + input_features.dispose(); + } + }); + + it("rejects invalid encoder_frame_layout at construction time", () => { + const invalidConfig = { + ...BASE_CONFIG, + "transformers.js_config": { + ...BASE_CONFIG["transformers.js_config"], + transducer: { + ...BASE_CONFIG["transformers.js_config"].transducer, + encoder_frame_layout: "BAD", + }, + }, + }; + expect(() => new NemoConformerForTDT(invalidConfig, BASE_SESSIONS, {})).toThrow("encoder_frame_layout"); + }); + + it( + "fails fast when named encoder output is missing at runtime", + async () => { + class MissingEncoderOutputModel extends NemoConformerForTDT { + async _runEncoder() { + return { + outputs: new Tensor("float32", new Float32Array([0.1, 0.2]), [1, 2, 1]), + }; + } + + async _runDecoder() { + const stateShape = [1, 1, 2]; + return { + outputs: new Tensor("float32", new Float32Array([9.0, 0.0, 0.0, 8.0]), [1, 1, 4]), + output_states_1: new Tensor("float32", new Float32Array([0, 0]), stateShape), + output_states_2: new Tensor("float32", new Float32Array([0, 0]), stateShape), + }; + } + } + + const config = { + ...BASE_CONFIG, + "transformers.js_config": { + ...BASE_CONFIG["transformers.js_config"], + transducer: { + ...BASE_CONFIG["transformers.js_config"].transducer, + io: { encoder_output: "encoder_out" }, + }, + }, + }; + const sessions = { + ...BASE_SESSIONS, + encoder_model: { + ...BASE_SESSIONS.encoder_model, + outputNames: ["encoder_out"], + }, + }; + const model = new MissingEncoderOutputModel(config, sessions, {}); + const inputs = { + input_features: new Tensor("float32", new Float32Array([0, 0, 0, 0, 0, 0]), [1, 3, 2]), + }; + + await expect(model.transcribe(inputs, { tokenizer: { decode: () => "" } })).rejects.toThrow('encoder output "encoder_out" was not returned'); + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + "fails fast when named decoder logits output is missing at runtime", + async () => { + class MissingDecoderOutputModel extends NemoConformerForTDT { + async _runEncoder() { + return { + outputs: new Tensor("float32", new Float32Array([0.1, 0.2]), [1, 2, 1]), + }; + } + + async _runDecoder() { + const stateShape = [1, 1, 2]; + return { + unexpected_logits: new Tensor("float32", new Float32Array([9.0, 0.0, 0.0, 8.0]), [1, 1, 4]), + output_states_1: new Tensor("float32", new Float32Array([0, 0]), stateShape), + output_states_2: new Tensor("float32", new Float32Array([0, 0]), stateShape), + }; + } + } + + const model = new MissingDecoderOutputModel(BASE_CONFIG, BASE_SESSIONS, {}); + const inputs = { + input_features: new Tensor("float32", new Float32Array([0, 0, 0, 0, 0, 0]), [1, 3, 2]), + }; + + await expect(model.transcribe(inputs, { tokenizer: { decode: () => "" } })).rejects.toThrow('decoder output "outputs" was not returned'); + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + "fails fast when named decoder state outputs are missing at runtime", + async () => { + class MissingDecoderStateOutputsModel extends NemoConformerForTDT { + async _runEncoder() { + return { + outputs: new Tensor("float32", new Float32Array([0.1, 0.2]), [1, 2, 1]), + }; + } + + async _runDecoder() { + return { + outputs: new Tensor("float32", new Float32Array([9.0, 0.0, 0.0, 8.0]), [1, 1, 4]), + }; + } + } + + const model = new MissingDecoderStateOutputsModel(BASE_CONFIG, BASE_SESSIONS, {}); + const inputs = { + input_features: new Tensor("float32", new Float32Array([0, 0, 0, 0, 0, 0]), [1, 3, 2]), + }; + + await expect(model.transcribe(inputs, { tokenizer: { decode: () => "" } })).rejects.toThrow('decoder state outputs "output_states_1" and "output_states_2" were not returned'); + }, + MAX_TEST_EXECUTION_TIME, + ); + + it("rejects duplicate decoder output aliases in transducer io config", () => { + const invalidConfig = { + ...BASE_CONFIG, + "transformers.js_config": { + ...BASE_CONFIG["transformers.js_config"], + transducer: { + ...BASE_CONFIG["transformers.js_config"].transducer, + io: { + decoder_output: "outputs", + decoder_output_state_1: "outputs", + decoder_output_state_2: "output_states_2", + }, + }, + }, + }; + expect(() => new NemoConformerForTDT(invalidConfig, BASE_SESSIONS, {})).toThrow("must be distinct"); + }); + + it("rejects duplicate decoder input aliases in transducer io config", () => { + const invalidConfig = { + ...BASE_CONFIG, + "transformers.js_config": { + ...BASE_CONFIG["transformers.js_config"], + transducer: { + ...BASE_CONFIG["transformers.js_config"].transducer, + io: { + decoder_encoder: "encoder_outputs", + decoder_token: "targets", + decoder_token_length: "target_length", + decoder_state_1: "input_states_1", + decoder_state_2: "input_states_1", + }, + }, + }, + }; + expect(() => new NemoConformerForTDT(invalidConfig, BASE_SESSIONS, {})).toThrow("must be distinct"); + }); + + it( + "disposes encoder outputs when frame-count validation fails before decode", + async () => { + class BadEncoderOutputModel extends NemoConformerForTDT { + constructor(config, sessions, encoderOutput) { + super(config, sessions, {}); + this.encoderOutput = encoderOutput; + } + + async _runEncoder() { + return { outputs: this.encoderOutput }; + } + } + + const badEncoderOutput = new Tensor("float32", new Float32Array([0, 1, 2, 3]), [2, 2]); + let disposed = 0; + const originalDispose = badEncoderOutput.dispose.bind(badEncoderOutput); + badEncoderOutput.dispose = () => { + disposed += 1; + originalDispose(); + }; + + const model = new BadEncoderOutputModel(BASE_CONFIG, BASE_SESSIONS, badEncoderOutput); + const inputs = { + input_features: new Tensor("float32", new Float32Array([0, 0]), [1, 1, 2]), + }; + + await expect( + model.transcribe(inputs, { + tokenizer: { decode: () => "" }, + }), + ).rejects.toThrow("expected encoder output dims"); + expect(disposed).toBe(1); + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + "disposes auxiliary decoder tensor outputs per decode step", + async () => { + class AuxDecoderOutputModel extends NemoConformerForTDT { + constructor(config, sessions) { + super(config, sessions, {}); + this.auxDisposals = 0; + } + + async _runEncoder() { + return { + outputs: new Tensor("float32", new Float32Array([0.1, 0.2]), [1, 2, 1]), + }; + } + + async _runDecoder() { + const stateShape = [1, 1, 2]; + const aux = new Tensor("float32", new Float32Array([1, 2, 3]), [1, 1, 3]); + const originalDispose = aux.dispose.bind(aux); + aux.dispose = () => { + this.auxDisposals += 1; + originalDispose(); + }; + return { + outputs: new Tensor("float32", new Float32Array([10.0, 0.0, 0.0, 8.0, 0.0]), [1, 1, 5]), + output_states_1: new Tensor("float32", new Float32Array([0, 0]), stateShape), + output_states_2: new Tensor("float32", new Float32Array([0, 0]), stateShape), + auxiliary_scores: aux, + }; + } + } + + const model = new AuxDecoderOutputModel(BASE_CONFIG, BASE_SESSIONS); + const inputs = { + input_features: new Tensor("float32", new Float32Array([0, 0]), [1, 1, 2]), + }; + + const output = await model.transcribe(inputs, { returnTimestamps: false }); + expect(output).toEqual(expect.objectContaining({ text: "" })); + expect(model.auxDisposals).toBe(1); + }, + MAX_TEST_EXECUTION_TIME, + ); + }); + + describe("Nemo Conformer TDT utilities", () => { + it("uses conservative sentence boundaries for punctuation, abbreviations, and long silences", () => { + expect(shouldEndSentenceAfterWord({ text: "Hello." }, { text: "World" }, 0)).toBe(true); + expect(shouldEndSentenceAfterWord({ text: "U.S." }, { text: "Report" }, 0)).toBe(false); + expect(shouldEndSentenceAfterWord({ text: "3." }, { text: "Title" }, 0)).toBe(false); + expect(shouldEndSentenceAfterWord({ text: "I." }, { text: "Overview" }, 0)).toBe(false); + expect(shouldEndSentenceAfterWord({ text: "Dr." }, { text: "Brown" }, 0)).toBe(false); + expect(shouldEndSentenceAfterWord({ text: "wait" }, { text: "Next" }, 3.2)).toBe(true); + }); + + it("partitions timed words into conservative sentence-like chunks", () => { + const words = [ + { text: "Hello.", startTime: 0, endTime: 0.4 }, + { text: "World", startTime: 0.5, endTime: 0.8 }, + { text: "again.", startTime: 0.8, endTime: 1.1 }, + { text: "U.S.", startTime: 1.2, endTime: 1.5 }, + { text: "Report", startTime: 1.6, endTime: 2.0 }, + { text: "update.", startTime: 2.0, endTime: 2.4 }, + { text: "pause", startTime: 6.0, endTime: 6.3 }, + { text: "Next", startTime: 9.5, endTime: 9.8 }, + { text: "sentence.", startTime: 9.8, endTime: 10.2 }, + ]; + + const segments = partitionNemoWordsIntoSegments(words); + expect(segments.map((x) => x.text)).toEqual(["Hello.", "World again.", "U.S. Report update.", "pause", "Next sentence."]); + expect(segments.map((x) => x.timestamp)).toEqual([ + [0, 0.4], + [0.5, 1.1], + [1.2, 2.4], + [6.0, 6.3], + [9.5, 10.2], + ]); + expect(buildNemoSegmentChunks(words)).toEqual([ + { text: "Hello.", timestamp: [0, 0.4] }, + { text: "World again.", timestamp: [0.5, 1.1] }, + { text: "U.S. Report update.", timestamp: [1.2, 2.4] }, + { text: "pause", timestamp: [6.0, 6.3] }, + { text: "Next sentence.", timestamp: [9.5, 10.2] }, + ]); + }); + + it("keeps word boundaries from the final decoded text for numeric and punctuation tokens", () => { + const rawById = { + 1: "▁score", + 2: ".", + 3: "48", + 4: "-", + 5: "year", + 6: "-", + 7: "old", + 8: "▁with", + 9: "0", + 10: ".", + 11: "5", + }; + const tokenizer = { + get_vocab() { + return rawById; + }, + decode(ids) { + if (ids.length === 1) { + return rawById[ids[0]].replace(/^▁/, ""); + } + return "score. 48-year-old with 0.5"; + }, + }; + + const output = buildTransducerDetailedOutputs( + tokenizer, + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], + [ + [0.0, 0.3], + [0.3, 0.4], + [0.5, 0.8], + [0.8, 0.85], + [0.85, 1.05], + [1.05, 1.1], + [1.1, 1.3], + [1.4, 1.7], + [1.8, 1.9], + [1.9, 1.95], + [1.95, 2.05], + ], + ); + + expect(output.words.map((x) => x.text)).toEqual(["score.", "48-year-old", "with", "0.5"]); + expect(output.tokens.map((x) => x.token)).toEqual(["score", ".", "48", "-", "year", "-", "old", "with", "0", ".", "5"]); + }); + + it("does not collapse distinct overlapping punctuation-only tokens during merge dedupe", () => { + expect( + dedupeMergedWords([ + { text: ".", startTime: 1.0, endTime: 1.3 }, + { text: "?", startTime: 1.2, endTime: 1.5 }, + { text: "?", startTime: 1.2, endTime: 1.6 }, + ]), + ).toEqual([ + { text: ".", startTime: 1.0, endTime: 1.3 }, + { text: "?", startTime: 1.2, endTime: 1.6 }, + ]); + }); + + it("builds word offsets from array-backed tokenizer vocabularies", () => { + const vocab = ["", "▁hello", "▁world"]; + const tokenizer = { + get_vocab() { + return vocab; + }, + decode(ids) { + const pieces = ids.map((id) => vocab[id] ?? "").join(""); + return pieces.replace(/▁/g, "").trim(); + }, + }; + + const output = buildTransducerWordOffsets( + tokenizer, + [1, 2], + [ + [0.0, 0.3], + [0.3, 0.6], + ], + null, + "hello world", + ); + + expect(output.words.map((x) => x.text)).toEqual(["hello", "world"]); + expect(output.tokens.map((x) => x.rawToken)).toEqual(["▁hello", "▁world"]); + expect(output.tokens.map((x) => x.isWordStart)).toEqual([true, true]); + }); + + it("falls back to decoded token text when tokenizer vocab metadata is unavailable", () => { + const token_ids = [1, 2]; + const timestamps = [ + [0.0, 0.3], + [0.3, 0.6], + ]; + + const fromNull = buildTransducerWordOffsets( + { + get_vocab: () => null, + decode(ids) { + return ids[0] === 1 ? " hello" : "world"; + }, + }, + token_ids, + timestamps, + null, + "hello world", + ); + const fromPrimitive = buildTransducerWordOffsets( + { + get_vocab: () => 42, + decode(ids) { + return ids[0] === 1 ? " hello" : "world"; + }, + }, + token_ids, + timestamps, + null, + "hello world", + ); + + expect(fromNull.words.map((x) => x.text)).toEqual(["hello", "world"]); + expect(fromPrimitive.words.map((x) => x.text)).toEqual(["hello", "world"]); + }); + + it("rejects mismatched empty timestamp inputs for word offsets", () => { + expect(() => + buildTransducerWordOffsets( + { + decode: () => "hello", + }, + [1], + [], + ), + ).toThrow("equal lengths"); + }); + + it("requires a tokenizer for non-empty word offsets", () => { + expect(() => buildTransducerWordOffsets(null, [1], [[0.0, 0.3]], null, "hello")).toThrow("requires a tokenizer"); + }); + + it( + "computes delta and delta-delta features", + async () => { + const input = new Tensor( + "float32", + Float32Array.from([ + // T=4, F=2 + 1, 2, 2, 4, 3, 6, 4, 8, + ]), + [1, 4, 2], + ); + + const split = computeTemporalDeltas(input, { order: 2, window: 1, concatenate: false }); + expect(split.delta.dims).toEqual([1, 4, 2]); + expect(split.delta_delta.dims).toEqual([1, 4, 2]); + + const concatOrder1 = computeTemporalDeltas(input, { order: 1, window: 1, concatenate: true }); + expect(concatOrder1.dims).toEqual([1, 4, 4]); + expect(Array.from(concatOrder1.data.slice(0, 8))).toEqual([ + 1, + 2, + 0.5, + 1, // t0: base + delta + 2, + 4, + 1, + 2, // t1: base + delta + ]); + + const concat = computeTemporalDeltas(input, { order: 2, window: 1, concatenate: true }); + expect(concat.dims).toEqual([1, 4, 6]); + }, + MAX_TEST_EXECUTION_TIME, + ); + + it("rejects non-float32 tensors for temporal deltas", () => { + const input = new Tensor("float64", Float64Array.from([1, 2, 2, 4]), [1, 2, 2]); + expect(() => computeTemporalDeltas(input, { order: 1, window: 1, concatenate: true })).toThrow('type "float32"'); + }); + + it("disposes intermediate delta tensors in concatenate paths", () => { + const input = new Tensor("float32", Float32Array.from([1, 2, 2, 4, 3, 6, 4, 8]), [1, 4, 2]); + const originalDispose = Tensor.prototype.dispose; + let disposeCalls = 0; + Tensor.prototype.dispose = function () { + disposeCalls += 1; + return originalDispose.call(this); + }; + + try { + const order1 = computeTemporalDeltas(input, { order: 1, window: 1, concatenate: true }); + const order2 = computeTemporalDeltas(input, { order: 2, window: 1, concatenate: true }); + expect(order1.dims).toEqual([1, 4, 4]); + expect(order2.dims).toEqual([1, 4, 6]); + } finally { + Tensor.prototype.dispose = originalDispose; + } + + // order=1 concat disposes one intermediate tensor, order=2 concat disposes two. + expect(disposeCalls).toBe(3); + }); + + it( + "creates stable audio cache keys", + async () => { + const a = Float32Array.from([0, 0.1, 0.2, 0.3]); + const b = Float32Array.from([0, 0.1, 0.2, 0.4]); + const ka1 = createAudioCacheKey(a, 16000); + const ka2 = createAudioCacheKey(a, 16000); + const kb = createAudioCacheKey(b, 16000); + + expect(ka1).toEqual(ka2); + expect(ka1).not.toEqual(kb); + }, + MAX_TEST_EXECUTION_TIME, + ); + + it("uses Nemo encoder selector key when resolving model files", async () => { + const files = await get_model_files("dummy/nemo", { + local_files_only: true, + config: { + architectures: ["UnknownArch"], + model_type: "nemo-conformer-tdt", + "transformers.js_config": {}, + }, + dtype: { + model: "int8", + encoder_model: "fp16", + decoder_model_merged: "q4", + }, + }); + expect(files).toEqual(["config.json", "onnx/encoder_model_fp16.onnx", "onnx/decoder_model_merged_q4.onnx"]); + }); + + it( + "distinguishes long waveforms that differ at unsampled indices", + async () => { + const a = new Float32Array(10000); + const b = new Float32Array(10000); + b[1] = 0.12345; // Index 1 was skipped by the prior stride-based hash for this length. + + const ka = createAudioCacheKey(a, 16000); + const kb = createAudioCacheKey(b, 16000); + expect(ka).not.toEqual(kb); + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + "evicts least-recently-used entries when full", + async () => { + const cache = new FeatureLRUCache({ max_entries: 2, max_size_mb: 4 }); + expect(cache.set("a", new Tensor("float32", new Float32Array([1, 2, 3]), [1, 3]))).toBe(true); + expect(cache.set("b", new Tensor("float32", new Float32Array([4, 5, 6]), [1, 3]))).toBe(true); + expect(cache.get("a")).not.toBeNull(); + + expect(cache.set("c", new Tensor("float32", new Float32Array([7, 8, 9]), [1, 3]))).toBe(true); + // `b` should be evicted because `a` was recently accessed. + expect(cache.get("b")).toBeNull(); + expect(cache.get("a")).not.toBeNull(); + expect(cache.get("c")).not.toBeNull(); + }, + MAX_TEST_EXECUTION_TIME, + ); + + it("disposes replaced cache entries", () => { + const cache = new FeatureLRUCache({ max_entries: 4, max_size_mb: 4 }); + const originalDispose = Tensor.prototype.dispose; + let disposeCalls = 0; + Tensor.prototype.dispose = function () { + disposeCalls += 1; + return originalDispose.call(this); + }; + + try { + cache.set("x", new Tensor("float32", new Float32Array([1, 2, 3]), [1, 3])); + cache.set("x", new Tensor("float32", new Float32Array([4, 5, 6]), [1, 3])); + expect(disposeCalls).toBe(1); + } finally { + Tensor.prototype.dispose = originalDispose; + cache.clear(); + } + }); + + it("does not dispose when re-setting the same value object for an existing key", () => { + const cache = new FeatureLRUCache({ max_entries: 4, max_size_mb: 4 }); + const tensor = new Tensor("float32", new Float32Array([1, 2, 3]), [1, 3]); + let disposeCalls = 0; + const originalDispose = tensor.dispose.bind(tensor); + tensor.dispose = () => { + disposeCalls += 1; + originalDispose(); + }; + + cache.set("x", tensor); + cache.set("x", tensor); + expect(cache.get("x")).toBe(tensor); + expect(disposeCalls).toBe(0); + + cache.clear(); + expect(disposeCalls).toBe(1); + }); + + it("disposes tensors on eviction and clear without double-disposing shared refs", () => { + const cache = new FeatureLRUCache({ max_entries: 1, max_size_mb: 4 }); + const originalDispose = Tensor.prototype.dispose; + let disposeCalls = 0; + Tensor.prototype.dispose = function () { + disposeCalls += 1; + return originalDispose.call(this); + }; + + try { + const sharedA = new Tensor("float32", new Float32Array([1, 2, 3]), [1, 3]); + cache.set("a", { + input_features: sharedA, + attention_mask: sharedA, + }); + const sharedB = new Tensor("float32", new Float32Array([4, 5, 6]), [1, 3]); + cache.set("b", { + input_features: sharedB, + attention_mask: sharedB, + }); + // Eviction of "a" should dispose sharedA once, despite duplicate field references. + expect(disposeCalls).toBe(1); + + cache.clear(); + // Clear should dispose sharedB once. + expect(disposeCalls).toBe(2); + } finally { + Tensor.prototype.dispose = originalDispose; + } + }); + + it("defers disposal for borrowed cache entries until they are released", () => { + const cache = new FeatureLRUCache({ max_entries: 1, max_size_mb: 4 }); + const tensorA = new Tensor("float32", new Float32Array([1, 2, 3]), [1, 3]); + const tensorB = new Tensor("float32", new Float32Array([4, 5, 6]), [1, 3]); + let disposeCalls = 0; + const track = (tensor) => { + const originalDispose = tensor.dispose.bind(tensor); + tensor.dispose = () => { + disposeCalls += 1; + originalDispose(); + }; + }; + track(tensorA); + track(tensorB); + + cache.set("a", tensorA); + const borrowedA = cache.acquire("a"); + expect(borrowedA?.value).toBe(tensorA); + + cache.set("b", tensorB); + expect(disposeCalls).toBe(0); + borrowedA?.release(); + expect(disposeCalls).toBe(1); + + const borrowedB = cache.acquire("b"); + expect(borrowedB?.value).toBe(tensorB); + cache.clear(); + expect(disposeCalls).toBe(1); + borrowedB?.release(); + expect(disposeCalls).toBe(2); + }); + + it("keeps borrowed entry bytes counted until release", () => { + const cache = new FeatureLRUCache({ max_entries: 4, max_size_mb: 0.00002 }); + const tensorA = new Tensor("float32", new Float32Array([1, 2, 3]), [1, 3]); + const tensorB = new Tensor("float32", new Float32Array([4, 5, 6]), [1, 3]); + + let tensorADisposals = 0; + const disposeA = tensorA.dispose.bind(tensorA); + tensorA.dispose = () => { + tensorADisposals += 1; + disposeA(); + }; + + let tensorBDisposals = 0; + const disposeB = tensorB.dispose.bind(tensorB); + tensorB.dispose = () => { + tensorBDisposals += 1; + disposeB(); + }; + + expect(cache.set("a", tensorA)).toBe(true); + const borrowedA = cache.acquire("a"); + expect(borrowedA?.value).toBe(tensorA); + + expect(cache.set("b", tensorB)).toBe(false); + expect(cache.get("a")).toBeNull(); + expect(cache.get("b")).toBeNull(); + expect(cache.stats().entries).toBe(0); + expect(cache.stats().size_mb).toBeGreaterThan(0); + expect(tensorADisposals).toBe(0); + expect(tensorBDisposals).toBe(1); + + borrowedA?.release(); + expect(cache.stats().size_mb).toBe(0); + expect(tensorADisposals).toBe(1); + }); + + it("treats zero cache limits as explicit no-cache mode without disposing inserted values", () => { + const byEntries = new FeatureLRUCache({ max_entries: 0, max_size_mb: 4 }); + const bySize = new FeatureLRUCache({ max_entries: 4, max_size_mb: 0 }); + const t1 = new Tensor("float32", new Float32Array([1, 2, 3]), [1, 3]); + const t2 = new Tensor("float32", new Float32Array([4, 5, 6]), [1, 3]); + + let t1Disposals = 0; + const t1Dispose = t1.dispose.bind(t1); + t1.dispose = () => { + t1Disposals += 1; + t1Dispose(); + }; + let t2Disposals = 0; + const t2Dispose = t2.dispose.bind(t2); + t2.dispose = () => { + t2Disposals += 1; + t2Dispose(); + }; + + expect(byEntries.set("x", t1)).toBe(false); + expect(bySize.set("y", t2)).toBe(false); + expect(byEntries.get("x")).toBeNull(); + expect(bySize.get("y")).toBeNull(); + expect(t1Disposals).toBe(0); + expect(t2Disposals).toBe(0); + + t1.dispose(); + t2.dispose(); + expect(t1Disposals).toBe(1); + expect(t2Disposals).toBe(1); + }); + + it("skips caching oversized values without disposing caller-owned tensors", () => { + const cache = new FeatureLRUCache({ max_entries: 4, max_size_mb: 0.000001 }); + const tensor = new Tensor("float32", new Float32Array([1, 2]), [1, 2]); + let disposeCalls = 0; + const originalDispose = tensor.dispose.bind(tensor); + tensor.dispose = () => { + disposeCalls += 1; + originalDispose(); + }; + + expect(cache.set("big", tensor)).toBe(false); + expect(cache.get("big")).toBeNull(); + expect(disposeCalls).toBe(0); + + tensor.dispose(); + expect(disposeCalls).toBe(1); + }); + + it("ignores non-numeric byteLength values in size estimation", () => { + const cache = new FeatureLRUCache({ max_entries: 4, max_size_mb: 4 }); + cache.set("x", { byteLength: "invalid" }); + expect(cache.stats().entries).toBe(1); + expect(cache.stats().size_mb).toBe(0); + cache.clear(); + }); + + it("rejects invalid cache limits", () => { + expect(() => new FeatureLRUCache({ max_entries: -1 })).toThrow("max_entries"); + expect(() => new FeatureLRUCache({ max_entries: 1.25 })).toThrow("max_entries"); + expect(() => new FeatureLRUCache({ max_size_mb: -1 })).toThrow("max_size_mb"); + expect(() => new FeatureLRUCache({ max_size_mb: Number.POSITIVE_INFINITY })).toThrow("max_size_mb"); + }); + }); +}; diff --git a/packages/transformers/tests/pipelines/test_pipelines_automatic_speech_recognition.js b/packages/transformers/tests/pipelines/test_pipelines_automatic_speech_recognition.js index 963306f7f..cbaa96636 100644 --- a/packages/transformers/tests/pipelines/test_pipelines_automatic_speech_recognition.js +++ b/packages/transformers/tests/pipelines/test_pipelines_automatic_speech_recognition.js @@ -96,7 +96,6 @@ export default () => { const model_id = "Xenova/tiny-random-Wav2Vec2ForCTC-ONNX"; const SAMPLING_RATE = 16000; const audios = [new Float32Array(SAMPLING_RATE).fill(0), Float32Array.from({ length: SAMPLING_RATE }, (_, i) => i / 16000)]; - const long_audios = [new Float32Array(SAMPLING_RATE * 60).fill(0), Float32Array.from({ length: SAMPLING_RATE * 60 }, (_, i) => (i % 1000) / 1000)]; const max_new_tokens = 5; /** @type {AutomaticSpeechRecognitionPipeline} */ @@ -125,5 +124,82 @@ export default () => { await pipe?.dispose(); }, MAX_MODEL_DISPOSE_TIME); }); + + describe("nemo-conformer-tdt", () => { + const makeUnitPipe = () => { + const calls = []; + const model = { + config: { model_type: "nemo-conformer-tdt" }, + async transcribe(_inputs, options) { + calls.push(options); + const result = { text: "hello world" }; + if (options.returnTimestamps) { + result.utteranceTimestamp = [0, 0.08]; + result.words = [ + { text: "hello", startTime: 0, endTime: 0.04 }, + { text: "world", startTime: 0.04, endTime: 0.08 }, + ]; + } + return result; + }, + async dispose() {}, + }; + const processor = Object.assign(async () => ({ input_features: {} }), { + feature_extractor: { config: { sampling_rate: 16000 } }, + }); + + return { + pipe: new AutomaticSpeechRecognitionPipeline({ + task: PIPELINE_ID, + model, + tokenizer: {}, + processor, + }), + calls, + }; + }; + + it("returns text when timestamps are disabled", async () => { + const { pipe, calls } = makeUnitPipe(); + await expect(pipe(new Float32Array(16000), { return_timestamps: false })).resolves.toEqual({ + text: "hello world", + }); + expect(calls).toHaveLength(1); + expect(calls[0]).toMatchObject({ + returnTimestamps: false, + returnWords: false, + returnMetrics: false, + }); + }); + + it("returns sentence chunks when return_timestamps is true", async () => { + const { pipe, calls } = makeUnitPipe(); + await expect(pipe(new Float32Array(16000), { return_timestamps: true })).resolves.toEqual({ + text: "hello world", + chunks: [{ text: "hello world", timestamp: [0, 0.08] }], + }); + expect(calls[0]).toMatchObject({ + returnTimestamps: true, + returnWords: true, + returnMetrics: false, + }); + }); + + it("returns word chunks when return_timestamps is 'word'", async () => { + const { pipe, calls } = makeUnitPipe(); + await expect(pipe(new Float32Array(16000), { return_timestamps: "word" })).resolves.toEqual({ + text: "hello world", + chunks: [ + { text: "hello", timestamp: [0, 0.04] }, + { text: "world", timestamp: [0.04, 0.08] }, + ], + }); + expect(calls[0]).toMatchObject({ + returnTimestamps: true, + returnWords: true, + returnMetrics: false, + }); + }); + }); }); }; diff --git a/packages/transformers/tests/pipelines/test_pipelines_nemo_conformer_tdt.js b/packages/transformers/tests/pipelines/test_pipelines_nemo_conformer_tdt.js new file mode 100644 index 000000000..2ef8f3986 --- /dev/null +++ b/packages/transformers/tests/pipelines/test_pipelines_nemo_conformer_tdt.js @@ -0,0 +1,793 @@ +import { Tensor } from "../../src/transformers.js"; +import { NEMO_FEATURE_OUTPUT_OWNERSHIP, NEMO_FEATURE_OUTPUT_RELEASE } from "../../src/models/nemo_conformer_tdt/feature_extraction_nemo_conformer_tdt.js"; +import { runNemoConformerTDTPipeline } from "../../src/models/nemo_conformer_tdt/pipeline_nemo_conformer_tdt.js"; + +const SAMPLING_RATE = 16000; + +const makeProcessor = (impl = async () => ({ input_features: {} })) => + Object.assign(impl, { + feature_extractor: { config: { sampling_rate: SAMPLING_RATE } }, + }); + +const makeTokenizer = () => ({ + decode(ids) { + const pieces = { + 1: "hello", + 2: "world", + 3: "again", + 4: "today", + }; + return ids + .map((id) => pieces[id] ?? "") + .filter(Boolean) + .join(" "); + }, +}); + +const prepareAudios = async (audios) => audios; + +const runPipeline = ({ model, audio = new Float32Array(SAMPLING_RATE), kwargs = {}, tokenizer = makeTokenizer(), processor = makeProcessor() }) => + runNemoConformerTDTPipeline({ + model, + processor, + tokenizer, + audio, + kwargs, + prepareAudios, + }); + +const withNemoTensorOwnership = (value, cacheOwnsTensors, release = null) => { + Object.defineProperty(value, NEMO_FEATURE_OUTPUT_OWNERSHIP, { + value: cacheOwnsTensors, + enumerable: false, + configurable: true, + }); + if (release) { + Object.defineProperty(value, NEMO_FEATURE_OUTPUT_RELEASE, { + value: release, + enumerable: false, + configurable: true, + }); + } + return value; +}; + +export default () => { + describe("Nemo Conformer TDT pipeline adapter", () => { + it("builds conservative sentence chunks from Nemo word timestamps", async () => { + const model = { + async transcribe() { + return { + text: "Hello. World again. U.S. Report update.", + utteranceTimestamp: [0, 2.4], + words: [ + { text: "Hello.", startTime: 0, endTime: 0.4 }, + { text: "World", startTime: 0.5, endTime: 0.8 }, + { text: "again.", startTime: 0.8, endTime: 1.1 }, + { text: "U.S.", startTime: 1.2, endTime: 1.5 }, + { text: "Report", startTime: 1.6, endTime: 2.0 }, + { text: "update.", startTime: 2.0, endTime: 2.4 }, + ], + }; + }, + }; + + await expect(runPipeline({ model, kwargs: { return_timestamps: true } })).resolves.toEqual({ + text: "Hello. World again. U.S. Report update.", + chunks: [ + { text: "Hello.", timestamp: [0, 0.4] }, + { text: "World again.", timestamp: [0.5, 1.1] }, + { text: "U.S. Report update.", timestamp: [1.2, 2.4] }, + ], + }); + }); + + it("uses explicit chunk_length_s as a bounded sentence window size override", async () => { + const calls = []; + const outputsByOffset = new Map([ + [ + 0, + { + text: "Alpha. Beta. Carry", + words: [ + { text: "Alpha.", startTime: 0, endTime: 1 }, + { text: "Beta.", startTime: 17, endTime: 18 }, + { text: "Carry", startTime: 19.95, endTime: 20 }, + ], + }, + ], + [ + 19.95, + { + text: "Carry on. Gamma", + words: [ + { text: "Carry", startTime: 19.95, endTime: 20 }, + { text: "on.", startTime: 20, endTime: 20.5 }, + { text: "Gamma", startTime: 37.9, endTime: 38 }, + ], + }, + ], + [ + 37.9, + { + text: "Gamma. Tail resumes. Omega.", + words: [ + { text: "Gamma.", startTime: 37.9, endTime: 39 }, + { text: "Tail", startTime: 39.2, endTime: 39.6 }, + { text: "resumes.", startTime: 39.6, endTime: 40.1 }, + { text: "Omega.", startTime: 40.1, endTime: 40.45 }, + ], + }, + ], + ]); + const model = { + async transcribe(_inputs, options) { + calls.push(options); + const item = outputsByOffset.get(options.timeOffset); + if (!item) { + throw new Error(`Unexpected timeOffset ${options.timeOffset}`); + } + return { + text: item.text, + utteranceTimestamp: [item.words[0].startTime, item.words[item.words.length - 1].endTime], + words: item.words, + }; + }, + }; + + await expect( + runPipeline({ + model, + audio: new Float32Array(40.5 * SAMPLING_RATE), + kwargs: { return_timestamps: "word", chunk_length_s: 2 }, + }), + ).resolves.toEqual({ + text: "Alpha. Beta. Carry on. Gamma. Tail resumes. Omega.", + chunks: [ + { text: "Alpha.", timestamp: [0, 1] }, + { text: "Beta.", timestamp: [17, 18] }, + { text: "Carry", timestamp: [19.95, 20] }, + { text: "on.", timestamp: [20, 20.5] }, + { text: "Gamma.", timestamp: [37.9, 39] }, + { text: "Tail", timestamp: [39.2, 39.6] }, + { text: "resumes.", timestamp: [39.6, 40.1] }, + { text: "Omega.", timestamp: [40.1, 40.45] }, + ], + }); + expect(calls.map((x) => x.timeOffset)).toEqual([0, 19.95, 37.9]); + expect(calls[0]).toMatchObject({ returnTimestamps: true, returnWords: true, returnMetrics: false, timeOffset: 0 }); + expect(calls[1]).toMatchObject({ returnTimestamps: true, returnWords: true, returnMetrics: false, timeOffset: 19.95 }); + expect(calls[2]).toMatchObject({ returnTimestamps: true, returnWords: true, returnMetrics: false, timeOffset: 37.9 }); + }); + + it("replaces boundary-truncated sentences with the longer retranscribed sentence", async () => { + const calls = []; + const outputsByOffset = new Map([ + [ + 0, + { + text: "Alpha. Beta. It won't run away, and it won't come to life.", + words: [ + { text: "Alpha.", startTime: 0, endTime: 1 }, + { text: "Beta.", startTime: 11, endTime: 12 }, + { text: "It", startTime: 17.2, endTime: 17.5 }, + { text: "won't", startTime: 17.5, endTime: 17.9 }, + { text: "run", startTime: 17.9, endTime: 18.2 }, + { text: "away,", startTime: 18.2, endTime: 18.6 }, + { text: "and", startTime: 18.6, endTime: 18.8 }, + { text: "it", startTime: 18.8, endTime: 19.0 }, + { text: "won't", startTime: 19.0, endTime: 19.3 }, + { text: "come", startTime: 19.3, endTime: 19.5 }, + { text: "to", startTime: 19.5, endTime: 19.65 }, + { text: "life.", startTime: 19.65, endTime: 19.8 }, + ], + }, + ], + [ + 17.2, + { + text: "It won't run away, and it won't come to life until someone finds it. Omega.", + words: [ + { text: "It", startTime: 17.2, endTime: 17.5 }, + { text: "won't", startTime: 17.5, endTime: 17.9 }, + { text: "run", startTime: 17.9, endTime: 18.2 }, + { text: "away,", startTime: 18.2, endTime: 18.6 }, + { text: "and", startTime: 18.6, endTime: 18.8 }, + { text: "it", startTime: 18.8, endTime: 19.0 }, + { text: "won't", startTime: 19.0, endTime: 19.3 }, + { text: "come", startTime: 19.3, endTime: 19.5 }, + { text: "to", startTime: 19.5, endTime: 19.65 }, + { text: "life", startTime: 19.65, endTime: 19.95 }, + { text: "until", startTime: 19.95, endTime: 20.4 }, + { text: "someone", startTime: 20.4, endTime: 21.0 }, + { text: "finds", startTime: 21.0, endTime: 21.5 }, + { text: "it.", startTime: 21.5, endTime: 22.0 }, + { text: "Omega.", startTime: 28, endTime: 29 }, + ], + }, + ], + ]); + const model = { + async transcribe(_inputs, options) { + calls.push(options); + const item = outputsByOffset.get(options.timeOffset); + if (!item) { + throw new Error(`Unexpected timeOffset ${options.timeOffset}`); + } + return { + text: item.text, + utteranceTimestamp: [item.words[0].startTime, item.words[item.words.length - 1].endTime], + words: item.words, + }; + }, + }; + + await expect( + runPipeline({ + model, + audio: new Float32Array(Math.ceil(31 * SAMPLING_RATE)), + kwargs: { return_timestamps: true, chunk_length_s: 20 }, + }), + ).resolves.toEqual({ + text: "Alpha. Beta. It won't run away, and it won't come to life until someone finds it. Omega.", + chunks: [ + { text: "Alpha.", timestamp: [0, 1] }, + { text: "Beta.", timestamp: [11, 12] }, + { text: "It won't run away, and it won't come to life until someone finds it.", timestamp: [17.2, 22] }, + { text: "Omega.", timestamp: [28, 29] }, + ], + }); + expect(calls.map((x) => x.timeOffset)).toEqual([0, 17.2]); + }); + + it("retranscribes the dropped last sentence from its start without stale carry", async () => { + const calls = []; + const outputsByOffset = new Map([ + [ + 0, + { + text: "Alpha. The pressure gauge mark. He watched as the fruit", + words: [ + { text: "Alpha.", startTime: 0, endTime: 1 }, + { text: "The", startTime: 16.8, endTime: 17.0 }, + { text: "pressure", startTime: 17.0, endTime: 17.4 }, + { text: "gauge", startTime: 17.4, endTime: 17.76 }, + { text: "mark.", startTime: 17.76, endTime: 18.56 }, + { text: "He", startTime: 18.56, endTime: 18.72 }, + { text: "watched", startTime: 18.72, endTime: 18.96 }, + { text: "as", startTime: 18.96, endTime: 19.04 }, + { text: "the", startTime: 19.04, endTime: 19.2 }, + { text: "fruit", startTime: 19.2, endTime: 19.36 }, + ], + }, + ], + [ + 18.56, + { + text: "He watched as the fluid.", + words: [ + { text: "He", startTime: 18.56, endTime: 18.72 }, + { text: "watched", startTime: 18.72, endTime: 19.12 }, + { text: "as", startTime: 19.12, endTime: 19.28 }, + { text: "the", startTime: 19.28, endTime: 19.36 }, + { text: "fluid.", startTime: 19.36, endTime: 20 }, + ], + }, + ], + ]); + const model = { + async transcribe(_inputs, options) { + calls.push(options); + const item = outputsByOffset.get(options.timeOffset); + if (!item) { + throw new Error(`Unexpected timeOffset ${options.timeOffset}`); + } + return { + text: item.text, + utteranceTimestamp: [item.words[0].startTime, item.words[item.words.length - 1].endTime], + words: item.words, + }; + }, + }; + + await expect( + runPipeline({ + model, + audio: new Float32Array(Math.ceil(21 * SAMPLING_RATE)), + kwargs: { return_timestamps: "word", chunk_length_s: 20 }, + }), + ).resolves.toEqual({ + text: "Alpha. The pressure gauge mark. He watched as the fluid.", + chunks: [ + { text: "Alpha.", timestamp: [0, 1] }, + { text: "The", timestamp: [16.8, 17] }, + { text: "pressure", timestamp: [17, 17.4] }, + { text: "gauge", timestamp: [17.4, 17.76] }, + { text: "mark.", timestamp: [17.76, 18.56] }, + { text: "He", timestamp: [18.56, 18.72] }, + { text: "watched", timestamp: [18.72, 19.12] }, + { text: "as", timestamp: [19.12, 19.28] }, + { text: "the", timestamp: [19.28, 19.36] }, + { text: "fluid.", timestamp: [19.36, 20] }, + ], + }); + expect(calls.map((x) => x.timeOffset)).toEqual([0, 18.56]); + }); + + it("preserves the pending prefix when cursor snapping restarts inside the last sentence", async () => { + const calls = []; + const outputsByOffset = new Map([ + [ + 0, + { + text: "Alpha. Carry on", + words: [ + { text: "Alpha.", startTime: 0, endTime: 19.65 }, + { text: "Carry", startTime: 19.7, endTime: 20.0 }, + { text: "on.", startTime: 20.5, endTime: 20.8 }, + ], + }, + ], + [ + 20, + { + text: "on. Gamma.", + words: [ + { text: "on.", startTime: 20.5, endTime: 20.8 }, + { text: "Gamma.", startTime: 28, endTime: 29 }, + ], + }, + ], + ]); + const model = { + async transcribe(_inputs, options) { + calls.push(options); + const item = outputsByOffset.get(options.timeOffset); + if (!item) { + throw new Error(`Unexpected timeOffset ${options.timeOffset}`); + } + return { + text: item.text, + utteranceTimestamp: [item.words[0].startTime, item.words[item.words.length - 1].endTime], + words: item.words, + }; + }, + }; + + await expect( + runPipeline({ + model, + audio: new Float32Array(Math.ceil(31 * SAMPLING_RATE)), + kwargs: { return_timestamps: true, chunk_length_s: 20 }, + }), + ).resolves.toEqual({ + text: "Alpha. Carry on. Gamma.", + chunks: [ + { text: "Alpha.", timestamp: [0, 19.65] }, + { text: "Carry on.", timestamp: [19.7, 20.8] }, + { text: "Gamma.", timestamp: [28, 29] }, + ], + }); + expect(calls.map((x) => x.timeOffset)).toEqual([0, 20]); + }); + + it("reconstructs windowed Nemo text from merged words when token decode drops spaces", async () => { + const calls = []; + const model = { + async transcribe(_inputs, options) { + calls.push(options); + if (options.timeOffset === 0) { + return { + text: "score. 48-year-old", + words: [ + { text: "score.", startTime: 0, endTime: 0.4 }, + { text: "48-year-old", startTime: 0.5, endTime: 1.3 }, + ], + }; + } + return { + text: "with 0.5", + words: [ + { text: "with", startTime: 1.4, endTime: 1.7 }, + { text: "0.5", startTime: 1.8, endTime: 2.05 }, + ], + }; + }, + }; + const tokenizer = { + decode(ids) { + const pieces = { + 1: "score", + 2: ".", + 3: "48", + 4: "-", + 5: "year", + 6: "old", + 7: "with", + 8: "0", + 9: "5", + }; + return ids.map((id) => pieces[id] ?? "").join(""); + }, + }; + + const output = await runPipeline({ + model, + tokenizer, + audio: new Float32Array(Math.ceil(20.1 * SAMPLING_RATE)), + kwargs: { return_timestamps: "word", chunk_length_s: 20 }, + }); + + expect(output.text).toBe("score. 48-year-old with 0.5"); + expect(output.chunks).toEqual([ + { text: "score.", timestamp: [0, 0.4] }, + { text: "48-year-old", timestamp: [0.5, 1.3] }, + { text: "with", timestamp: [1.4, 1.7] }, + { text: "0.5", timestamp: [1.8, 2.05] }, + ]); + expect(calls.map((x) => x.timeOffset)).toEqual([0, 10]); + }); + + it("auto-windows long Nemo audio with 90s sentence windows", async () => { + const calls = []; + const outputsByOffset = new Map([ + [ + 0, + { + text: "Alpha. Beta. Gamma. Carry", + words: [ + { text: "Alpha.", startTime: 0, endTime: 1 }, + { text: "Beta.", startTime: 30, endTime: 31 }, + { text: "Gamma.", startTime: 69, endTime: 70 }, + { text: "Carry", startTime: 84, endTime: 85 }, + ], + }, + ], + [ + 84, + { + text: "Carry on. Delta. Epsilon. Tail", + words: [ + { text: "Carry", startTime: 84, endTime: 85 }, + { text: "on.", startTime: 86, endTime: 87 }, + { text: "Delta.", startTime: 110, endTime: 111 }, + { text: "Epsilon.", startTime: 139, endTime: 140 }, + { text: "Tail", startTime: 154, endTime: 155 }, + ], + }, + ], + [ + 154, + { + text: "Tail resumes. Zeta. Eta. Final", + words: [ + { text: "Tail", startTime: 154, endTime: 155 }, + { text: "resumes.", startTime: 156, endTime: 157 }, + { text: "Zeta.", startTime: 180, endTime: 181 }, + { text: "Eta.", startTime: 209, endTime: 210 }, + { text: "Final", startTime: 224, endTime: 225 }, + ], + }, + ], + [ + 224, + { + text: "Final line. Omega.", + words: [ + { text: "Final", startTime: 224, endTime: 225 }, + { text: "line.", startTime: 226, endTime: 227 }, + { text: "Omega.", startTime: 250, endTime: 251 }, + ], + }, + ], + ]); + const model = { + async transcribe(_inputs, options) { + calls.push(options); + const item = outputsByOffset.get(options.timeOffset); + if (!item) { + throw new Error(`Unexpected timeOffset ${options.timeOffset}`); + } + return { + text: item.text, + utteranceTimestamp: [item.words[0].startTime, item.words[item.words.length - 1].endTime], + words: item.words, + }; + }, + }; + + await expect( + runPipeline({ + model, + audio: new Float32Array(300 * SAMPLING_RATE), + kwargs: { return_timestamps: "word" }, + }), + ).resolves.toEqual({ + text: "Alpha. Beta. Gamma. Carry on. Delta. Epsilon. Tail resumes. Zeta. Eta. Final line. Omega.", + chunks: [ + { text: "Alpha.", timestamp: [0, 1] }, + { text: "Beta.", timestamp: [30, 31] }, + { text: "Gamma.", timestamp: [69, 70] }, + { text: "Carry", timestamp: [84, 85] }, + { text: "on.", timestamp: [86, 87] }, + { text: "Delta.", timestamp: [110, 111] }, + { text: "Epsilon.", timestamp: [139, 140] }, + { text: "Tail", timestamp: [154, 155] }, + { text: "resumes.", timestamp: [156, 157] }, + { text: "Zeta.", timestamp: [180, 181] }, + { text: "Eta.", timestamp: [209, 210] }, + { text: "Final", timestamp: [224, 225] }, + { text: "line.", timestamp: [226, 227] }, + { text: "Omega.", timestamp: [250, 251] }, + ], + }); + expect(calls).toHaveLength(4); + expect(calls.map((x) => x.timeOffset)).toEqual([0, 84, 154, 224]); + for (const call of calls) { + expect(call).toMatchObject({ + returnTimestamps: true, + returnWords: true, + returnMetrics: false, + }); + } + }); + + it("does not truncate long audio when sentence cursor advances one second at a time", async () => { + const calls = []; + const expectedChunks = Array.from({ length: 13 }, (_, index) => ({ + text: `Alpha${index}.`, + timestamp: [index, index + 0.2], + })).concat([{ text: "Omega.", timestamp: [180, 180.5] }]); + const expectedText = expectedChunks.map((chunk) => chunk.text).join(" "); + + const model = { + async transcribe(_inputs, options) { + calls.push(options); + + if (Number.isInteger(options.timeOffset) && options.timeOffset >= 0 && options.timeOffset < 12) { + const offset = options.timeOffset; + return { + text: `Alpha${offset}. Carry`, + utteranceTimestamp: [offset, offset + 1.2], + words: [ + { text: `Alpha${offset}.`, startTime: offset, endTime: offset + 0.2 }, + { text: "Carry", startTime: offset + 1, endTime: offset + 1.2 }, + ], + }; + } + + if (options.timeOffset === 12) { + return { + text: "Alpha12. Omega.", + utteranceTimestamp: [12, 180.5], + words: [ + { text: "Alpha12.", startTime: 12, endTime: 12.2 }, + { text: "Omega.", startTime: 180, endTime: 180.5 }, + ], + }; + } + + if (options.timeOffset === 180) { + return { + text: "Omega.", + utteranceTimestamp: [180, 180.5], + words: [{ text: "Omega.", startTime: 180, endTime: 180.5 }], + }; + } + + throw new Error(`Unexpected timeOffset ${options.timeOffset}`); + }, + }; + + await expect( + runPipeline({ + model, + audio: new Float32Array(181 * SAMPLING_RATE), + kwargs: { return_timestamps: true }, + }), + ).resolves.toEqual({ + text: expectedText, + chunks: expectedChunks, + }); + + expect(calls.map((x) => x.timeOffset)).toEqual([...Array.from({ length: 13 }, (_, index) => index), 180]); + }); + + it("returns sentence chunks for auto-windowed long Nemo audio", async () => { + const calls = []; + const outputsByOffset = new Map([ + [ + 0, + { + text: "Alpha. Beta. Gamma. Carry", + words: [ + { text: "Alpha.", startTime: 0, endTime: 1 }, + { text: "Beta.", startTime: 30, endTime: 31 }, + { text: "Gamma.", startTime: 69, endTime: 70 }, + { text: "Carry", startTime: 84, endTime: 85 }, + ], + }, + ], + [ + 84, + { + text: "Carry on. Delta. Epsilon. Tail", + words: [ + { text: "Carry", startTime: 84, endTime: 85 }, + { text: "on.", startTime: 86, endTime: 87 }, + { text: "Delta.", startTime: 110, endTime: 111 }, + { text: "Epsilon.", startTime: 139, endTime: 140 }, + { text: "Tail", startTime: 154, endTime: 155 }, + ], + }, + ], + [ + 154, + { + text: "Tail resumes. Zeta. Eta. Final", + words: [ + { text: "Tail", startTime: 154, endTime: 155 }, + { text: "resumes.", startTime: 156, endTime: 157 }, + { text: "Zeta.", startTime: 180, endTime: 181 }, + { text: "Eta.", startTime: 209, endTime: 210 }, + { text: "Final", startTime: 224, endTime: 225 }, + ], + }, + ], + [ + 224, + { + text: "Final line. Omega.", + words: [ + { text: "Final", startTime: 224, endTime: 225 }, + { text: "line.", startTime: 226, endTime: 227 }, + { text: "Omega.", startTime: 250, endTime: 251 }, + ], + }, + ], + ]); + const model = { + async transcribe(_inputs, options) { + calls.push(options); + const item = outputsByOffset.get(options.timeOffset); + if (!item) { + throw new Error(`Unexpected timeOffset ${options.timeOffset}`); + } + return { + text: item.text, + utteranceTimestamp: [item.words[0].startTime, item.words[item.words.length - 1].endTime], + words: item.words, + }; + }, + }; + + await expect( + runPipeline({ + model, + audio: new Float32Array(300 * SAMPLING_RATE), + kwargs: { return_timestamps: true }, + }), + ).resolves.toEqual({ + text: "Alpha. Beta. Gamma. Carry on. Delta. Epsilon. Tail resumes. Zeta. Eta. Final line. Omega.", + chunks: [ + { text: "Alpha.", timestamp: [0, 1] }, + { text: "Beta.", timestamp: [30, 31] }, + { text: "Gamma.", timestamp: [69, 70] }, + { text: "Carry on.", timestamp: [84, 87] }, + { text: "Delta.", timestamp: [110, 111] }, + { text: "Epsilon.", timestamp: [139, 140] }, + { text: "Tail resumes.", timestamp: [154, 157] }, + { text: "Zeta.", timestamp: [180, 181] }, + { text: "Eta.", timestamp: [209, 210] }, + { text: "Final line.", timestamp: [224, 227] }, + { text: "Omega.", timestamp: [250, 251] }, + ], + }); + expect(calls.map((x) => x.timeOffset)).toEqual([0, 84, 154, 224]); + }); + + it("rejects non-finite audio samples before Nemo decoding", async () => { + const model = { + async transcribe() { + return { text: "hello world" }; + }, + }; + await expect( + runPipeline({ + model, + audio: Float32Array.from([0, Number.NaN, 0]), + kwargs: { return_timestamps: false }, + }), + ).rejects.toThrow("finite audio samples"); + }); + + it("disposes processor tensors after Nemo transcription when feature cache is disabled", async () => { + let disposeCalls = 0; + const model = { + async transcribe() { + return { text: "ok" }; + }, + }; + const processor = makeProcessor(async () => { + const input_features = new Tensor("float32", new Float32Array([0, 0]), [1, 1, 2]); + const attention_mask = new Tensor("int64", BigInt64Array.from([1n]), [1, 1]); + const trackDispose = (tensor) => { + const originalDispose = tensor.dispose.bind(tensor); + tensor.dispose = () => { + disposeCalls += 1; + originalDispose(); + }; + }; + trackDispose(input_features); + trackDispose(attention_mask); + return withNemoTensorOwnership({ input_features, attention_mask }, false); + }); + + await expect(runPipeline({ model, processor })).resolves.toEqual({ text: "ok" }); + expect(disposeCalls).toBe(2); + }); + + it("keeps processor tensors alive when Nemo feature cache owns tensor lifetimes", async () => { + let disposeCalls = 0; + let releaseCalls = 0; + let lastInputs = null; + const model = { + async transcribe() { + return { text: "ok" }; + }, + }; + const processor = makeProcessor(async () => { + const input_features = new Tensor("float32", new Float32Array([0, 0]), [1, 1, 2]); + const attention_mask = new Tensor("int64", BigInt64Array.from([1n]), [1, 1]); + const trackDispose = (tensor) => { + const originalDispose = tensor.dispose.bind(tensor); + tensor.dispose = () => { + disposeCalls += 1; + originalDispose(); + }; + }; + trackDispose(input_features); + trackDispose(attention_mask); + lastInputs = withNemoTensorOwnership({ input_features, attention_mask }, true, () => { + releaseCalls += 1; + }); + return lastInputs; + }); + + try { + await expect(runPipeline({ model, processor })).resolves.toEqual({ text: "ok" }); + expect(disposeCalls).toBe(0); + expect(releaseCalls).toBe(1); + } finally { + lastInputs?.input_features.dispose(); + lastInputs?.attention_mask.dispose(); + } + }); + + it("disposes processor tensors when Nemo feature cache limits disable caching", async () => { + let disposeCalls = 0; + const model = { + async transcribe() { + return { text: "ok" }; + }, + }; + const processor = makeProcessor(async () => { + const input_features = new Tensor("float32", new Float32Array([0, 0]), [1, 1, 2]); + const attention_mask = new Tensor("int64", BigInt64Array.from([1n]), [1, 1]); + const trackDispose = (tensor) => { + const originalDispose = tensor.dispose.bind(tensor); + tensor.dispose = () => { + disposeCalls += 1; + originalDispose(); + }; + }; + trackDispose(input_features); + trackDispose(attention_mask); + return withNemoTensorOwnership({ input_features, attention_mask }, false); + }); + + await expect(runPipeline({ model, processor })).resolves.toEqual({ text: "ok" }); + expect(disposeCalls).toBe(2); + }); + }); +};