From 85f940663f2e23df2629dd0f1e02cc3fe3f23a3d Mon Sep 17 00:00:00 2001 From: Elena Rastorgueva Date: Wed, 11 Mar 2026 23:24:14 +0000 Subject: [PATCH 01/40] add changes from duplex-realtime-inference branch, except duplex_stt_model.py modification for function_head Signed-off-by: Elena Rastorgueva --- .../conf/s2s_streaming.yaml | 105 + .../s2s_streaming_infer.py | 268 +++ .../triton/client_streaming.py | 247 +++ .../voicechat/1/infer_streaming.py | 345 +++ .../model_repo_s2s/voicechat/config.pbtxt | 90 + .../triton/start_triton.sh | 88 + .../speechlm2/inference/__init__.py | 13 + .../speechlm2/inference/factory/__init__.py | 13 + .../inference/factory/s2s_pipeline_builder.py | 47 + .../inference/model_wrappers/__init__.py | 13 + .../inference/model_wrappers/model_factory.py | 1093 +++++++++ .../nemotron_voicechat_inference_wrapper.py | 1947 +++++++++++++++++ .../model_wrappers/perception_cache.py | 537 +++++ .../speechlm2/inference/pipelines/__init__.py | 13 + .../pipelines/s2s_pipeline_interface.py | 71 + .../pipelines/streaming_s2s_pipeline.py | 763 +++++++ .../speechlm2/inference/streaming/__init__.py | 13 + .../inference/streaming/framing/__init__.py | 13 + .../streaming/framing/s2s_request_options.py | 27 + .../inference/streaming/state/__init__.py | 13 + .../streaming/state/s2s_context_manager.py | 292 +++ .../inference/streaming/state/s2s_state.py | 146 ++ .../speechlm2/inference/utils/__init__.py | 13 + .../inference/utils/pipeline_utils.py | 62 + .../speechlm2/inference/vllm/__init__.py | 13 + .../vllm/scripts/convert_eartts_checkpoint.py | 256 +++ .../scripts/convert_nemotronllm_checkpoint.py | 261 +++ .../inference/vllm/streaming_llm_engine.py | 480 ++++ .../speechlm2/inference/vllm/vllm_patch.py | 59 + 29 files changed, 7301 insertions(+) create mode 100644 examples/speechlm2/nemo_inference_pipelines/conf/s2s_streaming.yaml create mode 100644 examples/speechlm2/nemo_inference_pipelines/s2s_streaming_infer.py create mode 100644 examples/speechlm2/nemo_inference_pipelines/triton/client_streaming.py create mode 100644 examples/speechlm2/nemo_inference_pipelines/triton/model_repo_s2s/voicechat/1/infer_streaming.py create mode 100644 examples/speechlm2/nemo_inference_pipelines/triton/model_repo_s2s/voicechat/config.pbtxt create mode 100755 examples/speechlm2/nemo_inference_pipelines/triton/start_triton.sh create mode 100644 nemo/collections/speechlm2/inference/__init__.py create mode 100644 nemo/collections/speechlm2/inference/factory/__init__.py create mode 100644 nemo/collections/speechlm2/inference/factory/s2s_pipeline_builder.py create mode 100644 nemo/collections/speechlm2/inference/model_wrappers/__init__.py create mode 100644 nemo/collections/speechlm2/inference/model_wrappers/model_factory.py create mode 100755 nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py create mode 100644 nemo/collections/speechlm2/inference/model_wrappers/perception_cache.py create mode 100644 nemo/collections/speechlm2/inference/pipelines/__init__.py create mode 100644 nemo/collections/speechlm2/inference/pipelines/s2s_pipeline_interface.py create mode 100644 nemo/collections/speechlm2/inference/pipelines/streaming_s2s_pipeline.py create mode 100644 nemo/collections/speechlm2/inference/streaming/__init__.py create mode 100644 nemo/collections/speechlm2/inference/streaming/framing/__init__.py create mode 100644 nemo/collections/speechlm2/inference/streaming/framing/s2s_request_options.py create mode 100644 nemo/collections/speechlm2/inference/streaming/state/__init__.py create mode 100644 nemo/collections/speechlm2/inference/streaming/state/s2s_context_manager.py create mode 100644 nemo/collections/speechlm2/inference/streaming/state/s2s_state.py create mode 100644 nemo/collections/speechlm2/inference/utils/__init__.py create mode 100644 nemo/collections/speechlm2/inference/utils/pipeline_utils.py create mode 100644 nemo/collections/speechlm2/inference/vllm/__init__.py create mode 100644 nemo/collections/speechlm2/inference/vllm/scripts/convert_eartts_checkpoint.py create mode 100644 nemo/collections/speechlm2/inference/vllm/scripts/convert_nemotronllm_checkpoint.py create mode 100644 nemo/collections/speechlm2/inference/vllm/streaming_llm_engine.py create mode 100644 nemo/collections/speechlm2/inference/vllm/vllm_patch.py diff --git a/examples/speechlm2/nemo_inference_pipelines/conf/s2s_streaming.yaml b/examples/speechlm2/nemo_inference_pipelines/conf/s2s_streaming.yaml new file mode 100644 index 000000000000..bded9501448f --- /dev/null +++ b/examples/speechlm2/nemo_inference_pipelines/conf/s2s_streaming.yaml @@ -0,0 +1,105 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Runtime args (overridden via CLI) +# audio_file: "/path/to/your/audio_or_directory" +audio_file: ??? +output_dir: ./generated + +# Input audio padding (set at most one; null = disabled) +pad_audio_to_sec: null # Pad each audio to this fixed duration (seconds) +pad_silence_ratio: null # Append silence = ratio * original duration (e.g. 0.2 = 20%) +pad_audio_by_sec: null # Append this fixed number of extra seconds of silence + +pipeline_type: s2s_streaming + +# S2S model block +s2s: + model_path: ??? + llm_checkpoint_path: ??? + speaker_reference: ??? + engine_type: ??? # Engine type: 'native' or 'vllm_llm' or 'vllm_eartts' or 'vllm_llm_vllm_eartts' + vllm_llm_config: + model_path: ${s2s.model_path} # Inherits from s2s.model_path + max_model_len: 8192 # Maximum sequence length for vLLM + gpu_memory_utilization: 0.35 # GPU memory utilization (0.0-1.0) + dtype: bfloat16 # Data type for vLLM inference + engine_path: null # Path to vLLM engine (null = auto-convert if needed) + pretrained_llm: ${s2s.llm_checkpoint_path} # Inherits from s2s.llm_checkpoint_path + + vllm_tts_config: + model_path: ${s2s.vllm_llm_config.model_path} # Inherits from s2s.model_path + max_model_len: ${s2s.vllm_llm_config.max_model_len} + gpu_memory_utilization: ${s2s.vllm_llm_config.gpu_memory_utilization} + dtype: float32 # EarTTS requires float32 for proper audio quality (bfloat16 causes hallucinations) + engine_path: null + pretrained_llm: null + skip_tokenizer_init: true + + device: cuda + # ======================== + # Device Configuration + # ======================== + device_id: 0 # GPU device ID + compute_dtype: bfloat16 # Compute precision: 'bfloat16' for Ampere+, + # 'float16' for older GPUs + # 'float32' + # ======================== + # Inference settings + # ======================== + codec_token_history_size: 60 # Sliding-window buffer size; ignored when use_codec_cache is true + use_perception_cache: true # Enable cache-aware streaming for perception encoder + use_perception_cudagraph: true # Enable CUDA graph-accelerated perception encoder + use_codec_cache: true # Incremental codec decode to remove clicking sounds and wasted computation + # (when true, codec_token_history_size is unused) + + # Deterministic inference (native engine only). Ensures identical results across + # runs by disabling FlashAttention and forcing deterministic CUDA algorithms. + # Trade-offs: slower inference, might be worse results than non-deterministic mode, since + # non-deterministic mode was used in training. + deterministic: false + + # sampling parameters. if set all to 1.0, it will be greedy decoding. + top_p: 0.5 + repetition_penalty: 1.1 + temperature: 0.3 + force_turn_taking: true + force_turn_taking_threshold: 40 + force_turn_taking_pad_window: 25 + + # Inference logit boosts (applied to model logits at inference time) + inference_user_pad_boost: 0.8 # Boost ASR pad logit + inference_user_bos_boost: null # Boost ASR BOS logit + inference_user_eos_boost: null # Boost ASR EOS logit + + system_prompt: ??? + tts_system_prompt: null # TTS system prompt - conditions TTS generation style + # Requires a checkpoint trained with individual TTS prompts + + + +# ======================== +# Pipeline settings +# ======================== +matmul_precision: medium # Matrix multiplication precision: highest, high, medium + +streaming: + input_sample_rate: 16000 # Audio sample rate in Hz + output_sample_rate: 22050 # Audio sample rate in Hz + batch_size: 1 # Number of audio frames per batch + att_context_size: [70,0] # Attention context size: [70,13],[70,6],[70,2],[70,0] + chunk_size_in_secs: ??? # Needs to be multiple of 80ms + buffer_size_in_secs: ??? # Audio buffer size in seconds (larger = more context, better quality) + request_type: frame # Type of request: frame, only frame is supported for cache-aware streaming + max_len: 8192 diff --git a/examples/speechlm2/nemo_inference_pipelines/s2s_streaming_infer.py b/examples/speechlm2/nemo_inference_pipelines/s2s_streaming_infer.py new file mode 100644 index 000000000000..44a1979a1123 --- /dev/null +++ b/examples/speechlm2/nemo_inference_pipelines/s2s_streaming_infer.py @@ -0,0 +1,268 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +S2S Streaming Inference Client + +Usage: + python s2s_streaming_infer.py \ + audio_file=/path/to/audio_or_directory \ + s2s.model_path=/path/to/eartts_ckpt \ + s2s.llm_checkpoint_path=/path/to/llm_ckpt \ + s2s.speaker_reference=/path/to/speaker.wav \ + streaming.chunk_size_in_secs=0.08 \ + streaming.buffer_size_in_secs=5.6 +""" + +import json +import os +import re +from time import time +from typing import List, Optional + +import hydra +import soundfile as sf +from jiwer import wer as compute_wer +from nemo.collections.common.parts.preprocessing.manifest import get_full_path +from nemo.collections.speechlm2.inference.factory.s2s_pipeline_builder import S2SPipelineBuilder +from nemo.collections.speechlm2.inference.streaming.framing.s2s_request_options import S2SRequestOptions +from nemo.collections.speechlm2.inference.utils.pipeline_utils import PipelineOutput, clean_pred_text +from nemo.utils import logging +from omegaconf import DictConfig +import torch + + +def prepare_audio_data( + audio_file: str, + default_system_prompt: str | None = None, + sort_by_duration: bool = True, +) -> tuple[List[str], List[S2SRequestOptions], List[str | None]]: + """ + Get audio filepaths and per-stream options from a folder, single file, or manifest. + + When the input is a JSON manifest, each line may contain: + {"audio_filepath": "clip.wav", "text": "...", "system_prompt": "..."} + If ``system_prompt`` is absent on a line, *default_system_prompt* is used. + + Returns: + (filepaths, options, ground_truths) -- parallel lists of audio paths, + per-stream options, and ground-truth texts (None when unavailable). + """ + audio_file = audio_file.strip() + if not os.path.isabs(audio_file): + audio_file = os.path.abspath(audio_file) + + options: List[S2SRequestOptions] = [] + ground_truths: List[str | None] = [] + + if os.path.isdir(audio_file): + filepaths = [os.path.join(audio_file, x) for x in os.listdir(audio_file) if x.endswith(".wav")] + options = [S2SRequestOptions(system_prompt=default_system_prompt) for _ in filepaths] + ground_truths = [None] * len(filepaths) + elif audio_file.endswith(".wav"): + filepaths = [audio_file] + options = [S2SRequestOptions(system_prompt=default_system_prompt)] + ground_truths = [None] + elif audio_file.endswith((".json", ".jsonl")): + samples = [] + with open(audio_file, 'r') as f: + for line in f.readlines(): + if line.strip(): + samples.append(json.loads(line)) + filepaths = [get_full_path(entry["audio_filepath"], audio_file) for entry in samples] + options = [ + S2SRequestOptions( + system_prompt=entry.get("system_prompt", default_system_prompt), + ) + for entry in samples + ] + ground_truths = [entry.get("text", None) for entry in samples] + else: + raise ValueError(f"audio_file `{audio_file}` needs to be a folder, audio file, or manifest file") + + if sort_by_duration: + durations = [sf.SoundFile(fp).frames for fp in filepaths] + order = sorted(range(len(filepaths)), key=lambda i: durations[i]) + filepaths = [filepaths[i] for i in order] + options = [options[i] for i in order] + ground_truths = [ground_truths[i] for i in order] + + return filepaths, options, ground_truths + + +def calculate_duration(audio_filepaths: List[str]) -> float: + """Calculate the total duration of the audio files in seconds.""" + total_dur = 0 + for audio_filepath in audio_filepaths: + sound = sf.SoundFile(audio_filepath) + total_dur += sound.frames / sound.samplerate + return total_dur + + +def calculate_padded_duration( + audio_filepaths: List[str], + pad_audio_to_sec: float | None = None, + pad_silence_ratio: float | None = None, + pad_audio_by_sec: float | None = None, +) -> float: + """Calculate total duration including padding for RTFX reporting.""" + total = 0.0 + for fp in audio_filepaths: + sound = sf.SoundFile(fp) + orig = sound.frames / sound.samplerate + if pad_audio_to_sec is not None: + total += max(orig, pad_audio_to_sec) + elif pad_silence_ratio is not None: + total += orig * (1 + pad_silence_ratio) + elif pad_audio_by_sec is not None: + total += orig + pad_audio_by_sec + else: + total += orig + return total + + +def dump_output( + audio_filepaths: List[str], + output: PipelineOutput, + output_dir: str, + options: List[S2SRequestOptions], + ground_truths: List[str | None], +) -> None: + """ + Dump inference results to output_processed.json and output_raw.json. + + output_processed.json uses the same schema as the standalone wrapper's + output_results_processed.json (timestamps in pred_text via <|t|> / <$t$>). + + output_raw.json preserves all tokens including (pad tokens), + matching the standalone wrapper's output_results_raw.json. + + CTM files are still written for per-word audio-sample-based timing. + + Args: + audio_filepaths: List of audio file paths + output: Pipeline output + output_dir: Directory for all output files + options: Per-stream request options (carries the system prompt) + ground_truths: Ground-truth texts (None when unavailable) + """ + output_processed_path = os.path.join(output_dir, "output_processed.json") + output_raw_path = os.path.join(output_dir, "output_raw.json") + output_ctm_dir = os.path.join(output_dir, "ctm") + + os.makedirs(output_ctm_dir, exist_ok=True) + + asr_texts_ts = output.asr_texts_with_timestamps or [None] * len(audio_filepaths) + texts_ts = output.texts_with_timestamps or [""] * len(audio_filepaths) + raw_texts = output.raw_texts or [""] * len(audio_filepaths) + raw_asr_texts = output.raw_asr_texts or [""] * len(audio_filepaths) + + with open(output_processed_path, 'w') as f_proc, open(output_raw_path, 'w') as f_raw: + for audio_filepath, words, opts, gt, pred_text_ts, pred_src_text_ts, pred_text_raw, pred_src_text_raw in zip( + audio_filepaths, output.words, options, ground_truths, + texts_ts, asr_texts_ts, raw_texts, raw_asr_texts, + ): + stem = os.path.splitext(os.path.basename(audio_filepath))[0] + ctm_filepath = os.path.abspath(os.path.join(output_ctm_dir, f"{stem}.ctm")) + with open(ctm_filepath, 'w') as ctm_fout: + for word in words: + ctm_line = f"A {round(word.start, 2)} {round(word.duration, 2)} {word.text} {word.conf}" + ctm_fout.write(f"{stem} {ctm_line}\n") + + pred_audio_path = os.path.join(output_dir, "wav", f"{stem}.wav") + + record_processed = { + "id": stem, + "target_text": "", + "pred_audio": pred_audio_path, + "src_text": gt or "", + "pred_src_text": pred_src_text_ts or "", + "pred_text": pred_text_ts or "", + "system_prompt": opts.system_prompt or "", + } + json.dump(record_processed, f_proc, ensure_ascii=False) + f_proc.write('\n') + f_proc.flush() + + record_raw = { + "id": stem, + "target_text": "", + "pred_audio": pred_audio_path, + "src_text": gt or "", + "pred_src_text": pred_src_text_raw or "", + "pred_text": pred_text_raw or "", + "system_prompt": opts.system_prompt or "", + } + json.dump(record_raw, f_raw, ensure_ascii=False) + f_raw.write('\n') + f_raw.flush() + + +@hydra.main(config_path="./conf", config_name="s2s_streaming", version_base=None) +def main(cfg: DictConfig): + default_system_prompt = cfg.get("s2s", {}).get("system_prompt", None) + audio_filepaths, options, ground_truths = prepare_audio_data( + cfg.audio_file, default_system_prompt=default_system_prompt, sort_by_duration=False, + ) + logging.info(f"Found {len(audio_filepaths)} audio files to generate") + + # Set matmul precision + matmul_precision = cfg.get("matmul_precision", "high") + torch.set_float32_matmul_precision(matmul_precision) + logging.info(f"Using matmul precision: {matmul_precision}") + + pipeline = S2SPipelineBuilder.build_pipeline(cfg) + + start = time() + output = pipeline.run(audio_filepaths, options=options) + exec_dur = time() - start + logging.info(f"Generated {len(audio_filepaths)} files in {exec_dur:.2f}s") + + # Log RTFX (accounts for padding when configured) + pad_to = cfg.get("pad_audio_to_sec", None) + pad_ratio = cfg.get("pad_silence_ratio", None) + pad_by = cfg.get("pad_audio_by_sec", None) + if pad_to or pad_ratio or pad_by: + data_dur = calculate_padded_duration(audio_filepaths, pad_to, pad_ratio, pad_by) + else: + data_dur = calculate_duration(audio_filepaths) + rtfx = data_dur / exec_dur if exec_dur > 0 else float('inf') + logging.info(f"RTFX: {rtfx:.2f} ({data_dur:.2f}s / {exec_dur:.2f}s)") + + # Compute WER when ground-truth texts are available. + # Use asr_texts_with_timestamps (from tokens_to_str with full post-processing) + asr_texts = output.asr_texts_with_timestamps or [None] * len(audio_filepaths) + wer_scores = [] + for gt, asr_text in zip(ground_truths, asr_texts): + if gt and asr_text: + cleaned_gt = clean_pred_text(gt) + cleaned_pred = clean_pred_text(asr_text) + if cleaned_gt.strip() and cleaned_pred.strip(): + wer_scores.append(compute_wer(cleaned_gt, cleaned_pred)) + if wer_scores: + avg_wer = sum(wer_scores) / len(wer_scores) + logging.info( + f"WER: avg={avg_wer:.4f} ({avg_wer * 100:.2f}%), " + f"n={len(wer_scores)}, " + f"min={min(wer_scores):.4f}, max={max(wer_scores):.4f}" + ) + + # Dump the transcriptions and CTMs + output_dir = cfg.get("output_dir", "./generated") + dump_output(audio_filepaths, output, output_dir, options, ground_truths) + logging.info(f"Transcriptions written to {output_dir}/output_processed.json and {output_dir}/output_raw.json") + + +if __name__ == "__main__": + main() diff --git a/examples/speechlm2/nemo_inference_pipelines/triton/client_streaming.py b/examples/speechlm2/nemo_inference_pipelines/triton/client_streaming.py new file mode 100644 index 000000000000..78ca02373a77 --- /dev/null +++ b/examples/speechlm2/nemo_inference_pipelines/triton/client_streaming.py @@ -0,0 +1,247 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Streaming Triton client for the S2S voicechat model. + +Usage: + python client_streaming.py \ + --host localhost --port 8001 \ + --audio_filename /path/to/input.wav \ + --dur_test_audio 30 +""" + +import argparse +import uuid +import random +import sys + +import librosa +import numpy as np +import soundfile as sf +import time +import tritonclient.grpc as grpcclient +from tqdm import tqdm +from tritonclient.utils import * + +# Use Python's built-in logging so this script can run without NeMo installed +import logging +logging.basicConfig(stream=sys.stdout, level=logging.INFO) +logger = logging.getLogger(__name__) + +# Default values +DEFAULT_HOST = "localhost" +DEFAULT_PORT = 8001 +DEFAULT_NUM_FRAMES_PER_CHUNK = 1 +DEFAULT_DUR_TEST_AUDIO = 30 + +parser = argparse.ArgumentParser(description="Streaming client for voicechat model") +parser.add_argument("--host", type=str, default=DEFAULT_HOST, help=f"Triton server host (default: {DEFAULT_HOST})") +parser.add_argument("--port", type=int, default=DEFAULT_PORT, help=f"Triton server port (default: {DEFAULT_PORT})") +parser.add_argument("--num_frames_per_chunk", type=int, default=DEFAULT_NUM_FRAMES_PER_CHUNK, help=f"Number of 80ms frames per inference step (default: {DEFAULT_NUM_FRAMES_PER_CHUNK})") +parser.add_argument("--audio_filename", type=str, required=True, help="Path to input audio file") +parser.add_argument("--dur_test_audio", type=int, default=DEFAULT_DUR_TEST_AUDIO, help=f"Duration of test audio in seconds; audio will be padded or trimmed to this length (default: {DEFAULT_DUR_TEST_AUDIO})") +parser.add_argument("--output_dir", type=str, default=".", help="Directory to save output audio files (default: current directory)") +parser.add_argument("--system_prompt", type=str, default=None, help="System prompt to send to the model on the first request (overrides server default)") +args = parser.parse_args() + +model_name = "voicechat" +audio_file = args.audio_filename + +NUM_FRAMES_PER_CHUNK = args.num_frames_per_chunk +DUR_TEST_AUDIO = args.dur_test_audio +INPUT_CHUNK_SIZE_SAMPLES = int(16000 * 0.08) * NUM_FRAMES_PER_CHUNK # number of samples per input chunk +NUM_CHUNKS_TEST_AUDIO = int(DUR_TEST_AUDIO / (0.08 * NUM_FRAMES_PER_CHUNK)) +print(f"{NUM_CHUNKS_TEST_AUDIO=}") + +times_spend_on_inference = [] + + +def get_audio_as_chunks(audio_file): + audio_signal, sr = librosa.load(audio_file, sr=16000) + audio_signal = np.expand_dims(audio_signal, axis=0) + + padded_len_samples = int(NUM_CHUNKS_TEST_AUDIO * INPUT_CHUNK_SIZE_SAMPLES) + audio_signal_padded = np.zeros((1, padded_len_samples), dtype=np.float32) + + if padded_len_samples > audio_signal.shape[1]: # actually doing padding + audio_signal_padded[:, : audio_signal.shape[1]] = audio_signal + else: # actually need to trim (because audio is longer than maxlen) + audio_signal_padded = audio_signal[:, :padded_len_samples] + + audio_signal_chunks = [ + audio_signal_padded[:, i : i + INPUT_CHUNK_SIZE_SAMPLES] + for i in range(0, audio_signal_padded.shape[1], INPUT_CHUNK_SIZE_SAMPLES) + ] + + return audio_signal_chunks + + +def send_sequence_end(client, sequence_id): + """Send a final request with sequence_end=True to properly clean up the sequence""" + try: + logger.info(f"Sending sequence_end=True for sequence_id={sequence_id}") + + # Send empty audio chunk with sequence_end=True + empty_audio = np.zeros((1, INPUT_CHUNK_SIZE_SAMPLES), dtype=np.float32) + + inputs = [ + grpcclient.InferInput( + "audio_signal", empty_audio.shape, np_to_triton_dtype(empty_audio.dtype) + ), + ] + inputs[0].set_data_from_numpy(empty_audio) + + outputs = [ + grpcclient.InferRequestedOutput("output_text"), + grpcclient.InferRequestedOutput("output_audio"), + ] + + response = client.infer( + model_name, + inputs, + request_id=str(uuid.uuid4()), + outputs=outputs, + sequence_id=sequence_id, + sequence_start=False, + sequence_end=True, # This is the key - properly end the sequence + ) + logger.info("Sequence ended successfully") + + except Exception as e: + logger.error(f"Error ending sequence: {e}") + +with grpcclient.InferenceServerClient(f"{args.host}:{args.port}") as client: + audio_signal_chunks = get_audio_as_chunks(audio_file) + + generated_text = [] + generated_asr_text = [] + generated_audio = [] + + # Generate a numeric sequence ID instead of string UUID to match UINT64 type + sequence_id = random.randint(1, 2**63 - 1) # Generate random uint64 value + + try: + # If a system prompt is provided, send a separate prefill request first: + # zero-length audio + system_prompt, with sequence_start=True. + prefill_sent = False + if args.system_prompt is not None: + logger.info(f"Sending prefill request with system_prompt ({len(args.system_prompt)} chars)") + empty_audio = np.zeros((1, 0), dtype=np.float32) + prefill_inputs = [ + grpcclient.InferInput( + "audio_signal", empty_audio.shape, np_to_triton_dtype(empty_audio.dtype) + ), + ] + prefill_inputs[0].set_data_from_numpy(empty_audio) + + prompt_np = np.array([args.system_prompt.encode("utf-8")], dtype=object) + prompt_input = grpcclient.InferInput("system_prompt", prompt_np.shape, "BYTES") + prompt_input.set_data_from_numpy(prompt_np) + prefill_inputs.append(prompt_input) + + prefill_outputs = [ + grpcclient.InferRequestedOutput("output_text"), + grpcclient.InferRequestedOutput("output_asr_text"), + grpcclient.InferRequestedOutput("output_audio"), + ] + + prefill_start = time.time() + client.infer( + model_name, + prefill_inputs, + request_id=str(uuid.uuid4()), + outputs=prefill_outputs, + sequence_id=sequence_id, + sequence_start=True, + sequence_end=False, + ) + logger.info(f"Prefill completed in {time.time() - prefill_start:.3f}s") + prefill_sent = True + + for idx, audio_chunk in tqdm(enumerate(audio_signal_chunks)): + inputs = [ + grpcclient.InferInput( + "audio_signal", audio_chunk.shape, np_to_triton_dtype(audio_chunk.dtype) + ), + ] + + inputs[0].set_data_from_numpy(audio_chunk) + + outputs = [ + grpcclient.InferRequestedOutput("output_text"), + grpcclient.InferRequestedOutput("output_asr_text"), + grpcclient.InferRequestedOutput("output_audio"), + ] + + start_time = time.time() + response = client.infer( + model_name, + inputs, + request_id=str(uuid.uuid4()), + outputs=outputs, + sequence_id=sequence_id, + sequence_start=(idx == 0 and not prefill_sent), + sequence_end=idx == len(audio_signal_chunks) - 1, + ) + end_time = time.time() + + result = response.get_response() + output_text = response.as_numpy("output_text") + output_asr_text = response.as_numpy("output_asr_text") + output_audio = response.as_numpy("output_audio") + + generated_text.extend([i.decode("utf-8") for i in output_text]) + generated_asr_text.extend([i.decode("utf-8") for i in output_asr_text]) + + if output_audio.shape[1] > 0: + times_spend_on_inference.append(end_time - start_time) + generated_audio.append(output_audio) + + except KeyboardInterrupt: + logger.info("\nKeyboardInterrupt received! Calling send_sequence_end...") + send_sequence_end(client, sequence_id) + logger.info("Sequence cleanup completed. Exiting...") + sys.exit(0) + + logger.info("Agent text: " + "".join([str(i) for i in generated_text])) + logger.info("ASR text (user's speech): " + "".join([str(i) for i in generated_asr_text])) + generated_audio = np.concatenate(generated_audio, axis=1) + + import os + os.makedirs(args.output_dir, exist_ok=True) + + output_audio_path = os.path.join(args.output_dir, "output_audio.wav") + sf.write(output_audio_path, generated_audio.squeeze(0), 22050) + logger.info(f"Agent audio written to {output_audio_path}") + + # Save audio file with both input and output in each channel + # Resample input to 22050 Hz, and pad shorter file to the same length as the longer one + input_audio, sr = librosa.load(audio_file, sr=22050) + generated_audio_1d = generated_audio.squeeze(0) # Convert from [1, T] to [T] + maxlen = max(input_audio.shape[0], generated_audio_1d.shape[0]) + input_audio = np.pad(input_audio, (0, maxlen - input_audio.shape[0]), mode="constant") + generated_audio_1d = np.pad(generated_audio_1d, (0, maxlen - generated_audio_1d.shape[0]), mode="constant") + both_audio = np.column_stack([input_audio, generated_audio_1d]) # Create stereo: [T, 2] + combined_path = os.path.join(args.output_dir, "input_and_output_combined.wav") + sf.write(combined_path, both_audio, 22050) + logger.info(f"Input and output combined audio written to {combined_path}") + + logger.info(f"Average time spend on inference: {np.mean(times_spend_on_inference)}") + logger.info(f"std of time spend on inference: {np.std(times_spend_on_inference)}") + logger.info(f"Median time spend on inference: {np.median(times_spend_on_inference)}") + logger.info(f"Min time spend on inference: {np.min(times_spend_on_inference)}") + logger.info(f"Max time spend on inference: {np.max(times_spend_on_inference)}") + logger.info(f"All times spend on inference: {[round(i, 4) for i in times_spend_on_inference]}") + logger.info(f"Number of chunks: {len(times_spend_on_inference)}") diff --git a/examples/speechlm2/nemo_inference_pipelines/triton/model_repo_s2s/voicechat/1/infer_streaming.py b/examples/speechlm2/nemo_inference_pipelines/triton/model_repo_s2s/voicechat/1/infer_streaming.py new file mode 100644 index 000000000000..bf59203366ae --- /dev/null +++ b/examples/speechlm2/nemo_inference_pipelines/triton/model_repo_s2s/voicechat/1/infer_streaming.py @@ -0,0 +1,345 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations +from typing import List, Iterable, Tuple +import os +import numpy as np +import torch + +from nemo.collections.asr.inference.streaming.framing.request import Frame +from nemo.collections.speechlm2.inference.factory.s2s_pipeline_builder import S2SPipelineBuilder +from nemo.collections.speechlm2.inference.streaming.framing.s2s_request_options import S2SRequestOptions + +import triton_python_backend_utils as pb_utils + +from omegaconf import OmegaConf +from nemo.utils import logging +import time + + +class TritonPythonModel: + """Triton Python model for streaming S2S generation. + + This model uses the NeMo S2S pipeline to generate speech from speech input. + Every Python model that is created must have "TritonPythonModel" as the class name. + """ + + def _resolve_env_overrides(self, cfg): + """Resolve ??? placeholders in the config from environment variables. + + This allows start_triton.sh to control model paths and settings via + env vars, while sharing the same s2s_streaming.yaml used by the CLI. + + Env var mapping (cfg key -> env var, default): + s2s.model_path -> S2S_MODEL_PATH (required) + s2s.llm_checkpoint_path -> S2S_LLM_CHECKPOINT_PATH (required) + s2s.speaker_reference -> S2S_SPEAKER_REFERENCE (required) + s2s.engine_type -> S2S_ENGINE_TYPE (default: native) + s2s.system_prompt -> S2S_SYSTEM_PROMPT (default: none) + s2s.tts_system_prompt -> S2S_TTS_SYSTEM_PROMPT (default: none) + s2s.use_codec_cache -> S2S_USE_CODEC_CACHE (default: true) + streaming.chunk_size_in_secs -> S2S_CHUNK_SIZE_IN_SECS (default: 0.08) + streaming.buffer_size_in_secs -> S2S_BUFFER_SIZE_IN_SECS (default: 5.6) + """ + env_overrides = { + # Required + "s2s.model_path": ("S2S_MODEL_PATH", None), + "s2s.llm_checkpoint_path": ("S2S_LLM_CHECKPOINT_PATH", None), + "s2s.speaker_reference": ("S2S_SPEAKER_REFERENCE", None), + # Optional (with defaults) + "s2s.engine_type": ("S2S_ENGINE_TYPE", "native"), + "s2s.system_prompt": ("S2S_SYSTEM_PROMPT", None), + "s2s.tts_system_prompt": ("S2S_TTS_SYSTEM_PROMPT", None), + "s2s.use_codec_cache": ("S2S_USE_CODEC_CACHE", True), + "streaming.chunk_size_in_secs": ("S2S_CHUNK_SIZE_IN_SECS", 0.08), + "streaming.buffer_size_in_secs": ("S2S_BUFFER_SIZE_IN_SECS", 5.6), + } + for cfg_key, (env_var, default) in env_overrides.items(): + val = os.environ.get(env_var) + if val is not None: + # Cast to match the default's type (e.g. "0.08" -> float) + if default is not None and isinstance(default, bool): + val = val.lower() in ("true", "1", "yes") + elif default is not None and isinstance(default, float): + val = float(val) + elif default is not None and isinstance(default, int): + val = int(val) + OmegaConf.update(cfg, cfg_key, val, force_add=True) + elif default is not None: + OmegaConf.update(cfg, cfg_key, default, force_add=True) + + def load_model(self, config_path: str): + """Load the S2S pipeline from a YAML config file. + + Args: + config_path: Path to a shared YAML config file (s2s_streaming.yaml). + Fields marked ??? are resolved from environment variables + exported by start_triton.sh. + """ + cfg = OmegaConf.load(config_path) + self._resolve_env_overrides(cfg) + + self.pipeline = S2SPipelineBuilder.build_pipeline(cfg) + self.pipeline.open_session() + + # Compute chunk size in samples from the pipeline's config + self.chunk_size = int(self.pipeline.chunk_size_in_secs * self.pipeline.input_sample_rate) + + # Track text positions to return only incremental updates + self.text_positions = {} # stream_id -> last_text_length + self.asr_text_positions = {} # stream_id -> last_asr_text_length + + def initialize(self, args): + """`initialize` is called only once when the model is being loaded. + Implementing `initialize` function is optional. This function allows + the model to initialize any state associated with this model. + + Parameters + ---------- + args : dict + Both keys and values are strings. The dictionary keys and values are: + * model_config: A JSON string containing the model configuration + * model_instance_kind: A string containing model instance kind + * model_instance_device_id: A string containing model instance device ID + * model_repository: Model repository path + * model_version: Model version + * model_name: Model name + """ + # Config path: set S2S_TRITON_CONFIG_PATH env var (start_triton.sh does this automatically). + config_path = os.environ.get("S2S_TRITON_CONFIG_PATH") + if not config_path: + raise ValueError( + "S2S_TRITON_CONFIG_PATH environment variable is not set. " + "Use start_triton.sh or set it to the path of s2s_streaming.yaml." + ) + logging.info(f"Loading S2S Triton model from config: {config_path}") + self.load_model(config_path) + + # Warm up the inference engine(s) with a throwaway prefill so the + # first real client request doesn't pay one-time initialization cost. + self.pipeline.warmup() + + def finalize(self) -> None: + """Finalize the model.""" + # Close the session, clear state pool, and empty CUDA cache + self.pipeline.close_session() + torch.cuda.empty_cache() + + def validate_and_convert_audio(self, audio_signal: np.ndarray) -> torch.Tensor: + """Validate that the audio chunk matches the expected size and convert to tensor.""" + if audio_signal.ndim == 2: + audio_signal = audio_signal.flatten() + + if len(audio_signal) != self.chunk_size: + expected_frames = self.pipeline.num_frames_per_chunk + actual_secs = len(audio_signal) / self.pipeline.input_sample_rate + raise ValueError( + f"Audio chunk size mismatch: received {len(audio_signal)} samples ({actual_secs:.3f}s) " + f"but server expects {self.chunk_size} samples " + f"({self.pipeline.chunk_size_in_secs}s = {expected_frames} frame(s)). " + f"Make sure the client's num_frames_per_chunk matches the server's " + f"chunk_size_in_secs={self.pipeline.chunk_size_in_secs}." + ) + + return torch.tensor(audio_signal, dtype=torch.float32) + + def triton_requests_to_frames(self, requests: Iterable) -> List[Frame]: + """ + Convert Triton inference requests into streaming audio Frames. + + Extracts audio data and sequence batching controls (START, END, CORRID) + from each Triton request and wraps them in Frame dataclasses for the + streaming S2S pipeline. + + Since max_batch_size=0, processes one request at a time. + + Returns: + List of Frame objects (one per request) + """ + frames = [] + + for request in requests: + # Get audio input + audio_signal = pb_utils.get_input_tensor_by_name(request, "audio_signal").as_numpy() + + # Extract sequence batching metadata from Triton control inputs + # These are automatically populated when client uses sequence_start/end/id + is_first = False + is_last = False + stream_id = 0 + + try: + start_tensor = pb_utils.get_input_tensor_by_name(request, "START") + if start_tensor is not None: + is_first = bool(start_tensor.as_numpy()[0]) + except Exception: + pass + + try: + end_tensor = pb_utils.get_input_tensor_by_name(request, "END") + if end_tensor is not None: + is_last = bool(end_tensor.as_numpy()[0]) + except Exception: + pass + + try: + corrid_tensor = pb_utils.get_input_tensor_by_name(request, "CORRID") + if corrid_tensor is not None: + stream_id = int(corrid_tensor.as_numpy()[0]) + except Exception: + pass + + # Extract optional per-stream system prompt (sent on the first request) + frame_options = None + if is_first: + system_prompt = None + try: + prompt_tensor = pb_utils.get_input_tensor_by_name(request, "system_prompt") + if prompt_tensor is not None: + raw = prompt_tensor.as_numpy()[0] + system_prompt = raw.decode("utf-8") if isinstance(raw, bytes) else str(raw) + except Exception: + pass + if system_prompt is None: + system_prompt = self.pipeline.system_prompt + frame_options = S2SRequestOptions(system_prompt=system_prompt) + + # Zero-length audio = prefill-only frame; pass through without validation + if audio_signal.size == 0: + samples = torch.empty(0, dtype=torch.float32) + else: + samples = self.validate_and_convert_audio(audio_signal) + + frames.append(Frame( + samples=samples, + stream_id=stream_id, + is_first=is_first, + is_last=is_last, + options=frame_options, + )) + + return frames + + def get_generations(self, frames: List[Frame]) -> List[Tuple]: + """ + Generate speech for the requests. + + Uses StreamingS2SPipeline.generate_step() which updates internal state, + then extracts results from per-stream S2SStreamingState objects. + + Zero-length first frames are prefill-only: generate_step handles them + internally and returns early; this method returns empty results for them. + + Returns a list of tuples, where each tuple contains: + - generated audio tensor + - generated text string (incremental, only new text since last response) + - generated ASR text string (incremental, only new ASR text since last response) + """ + _t_generate_step = time.time() + self.pipeline.generate_step(frames) + _t_generate_step_done = time.time() + + _t_extract = time.time() + generations = [] + + for frame in frames: + stream_id = frame.stream_id + + # Prefill-only frames don't produce audio/text output + if frame.is_first and frame.samples.numel() == 0: + generations.append((torch.empty(1, 0), "", "")) + continue + + state = self.pipeline.get_or_create_state(stream_id) + audio = state.audio_buffer + + full_text = state.get_output_text() + full_asr_text = state.get_output_asr_text() + + if stream_id not in self.text_positions: + self.text_positions[stream_id] = 0 + last_position = self.text_positions[stream_id] + incremental_text = full_text[last_position:] + self.text_positions[stream_id] = len(full_text) + + if stream_id not in self.asr_text_positions: + self.asr_text_positions[stream_id] = 0 + last_asr_position = self.asr_text_positions[stream_id] + incremental_asr_text = full_asr_text[last_asr_position:] + self.asr_text_positions[stream_id] = len(full_asr_text) + + generations.append((audio, incremental_text, incremental_asr_text)) + + state.cleanup_after_response() + + if frame.is_last: + self.pipeline.delete_state(stream_id) + if stream_id in self.text_positions: + del self.text_positions[stream_id] + if stream_id in self.asr_text_positions: + del self.asr_text_positions[stream_id] + _t_extract_done = time.time() + + logging.info(f"get_generations breakdown: generate_step={(_t_generate_step_done - _t_generate_step)*1000:.2f}ms, " + f"extract+cleanup={(_t_extract_done - _t_extract)*1000:.2f}ms") + + return generations + + def execute(self, requests: Iterable) -> List[pb_utils.InferenceResponse]: + """Execute the model and return the responses. + + Zero-length audio with ``sequence_start=True`` and a ``system_prompt`` + is treated as a prefill-only request by the pipeline (no fake audio + needed). All other requests are normal audio generation. + + Returns: + - output_audio: float32 array of generated audio samples + - output_text: UTF-8 encoded string of generated text (agent's response) + - output_asr_text: UTF-8 encoded string of ASR text (user's transcribed speech) + """ + start_time = time.time() + + _t_to_frames = time.time() + frames = self.triton_requests_to_frames(requests) + _t_to_frames_done = time.time() + + _t_generations = time.time() + generations = self.get_generations(frames) + _t_generations_done = time.time() + + responses = [] + for audio, text, asr_text in generations: + if isinstance(audio, torch.Tensor): + audio_np = audio.detach().cpu().numpy().astype(np.float32) + if audio_np.ndim == 1: + audio_np = audio_np.reshape(1, -1) + else: + audio_np = np.zeros((1, 0), dtype=np.float32) + + text_np = np.array([text.encode('utf-8')], dtype=object) + asr_text_np = np.array([asr_text.encode('utf-8')], dtype=object) + + responses.append(pb_utils.InferenceResponse(output_tensors=[ + pb_utils.Tensor("output_audio", audio_np), + pb_utils.Tensor("output_text", text_np), + pb_utils.Tensor("output_asr_text", asr_text_np), + ])) + + end_time = time.time() + logging.info(f"TritonPythonModel.execute time: {end_time - start_time:.2f} seconds") + logging.info(f"execute() breakdown: triton_requests_to_frames={(_t_to_frames_done - _t_to_frames)*1000:.2f}ms, " + f"get_generations={(_t_generations_done - _t_generations)*1000:.2f}ms") + + return responses diff --git a/examples/speechlm2/nemo_inference_pipelines/triton/model_repo_s2s/voicechat/config.pbtxt b/examples/speechlm2/nemo_inference_pipelines/triton/model_repo_s2s/voicechat/config.pbtxt new file mode 100644 index 000000000000..113b6d96a006 --- /dev/null +++ b/examples/speechlm2/nemo_inference_pipelines/triton/model_repo_s2s/voicechat/config.pbtxt @@ -0,0 +1,90 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: "voicechat" +default_model_filename: "infer_streaming.py" +backend: "python" +max_batch_size: 0 + +input { + name: "audio_signal" + data_type: TYPE_FP32 + dims: [-1, -1] +} + +input { + name: "system_prompt" + data_type: TYPE_STRING + dims: [-1] + optional: true +} + +output { + name: "output_text" + data_type: TYPE_STRING + dims: [-1] +} + +output { + name: "output_asr_text" + data_type: TYPE_STRING + dims: [-1] +} + +output [ + { + name: "output_audio" + data_type: TYPE_FP32 + dims: [-1, -1] + } +] + +sequence_batching { + max_sequence_idle_microseconds: 30000000 + oldest + { + max_candidate_sequences: 1 + } + control_input [ + { + name: "START" + control [ + { + kind: CONTROL_SEQUENCE_START + fp32_false_true: [ 0, 1 ] + } + ] + }, + { + name: "END" + control [ + { + kind: CONTROL_SEQUENCE_END + fp32_false_true: [ 0, 1 ] + } + ] + }, + { + name: "CORRID" + control [ + { + kind: CONTROL_SEQUENCE_CORRID + data_type: TYPE_UINT64 + } + ] + } + ] +} + +instance_group [{ kind: KIND_GPU, gpus: [0] }] diff --git a/examples/speechlm2/nemo_inference_pipelines/triton/start_triton.sh b/examples/speechlm2/nemo_inference_pipelines/triton/start_triton.sh new file mode 100755 index 000000000000..8f78f1e47b1f --- /dev/null +++ b/examples/speechlm2/nemo_inference_pipelines/triton/start_triton.sh @@ -0,0 +1,88 @@ +#!/bin/bash +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Start Triton Inference Server for S2S voicechat model. +# +# Shares the same s2s_streaming.yaml config used by s2s_streaming_infer.py. +# Fields marked ??? in the YAML are resolved from environment variables below. +# +# Usage: +# S2S_MODEL_PATH=/path/to/eartts_ckpt \ +# S2S_LLM_CHECKPOINT_PATH=/path/to/llm_ckpt \ +# S2S_SPEAKER_REFERENCE=/path/to/speaker.wav \ +# ./start_triton.sh +# +# Environment variables (required): +# S2S_MODEL_PATH - Path to the EarTTS / S2S checkpoint +# S2S_LLM_CHECKPOINT_PATH - Path to the LLM checkpoint +# S2S_SPEAKER_REFERENCE - Path to a speaker reference .wav file +# +# Environment variables (optional): +# S2S_ENGINE_TYPE - Engine type (default: native) +# S2S_SYSTEM_PROMPT - LLM system prompt text (default: none) +# S2S_TTS_SYSTEM_PROMPT - TTS system prompt, (default: none) +# S2S_CHUNK_SIZE_IN_SECS - Chunk size in seconds, multiple of 0.08 (default: 0.08) +# S2S_BUFFER_SIZE_IN_SECS - Audio buffer size in seconds (default: 5.6) +# S2S_USE_CODEC_CACHE - "true"/"false": incremental codec decode (default: true) +# S2S_TRITON_CONFIG_PATH - Override the YAML config file path +# MODEL_REPO_DIR - Override the Triton model repository path + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + +# All variables below are exported so they are visible to the Triton Python +# backend (infer_streaming.py reads them via os.environ). + +# ======================== +# Model paths (required) +# ======================== +export S2S_MODEL_PATH="${S2S_MODEL_PATH:?Please set S2S_MODEL_PATH to the EarTTS / S2S checkpoint path}" +export S2S_LLM_CHECKPOINT_PATH="${S2S_LLM_CHECKPOINT_PATH:?Please set S2S_LLM_CHECKPOINT_PATH to the LLM checkpoint path}" +export S2S_SPEAKER_REFERENCE="${S2S_SPEAKER_REFERENCE:?Please set S2S_SPEAKER_REFERENCE to a speaker reference .wav file}" + +# ======================== +# Optional overrides +# ======================== +export S2S_ENGINE_TYPE="${S2S_ENGINE_TYPE:-native}" +export S2S_SYSTEM_PROMPT="${S2S_SYSTEM_PROMPT:-}" +export S2S_TTS_SYSTEM_PROMPT="${S2S_TTS_SYSTEM_PROMPT:-}" +export S2S_CHUNK_SIZE_IN_SECS="${S2S_CHUNK_SIZE_IN_SECS:-0.08}" +export S2S_BUFFER_SIZE_IN_SECS="${S2S_BUFFER_SIZE_IN_SECS:-5.6}" +export S2S_USE_CODEC_CACHE="${S2S_USE_CODEC_CACHE:-true}" +export S2S_TRITON_CONFIG_PATH="${S2S_TRITON_CONFIG_PATH:-${SCRIPT_DIR}/../conf/s2s_streaming.yaml}" +export MODEL_REPO_DIR="${MODEL_REPO_DIR:-${SCRIPT_DIR}/model_repo_s2s}" + + +echo "=== S2S Triton Server ===" +echo " S2S_MODEL_PATH: ${S2S_MODEL_PATH}" +echo " S2S_LLM_CHECKPOINT_PATH: ${S2S_LLM_CHECKPOINT_PATH}" +echo " S2S_SPEAKER_REFERENCE: ${S2S_SPEAKER_REFERENCE}" +echo " S2S_ENGINE_TYPE: ${S2S_ENGINE_TYPE}" +echo " S2S_CHUNK_SIZE_IN_SECS: ${S2S_CHUNK_SIZE_IN_SECS}" +echo " S2S_BUFFER_SIZE_IN_SECS: ${S2S_BUFFER_SIZE_IN_SECS}" +echo " S2S_USE_CODEC_CACHE: ${S2S_USE_CODEC_CACHE}" +echo " S2S_SYSTEM_PROMPT: ${S2S_SYSTEM_PROMPT:-}" +echo " S2S_TTS_SYSTEM_PROMPT: ${S2S_TTS_SYSTEM_PROMPT:-}" +echo " MODEL_REPO_DIR: ${MODEL_REPO_DIR}" +echo " S2S_TRITON_CONFIG_PATH: ${S2S_TRITON_CONFIG_PATH}" +echo "=========================" + +TRITON_BIN="${TRITON_BIN:-/opt/tritonserver/bin/tritonserver}" +if [ ! -x "${TRITON_BIN}" ]; then + echo "ERROR: Triton server not found at ${TRITON_BIN}" + echo " Are you running inside a Triton container?" + exit 1 +fi + +"${TRITON_BIN}" --model-repository="${MODEL_REPO_DIR}" diff --git a/nemo/collections/speechlm2/inference/__init__.py b/nemo/collections/speechlm2/inference/__init__.py new file mode 100644 index 000000000000..9e3fb699d9f6 --- /dev/null +++ b/nemo/collections/speechlm2/inference/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/speechlm2/inference/factory/__init__.py b/nemo/collections/speechlm2/inference/factory/__init__.py new file mode 100644 index 000000000000..9e3fb699d9f6 --- /dev/null +++ b/nemo/collections/speechlm2/inference/factory/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/speechlm2/inference/factory/s2s_pipeline_builder.py b/nemo/collections/speechlm2/inference/factory/s2s_pipeline_builder.py new file mode 100644 index 000000000000..c943ba8b14d8 --- /dev/null +++ b/nemo/collections/speechlm2/inference/factory/s2s_pipeline_builder.py @@ -0,0 +1,47 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from omegaconf.dictconfig import DictConfig + +from nemo.utils import logging as logger +from nemo.collections.speechlm2.inference.pipelines.streaming_s2s_pipeline import StreamingS2SPipeline +from nemo.collections.speechlm2.inference.model_wrappers.nemotron_voicechat_inference_wrapper import NemotronVoicechatInferenceWrapper + + +class S2SPipelineBuilder: + """Factory that builds a streaming S2S pipeline.""" + + @classmethod + def build_pipeline( + cls, + cfg: DictConfig + ) -> StreamingS2SPipeline: + """ + Build the streaming S2S pipeline based on the config. + Args: + cfg: (DictConfig) Config + Returns: + Returns StreamingS2SPipeline object + """ + s2s_model = NemotronVoicechatInferenceWrapper(model_cfg=cfg.s2s) + + logger.info(f"S2S model `{cfg.s2s.model_path}` loaded") + + pipeline = StreamingS2SPipeline( + cfg, + s2s_model, + ) + logger.info(f"`{type(pipeline).__name__}` pipeline loaded") + return pipeline + diff --git a/nemo/collections/speechlm2/inference/model_wrappers/__init__.py b/nemo/collections/speechlm2/inference/model_wrappers/__init__.py new file mode 100644 index 000000000000..9e3fb699d9f6 --- /dev/null +++ b/nemo/collections/speechlm2/inference/model_wrappers/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/speechlm2/inference/model_wrappers/model_factory.py b/nemo/collections/speechlm2/inference/model_wrappers/model_factory.py new file mode 100644 index 000000000000..505b60a57e56 --- /dev/null +++ b/nemo/collections/speechlm2/inference/model_wrappers/model_factory.py @@ -0,0 +1,1093 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Model Interface for S2S Inference + +This module provides an abstract interface for model inference engines, +allowing seamless swapping between different implementations (e.g., native PyTorch, vLLM) +without modifying the inference code. + +Usage Example: + from nemo.collections.speechlm2.inference.model_wrappers.model_factory import create_model + + # Create interface (automatically wraps existing model) + model_interface = create_model( + model=your_model, + engine_type="native" # or "vllm" + ) + + # Use the interface exactly as you would use self.model() + ans = model_interface(input_embeds, cache=cache) +""" + +from abc import ABC, abstractmethod +from typing import Optional, Dict, Any, Union, Set +import math +import os +import torch +from transformers import DynamicCache +from dataclasses import dataclass + +from nemo.utils import logging + +class ModelInterface(ABC): + """ + Base interface for model inference engines with shared sampling utilities. + + This interface defines the contract that all model implementations must follow, + ensuring consistent behavior across different engine types. It also provides + concrete implementations of sampling methods (top-p, repetition penalty) that + can be shared across all engines. + """ + + def __init__( + self, + special_token_ids: Optional[Set[int]] = None, + top_p: float = 1.0, + repetition_penalty: float = 1.0, + temperature: float = 1.0, + ): + """ + Initialize base interface with sampling parameters. + + Args: + special_token_ids: Set of special token IDs (pad, eos, bos) that should bypass sampling. + These tokens will use greedy decoding and won't be penalized. + top_p: Top-p (nucleus) sampling threshold. 1.0 disables it (greedy). Default: 1.0 + repetition_penalty: Penalty for repeated tokens. 1.0 disables it. Default: 1.0 + temperature: Temperature for sampling. 1.0 = no change, <1.0 = sharper, >1.0 = flatter. + 0.0 = greedy (argmax). Default: 1.0 + """ + if not math.isfinite(temperature): + raise ValueError(f"temperature must be finite, got {temperature}") + if temperature < 0.0: + raise ValueError(f"temperature must be >= 0.0, got {temperature}") + + self.special_token_ids = special_token_ids if special_token_ids is not None else set() + self.top_p = top_p + self.repetition_penalty = repetition_penalty + self.temperature = temperature + + def _sample_text_token( + self, + logits: torch.Tensor, + generated_tokens: torch.Tensor, + current_step: int, + ) -> torch.Tensor: + """ + Sample text tokens with optional top-p sampling and repetition penalty. + Special tokens (pad, eos, bos) bypass sampling - if they have highest probability, return them directly. + + Args: + logits: Logits tensor of shape (B, V) for vocabulary V. + generated_tokens: Previously generated tokens of shape (B, T). + current_step: Current decoding step (used to slice generated_tokens). + + Returns: + Sampled token ids of shape (B,). + """ + B, V = logits.shape + device = logits.device + + # First check greedy tokens (on original logits) + greedy_tokens = logits.argmax(dim=-1) # (B,) + + # If no sampling needed (all disabled), return greedy + if self.top_p >= 1.0 and self.repetition_penalty == 1.0 and (self.temperature == 1.0 or self.temperature == 0.0): + return greedy_tokens + + # temperature=0 means greedy + if self.temperature == 0.0: + return greedy_tokens + + # For each batch, if greedy is special token, use greedy; otherwise sample + sampled_tokens = greedy_tokens.clone() + + for b in range(B): + # If greedy token is a special token, keep it (no sampling) + if greedy_tokens[b].item() in self.special_token_ids: + continue + + # Not a special token - apply repetition penalty and sampling + batch_logits = logits[b].clone() # (V,) + + # Apply repetition penalty + if self.repetition_penalty != 1.0 and current_step > 0: + prev_tokens = generated_tokens[b, :current_step] + unique_prev = prev_tokens.unique() + # Exclude special tokens from penalty + if self.special_token_ids: + # Use unique_prev.device to ensure tensors are on the same device + # (generated_tokens may be on a different device than logits, e.g., vLLM returns CPU logits) + special_tensor = torch.tensor(list(self.special_token_ids), device=unique_prev.device) + mask = ~torch.isin(unique_prev, special_tensor) + unique_prev = unique_prev[mask] + + for token_id in unique_prev: + token_id = token_id.item() + if batch_logits[token_id] > 0: + batch_logits[token_id] = batch_logits[token_id] / self.repetition_penalty + else: + batch_logits[token_id] = batch_logits[token_id] * self.repetition_penalty + + # Apply temperature scaling + if self.temperature != 1.0: + batch_logits = batch_logits / self.temperature + + # Apply top-p sampling + if self.top_p < 1.0: + sorted_logits, sorted_indices = torch.sort(batch_logits, descending=True) + sorted_probs = torch.softmax(sorted_logits, dim=-1) + cumulative_probs = torch.cumsum(sorted_probs, dim=-1) + + # Remove tokens with cumulative prob > top_p, keeping at least one + sorted_indices_to_remove = cumulative_probs > self.top_p + # Shift to keep the first token that exceeds threshold + sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone() + sorted_indices_to_remove[0] = False + + # Set to -inf + indices_to_remove = sorted_indices[sorted_indices_to_remove] + batch_logits[indices_to_remove] = float('-inf') + + # Sample from the filtered distribution + probs = torch.softmax(batch_logits, dim=-1) + sampled_tokens[b] = torch.multinomial(probs, num_samples=1).item() + + return sampled_tokens + + @abstractmethod + def __call__( + self, + input_embeds: torch.Tensor, + cache: Optional[Any] = None, + **kwargs + ) -> Dict[str, Any]: + """ + Perform model inference. + + Args: + input_embeds: Input embeddings tensor of shape [batch, seq_len, hidden_dim] + cache: Optional cache object (e.g., DynamicCache for transformers) + **kwargs: Additional model-specific arguments + + Returns: + Dictionary containing: + - 'text_logits': Logits for text generation [batch, seq_len, vocab_size] + - 'cache': Updated cache object (if cache was provided) + - Additional model-specific outputs + """ + pass + + @abstractmethod + def to(self, device_or_dtype: Union[torch.device, torch.dtype]) -> 'ModelInterface': + """Move model to specified device or convert to specified dtype.""" + pass + + @abstractmethod + def eval(self) -> 'ModelInterface': + """Set model to evaluation mode.""" + pass + + @property + @abstractmethod + def device(self) -> torch.device: + """Get the device of the model.""" + pass + + +class VllmLLMModel(ModelInterface): + """ + vLLM-based model interface using LLMStreamingEngine. + + + This wraps the LLMStreamingEngine to provide async streaming inference + while conforming to the ModelInterface contract. Supports multiple concurrent + requests sharing a single engine instance. + + model = VllmLLMModel(...) + + async def process_stream(embeds, stream_id): + # Use the async engine directly + result = await model._async_inference(embeds, f"stream_{stream_id}", seq_len) + return result + + # Run multiple streams concurrently in same event loop + async def main(): + results = await asyncio.gather( + process_stream(embeds1, 1), + process_stream(embeds2, 2), + process_stream(embeds3, 3) + ) + + asyncio.run(main()) + """ + + def __init__( + self, + model_path: str, + max_model_len: int = 1024, + gpu_memory_utilization: float = 0.8, + trust_remote_code: bool = True, + dtype: str = "bfloat16", + engine_path: Optional[str] = None, + pretrained_llm: Optional[str] = None, + special_token_ids: Optional[Set[int]] = None, + top_p: float = 1.0, + repetition_penalty: float = 1.0, + temperature: float = 1.0, + model_type: str = "llm", + **sampling_kwargs + ): + """ + Initialize vLLM model interface with LLMStreamingEngine. + + Args: + model_path: Path to the vLLM-compatible model checkpoint + max_model_len: Maximum sequence length + gpu_memory_utilization: GPU memory utilization ratio (0.0-1.0) + trust_remote_code: Whether to trust remote code in model + dtype: Data type for embeddings (e.g., "bfloat16", "float16") + engine_path: Optional path to pre-converted vLLM model + pretrained_llm: Optional path to pretrained LLM for conversion + special_token_ids: Set of special token IDs (for potential post-processing) + top_p: Top-p sampling (currently vLLM uses greedy decoding) + repetition_penalty: Repetition penalty (currently not used by vLLM engine) + temperature: Temperature for sampling. Applied in _sample_text_token, not in vLLM engine. + model_type: Type of model for vLLM engine ("llm", "chatglm", etc.) + **sampling_kwargs: Additional sampling parameters passed to vLLM engine. + By default, vLLM uses greedy decoding (temperature=0) + """ + # Initialize base class with sampling parameters + super().__init__( + special_token_ids=special_token_ids, + top_p=top_p, + repetition_penalty=repetition_penalty, + temperature=temperature, + ) + + import asyncio + from nemo.collections.speechlm2.inference.vllm.streaming_llm_engine import LLMStreamingEngine + + self.model_path = model_path + self.pretrained_llm = pretrained_llm + self._dtype = dtype + self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Force greedy decoding in vLLM by setting temperature=0 if not specified + if 'temperature' not in sampling_kwargs: + sampling_kwargs['temperature'] = 0.0 + + if engine_path is None: + # convert model to vLLM format if needed + dir_name = os.path.basename(os.path.normpath(model_path)) + engine_path = "/tmp/" + dir_name + f"_vllm_converted_{model_type}" + if os.path.exists(engine_path): + logging.info(f"Found existing vLLM converted model at {engine_path}") + else: + self._convert_ckpt( + save_path=engine_path + ) + + from nemo.collections.speechlm2.inference.vllm.streaming_llm_engine import create_engine + # Initialize the streaming engine + self.engine = create_engine( + engine_type=model_type, + model_path=engine_path, + max_model_len=max_model_len, + gpu_memory_utilization=gpu_memory_utilization, + trust_remote_code=trust_remote_code, + dtype=dtype, + **sampling_kwargs + ) + # Track request counter + self._request_counter = 0 + + # Get or create event loop + try: + self._loop = asyncio.get_event_loop() + except RuntimeError: + self._loop = asyncio.new_event_loop() + asyncio.set_event_loop(self._loop) + + # Initialize engine immediately to avoid first-call latency + logging.info("Initializing vLLM engine (this may take a moment)...") + self._loop.run_until_complete(self.engine.initialize()) + + if self.engine.engine.tokenizer is not None and not self.special_token_ids: + self.special_token_ids = self._get_special_token_ids_from_vllm_tokenizer(self.engine.engine.tokenizer) + + logging.debug(f"Special token IDs: {self.special_token_ids}") + logging.info("vLLM engine ready!") + + @staticmethod + def _get_special_token_ids_from_vllm_tokenizer(tokenizer) -> Set[int]: + """ + Extract special token IDs from a vLLM tokenizer. + Looks for: '' (bos), '' (eos), '' (pad). + + Args: + tokenizer: A vLLM CachedTokenizer instance. + + Returns: + Set of special token IDs. + """ + special_ids = set() + for token in ('', '', ''): + try: + tid = tokenizer.convert_tokens_to_ids(token) + if isinstance(tid, int): + special_ids.add(tid) + except Exception: + pass + return special_ids + + def _convert_ckpt(self, save_path: str): + """Convert existing checkpoint to vLLM format and save.""" + from nemo.collections.speechlm2.inference.vllm.scripts.convert_nemotronllm_checkpoint import convert_nemo_to_hf_format + + convert_nemo_to_hf_format( + checkpoint_path=self.model_path, + output_dir=save_path, + pretrained_llm=self.pretrained_llm, + dtype=self._dtype + ) + logging.info(f"Converted model saved to {save_path}") + + def _generate_request_id(self) -> str: + """Generate a unique request ID.""" + self._request_counter += 1 + return f"vllm_request_{self._request_counter}" + + def __call__( + self, + input_embeds: torch.Tensor, + request_id: Optional[str] = "request_id_1", + **kwargs + ) -> Dict[str, Any]: + """ + Perform inference using vLLM streaming engine. + + Args: + inputs: + cache: Optional cache object (currently not used for streaming) + generated_tokens: Optional tensor of generated tokens + current_step: Current decoding step + request_id: Unique request identifier for this generation + **kwargs: Additional model-specific arguments + + Returns: + Dictionary containing: + - predicted_token: Last generated text token + - asr_predicted_token: Last generated ASR token + - cache: None (vLLM manages cache internally) + - is_finished: Whether generation is complete + - request_id: The request identifier + """ + # Run async inference + result = self._loop.run_until_complete( + self._async_inference(input_embeds, request_id, **kwargs) + ) + return result + + async def _async_inference( + self, + inputs: Union[torch.Tensor, list[torch.Tensor]], + request_id: str, + **kwargs + ) -> Dict[str, Any]: + """ + Async inference using the streaming engine. + + Args: + input_embeds: Input embeddings [batch, seq_len, hidden_dim] + request_id: Unique request identifier + seq_len: Number of decoding steps to perform + + Returns: + Dictionary with text_logits and other outputs + """ + # Check request status and restart if needed + from nemo.collections.speechlm2.inference.vllm.streaming_llm_engine import StreamStatus + + if request_id not in self.engine.requests: + await self.engine.start_generation(request_id=request_id) + else: + # Check if request is finished and needs restart + request_state = self.engine.requests[request_id] + if request_state.status in (StreamStatus.FINISHED, StreamStatus.ABORTED): + logging.warning( + f"Request {request_id} was {request_state.status.value}. " + f"Generated {len(request_state.generated_tokens)} tokens before stopping. " + "Cleaning up and restarting..." + ) + # Try to abort cleanly first + try: + await self.engine.abort_generation(request_id) + except Exception: + pass + # Start fresh + await self.engine.start_generation(request_id=request_id) + + # Process embeddings to generate tokens + return await self._process_inputs_to_outputs(inputs, request_id, **kwargs) + + async def _process_inputs_to_outputs( + self, + input_embeds: torch.Tensor, + request_id: str, + decode_steps: int = 1, + prompt_token_ids: Optional[list] = None, + generated_tokens: Optional[torch.Tensor] = None, + current_step: int = 0 + ) -> Dict[str, Any]: + """ + Process embeddings sequentially to generate text and ASR tokens. + + Args: + input_embeds: Input embeddings [batch, seq_len, hidden_dim] + request_id: Request identifier + decode_steps: Number of decoding steps to perform; decode steps = 0 means prefill + prompt_token_ids: Optional list of prompt token IDs for prefill + generated_tokens: Previously generated tokens [batch, num_generated]. + Required for repetition_penalty. If None, creates empty tensor. + current_step: Current decoding step. Used for repetition penalty. + """ + + if decode_steps == 0: + # prefill only, no token generation + input_embeds = input_embeds.flatten(0, 1) # [seq_len, hidden_dim] + result = await self.engine.generate_next_token([input_embeds], + prompt_token_ids, + request_id=request_id) + return True if result is not None else False + + # Process each embedding in sequence + text_token_ids = [] + asr_token_ids = [] + result = None + for i in range(decode_steps): + # Extract single embedding [1, hidden_dim] + single_embed = input_embeds[:, i:i+1, :].squeeze(1) # [batch, hidden_dim] + + # Generate next token + result = await self.engine.generate_next_token([single_embed], request_id=request_id) + if result is None: + # No token generated (finished or error) + break + + text_token_ids.append(result.token_id) + asr_token_ids.append(result.custom_outputs["asr_tokens"]) # Assuming custom_outputs contains asr tokens + + if result.is_finished: + break + + assert len(text_token_ids) <= decode_steps, "Generated more tokens than input embeddings" + # Handle case when no tokens were generated + is_finished = False + if text_token_ids: + is_finished = len(text_token_ids) < decode_steps or (result and result.is_finished) + + text_logits = result.custom_outputs["text_logits"] if result else None + + predicted_token = text_token_ids[-1] + if self.top_p < 1.0 or self.repetition_penalty != 1.0 or (self.temperature != 1.0 and self.temperature != 0.0): + # Use provided generated_tokens or create empty tensor + batch_size = text_logits.shape[0] + if generated_tokens is None: + gen_tokens = torch.empty(batch_size, 0, device=text_logits.device, dtype=torch.long) + else: + gen_tokens = generated_tokens + + # Apply sampling with top-p and repetition penalty + predicted_token = self._sample_text_token( + logits=text_logits, + generated_tokens=gen_tokens, + current_step=current_step, + ) + + ans = { + "predicted_token": predicted_token, + "asr_predicted_token": asr_token_ids[-1], + "cache": None, # vLLM manages cache internally + "is_finished": is_finished, + "request_id": request_id + } + if result and result.custom_outputs and "function_tokens" in result.custom_outputs: + ans["function_predicted_token"] = result.custom_outputs["function_tokens"] + return ans + + + def to(self, device_or_dtype: Union[torch.device, torch.dtype]) -> 'VLLMModel': + """ + Move model to specified device or convert to specified dtype. + + Note: vLLM manages device placement internally, this is for compatibility. + """ + if isinstance(device_or_dtype, torch.device): + self._device = device_or_dtype + elif isinstance(device_or_dtype, torch.dtype): + # dtype conversion not directly supported, update config + pass + return self + + def eval(self) -> 'VLLMModel': + """Set model to evaluation mode (vLLM is always in eval mode).""" + return self + + @property + def device(self) -> torch.device: + """Get the device of the model.""" + return self._device + + def abort_request(self, request_id: str) -> bool: + """ + Abort a specific generation request. + + Args: + request_id: Request identifier to abort + + Returns: + bool: True if abort was successful + """ + return self._loop.run_until_complete( + self.engine.abort_generation(request_id) + ) + + def restart_request(self, request_id: str) -> bool: + """ + Restart a finished or aborted generation request. + + Args: + request_id: Request identifier to restart + + Returns: + bool: True if restart was successful + """ + # First abort if active + if request_id in self.engine.requests: + self.abort_request(request_id) + + # Start new generation + return self._loop.run_until_complete( + self.engine.start_generation(request_id=request_id) + ) + + def get_request_status(self, request_id: Optional[str] = None) -> Dict[str, Any]: + """ + Get status of a specific request or all requests. + + Args: + request_id: Optional request ID. If None, returns all requests. + + Returns: + Status dictionary + """ + return self.engine.get_status(request_id) + + def shutdown(self): + """Shutdown the vLLM engine and cleanup resources.""" + self._loop.run_until_complete(self.engine.shutdown()) + + def __del__(self): + """Cleanup on deletion.""" + try: + self.shutdown() + except Exception: + pass + +@dataclass +class TTSGenerationResult: + codes: torch.Tensor # Generated acoustic tokens + past_key_values: Optional[Any] # Updated cache (if applicable) + + def __getitem__(self, item: str | int): + """Allows for accessing attributes by key or index.""" + if isinstance(item, str): + return getattr(self, item) + else: + # Access fields in the order they are defined in the dataclass + return getattr(self, fields(self)[item].name) + + +class VllmEARTTSModel(VllmLLMModel): + """ + vLLM-based model interface specialized for EARTTS models. + + Inherits from VllmLLMModel and sets EARTTS-specific configurations. + """ + + def __init__(self, **kwargs): + """ + Initialize vLLM EARTTS model interface. + + Args: + **kwargs: Arguments passed to the VllmLLMModel constructor + """ + super().__init__(**kwargs) + logging.info("VllmEARTTSModel initialized with EARTTS-specific settings.") + + def _convert_ckpt(self, save_path: str): + """Convert EARTTS checkpoint to vLLM format.""" + from nemo.collections.speechlm2.inference.vllm.scripts.convert_eartts_checkpoint import convert + ckpt_dir = os.path.normpath(self.model_path) + config_file = os.path.join(ckpt_dir, "config.json") + model_ckpt = os.path.join(ckpt_dir, "model.safetensors") + convert(save_path, config_file, model_ckpt) + + def __call__( + self, + inputs: Optional[Dict[str, torch.Tensor]] = None, + request_id: Optional[str] = None, + prompt_token_ids: Optional[list] = None, + **kwargs + ) -> TTSGenerationResult: + """ + Perform TTS inference using vLLM streaming engine. + + Supports two calling conventions: + 1. model(inputs_dict, request_id="id") - pass dict as first positional arg + 2. model(**inputs_dict) - unpack dict as keyword arguments + + Args: + inputs: Optional dict of model inputs (if None, uses **kwargs) + request_id: Optional request identifier + **kwargs: Model inputs as keyword arguments (used if inputs is None): + - code: prev_audio_tokens + - context_hidden_state: context_hidden_state (must be None) + - subword_ids: current_subword_id + - subword_mask: current_subword_mask + - past_key_values: past_key_values + - use_cache: True + - guidance_enabled: guidance_enabled + - generation_config: generation_config + - ignore_eos_flag_stop: ignore_eos_flag_stop + + Returns: + TTSGenerationResult containing generated acoustic tokens and cache + """ + # Handle both calling conventions + if inputs is not None: + # Called as model(inputs_dict, request_id="id") + input_dict = inputs + else: + # Called as model(**inputs_dict) + # Extract request_id from kwargs if present + if request_id is None: + request_id = kwargs.pop('request_id', None) + input_dict = kwargs + + # Use default request_id if still None + if request_id is None: + request_id = 'tts_request_id_1' + + # Run async inference + result = self._loop.run_until_complete( + self._async_inference(input_dict, request_id, prompt_token_ids=prompt_token_ids) + ) + + return result + + async def _process_inputs_to_outputs( + self, + inputs: Dict[str, torch.Tensor], + request_id: str, + prompt_token_ids: Optional[list] = None, + ) -> Dict[str, Any]: + """ + Process embeddings sequentially to generate text and ASR tokens. + + Args: + inputs = { + "code": prev_audio_tokens, + "context_hidden_state": context_hidden_state, + "subword_ids": current_subword_id, + "subword_mask": current_subword_mask, + "past_key_values": past_key_values, + "use_cache": True, + "guidance_enabled": guidance_enabled, + "generation_config": generation_config, + "ignore_eos_flag_stop": ignore_eos_flag_stop, + } + Returns: + step_acoustic_tokens: Generated acoustic tokens for the current step + cache: None (vLLM manages cache internally) + """ + + assert inputs["context_hidden_state"] is None, "EARTTS vllm model does not support context_hidden_state input" + + codes = inputs["code"].squeeze(0) # T x 31 + if codes.shape[0] > 1: + # in prefill stage, we needto shift acoustic tokens for vllm, + # replicating the NeMo logic from here: + # https://github.com/erastorgueva-nv/NeMo/blob/duplex-realtime-inference/nemo/collections/speechlm2/modules/ear_tts_model.py#L1357 + codes = torch.nn.functional.pad(codes[:-1], [0, 0, 1, 0]) + input_tensors = [ + codes, + inputs["subword_ids"].squeeze(0), + inputs["subword_mask"].squeeze(0), + ] + + if "non_prompt_mask" in inputs: + # Apply edge detection to match native model's BOS placement logic: + # BOS should only be applied at the FIRST position where non_prompt_mask is True + non_prompt_mask = inputs["non_prompt_mask"].squeeze(0) # T + # Compute edge: positions where mask is True AND previous position is False + padded_prev = torch.nn.functional.pad(non_prompt_mask[:-1], [1, 0], value=False) + bos_mask = (non_prompt_mask & (~padded_prev)).to(dtype=getattr(torch, self._dtype)) + input_tensors.append(bos_mask) + + + else: + current_subword_id = input_tensors[1] + # Use a tiny epsilon instead of exact 0 so the vLLM model's + # (bos_mask == 0) check is False during decoding. This prevents + # use_audio_prompt_frozen_projection from incorrectly applying the + # speaker-prompt projection to every decoding step. The epsilon is + # small enough that bos_mask * bos_emb remains negligible. + bos_mask = torch.full_like(current_subword_id, 1e-20, dtype=getattr(torch, self._dtype)) + input_tensors.append(bos_mask) + + result = await self.engine.generate_next_token(input_tensors, prompt_token_ids=prompt_token_ids, request_id=request_id) + acoustic_tokens = result.custom_outputs["acoustic_tokens"] # T x 31 + step_acoustic_tokens = acoustic_tokens[-1:] # 1 x 31 + return TTSGenerationResult( + codes=step_acoustic_tokens.unsqueeze(0).cuda(), # Add batch dim back: 1 x 1 x 31 + past_key_values=None # vLLM manages cache internally + ) + +class NativeModel(ModelInterface): + """ + Native PyTorch model interface. + + This wraps the existing DuplexS2SExternalSpeechDecoderModel to conform + to the ModelInterface contract. Supports top-k, top-p sampling and repetition penalty. + """ + + def __init__( + self, + model, + special_token_ids: Optional[Set[int]] = None, + top_p: float = 1.0, + repetition_penalty: float = 1.0, + temperature: float = 1.0, + ): + """ + Initialize with an existing model. + + Args: + model: The DuplexS2SExternalSpeechDecoderModel instance + special_token_ids: Set of special token IDs (pad, eos, bos) that should bypass sampling. + These tokens will use greedy decoding and won't be penalized. + If None, will try to extract from model.tokenizer for tokens: + '' (bos), '' (eos), '' (pad). + You can also manually provide: {tokenizer.pad_token_id, tokenizer.eos_token_id, tokenizer.bos_token_id} + top_p: Top-p (nucleus) sampling threshold. 1.0 disables it (greedy). Default: 1.0 + repetition_penalty: Penalty for repeated tokens. 1.0 disables it. Default: 1.0 + Recommended value when enabling: 1.2 + temperature: Temperature for sampling. 1.0 = no change, <1.0 = sharper, >1.0 = flatter. + 0.0 = greedy (argmax). Default: 1.0 + """ + # Default special token IDs: bos=1, eos=2, pad=12 + DEFAULT_SPECIAL_TOKEN_IDS = {1, 2, 12} + + # Try to extract special token IDs from model if not provided + if special_token_ids is None: + special_token_ids = self._extract_special_token_ids_from_nemo(model) + # Fallback to default if extraction failed + if not special_token_ids: + special_token_ids = DEFAULT_SPECIAL_TOKEN_IDS + # Initialize base class with sampling parameters + super().__init__( + special_token_ids=special_token_ids, + top_p=top_p, + repetition_penalty=repetition_penalty, + temperature=temperature, + ) + + self.model = model + + logging.debug(f"Special token IDs: {self.special_token_ids}") + + # Validate: if sampling is enabled, special_token_ids should be set + sampling_active = top_p < 1.0 or repetition_penalty != 1.0 or (temperature != 1.0 and temperature != 0.0) + if sampling_active and not self.special_token_ids: + import warnings + warnings.warn( + "Sampling is enabled but special_token_ids is empty. " + "Could not auto-extract from model.tokenizer. " + "Please provide special_token_ids manually to ensure special tokens use greedy decoding. " + "Otherwise, EOS tokens may be randomly sampled and generation may not stop properly!" + ) + + def __call__( + self, + input_embeds: torch.Tensor, + cache: Optional[Any] = None, + generated_tokens: Optional[torch.Tensor] = None, + current_step: int = 0, + **kwargs + ) -> Dict[str, Any]: + """ + Perform inference using the native model. + + Args: + input_embeds: Input embeddings [batch, seq_len, hidden_dim] + cache: Optional DynamicCache for transformers + generated_tokens: Previously generated tokens [batch, num_generated]. + Required for repetition_penalty. If None, creates empty tensor. + current_step: Current decoding step. Used for repetition penalty. + **kwargs: Additional arguments passed to the model + + Returns: + Dictionary with 'predicted_token', 'asr_predicted_token', and 'cache' + """ + # Call the underlying model + result = self.model.stt_model(input_embeds, cache=cache, **kwargs) + + # Ensure consistent return format + if not isinstance(result, dict): + raise TypeError(f"Model returned {type(result)}, expected dict") + + if 'text_logits' not in result: + raise KeyError("Model output must contain 'text_logits' key") + + text_logits = result["text_logits"][:, -1] # [batch, vocab_size] + batch_size = text_logits.shape[0] + + # Use provided generated_tokens or create empty tensor + if generated_tokens is None: + gen_tokens = torch.empty(batch_size, 0, device=text_logits.device, dtype=torch.long) + else: + gen_tokens = generated_tokens + + # Apply sampling with top-p and repetition penalty + predicted_token = self._sample_text_token( + logits=text_logits, + generated_tokens=gen_tokens, + current_step=current_step, + ) + + # ASR tokens use greedy decoding (no sampling) + asr_predicted_token = result["asr_logits"][:, -1].argmax(dim=-1) + + ans = { + "predicted_token": predicted_token, + "asr_predicted_token": asr_predicted_token, + "cache": result.get("cache", None), + } + if "function_logits" in result: + ans["function_predicted_token"] = result["function_logits"][:, -1].argmax(dim=-1) + return ans + + @staticmethod + def _extract_special_token_ids_from_nemo(model) -> Set[int]: + """ + Extract special token IDs from NeMo model's tokenizer. + + NeMo tokenizer uses bos_token, eos_token, pad_token (not bos_token_id). + Then converts token strings to IDs using token_to_id method. + + Args: + model: The DuplexS2SExternalSpeechDecoderModel instance + + Returns: + Set of special token IDs, or empty set if extraction fails + """ + special_ids = set() + try: + tokenizer = model.stt_model.tokenizer + + # Get token strings (NeMo uses bos_token, not bos_token_id) + bos_token = getattr(tokenizer, 'bos_token', None) + eos_token = getattr(tokenizer, 'eos_token', None) + pad_token = getattr(tokenizer, 'pad_token', None) + + # Convert token strings to IDs + if hasattr(tokenizer, 'token_to_id'): + for token in [bos_token, eos_token, pad_token]: + if token is not None: + tid = tokenizer.token_to_id(token) + if tid is not None and isinstance(tid, int): + special_ids.add(tid) + except Exception as e: + pass # Return empty set on failure + + return special_ids + + def to(self, device_or_dtype: Union[torch.device, torch.dtype]) -> 'NativeModelInterface': + """Move underlying model to device or convert dtype.""" + self.model = self.model.to(device_or_dtype) + return self + + def eval(self) -> 'NativeModelInterface': + """Set underlying model to eval mode.""" + self.model.eval() + return self + + @property + def device(self) -> torch.device: + """Get device of the underlying model.""" + # Try to get device from model parameters + try: + return next(self.model.parameters()).device + except StopIteration: + # No parameters, return CPU + return torch.device('cpu') + + def __getattr__(self, name: str): + """ + Delegate attribute access to the underlying model. + + This allows transparent access to model attributes like + perception, tokenizer, etc.ß + """ + # Avoid infinite recursion for special attributes + if name in ('model', '__dict__', '__class__'): + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") + + # Delegate to wrapped model + return getattr(self.model, name) + + +def create_model( + model=None, + engine_type: str = "native", + vllm_config: Optional[Dict[str, Any]] = None, + special_token_ids: Optional[Set[int]] = None, + top_p: float = 1.0, + repetition_penalty: float = 1.0, + temperature: float = 1.0, + **kwargs +) -> ModelInterface: + """ + Factory function to create appropriate model interface. + + This is the main entry point for creating model interfaces. + + Args: + model: The base model to wrap (required for "native" engine, optional for "vllm") + engine_type: Type of engine ("native", "vllm") + vllm_config: Configuration dict for vLLM engines (required for "vllm") + special_token_ids: Set of special token IDs (pad, eos, bos) that should bypass sampling. + If None (default), will auto-extract from model.tokenizer for tokens: + '' (bos), '' (eos), '' (pad). + You can manually provide: {tokenizer.pad_token_id, tokenizer.eos_token_id, tokenizer.bos_token_id} + top_p: Top-p (nucleus) sampling threshold. 1.0 disables it (greedy). Default: 1.0 + repetition_penalty: Penalty for repeated tokens. 1.0 disables it. Default: 1.0 + temperature: Temperature for sampling. 1.0 = no change, 0.0 = greedy. Default: 1.0 + **kwargs: Additional arguments passed to the interface constructor + + Returns: + ModelInterface instance + + Example: + >>> # Use native PyTorch model with greedy decoding (default) + >>> interface = create_model(model, engine_type="native") + >>> + >>> # Use native with top-p sampling (special_token_ids auto-extracted from model.tokenizer) + >>> # Auto-extracts IDs for: '', '', '' + >>> interface = create_model( + >>> model, + >>> engine_type="native", + >>> top_p=0.9 + >>> ) + >>> + >>> # Use native with top-p and repetition penalty (auto-extract special tokens) + >>> interface = create_model( + >>> model, + >>> engine_type="native", + >>> top_p=0.9, + >>> repetition_penalty=1.2 + >>> ) + >>> + >>> # Manually provide special_token_ids (if auto-extraction fails or you want custom tokens) + >>> special_ids = { + >>> tokenizer.pad_token_id, + >>> tokenizer.eos_token_id, + >>> tokenizer.bos_token_id + >>> } + >>> interface = create_model( + >>> model, + >>> engine_type="native", + >>> special_token_ids=special_ids, + >>> top_p=0.9, + >>> repetition_penalty=1.2 + >>> ) + >>> + >>> # Use vLLM with streaming engine + >>> vllm_cfg = { + >>> "model_path": "/path/to/vllm/checkpoint", + >>> "max_model_len": 10240, + >>> "gpu_memory_utilization": 0.8, + >>> "dtype": "bfloat16" + >>> } + >>> interface = create_model( + >>> engine_type="vllm", + >>> vllm_config=vllm_cfg + >>> ) + >>> + >>> # Perform inference + >>> result = interface(input_embeds, cache=cache) + >>> + >>> # For repetition penalty, pass generated_tokens and current_step + >>> result = interface(input_embeds, cache=cache, generated_tokens=prev_tokens, current_step=step) + """ + engine_type = engine_type.lower() + + if engine_type == "native": + if model is None: + raise ValueError("model must be provided for native engine") + return NativeModel( + model=model, + special_token_ids=special_token_ids, + top_p=top_p, + repetition_penalty=repetition_penalty, + temperature=temperature, + ) + + elif engine_type == "vllm_eartts": + if vllm_config is None: + raise ValueError("vllm_config must be provided for vLLM EARTTS engine") + # VllmEARTTSModel for TTS inference + return VllmEARTTSModel( + **vllm_config, + model_type="eartts", + special_token_ids=special_token_ids, + top_p=top_p, + repetition_penalty=repetition_penalty, + temperature=temperature, + **kwargs + ) + + elif engine_type.startswith("vllm"): + if vllm_config is None: + raise ValueError("vllm_config must be provided for vLLM engine") + # VllmLLMModel doesn't need the PyTorch model, only the config + return VllmLLMModel( + **vllm_config, + model_type="llm", + special_token_ids=special_token_ids, + top_p=top_p, + repetition_penalty=repetition_penalty, + temperature=temperature, + **kwargs + ) + + else: + raise ValueError( + f"Unknown engine_type: {engine_type}. " + f"Supported types: 'native', 'vllm', 'vllm_llm', 'vllm_eartts'" + ) diff --git a/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py b/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py new file mode 100755 index 000000000000..6aabcaa41a15 --- /dev/null +++ b/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py @@ -0,0 +1,1947 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import yaml +from omegaconf import OmegaConf, DictConfig +import numpy as np +import librosa +import time +from transformers import DynamicCache +import re +import os +import sys +import argparse +import math +import torchaudio +import functools +from dataclasses import dataclass +from typing import Optional, Tuple +from nemo.utils import logging +from jiwer import wer + +import gc +import types + + +# Set environment variables (use existing env vars if set, otherwise use defaults) +_default_cache = "/tmp/cache" +os.environ.setdefault("HF_HOME", _default_cache) +os.environ.setdefault("TORCH_HOME", _default_cache) +os.environ.setdefault("NEMO_CACHE_DIR", _default_cache) +os.environ.setdefault("NEMO_NLP_TMP", os.path.join(_default_cache, "nemo_nlp_tmp")) + +from nemo.collections.speechlm2.models.nemotron_voicechat import NemotronVoiceChat + +from nemo.collections.speechlm2.models.duplex_s2s_model import tokens_to_str +from nemo.collections.speechlm2.parts.precision import fp32_precision +from nemo.collections.audio.parts.utils.transforms import resample +from nemo.collections.speechlm2.inference.model_wrappers.model_factory import create_model +from nemo.collections.speechlm2.inference.model_wrappers.perception_cache import ( + PerceptionCacheState, + PerceptionCacheManager, +) +from nemo.collections.speechlm2.inference.utils.pipeline_utils import clean_pred_text + + +def tokens_to_str_raw(tokens: torch.Tensor, lengths: torch.Tensor, tokenizer, pad_id: int) -> list: + """ + Convert token IDs to text strings, preserving ALL special tokens including (pad token). + + Unlike tokens_to_str, this function uses ids_to_tokens which preserves special tokens, + and does NOT filter out any tokens (including pad tokens like ). + + Args: + tokens: Token IDs tensor (B, T) + lengths: Length of each sequence (B,) + tokenizer: Tokenizer for decoding + pad_id: Pad token ID (not used for filtering in raw mode, kept for API compatibility) + + Returns: + List of decoded text strings with ALL special tokens preserved (including ) + """ + ans = [] + for hyp_ids, hyp_len in zip(tokens.cpu(), lengths.cpu()): + hyp_ids = hyp_ids[:hyp_len] + # Do NOT filter out any tokens - keep everything including pad tokens () + hyp_ids_list = hyp_ids.tolist() + + # Use ids_to_tokens which preserves special tokens like + toks = tokenizer.ids_to_tokens(hyp_ids_list) + + # Only replace 'Ġ' with space for proper word boundaries, keep all special tokens + toks = [tok.replace('Ġ', ' ') for tok in toks] + + ans.append("".join(toks)) + return ans + + + +# --- Configuration --- +DEFAULT_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +# --- Streaming Parameters --- +SAMPLE_RATE = 16000 +FRAME_SIZE_SEC = 0.08 # 80ms per frame +FRAME_SIZE_SAMPLES = int(SAMPLE_RATE * FRAME_SIZE_SEC) # 1280 samples + +TTS_SAMPLE_RATE = 22050 + + +# Default hyper-parameters that can be overridden via `model_cfg` +DEFAULT_BUFFER_SIZE_FRAMES = 71 +DEFAULT_NUM_FRAMES_PER_CHUNK = 1 +# Only used when use_codec_cache=False (sliding-window fallback). +# Ignored when the codec streaming cache is enabled. +DEFAULT_CODEC_TOKEN_HISTORY_SIZE = 600 + + +class NemotronVoicechatInferenceWrapper: + """ + Inference wrapper for NemotronVoiceChat models. + Uses a sliding window buffer and processes audio frame by frame. + """ + + def __init__(self, model_cfg: DictConfig): + """ + Initialize the model for realtime streaming inference. + + Args: + model_cfg (DictConfig): Configuration describing the model paths and runtime parameters. + """ + if model_cfg is None: + raise ValueError("model_cfg must be provided") + if not isinstance(model_cfg, DictConfig): + model_cfg = OmegaConf.create(model_cfg) + + + logging.info(f"pythonpath: {sys.path}") + + + logging.info(f"before setting - torch.backends.cudnn.allow_tf32: {torch.backends.cudnn.allow_tf32}") + logging.info(f"before setting - torch.backends.cuda.matmul.allow_tf32: {torch.backends.cuda.matmul.allow_tf32}") + logging.info(f"before setting - torch.get_float32_matmul_precision(): {torch.get_float32_matmul_precision()}") + + torch.backends.cudnn.allow_tf32 = True + torch.backends.cuda.matmul.allow_tf32 = True + torch.set_float32_matmul_precision("medium") + + self._deterministic = bool(model_cfg.get("deterministic", False)) + if self._deterministic: + engine_type = model_cfg.get("engine_type", "native") + if "vllm" in engine_type.lower(): + raise ValueError( + "`deterministic` is not compatible with vLLM engines because vLLM uses custom " + "CUDA kernels (PagedAttention, FlashAttention) that do not support deterministic mode. " + f"Got engine_type='{engine_type}'. Use engine_type='native' for deterministic inference." + ) + + # Required by torch.use_deterministic_algorithms for cuBLAS reproducibility + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" + + torch.manual_seed(0) + torch.cuda.manual_seed_all(0) + torch.backends.cuda.enable_flash_sdp(False) + torch.backends.cuda.enable_mem_efficient_sdp(False) + torch.use_deterministic_algorithms(True, warn_only=False) + + logging.info("Deterministic mode ENABLED") + logging.info(f" CUBLAS_WORKSPACE_CONFIG={os.environ.get('CUBLAS_WORKSPACE_CONFIG')}") + logging.info(f" flash_sdp enabled: {torch.backends.cuda.flash_sdp_enabled()}") + logging.info(f" mem_efficient_sdp enabled: {torch.backends.cuda.mem_efficient_sdp_enabled()}") + logging.info( + " NOTE: deterministic mode uses different CUDA kernels (e.g. math SDPA instead of " + "FlashAttention), so results may differ slightly from non-deterministic mode. " + "Inference will also be slower." + ) + + logging.info(f"torch.backends.cudnn.allow_tf32: {torch.backends.cudnn.allow_tf32}") + logging.info(f"torch.backends.cuda.matmul.allow_tf32: {torch.backends.cuda.matmul.allow_tf32}") + logging.info(f"torch.get_float32_matmul_precision(): {torch.get_float32_matmul_precision()}") + + self.model_cfg = model_cfg + + self.model_path = model_cfg.get("model_path") + if not self.model_path: + raise ValueError("`model_cfg.model_path` must be provided.") + + self.llm_checkpoint_path = model_cfg.get("llm_checkpoint_path") + if not self.llm_checkpoint_path: + raise ValueError("`model_cfg.llm_checkpoint_path` must be provided.") + + self.decode_audio = bool(model_cfg.get("decode_audio", True)) + # Number of past codec tokens kept in the sliding-window decode buffer. + # Only used when use_codec_cache=False (the fallback path). When the + # codec cache is enabled, context is maintained incrementally inside + # CausalConv1dCache and this value is ignored. + self.codec_token_history_size = int( + model_cfg.get("codec_token_history_size", DEFAULT_CODEC_TOKEN_HISTORY_SIZE) + ) + + self.speaker_reference = model_cfg.get("speaker_reference") + if self.decode_audio and not self.speaker_reference: + raise ValueError("`model_cfg.speaker_reference` must be provided when decode_audio is enabled.") + + self.tts_system_prompt = model_cfg.get("tts_system_prompt", None) + logging.info(f"TTS system prompt: {self.tts_system_prompt}") + + compute_dtype = model_cfg.get("compute_dtype", "bfloat16") + self.dtype = self._resolve_dtype(compute_dtype) + + self.device = self._resolve_device( + device=model_cfg.get("device"), + device_id=model_cfg.get("device_id"), + ) + + logging.info("=" * 70) + logging.info("INITIALIZING REALTIME STREAMING INFERENCE") + logging.info("=" * 70) + logging.info(f"Frame size: {FRAME_SIZE_SEC}s ({FRAME_SIZE_SAMPLES} samples @ {SAMPLE_RATE}Hz)") + logging.info(f"Device: {self.device}") + logging.info(f"Compute dtype: {self.dtype}") + logging.info(f"Decode audio: {self.decode_audio}") + logging.info(f"Engine type: {model_cfg.get('engine_type', 'native')}") + logging.info(f"Sampling - top_p: {model_cfg.get('top_p', 1.0)}, repetition_penalty: {model_cfg.get('repetition_penalty', 1.0)}, temperature: {model_cfg.get('temperature', 1.0)}") + logging.info("=" * 70) + + # Cached TTS helpers populated during initialization/warmup + self.first_context_subword_id = None + self.generation_config = None + self.first_tts_code_input = None + self.first_tts_past_key_values_input = None + + + self.model = None + self.model_llm_interface = None + self.tokenizer = None + + # vLLM configuration + self.engine_type = model_cfg.get("engine_type", "native") + self.use_vllm_llm = "vllm_llm" in self.engine_type.lower() + self.use_vllm_eartts = "vllm_eartts" in self.engine_type.lower() + self.vllm_llm_config = model_cfg.get("vllm_llm_config", None) + self.vllm_tts_config = model_cfg.get("vllm_tts_config", None) + self.request_id = "streaming_request_0" # For vLLM streaming + + # Sampling parameters + self.top_p = float(model_cfg.get("top_p", 1.0)) + self.repetition_penalty = float(model_cfg.get("repetition_penalty", 1.0)) + self.temperature = float(model_cfg.get("temperature", 1.0)) + + # Codec streaming cache: decode only new tokens each step using the + # codec's CausalConv1dCache, which maintains ConvNeXt and ISTFT state + # across calls for sample-continuous audio. When enabled, the + # codec_token_history_size parameter and audio_toks_buffer are unused. + # When disabled, falls back to the sliding-window decode that re-decodes + # codec_token_history_size tokens each step and extracts the tail. + self.use_codec_cache = bool(model_cfg.get("use_codec_cache", True)) + if self.use_codec_cache and self.decode_audio: + configured_history = model_cfg.get("codec_token_history_size", None) + if configured_history is not None: + logging.info( + f"use_codec_cache is enabled — codec_token_history_size ({configured_history}) " + f"will be ignored (context is maintained incrementally by the codec cache)." + ) + + # Perception cache configuration + self.use_perception_cache = bool(model_cfg.get("use_perception_cache", False)) + use_perception_cudagraph = bool(model_cfg.get("use_perception_cudagraph", False)) + if use_perception_cudagraph and not self.use_perception_cache: + raise ValueError( + "use_perception_cudagraph requires use_perception_cache to be enabled. " + "Please also set use_perception_cache=True." + ) + self.perception_cache_mgr: Optional[PerceptionCacheManager] = None + self._use_perception_cudagraph = use_perception_cudagraph + + self._initialize_model() + + logging.info("NemotronVoicechatInferenceWrapper initialized successfully.") + + logging.info(f"{self.model.stt_model.perception.encoder._cfg = }") + logging.info(f"{self.model.stt_model.perception.encoder.streaming_cfg = }") + + @staticmethod + def _resolve_dtype(compute_dtype): + if isinstance(compute_dtype, torch.dtype): + return compute_dtype + if compute_dtype is None: + return torch.bfloat16 + if isinstance(compute_dtype, str): + key = compute_dtype.lower() + mapping = { + "bfloat16": torch.bfloat16, + "bf16": torch.bfloat16, + "float16": torch.float16, + "fp16": torch.float16, + "half": torch.float16, + "float32": torch.float32, + "fp32": torch.float32, + "full": torch.float32, + } + if key in mapping: + return mapping[key] + raise ValueError(f"Unsupported compute_dtype: {compute_dtype}") + + @staticmethod + def _resolve_device(device=None, device_id=None): + if isinstance(device, torch.device): + resolved_device = device + else: + if device is None: + resolved_device = DEFAULT_DEVICE + else: + device_str = str(device) + base = device_str + if device_id is not None and device_str.startswith("cuda") and ":" not in device_str: + base = f"{device_str}:{device_id}" + resolved_device = torch.device(base) + return resolved_device + + def _samples_per_audio_output_frame(self): + rate = getattr(self, "target_sample_rate", None) + if rate is None: + cfg_rate = None + try: + cfg_rate = self.model_cfg.get("tts_sample_rate", None) + except Exception: + cfg_rate = None + if cfg_rate is None: + try: + cfg_rate = self.model_cfg.get("output_sample_rate", None) + except Exception: + cfg_rate = None + if cfg_rate is not None: + rate = float(cfg_rate) + if rate is None: + rate = TTS_SAMPLE_RATE + samples = int(float(rate) * FRAME_SIZE_SEC) + return samples + + def _load_and_merge_configs(self): + """Load and merge configurations from both nano and eartts checkpoints.""" + logging.info("Loading and merging configurations...") + + # Load nano's config (for LLM, perception) + nano_config_file = os.path.join(self.llm_checkpoint_path, "config.json") + logging.info(f" Loading nano config: {nano_config_file}") + with open(nano_config_file, 'r') as f: + import json + nano_cfg_dict = json.load(f) + nano_cfg = DictConfig(nano_cfg_dict) + + # Load eartts's config (for TTS) + eartts_config_file = os.path.join(self.model_path, "config.json") + logging.info(f" Loading eartts config: {eartts_config_file}") + with open(eartts_config_file, 'r') as f: + eartts_cfg_dict = json.load(f) + eartts_cfg = DictConfig(eartts_cfg_dict) + + # Start with nano's config as base + merged_cfg = nano_cfg + + # Override TTS-related parts with eartts's config + logging.info(" Merging: Using nano's config for LLM/perception, eartts's for TTS") + if 'model' in eartts_cfg and 'speech_generation' in eartts_cfg.model: + merged_cfg.model.speech_generation = eartts_cfg.model.speech_generation + logging.info(" TTS config from eartts") + + # Set speaker reference + if 'model' not in merged_cfg: + merged_cfg.model = {} + merged_cfg.model.inference_speaker_reference = self.speaker_reference + + # Ensure data section has correct sample rates + if 'data' not in merged_cfg: + merged_cfg.data = eartts_cfg.data + + logging.info(f" Final config:") + logging.info(f" - pretrained_llm: {merged_cfg.model.stt.model.pretrained_llm}") + logging.info(f" - perception.d_model: {merged_cfg.model.stt.model.perception.modality_adapter.d_model}") + logging.info(f" - speech_generation: {'present' if 'speech_generation' in merged_cfg.model else 'missing'}") + + return merged_cfg + + def _initialize_model(self): + """Initialize the NemotronVoiceChat with hybrid loading.""" + from safetensors.torch import load_file + from nemo.collections.speechlm2.parts.pretrained import set_model_dict_for_partial_init + + logging.info("Initializing model with hybrid loading strategy...") + + + # Step 1: Load and merge configs + cfg = self._load_and_merge_configs() + + # Step 2: DO NOT set pretrained_s2s_model - we'll load weights manually + cfg.model.stt.model.pretrained_s2s_model = None + cfg.model.speech_generation.model.pretrained_model = None + + # Convert to dict for model initialization + cfg_dict = OmegaConf.to_container(cfg, resolve=True) + + # Step 3: Initialize model structure + logging.info("Initializing model structure...") + start_DuplexS2S_init = time.time() + self.model = NemotronVoiceChat(cfg_dict) + logging.info(f"Time taken to initialize NemotronVoiceChat: {time.time() - start_DuplexS2S_init} seconds") + logging.info(" Model structure initialized") + + # Step 4: Load nano's checkpoint (LLM + perception) + if self.llm_checkpoint_path is not None: + logging.info("Loading LLM + perception:") + logging.info(f" Path: {self.llm_checkpoint_path}") + + nano_state_dict = load_file(os.path.join(self.llm_checkpoint_path, "model.safetensors")) + + # Filter to non-TTS weights + tts_keys = ['tts_model.', 'speech_generation.'] + + # If using vLLM for LLM, also exclude LLM weights to save memory + # vLLM will load its own copy of the LLM + if self.use_vllm_llm: + llm_keys = ['stt_model.llm.'] + exclude_keys = tts_keys + llm_keys + logging.info(f" Using vLLM - excluding LLM weights from nano checkpoint") + else: + exclude_keys = tts_keys + + nano_filtered = {k: v for k, v in nano_state_dict.items() + if not any(k.startswith(prefix) for prefix in exclude_keys)} + + logging.info(f" Loading {len(nano_filtered)} parameters (excluded: {exclude_keys})...") + + # Free the full state dict immediately to save CPU memory + del nano_state_dict + gc.collect() + + nano_filtered = set_model_dict_for_partial_init(nano_filtered, self.model.state_dict()) + missing, unexpected = self.model.load_state_dict(nano_filtered, strict=False) + + # Free filtered dict + del nano_filtered + gc.collect() + + missing_non_excluded = [k for k in missing if not any(k.startswith(prefix) for prefix in exclude_keys)] + unexpected_non_excluded = [k for k in unexpected if not any(k.startswith(prefix) for prefix in exclude_keys)] + + if missing_non_excluded: + logging.info(f" {len(missing_non_excluded)} keys missing (might be OK)") + if unexpected_non_excluded: + logging.info(f" {len(unexpected_non_excluded)} unexpected keys") + + # Step 5: Load eartts's checkpoint (TTS only) + if self.model_path is not None: + logging.info("Loading TTS checkpoint:") + logging.info(f" Path: {self.model_path}") + + eartts_state_dict = load_file(os.path.join(self.model_path, "model.safetensors")) + + # Filter to only TTS weights + tts_keys_filter = ['tts_model.'] + eartts_tts_only = {k: v for k, v in eartts_state_dict.items() + if any(k.startswith(prefix) for prefix in tts_keys_filter)} + + logging.info(f" Loading {len(eartts_tts_only)} TTS parameters...") + + start_tts_load_state_dict = time.time() + missing, unexpected = self.model.load_state_dict(eartts_tts_only, strict=False) + logging.info(f"Time taken to load TTS state dict: {time.time() - start_tts_load_state_dict} seconds") + + missing_tts = [k for k in missing if any(k.startswith(prefix) for prefix in tts_keys_filter)] + unexpected_tts = [k for k in unexpected if any(k.startswith(prefix) for prefix in tts_keys_filter)] + + if missing_tts: + logging.info(f" {len(missing_tts)} TTS keys missing") + for mk in missing_tts: + logging.info(f" missing: {mk}") + if unexpected_tts: + logging.info(f" {len(unexpected_tts)} unexpected TTS keys") + + if self.use_vllm_eartts: + # gonna convert and load vllm eartts engine + # Use object.__setattr__ to bypass PyTorch's module registration + # since VllmEARTTSModel is not a torch.nn.Module + del self.model.tts_model.tts_model + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize() + object.__setattr__( + self.model.tts_model, + 'tts_model', + create_model( + model=self.model_path, + engine_type="vllm_eartts", + vllm_config=self.vllm_tts_config) + ) + from nemo.collections.speechlm2.inference.vllm.vllm_patch import patched_infer_codes_one_step + self.model.tts_model.infer_codes_one_step = types.MethodType(patched_infer_codes_one_step, self.model.tts_model) + + logging.info(f" eartts checkpoint loaded (TTS only)") + + logging.info("\nHybrid loading completed!") + + # If using vLLM for LLM, delete native LLM BEFORE moving to device to save memory + if self.use_vllm_llm: + logging.info("\nDeleting native LLM before GPU transfer (will use vLLM instead)...") + if hasattr(self.model.stt_model, 'llm') and self.model.stt_model.llm is not None: + # Delete all submodules of LLM to free memory + for name, child in list(self.model.stt_model.llm.named_children()): + delattr(self.model.stt_model.llm, name) + del self.model.stt_model.llm + self.model.stt_model.llm = None + gc.collect() + torch.cuda.empty_cache() + logging.info(" Native LLM deleted") + + # Setup model + self.model.to(self.device) + self.model.eval() + + # Convert only the S2S components to the configured dtype, not the TTS model + logging.info(f"Converting S2S components to {self.dtype} (keeping TTS in float32)...") + if self.model.stt_model.llm is not None: + self.model.stt_model.llm = self.model.stt_model.llm.to(self.dtype) + self.model.stt_model.lm_head = self.model.stt_model.lm_head.to(self.dtype) + self.model.stt_model.embed_tokens = self.model.stt_model.embed_tokens.to(self.dtype) + self.model.stt_model.asr_head = self.model.stt_model.asr_head.to(self.dtype) + self.model.stt_model.embed_asr_tokens = self.model.stt_model.embed_asr_tokens.to(self.dtype) + if self.model.stt_model.function_head is not None: + self.model.stt_model.function_head = self.model.stt_model.function_head.to(self.dtype) + logging.info("function_head converted to %s", self.dtype) + #self.model.stt_model.perception = self.model.stt_model.perception.to(self.dtype) + logging.info("S2S components converted, TTS kept in float32") + logging.info("new update, perception also is kept in float32") + + # commenting this out to avoid error when try vllm tts + # and anyway - when sticking to "native", saw no difference in output + # with and without this call + #self.model.on_train_epoch_start() + self.tokenizer = self.model.stt_model.tokenizer + + + # allow overrides/additions from the self.model_cfg of nemotron_voicechat_inference_wrapper, + # into the model cfg that is read from config.json of the model. + # Specifically, this is so that we can specify inference_pad_boost, ... etc. + for key in ( + "inference_pad_boost", + "inference_bos_boost", + "inference_eos_boost", + "inference_user_pad_boost", + "inference_user_bos_boost", + "inference_user_eos_boost", + ): + val = self.model_cfg.get(key, None) + if val is not None: + OmegaConf.update(self.model.stt_model.cfg, key, val) + + # Print inference boost values + logging.info(f"inference_eos_boost: {self.model.stt_model.cfg.get('inference_eos_boost', None)}") + logging.info(f"inference_bos_boost: {self.model.stt_model.cfg.get('inference_bos_boost', None)}") + logging.info(f"inference_pad_boost: {self.model.stt_model.cfg.get('inference_pad_boost', None)}") + logging.info(f"inference_user_pad_boost: {self.model.stt_model.cfg.get('inference_user_pad_boost', None)}") + logging.info(f"inference_user_bos_boost: {self.model.stt_model.cfg.get('inference_user_bos_boost', None)}") + logging.info(f"inference_user_eos_boost: {self.model.stt_model.cfg.get('inference_user_eos_boost', None)}") + + # Wrap model with appropriate interface (Native or vLLM) + if self.use_vllm_llm: + logging.info("\nWrapping model with VllmLLMModel interface...") + if self.vllm_llm_config is None: + raise ValueError("vllm_llm_config must be provided when engine_type contains'vllm_llm'") + + # LLM already deleted above, just ensure cleanup + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize() + + # Set logit boosts as env vars BEFORE creating the vLLM engine, + # so they are inherited by the forked worker process. The modified + # nemotron_h.py reads VLLM_ASR_BOOST_ and + # VLLM_TEXT_BOOST_ in __init__. + stt = self.model.stt_model + asr_boost_map = { + "inference_user_pad_boost": stt.text_pad_id, + "inference_user_bos_boost": stt.user_bos_id, + "inference_user_eos_boost": stt.text_eos_id, + } + for cfg_key, token_id in asr_boost_map.items(): + val = self.model_cfg.get(cfg_key, None) + if val is not None and float(val) != 0.0: + env_key = f"VLLM_ASR_BOOST_{token_id}" + os.environ[env_key] = str(float(val)) + logging.info(f"Set env {env_key}={val} (from {cfg_key})") + + text_boost_map = { + "inference_pad_boost": stt.text_pad_id, + "inference_bos_boost": stt.text_bos_id, + "inference_eos_boost": stt.text_eos_id, + } + for cfg_key, token_id in text_boost_map.items(): + val = self.model_cfg.get(cfg_key, None) + if val is not None and float(val) != 0.0: + env_key = f"VLLM_TEXT_BOOST_{token_id}" + os.environ[env_key] = str(float(val)) + logging.info(f"Set env {env_key}={val} (from {cfg_key})") + + self.model_llm_interface = create_model( + model=self.model_path, + engine_type="vllm_llm", + vllm_config=self.vllm_llm_config, + top_p=self.top_p, + repetition_penalty=self.repetition_penalty, + temperature=self.temperature, + ) + + logging.info("VllmLLMModel interface created") + else: + logging.info("\nWrapping model with NativeModel interface...") + self.model_llm_interface = create_model( + model=self.model, + engine_type="native", + top_p=self.top_p, + repetition_penalty=self.repetition_penalty, + temperature=self.temperature, + ) + logging.info("NativeModel interface created") + + # Get TTS info + if hasattr(self.model, 'tts_model'): + self.target_fps = self.model.tts_model.target_fps + self.target_sample_rate = self.model.tts_model.target_sample_rate + logging.info(f"\nTTS model initialized: target_fps={self.target_fps}, sample_rate={self.target_sample_rate}") + if self.decode_audio: + self._prepare_tts_initial_state() + else: + logging.warning("Warning: TTS model not found in the model") + + # Setup perception cache if enabled + if self.use_perception_cache: + self.perception_cache_mgr = PerceptionCacheManager( + model=self.model, + device=self.device, + dtype=self.dtype, + use_cudagraph=self._use_perception_cudagraph, + ) + if not self.perception_cache_mgr.setup(): + self.use_perception_cache = False + self.perception_cache_mgr = None + + def _get_bos_embedding(self): + """Get beginning of sequence embedding.""" + text_bos = torch.full((1,), fill_value=self.model.stt_model.text_pad_id, device=self.device) + input_embeds = self.model.stt_model.embed_tokens(text_bos) + return input_embeds.to(dtype=self.dtype) + + def _get_asr_bos_embedding(self) -> torch.Tensor: + """Get ASR BOS embedding for AR decoding.""" + text_bos = torch.full((1,), fill_value=self.model.stt_model.text_pad_id, device=self.device) + input_embeds = self.model.stt_model.embed_asr_tokens(text_bos) + return input_embeds.to(dtype=self.dtype) + + def _prepare_system_prompt_embeddings( + self, + system_prompt: str, + ) -> Tuple[Optional[torch.Tensor], int]: + """ + Prepare system prompt embeddings consistent with offline_inference. + + In offline_inference, prompt embeddings are structured as: + - Position 0: prompt_token_emb + bos_emb + asr_bos + - Position t > 0: prompt_token_emb + pad_emb + pad_asr + + Args: + system_prompt: The system prompt text + + Returns: + Tuple of (prompt_embedded [1, prompt_len, H], prompt_length) + Returns (None, 0) if system_prompt is empty + """ + + if not system_prompt or not system_prompt.strip(): + return None, 0 + + logging.info(f"Preparing system prompt: {system_prompt[:100]}...") + + # Step 1: Tokenize the prompt + # Format: [bos] + text_tokens + [eos] (consistent with collate_system_prompt) + prompt_token_ids = ( + [self.tokenizer.bos_id] + + self.tokenizer.text_to_ids(system_prompt) + + [self.tokenizer.eos_id] + ) + prompt_tokens = torch.tensor(prompt_token_ids, dtype=torch.long, device=self.device).unsqueeze(0) # [1, prompt_len] + prompt_len = prompt_tokens.shape[1] + + logging.info(f" Prompt length: {prompt_len} tokens") + + # Step 2: Embed the prompt tokens (this acts as the "audio channel" for prompt positions) + prompt_embedded = self.model.stt_model.embed_tokens(prompt_tokens) # [1, prompt_len, H] + prompt_embedded = prompt_embedded.to(dtype=self.dtype) + + # Step 3: Add pad embeddings for text and ASR channels (for positions t > 0) + # In offline_inference, prompt positions use gen_text[:, t-1] = pad_id + pad_id = self.model.stt_model.text_pad_id + pad_token = torch.full((1,), fill_value=pad_id, device=self.device, dtype=torch.long) + pad_emb = self.model.stt_model.embed_tokens(pad_token).to(dtype=self.dtype) # [1, H] + pad_asr_emb = self.model.stt_model.embed_asr_tokens(pad_token).to(dtype=self.dtype) # [1, H] + + # For positions t > 0, add pad embeddings (simulating gen_text[:, t-1] = pad_id) + has_fc = self.model.stt_model.function_head is not None + if prompt_len > 1: + prompt_embedded[:, 1:, :] += pad_emb + prompt_embedded[:, 1:, :] += pad_asr_emb + if has_fc: + prompt_embedded[:, 1:, :] += pad_emb # FC channel also uses pad at t > 0 + + # Step 4: For position 0, add BOS embeddings + bos_emb = self._get_bos_embedding() # [1, H] + asr_bos_emb = self._get_asr_bos_embedding() # [1, H] + prompt_embedded[:, 0, :] += bos_emb.squeeze(0) + prompt_embedded[:, 0, :] += asr_bos_emb.squeeze(0) + if has_fc: + prompt_embedded[:, 0, :] += pad_emb.squeeze(0) # FC channel uses pad at t=0 + + logging.info(f" System prompt embeddings prepared: shape {prompt_embedded.shape}") + + return prompt_embedded, prompt_len + + def _clone_cache(self, cache): + """Deep clone cache structures to ensure complete isolation between streams.""" + if cache is None: + return None + if isinstance(cache, torch.Tensor): + return cache.detach().clone() + if isinstance(cache, (list, tuple)): + return type(cache)(self._clone_cache(x) for x in cache) + if isinstance(cache, dict): + return {k: self._clone_cache(v) for k, v in cache.items()} + # Handle complex objects (e.g., DynamicCache with __dict__ attributes) + # Use deepcopy to ensure complete isolation between streams + if hasattr(cache, '__dict__'): + import copy + return copy.deepcopy(cache) + return cache + + def _prepare_tts_initial_state(self): + if not self.decode_audio: + return + if not hasattr(self.model, 'tts_model'): + return + + logging.info("Preparing TTS warmup state...") + + with fp32_precision(): + speaker_audio, speaker_sr = torchaudio.load(self.speaker_reference) + speaker_audio = resample(speaker_audio, speaker_sr, self.model.tts_model.target_sample_rate) + + speaker_audio = speaker_audio.to(self.device) + speaker_audio_lens = torch.tensor([speaker_audio.size(1)], device=self.device).long() + + # init tts_model + self.model.tts_model.set_init_inputs( + speaker_audio=speaker_audio, + speaker_audio_lens=speaker_audio_lens, + system_prompt=self.tts_system_prompt, + ) + init_inputs = self.model.tts_model.get_init_inputs(B=1) + + self.generation_config = self.model.tts_model._get_generation_config(guidance_enabled=True) + init_inputs.update({"use_cache": True, "past_key_values": None, "guidance_enabled": True}) + + with torch.no_grad(): + if self.use_vllm_eartts: + self.tts_prompt_token_ids = init_inputs["subword_ids"].squeeze().cpu().numpy().tolist() + self.tts_init_inputs = init_inputs + outputs = self.model.tts_model.tts_model( + self.tts_init_inputs, + request_id="tts_system_prompt_prefill_request", + prompt_token_ids=self.tts_prompt_token_ids + ) + # abort this request + self.model.tts_model.tts_model.abort_request("tts_system_prompt_prefill_request") + else: + outputs = self.model.tts_model.tts_model(**init_inputs) + + code = init_inputs["code"][:, -1:] + # code, _, _ = self.model.tts_model.tts_model.generate_step( + # outputs.hidden_states[:, -1:], **self.generation_config + # ) + + self.first_context_subword_id = init_inputs["subword_ids"][:, -1].unsqueeze(-1) + self.first_tts_code_input = code.detach().clone() + self.first_tts_past_key_values_input = self._clone_cache(outputs.past_key_values) + + + logging.info("TTS warmup state prepared") + + def _update_audio_buffer(self, audio_buffer, buffer_fill_level, new_audio, buffer_size_samples): + """ + Append incoming samples to the sliding-window buffer and produce the view used for inference. + + Parameters: + audio_buffer (torch.Tensor): Tensor of shape `[1, buffer_size_samples]` holding the latest audio samples. + buffer_fill_level (int): Number of valid samples currently stored in `audio_buffer`. + new_audio (torch.Tensor): Incoming samples of shape `[1, slice_n_samples]` for the current step. + buffer_size_samples (int): Total capacity of the buffer in samples. + + Returns: + Tuple[torch.Tensor, int, torch.Tensor]: + - Updated `audio_buffer` containing the newest samples (always capped to `buffer_size_samples`). + - Updated `buffer_fill_level`, reflecting how many contiguous samples are valid. + - `current_buffer`, a view over the valid portion of the buffer used for the model input. + + Notes: + `audio_buffer` always retains the last `buffer_size_samples` samples even when overfilled, + whereas `current_buffer` may be shorter during the initial warm-up phase when the buffer + is not yet full. + """ + if new_audio.shape[1] == 0: + current_buffer = audio_buffer[:, :buffer_fill_level] + return audio_buffer, buffer_fill_level, current_buffer + + remaining = new_audio + + if buffer_fill_level < buffer_size_samples and remaining.shape[1] > 0: + warmup_take = min(buffer_size_samples - buffer_fill_level, remaining.shape[1]) + if warmup_take > 0: + audio_buffer[:, buffer_fill_level:buffer_fill_level + warmup_take] = remaining[:, :warmup_take] + buffer_fill_level += warmup_take + remaining = remaining[:, warmup_take:] + + if remaining.shape[1] > 0: + if remaining.shape[1] >= buffer_size_samples: + audio_buffer = remaining[:, -buffer_size_samples:] + else: + audio_buffer = torch.cat([ + audio_buffer[:, remaining.shape[1]:], + remaining + ], dim=1) + buffer_fill_level = buffer_size_samples + current_buffer = audio_buffer if buffer_fill_level == buffer_size_samples else audio_buffer[:, :buffer_fill_level] + return audio_buffer, buffer_fill_level, current_buffer + + def infer_one_step(self, + audio_input, + num_frames_per_chunk, + frame_idx, + gen_text, + audio_toks_buffer, + input_embeds_history, + dynamic_cache, + past_key_values=None, + code=None, + subword_mask=None, + gen_asr_text=None, + gen_function_text=None, + request_id: Optional[str] = None, + perception_cache: Optional[PerceptionCacheState] = None, + has_prompt: bool = False, + codec_cache=None): + + # Set up effective request ID for vLLM streaming + effective_request_id = request_id or self.request_id + + start_time_one_step = time.time() + use_cache = dynamic_cache is not None + batch_size = gen_text.shape[0] + + predicted_tokens = torch.empty((batch_size, num_frames_per_chunk), dtype=gen_text.dtype, device=gen_text.device) + asr_predicted_tokens = torch.empty((batch_size, num_frames_per_chunk), dtype=gen_text.dtype, device=gen_text.device) + function_predicted_tokens = torch.empty((batch_size, num_frames_per_chunk), dtype=gen_text.dtype, device=gen_text.device) + + # Do "perception" step outside the for-loop + start_perception = time.time() + + if self.use_perception_cache and perception_cache is not None and perception_cache.is_initialized(): + # Cache-aware perception + source_encoded, perception_cache = self.perception_cache_mgr.step( + audio_input=audio_input, + frame_idx=frame_idx, + num_frames_per_chunk=num_frames_per_chunk, + perception_cache=perception_cache, + ) + else: + # Standard perception (full buffer processing) + buffer_len = torch.tensor([audio_input.shape[1]], dtype=torch.long, device=self.device) + source_encoded, _, _ = self.model.stt_model.perception( + input_signal=audio_input, + input_signal_length=buffer_len, + return_encoder_emb=True, + ) + + torch.cuda.synchronize() + time_perception = time.time() - start_perception + logging.info(f"Time taken for perception: {time_perception:.3f}s") + source_encoded = source_encoded.to(self.dtype) + total_encoded_frames = source_encoded.shape[1] + + # Determine embedding position based on whether we're using cache + if self.use_perception_cache and perception_cache is not None and perception_cache.is_initialized(): + # With cache: we get exactly num_frames_per_chunk output frames + # Use all of them directly + embedding_position = 0 + newest_frame_index = total_encoded_frames - 1 + base_frame_index = 0 + else: + # Without cache: Use the second-to-last encoded frame (-2) as the "newest" frame embedding. + # This is because the model's expects the chunk sizes to be size 10ms, 80ms, 80ms, 80ms, ...., + # but we pass in always 80ms, 80ms, 80ms.... + # e.g. + # (1) if we pass in just one 80ms chunk -> the model treats it as 10ms, then 70ms with 10ms silence padding at the end. + # (2) if we pass 80ms, 80ms -> the model treats it as 10ms, 80ms, 70ms with 10ms silence padding at the end. + # => we do not want to use the final embedding due to containing silence padding. We want to use the second-to-last embedding. + embedding_position = -2 + newest_frame_index = total_encoded_frames + embedding_position + base_frame_index = newest_frame_index - (num_frames_per_chunk - 1) + base_frame_index = max(base_frame_index, 0) + + new_input_embeds = [] + new_codes_for_decode = [] + for frame_offset in range(num_frames_per_chunk): + current_frame_idx = frame_idx + frame_offset + current_frame_index = base_frame_index + frame_offset + current_frame_index = min(current_frame_index, total_encoded_frames - 1) + current_frame_embedding = source_encoded[:, current_frame_index:current_frame_index + 1, :] + + current_input_emb = current_frame_embedding.clone() + + has_fc = gen_function_text is not None + + if current_frame_idx == 0 and not has_prompt: + # Only add BOS if there's no prompt (BOS is already in prompt's position 0) + current_input_emb += self._get_bos_embedding() + current_input_emb += self._get_asr_bos_embedding() + if has_fc: + pad_id = self.model.stt_model.text_pad_id + fc_pad_token = torch.full((1,), fill_value=pad_id, device=self.device, dtype=torch.long) + current_input_emb += self.model.stt_model.embed_tokens(fc_pad_token).to(dtype=self.dtype) + elif current_frame_idx == 0 and has_prompt: + # With prompt: first audio frame uses pad embedding (like offline_inference) + # gen_text[:, -1] from prompt positions is pad_id + pad_id = self.model.stt_model.text_pad_id + pad_token = torch.full((1,), fill_value=pad_id, device=self.device, dtype=torch.long) + pad_emb = self.model.stt_model.embed_tokens(pad_token).to(dtype=self.dtype) + pad_asr_emb = self.model.stt_model.embed_asr_tokens(pad_token).to(dtype=self.dtype) + current_input_emb += pad_emb + current_input_emb += pad_asr_emb + if has_fc: + current_input_emb += self.model.stt_model.embed_tokens(pad_token).to(dtype=self.dtype) + else: + # t > 0: add embeddings from model's own predictions at t-1 + last_token_emb = self.model.stt_model.embed_tokens(gen_text[:, current_frame_idx - 1]) + current_input_emb += last_token_emb + last_asr_token_emb = self.model.stt_model.embed_asr_tokens(gen_asr_text[:, current_frame_idx - 1]) + current_input_emb += last_asr_token_emb + if has_fc: + last_fc_token_emb = self.model.stt_model.embed_tokens(gen_function_text[:, current_frame_idx - 1]) + current_input_emb += last_fc_token_emb.to(dtype=self.dtype) + + start_stt_model = time.time() + + if use_cache or self.use_vllm_llm: + if self.use_vllm_llm: + # vLLM requires request_id + ans = self.model_llm_interface( + current_input_emb, + request_id=effective_request_id, + generated_tokens=gen_text, + current_step=current_frame_idx + ) + else: + ans = self.model_llm_interface( + current_input_emb, + cache=dynamic_cache, + generated_tokens=gen_text, + current_step=current_frame_idx + ) + dynamic_cache = ans["cache"] + else: + new_input_embeds.append(current_input_emb) + full_input_embeds = torch.cat(input_embeds_history + new_input_embeds, dim=1) + ans = self.model_llm_interface( + full_input_embeds, + cache=None, + generated_tokens=gen_text, + current_step=current_frame_idx + ) + + torch.cuda.synchronize() + time_stt_model = time.time() - start_stt_model + logging.info(f"Time taken for stt_model: {time_stt_model:.3f}s") + + predicted_token = ans["predicted_token"] + asr_predicted_token = ans["asr_predicted_token"] + + gen_text[:, current_frame_idx] = predicted_token + predicted_tokens[:, frame_offset] = predicted_token + + gen_asr_text[:, current_frame_idx] = asr_predicted_token + asr_predicted_tokens[:, frame_offset] = asr_predicted_token + + if "function_predicted_token" in ans: + function_predicted_tokens[:, frame_offset] = ans["function_predicted_token"] + if gen_function_text is not None: + gen_function_text[:, current_frame_idx] = ans["function_predicted_token"] + + # Apply forced turn taking based on ASR results + self._maybe_apply_forced_turn_taking(current_frame_idx, gen_text, gen_asr_text) + # Update predicted_tokens with any changes made by forced turn taking + predicted_tokens[:, frame_offset] = gen_text[:, current_frame_idx] + + if self.decode_audio: + current_subword_id = gen_text[:, current_frame_idx].unsqueeze(-1) + + # do one step inference on Duplex TTS model + if current_frame_idx == 0: + if self.first_context_subword_id is None: + raise RuntimeError("first_context_subword_id is not initialized. Ensure TTS warmup ran successfully.") + prev_subword_id = self.first_context_subword_id + else: + prev_subword_id = gen_text[:, current_frame_idx-1].unsqueeze(-1) + + # create subword_mask + current_subword_mask = subword_mask[:, current_frame_idx].unsqueeze(-1) + + if self.generation_config is None: + raise RuntimeError("generation_config is not initialized. Ensure TTS warmup ran successfully.") + + start_tts_model = time.time() + inputs = { + "current_subword_id": current_subword_id, + "prev_subword_id": prev_subword_id, + "current_subword_mask": current_subword_mask, + "prev_audio_tokens": code, + "past_key_values": past_key_values, + "guidance_enabled": True, + "generation_config": self.generation_config, + "ignore_eos_flag_stop": True, + } + if self.use_vllm_eartts: + inputs["request_id"] = effective_request_id + + code, past_key_values = self.model.tts_model.infer_codes_one_step( + **inputs + ) + + torch.cuda.synchronize() + time_tts_model = time.time() - start_tts_model + logging.info(f"Time taken for tts_model: {time_tts_model:.3f}s") + + new_codes_for_decode.append(code.clone()) + # Update sliding-window buffer (only needed for fallback decode when codec_cache is off) + if audio_toks_buffer is not None: + audio_toks_buffer = torch.cat([audio_toks_buffer[:, 1:], code], dim=1) + + # now that we've saved audio_toks_buffer for audio decoding purposes, + # we can potentially overwrite the audio token with silence tokens (for feeding to the audio token predictor) + if self.model.cfg.get('inference_force_speech_silence_on_eos', None): + silence_codes = self.model.tts_model.codec_silence_tokens.view(1, 1, -1).expand(code.shape) + code = torch.where( + current_subword_id.unsqueeze(-1) == self.model.tts_model.text_eos_id, + silence_codes, + code, + ) + + # exit for-loop & do audio decoding non-autoregressively (if decode_audio is True) + if self.decode_audio: + samples_per_audio_output_frame = self._samples_per_audio_output_frame() + logging.debug(f"\nDecoding audio for {frame_idx}-th frame ({num_frames_per_chunk=})") + + start_time_decode = time.time() + with fp32_precision(), torch.no_grad(): + if codec_cache is not None and new_codes_for_decode: + # Incremental decode: feed only the num_frames_per_chunk new tokens + # to the codec. CausalConv1dCache maintains all necessary ConvNeXt + # and ISTFT overlap state from prior calls, so no history buffer + # is needed — this replaces the sliding-window approach entirely. + new_codes_tensor = torch.cat(new_codes_for_decode, dim=1) + if hasattr(self.model.tts_model, '_control_codes'): + from nemo.collections.speechlm2.models.duplex_ear_tts import replace_control_speech_codes + new_codes_tensor = replace_control_speech_codes( + new_codes_tensor, + self.model.tts_model._control_codes, + getattr(self.model.tts_model, 'codec_silence_tokens', None), + ) + new_code_len = torch.tensor( + [new_codes_tensor.shape[1]], dtype=torch.long, device=self.device + ) + decoded_audio_new, _ = self.model.tts_model.audio_codec.decode( + new_codes_tensor, new_code_len, cache=codec_cache, + ) + logging.debug(f" Incremental decode: {new_codes_tensor.shape[1]} new tokens -> {decoded_audio_new.shape}") + else: + # Fallback: full-buffer sliding-window decode (original behavior) + len_audio_toks_buffer = torch.tensor( + [self.codec_token_history_size], dtype=torch.long, device=self.device + ) + decoded_audio, decoded_audio_len = self.model.tts_model.audio_codec.decode( + audio_toks_buffer, len_audio_toks_buffer, + ) + decoded_audio_new = decoded_audio[:, :, -samples_per_audio_output_frame * num_frames_per_chunk:] + logging.debug(f" Sliding-window decode: extracted {decoded_audio_new.shape} from {decoded_audio.shape}") + + torch.cuda.synchronize() + time_audio_codec = time.time() - start_time_decode + logging.info(f"Time taken for audio_codec: {time_audio_codec:.3f}s") + + else: + audio_toks_buffer = None + decoded_audio_new = None + time_tts_model = 0 + time_audio_codec = 0 + + # Convert new text tokens to string via tokens_to_text (convert_tokens_to_string) + # so byte-level BPE is decoded properly (e.g. "é" → "é") and leading spaces + # from Ġ-prefixed tokens are preserved for correct concatenation of incremental + # chunks: " Musée" + " National" → " Musée National". + # NOTE: multi-byte UTF-8 characters whose BPE tokens span two frames will show + # as replacement chars (�) because each frame is decoded independently. A proper + # fix would require an incremental UTF-8 decoder that buffers incomplete trailing + # bytes across frames. + predicted_text_strs = [] + for predicted_tok_ids_b in predicted_tokens: + predicted_tok_ids_b = predicted_tok_ids_b.tolist() + predicted_toks_b = self.tokenizer.ids_to_tokens(predicted_tok_ids_b) + predicted_toks_b = [tok for tok in predicted_toks_b if tok != ''] + predicted_text_strs.append(self.tokenizer.tokens_to_text(predicted_toks_b)) + + # convert new ASR tokens to string + asr_predicted_text_strs = [] + for asr_predicted_tok_ids_b in asr_predicted_tokens: + asr_predicted_tok_ids_b = asr_predicted_tok_ids_b.tolist() + asr_predicted_toks_b = self.tokenizer.ids_to_tokens(asr_predicted_tok_ids_b) + asr_predicted_toks_b = [tok for tok in asr_predicted_toks_b if tok != ''] + asr_predicted_text_strs.append(self.tokenizer.tokens_to_text(asr_predicted_toks_b)) + + logging.info(f'frame {frame_idx}: USER\'s asr_predicted_text_strs: {asr_predicted_text_strs}') + logging.info(f'frame {frame_idx}: --------------------------------AGENT\'s predicted_text_strs: {predicted_text_strs}') + + torch.cuda.synchronize() + + time_for_one_step = time.time() - start_time_one_step + logging.info(f'frame {frame_idx}: Time taken for one step: {time_for_one_step:.3f}s') + + result = { + 'predicted_text_tokens': predicted_tokens, + 'asr_predicted_text_tokens': asr_predicted_tokens, + 'audio_toks_buffer': audio_toks_buffer, + 'decoded_audio_new': decoded_audio_new, + 'predicted_text_strs': predicted_text_strs, + 'asr_predicted_text_strs': asr_predicted_text_strs, + 'input_embeds_history': input_embeds_history + new_input_embeds if not use_cache else input_embeds_history, + 'dynamic_cache': dynamic_cache if use_cache else None, + 'past_key_values': past_key_values, + 'code': code, + 'perception_cache': perception_cache, + 'codec_cache': codec_cache, + } + if self.model.stt_model.function_head is not None: + result['function_predicted_text_tokens'] = function_predicted_tokens + return result + + def abort_request(self, request_id: Optional[str]) -> bool: + """ + Abort an in-flight vLLM streaming request if the backend supports it. + """ + if not request_id: + return False + + success = False + + # Abort LLM if applicable + if self.use_vllm_llm: + abort_fn = getattr(self.model_llm_interface, "abort_request", None) + if callable(abort_fn): + try: + if abort_fn(request_id): + success = True + logging.info(f"Aborted LLM request {request_id} successfully.") + except Exception as exc: + logging.warning(f"Failed to abort LLM request {request_id}: {exc}") + + # Abort EarTTS if applicable + if self.use_vllm_eartts: + abort_fn = getattr(self.model.tts_model.tts_model, "abort_request", None) + if callable(abort_fn): + try: + if abort_fn(request_id): + success = True + logging.info(f"Aborted EarTTS request {request_id} successfully.") + except Exception as exc: + logging.warning(f"Failed to abort EarTTS request {request_id}: {exc}") + + return success + + + def _maybe_apply_forced_turn_taking(self, t, gen_text, gen_asr): + """Apply forced turn-taking rules based on ASR channel tokens.""" + if not self.model_cfg.get("force_turn_taking", False): + return + + threshold = self.model_cfg.get("force_turn_taking_threshold", 40) + pad_window_steps = self.model_cfg.get("force_turn_taking_pad_window", 25) + + B = gen_text.size(0) + + for batch_idx in range(B): + lookback_start = max(0, t - threshold) + agent_text_window = gen_text[batch_idx, lookback_start:t] + current_asr_token = gen_asr[batch_idx, t] + + # ASR EOS or ~1 sec of pad tokens → insert agent BOS if not present in window + # Skip if we don't have enough tokens at the beginning + if t < pad_window_steps: + continue + + pad_lookback_start = t - pad_window_steps + asr_recent_tokens = gen_asr[batch_idx, pad_lookback_start:t] + has_pad_window = (asr_recent_tokens == self.model.stt_model.text_pad_id).all() if len(asr_recent_tokens) > 0 else False + + # Require that the pad window starts after a non-pad token + if has_pad_window and pad_lookback_start > 0: + token_before_window = gen_asr[batch_idx, pad_lookback_start - 1] + has_pad_window = (token_before_window != self.model.stt_model.text_pad_id) and (token_before_window != self.model.stt_model.user_bos_id) + elif has_pad_window and pad_lookback_start == 0: + # If the pad window starts at position 0, it doesn't meet the requirement + has_pad_window = False + + if has_pad_window: + if not (agent_text_window == self.model.stt_model.text_bos_id).any(): + gen_text[batch_idx, t] = self.model.stt_model.text_bos_id + logging.info(f"Forced turn-taking at frame {t}: inserted agent BOS (reason: pad window)") + + # ASR BOS → insert agent EOS if not present in window + elif current_asr_token == self.model.stt_model.user_bos_id: + if not (agent_text_window == self.model.stt_model.text_eos_id).any(): + gen_text[batch_idx, t] = self.model.stt_model.text_eos_id + logging.info(f"Forced turn-taking at frame {t}: inserted agent EOS (reason: user started speaking)") + + @torch.no_grad() + def inference_realtime_streaming(self, audio_path: str, num_frames_per_chunk: int = None, request_id: Optional[str] = None, pad_audio_to_sec: Optional[float] = None, pad_silence_ratio: Optional[float] = None, pad_audio_by_sec: Optional[float] = None, system_prompt: Optional[str] = None): + """ + Perform realtime streaming inference simulating microphone capture. + + Args: + audio_path: Path to input audio file (simulates microphone input) + num_frames_per_chunk: Number of frames to process per inference step (default: 1) + request_id: Optional request ID for vLLM streaming + pad_audio_to_sec: Optional duration to pad audio to (in seconds) + pad_silence_ratio: Optional ratio of original duration to append as silence (e.g. 0.2 = 20%) + pad_audio_by_sec: Optional fixed number of extra seconds of silence to append + system_prompt: Optional system prompt to provide context to the model + + Returns: + Dictionary with 'text', 'tokens_text', 'tokens_audio', 'audio', 'audio_len', 'system_prompt' + """ + # Use provided value or default + if num_frames_per_chunk is None: + num_frames_per_chunk = DEFAULT_NUM_FRAMES_PER_CHUNK + if num_frames_per_chunk < 1: + raise ValueError("num_frames_per_chunk must be at least 1") + start_time = time.time() + + logging.info("\n" + "=" * 70) + logging.info("STARTING REALTIME STREAMING INFERENCE") + logging.info("=" * 70) + + # Set up request ID for vLLM streaming + stream_request_id = request_id or self.request_id + + buffer_size_frames = int(self.model_cfg.get("buffer_size_frames", DEFAULT_BUFFER_SIZE_FRAMES)) + buffer_size_samples = buffer_size_frames * FRAME_SIZE_SAMPLES + if num_frames_per_chunk > buffer_size_frames: + raise ValueError( + f"num_frames_per_chunk ({num_frames_per_chunk}) must be " + f"less than or equal to buffer_size_frames ({buffer_size_frames})." + ) + + att_context_size = self.model.stt_model.perception.encoder._cfg.att_context_size + if self.use_perception_cache: + min_buffer = num_frames_per_chunk * (att_context_size[1] + 1) + 2 + reason = ( + f"must be >= num_frames_per_chunk * (att_context_size[1] + 1) + 2 = " + f"{num_frames_per_chunk} * ({att_context_size[1]} + 1) + 2 = {min_buffer} " + f"when using perception cache (+2 to minimize windowing artifacts)" + ) + else: + min_buffer = att_context_size[0] + att_context_size[1] + 1 + reason = ( + f"must be >= att_context_size[0] + att_context_size[1] + 1 = " + f"{att_context_size[0]} + {att_context_size[1]} + 1 = {min_buffer} " + f"without perception cache" + ) + if buffer_size_frames < min_buffer: + raise ValueError( + f"buffer_size_frames ({buffer_size_frames}) is too small: {reason}." + ) + if self.decode_audio and not self.use_codec_cache and num_frames_per_chunk > self.codec_token_history_size: + raise ValueError( + f"num_frames_per_chunk ({num_frames_per_chunk}) must be " + f"<= codec_token_history_size ({self.codec_token_history_size}) when decode_audio=True " + f"and use_codec_cache=False. " + f"Either reduce num_frames_per_chunk, increase codec_token_history_size, or enable use_codec_cache." + ) + logging.info(f"Buffer size: {buffer_size_frames} frames ({buffer_size_frames * FRAME_SIZE_SEC}s)") + logging.info(f"Frames per inference step: {num_frames_per_chunk}") + + # Load audio file (simulating microphone stream) + logging.info(f"Loading audio file: {audio_path}") + audio_signal, sr = librosa.load(audio_path, sr=SAMPLE_RATE) + total_samples = len(audio_signal) + total_duration = total_samples / SAMPLE_RATE + + logging.info(f" Total duration: {total_duration:.2f}s") + logging.info(f" Total samples: {total_samples}") + + # Optionally pad audio (at most one of these is set; enforced by caller) + if pad_audio_to_sec is not None and pad_audio_to_sec > total_duration: + target_samples = int(pad_audio_to_sec * SAMPLE_RATE) + audio_signal = np.pad(audio_signal, (0, target_samples - total_samples), mode='constant') + total_samples = len(audio_signal) + logging.info(f" Padded to {pad_audio_to_sec:.2f}s ({total_samples} samples)") + elif pad_silence_ratio is not None: + extra_samples = int(total_duration * pad_silence_ratio * SAMPLE_RATE) + audio_signal = np.pad(audio_signal, (0, extra_samples), mode='constant') + total_samples = len(audio_signal) + logging.info(f" Padded with {pad_silence_ratio*100:.1f}% extra silence ({extra_samples} samples)") + elif pad_audio_by_sec is not None: + extra_samples = int(pad_audio_by_sec * SAMPLE_RATE) + audio_signal = np.pad(audio_signal, (0, extra_samples), mode='constant') + total_samples = len(audio_signal) + logging.info(f" Padded with {pad_audio_by_sec:.2f}s extra silence ({extra_samples} samples)") + + # derive num_inference_steps + total_frames_maybe = int(np.ceil(total_samples / FRAME_SIZE_SAMPLES)) # "maybe" because we might need to add padding + num_inference_steps = (total_frames_maybe // num_frames_per_chunk) + if total_frames_maybe % num_frames_per_chunk != 0: + num_inference_steps += 1 + total_frames = num_inference_steps * num_frames_per_chunk + + # pad audio signal so that it is divisible by num_inference_steps + padded_total_samples = num_inference_steps * num_frames_per_chunk * FRAME_SIZE_SAMPLES + if padded_total_samples > total_samples: + audio_signal = np.pad(audio_signal, (0, padded_total_samples - total_samples), mode='constant') + logging.info(f" Padded to: {padded_total_samples} samples") + logging.info(f" {num_frames_per_chunk=} => {total_frames=}, {num_inference_steps=}") + + # convert audio signal to tensor + audio_signal_tensor = torch.tensor(audio_signal, dtype=self.dtype, device=self.device).unsqueeze(0) + + # Check if Nemotron (no cache support) + use_cache = 'Nemotron' not in self.model.stt_model.cfg.pretrained_llm + logging.info(f"Model: {self.model.stt_model.cfg.pretrained_llm}") + logging.info(f" Use cache: {use_cache}") + + # Initialize buffer and state + audio_buffer = torch.zeros(1, buffer_size_samples, dtype=self.dtype, device=self.device) + buffer_fill_level = 0 # How many samples currently in buffer + + # Initialize LLM cache + if use_cache: + llm_cache = DynamicCache() + else: + llm_cache = None + input_embeds_history = [] # For no-cache mode + + # Process system prompt if provided (before streaming audio) + prompt_embedded = None + prompt_len = 0 + + if system_prompt: + start_get_prompt_embeddings = time.time() + prompt_embedded, prompt_len = self._prepare_system_prompt_embeddings(system_prompt) + logging.info(f"Time taken to get prompt embeddings: {time.time() - start_get_prompt_embeddings:.3f}s") + if prompt_embedded is not None and "vllm" in self.engine_type.lower(): + # Prepare token IDs for the prompt + prompt_token_ids = ( + [self.tokenizer.bos_id] + + self.tokenizer.text_to_ids(system_prompt) + + [self.tokenizer.eos_id] + ) + + # For vLLM mode: use efficient BATCH prefill (~20x faster than sequential) + logging.info(f" Batch prefilling {prompt_len} prompt embeddings...") + start_batch_prefill = time.time() + with torch.no_grad(): + success = self.model_llm_interface( + prompt_embedded, + request_id=stream_request_id, + decode_steps=0, + prompt_token_ids=prompt_token_ids, + ) + logging.info(f"Time taken to batch prefill stt model: {time.time() - start_batch_prefill:.3f}s") + if success: + logging.info(f" System prompt prefilled ({prompt_len} tokens)") + else: + raise RuntimeError("vLLM batch prefill for system prompt failed.") + elif prompt_embedded is not None and not use_cache: + # For no-cache mode (Nemotron): add prompt embeddings to history + # Split into individual frames for consistent processing + for t in range(prompt_len): + input_embeds_history.append(prompt_embedded[:, t:t+1, :]) + logging.info(f" Added {prompt_len} prompt embeddings to input_embeds_history") + elif prompt_embedded is not None and use_cache: + # For cache mode: process prompt through LLM to update cache + with torch.no_grad(): + ans = self.model.stt_model(prompt_embedded, cache=llm_cache) + llm_cache = ans.get("cache", llm_cache) + logging.info(f" System prompt processed, cache updated") + + # Initialize TTS + code = None + past_key_values = None + subword_mask = None + audio_toks_buffer = None + if self.decode_audio and hasattr(self.model, 'tts_model'): + + # Sliding-window buffer is only needed when codec_cache is off + if not self.use_codec_cache: + audio_toks_buffer = self.model.tts_model.codec_silence_tokens.view(1, 1, -1).expand( + -1, self.codec_token_history_size, -1 + ).to(self.device) + + if ( + self.first_context_subword_id is None + or self.generation_config is None + or self.first_tts_code_input is None + or self.first_tts_past_key_values_input is None + ) and not self.use_vllm_eartts: + raise RuntimeError("TTS warmup state was not prepared during initialization.") + + if not self.use_vllm_eartts: + past_key_values = self._clone_cache(self.first_tts_past_key_values_input) + code = self.first_tts_code_input.detach().clone() + else: + start_batch_prefill = time.time() + logging.info(f" Batch prefilling TTS model with speaker embedding...") + # use speaker embedding to prefill EarTTS's vLLM + tts_result = self.model.tts_model.tts_model( + self.tts_init_inputs, + request_id=stream_request_id, + prompt_token_ids=self.tts_prompt_token_ids + ) + code = self.first_tts_code_input.detach().clone() + past_key_values = None + logging.info(f"Time taken to batch prefill tts model: {time.time() - start_batch_prefill:.3f}s") + # Initialize subword_mask for vLLM path as well + subword_mask = torch.ones(1, total_frames, device=self.device, dtype=torch.bool) + logging.info(f"TTS initialized") + + # Initialize perception cache if enabled + perception_cache = None + if self.use_perception_cache: + perception_cache = self.perception_cache_mgr.get_initial_state(batch_size=1) + logging.info(f"Perception cache initialized") + + # Initialize codec streaming cache to remove clicking sounds and wasted inference computation + codec_cache = None + if self.decode_audio and self.use_codec_cache: + from nemo.collections.speechlm2.modules.ear_tts_vae_codec import CausalConv1dCache + codec_cache = CausalConv1dCache() + logging.info(f"Codec streaming cache initialized") + + gen_text = torch.full((1, total_frames), self.model.stt_model.text_pad_id, device=self.device, dtype=torch.long) + gen_asr_text = torch.full((1, total_frames), self.model.stt_model.text_pad_id, device=self.device, dtype=torch.long) + has_function_head = self.model.stt_model.function_head is not None + if has_function_head: + gen_function_text = torch.full((1, total_frames), self.model.stt_model.text_pad_id, device=self.device, dtype=torch.long) + + # initialize list to which we will append generated audio segments + audio_segments = [] + + logging.info("\n" + "=" * 70) + logging.info("STARTING FRAME-BY-FRAME PROCESSING") + logging.info("=" * 70) + + # frame_idx corresponds to index of the first frame passed to infer_one_step + # (we need this distinction in the case that num_frames_per_chunk > 1) + frame_idx = 0 + while frame_idx < total_frames: + slice_start = frame_idx * FRAME_SIZE_SAMPLES + slice_n_samples = num_frames_per_chunk * FRAME_SIZE_SAMPLES + slice_end = slice_start + slice_n_samples + new_audio = audio_signal_tensor[:, slice_start:slice_end] + + audio_buffer, buffer_fill_level, current_buffer = self._update_audio_buffer( + audio_buffer, buffer_fill_level, new_audio, buffer_size_samples + ) + + result = self.infer_one_step( + audio_input=current_buffer, + num_frames_per_chunk=num_frames_per_chunk, + frame_idx=frame_idx, + gen_text=gen_text, + audio_toks_buffer=audio_toks_buffer if self.decode_audio else None, + input_embeds_history=input_embeds_history if not use_cache else [], + dynamic_cache=llm_cache if use_cache else None, + past_key_values=past_key_values if self.decode_audio else None, + code=code if self.decode_audio else None, + subword_mask=subword_mask if self.decode_audio else None, + gen_asr_text=gen_asr_text, + gen_function_text=gen_function_text if has_function_head else None, + request_id=stream_request_id, + perception_cache=perception_cache, + has_prompt=(prompt_len > 0), + codec_cache=codec_cache, + ) + + # handle results from infer_one_step + if has_function_head and 'function_predicted_text_tokens' in result: + for fi in range(num_frames_per_chunk): + gen_function_text[:, frame_idx + fi] = result['function_predicted_text_tokens'][:, fi] + input_embeds_history = result['input_embeds_history'] + llm_cache = result['dynamic_cache'] + if self.use_perception_cache: + perception_cache = result.get('perception_cache', perception_cache) + if self.decode_audio: + audio_toks_buffer = result['audio_toks_buffer'] + decoded_audio_new = result['decoded_audio_new'] + if decoded_audio_new is not None: + audio_segments.append(decoded_audio_new) + + past_key_values = result['past_key_values'] + code = result['code'] + codec_cache = result.get('codec_cache', codec_cache) + else: + decoded_audio_new = None + + if frame_idx % 10 == 0 or frame_idx < 3 or gen_text[:, frame_idx].item() == self.model.stt_model.text_eos_id: + token_str = self.tokenizer.ids_to_text([gen_text[0, frame_idx].item()]) + buffer_status = f"{buffer_fill_level}/{buffer_size_samples}" if buffer_fill_level < buffer_size_samples else "FULL" + special_label = "" + if gen_text[0, frame_idx].item() == self.model.stt_model.text_bos_id: + special_label = " [BOS]" + elif gen_text[0, frame_idx].item() == self.model.stt_model.text_eos_id: + special_label = " [EOS]" + elif gen_text[0, frame_idx].item() == self.model.stt_model.text_pad_id: + special_label = " [PAD]" + logging.info(f"Frame {frame_idx:3d}/{total_frames} | Buffer: {buffer_status:20s} | Token: {gen_text[0, frame_idx].item():5d}{special_label} | '{token_str}'") + + frame_idx += num_frames_per_chunk + + # Prepare results + elapsed_time = time.time() - start_time + logging.info("\n" + "=" * 70) + logging.info("STREAMING INFERENCE COMPLETED") + logging.info("=" * 70) + logging.info(f"Total time: {elapsed_time:.2f}s") + logging.info(f"Audio duration: {total_duration:.2f}s") + logging.info(f"RTF (Real-Time Factor): {elapsed_time / total_duration:.2f}x") + logging.info(f"Processed frames: {total_frames}") + + # Trim to actual length + # TODO: this is currently redundant since we iterate over all frames in the while loop + gen_text = gen_text[:, :total_frames] + gen_asr_text = gen_asr_text[:, :total_frames] + + # Decode text + lengths = torch.tensor([total_frames], dtype=torch.long, device=self.device) + text_output = tokens_to_str(gen_text, lengths, tokenizer=self.tokenizer, pad_id=self.model.stt_model.text_pad_id, eval_text_turn_taking=True) + + # Decode ASR text + asr_text_output = tokens_to_str(gen_asr_text, lengths, tokenizer=self.tokenizer, pad_id=self.model.stt_model.text_pad_id, eval_text_turn_taking=True) + + # Also create raw versions with kept for comparison + text_output_raw = tokens_to_str_raw(gen_text, lengths, tokenizer=self.tokenizer, pad_id=self.model.stt_model.text_pad_id) + asr_text_output_raw = tokens_to_str_raw(gen_asr_text, lengths, tokenizer=self.tokenizer, pad_id=self.model.stt_model.text_pad_id) + + logging.info(f"Generated text: {text_output[0]}") + logging.info(f"Generated ASR text: {asr_text_output[0]}") + + # Decode function calling channel + if has_function_head: + gen_function_text = gen_function_text[:, :total_frames] + function_text_output = tokens_to_str(gen_function_text, lengths, tokenizer=self.tokenizer, pad_id=self.model.stt_model.text_pad_id, eval_text_turn_taking=False) + function_text_output_raw = tokens_to_str_raw(gen_function_text, lengths, tokenizer=self.tokenizer, pad_id=self.model.stt_model.text_pad_id) + logging.info(f"Generated function text: {function_text_output[0]}") + + ans = { + "text": text_output, + "text_raw": text_output_raw, + "tokens_text": gen_text, + "tokens_len": lengths, + "audio": torch.cat(audio_segments, dim=-1) if audio_segments else None, + "asr_text": asr_text_output, + "asr_text_raw": asr_text_output_raw, + "asr_tokens": gen_asr_text, + "system_prompt": system_prompt if system_prompt else "", + } + if has_function_head: + ans["function_text"] = function_text_output + ans["function_text_raw"] = function_text_output_raw + ans["function_tokens"] = gen_function_text + + if self.use_vllm_llm or self.use_vllm_eartts: + self.abort_request(stream_request_id) + + return ans + + +def main(): + parser = argparse.ArgumentParser(description="Realtime Streaming Inference") + parser.add_argument("--model_path", type=str, required=True, + help="Path to eartts's checkpoint with TTS (HF format)") + parser.add_argument("--llm_checkpoint_path", type=str, required=True, + help="Path to checkpoint with LLM/perception (HF format)") + parser.add_argument("--audio_path", type=str, default=None, + help="Path to input audio file (for single-file mode)") + parser.add_argument("--input_json", type=str, default=None, + help="Path to input JSON file containing list of records with audio_filepath and text fields (for batch mode)") + parser.add_argument("--output_json", type=str, default=None, + help="Path to output JSON file with predictions") + parser.add_argument("--output_dir", type=str, default="output_streaming", + help="Output directory for audio files and JSON results") + parser.add_argument("--pad_audio_to_sec", type=float, default=None, + help="Pad audio to this duration in seconds (useful for consistent buffer behavior)") + parser.add_argument("--pad_silence_ratio", type=float, default=None, + help="Append silence equal to this ratio of the original audio duration (e.g. 0.2 = 20%% extra)") + parser.add_argument("--pad_audio_by_sec", type=float, default=None, + help="Append this many seconds of extra silence after the audio") + parser.add_argument("--speaker_reference", type=str, required=True, + help="Path to speaker reference audio file") + parser.add_argument("--buffer_size_frames", type=int, default=DEFAULT_BUFFER_SIZE_FRAMES, + help=f"Size of audio buffer in frames (each frame = 80ms, default: {DEFAULT_BUFFER_SIZE_FRAMES})") + parser.add_argument("--num_frames_per_chunk", type=int, default=DEFAULT_NUM_FRAMES_PER_CHUNK, + help="Number of frames per inference step (default: 1)") + parser.add_argument("--decode_audio", action="store_true", + help="Whether to decode audio") + parser.add_argument("--combine_inp_out_audio", action="store_true", + help="Whether to combine input and output audio into a stereo file") + + # Deterministic inference + parser.add_argument("--deterministic", action="store_true", + help="Enable fully deterministic inference (disables FlashAttention, forces deterministic " + "CUDA algorithms). Useful for reproducible benchmarking. Not compatible with vLLM engines. " + "Note: results may differ slightly from non-deterministic mode due to different compute path.") + + # Perception cache argument + parser.add_argument("--use_perception_cache", action="store_true", + help="Enable cache-aware streaming for perception encoder") + parser.add_argument("--use_perception_cudagraph", action="store_true", + help="Use CUDA graphs for perception encoder (requires --use_perception_cache)") + # Codec streaming cache argument + parser.add_argument("--use_codec_cache", action="store_true", + help="Enable incremental codec decode to remove clicking sounds and wasted inference computation (recommended)") + + # vLLM arguments + parser.add_argument("--engine_type", type=str, default="native", choices=["native", "vllm_llm", "vllm_eartts", "vllm_llm_vllm_eartts"], + help="Engine type for inference (default: native)") + parser.add_argument("--vllm_llm_engine_path", type=str, default=None, + help="Path to vLLM-compatible model checkpoint if the path not exists, it will be auto-converted") + parser.add_argument("--vllm_max_model_len", type=int, default=768, + help="Maximum sequence length for vLLM (default: 768)") + parser.add_argument("--vllm_gpu_memory_utilization", type=float, nargs='+', default=[0.4], + help="GPU memory utilization for vLLM. Single value shared by both engines; two values assign to LLM and TTS respectively.") + parser.add_argument("--vllm_llm_dtype", type=str, default="bfloat16", + help="Data type for vLLM (default: bfloat16)") + + # vLLM EarTTS arguments + parser.add_argument("--vllm_eartts_engine_path", type=str, default=None, + help="Path to vLLM-compatible EarTTS model checkpoint if the path not exists, it will be auto-converted") + parser.add_argument("--vllm_eartts_dtype", type=str, default="float32", + help="Data type for vLLM (default: float32)") + + # Sampling parameters + parser.add_argument("--top_p", type=float, default=1.0, + help="Top-p (nucleus) sampling threshold. 1.0 disables it (greedy). Default: 1.0") + parser.add_argument("--repetition_penalty", type=float, default=1.0, + help="Repetition penalty for generated tokens. 1.0 disables it. Default: 1.0. Recommended: 1.2") + parser.add_argument("--temperature", type=float, default=1.0, + help="Temperature for sampling. 1.0 = no change, <1.0 = sharper, >1.0 = flatter, 0.0 = greedy. Default: 1.0") + + # Turn-taking + parser.add_argument("--force_turn_taking", action="store_true", + help="Enable forced turn-taking based on ASR channel tokens") + parser.add_argument("--force_turn_taking_threshold", type=int, default=40, + help="Number of lookback steps for turn-taking detection (default: 40)") + parser.add_argument("--force_turn_taking_pad_window", type=int, default=25, + help="Number of consecutive ASR pad tokens to trigger turn-taking (default: 25)") + + # Inference logit boosts + parser.add_argument("--inference_pad_boost", type=float, default=None, + help="Boost for agent pad logit at inference time") + parser.add_argument("--inference_bos_boost", type=float, default=None, + help="Boost for agent BOS logit at inference time") + parser.add_argument("--inference_eos_boost", type=float, default=None, + help="Boost for agent EOS logit at inference time") + parser.add_argument("--inference_user_pad_boost", type=float, default=None, + help="Boost for ASR pad logit at inference time") + parser.add_argument("--inference_user_bos_boost", type=float, default=None, + help="Boost for ASR BOS logit at inference time") + parser.add_argument("--inference_user_eos_boost", type=float, default=None, + help="Boost for ASR EOS logit at inference time") + + # System prompt + parser.add_argument("--system_prompt", type=str, default=None, + help="System prompt to provide context to the model. Can also be specified per-record in input JSON.") + parser.add_argument("--tts_system_prompt", type=str, default=None, + help="System prompt for EARTTS model.") + args = parser.parse_args() + + # Validate arguments: either audio_path OR input_json must be provided + if args.audio_path is None and args.input_json is None: + parser.error("Either --audio_path (single-file mode) or --input_json (batch mode) must be provided") + if args.audio_path is not None and args.input_json is not None: + parser.error("Cannot use both --audio_path and --input_json at the same time") + + if sum(x is not None for x in [args.pad_audio_to_sec, args.pad_silence_ratio, args.pad_audio_by_sec]) > 1: + raise ValueError("Set at most one of: --pad_audio_to_sec, --pad_silence_ratio, --pad_audio_by_sec") + if not math.isfinite(args.temperature) or args.temperature < 0.0: + parser.error(f"--temperature must be a finite value >= 0.0, got {args.temperature}") + + try: + import json + import soundfile as sf + + model_cfg_dict = { + "model_path": args.model_path, + "llm_checkpoint_path": args.llm_checkpoint_path, + "speaker_reference": args.speaker_reference, + "buffer_size_frames": args.buffer_size_frames, + "decode_audio": bool(args.decode_audio), + "engine_type": args.engine_type, + "deterministic": bool(args.deterministic), + "use_perception_cache": bool(args.use_perception_cache), + "use_perception_cudagraph": bool(args.use_perception_cudagraph), + "use_codec_cache": bool(args.use_codec_cache), + "top_p": args.top_p, + "repetition_penalty": args.repetition_penalty, + "temperature": args.temperature, + "tts_system_prompt": args.tts_system_prompt, + "force_turn_taking": args.force_turn_taking, + "force_turn_taking_threshold": args.force_turn_taking_threshold, + "force_turn_taking_pad_window": args.force_turn_taking_pad_window, + "inference_pad_boost": args.inference_pad_boost, + "inference_bos_boost": args.inference_bos_boost, + "inference_eos_boost": args.inference_eos_boost, + "inference_user_pad_boost": args.inference_user_pad_boost, + "inference_user_bos_boost": args.inference_user_bos_boost, + "inference_user_eos_boost": args.inference_user_eos_boost, + } + + # Pop GPU memory utilization values: first for LLM, second (or same) for TTS + _gpu_mem = list(args.vllm_gpu_memory_utilization) + gpu_mem_llm = _gpu_mem.pop(0) + gpu_mem_tts = _gpu_mem.pop(0) if _gpu_mem else gpu_mem_llm + + # Add vLLM configuration if using vLLM engine + if "vllm_llm" in args.engine_type: + model_cfg_dict["vllm_llm_config"] = { + "model_path": args.model_path, + "max_model_len": args.vllm_max_model_len, + "gpu_memory_utilization": gpu_mem_llm, + "dtype": args.vllm_llm_dtype, + "engine_path": args.vllm_llm_engine_path, # Will auto-convert if needed + "pretrained_llm": args.llm_checkpoint_path, + } + + if "vllm_eartts" in args.engine_type: + model_cfg_dict["vllm_tts_config"] = { + "model_path": args.model_path, # we use exactly the same whole duplexs2s ckpt + "max_model_len": args.vllm_max_model_len, + "gpu_memory_utilization": gpu_mem_tts, + "dtype": args.vllm_eartts_dtype, + "engine_path": args.vllm_eartts_engine_path, + "pretrained_llm": None, + "skip_tokenizer_init": True + } + + model_cfg = OmegaConf.create(model_cfg_dict) + + model = NemotronVoicechatInferenceWrapper(model_cfg=model_cfg) + + # ========================================= + # Load input records (from JSON manifest or single audio file) + # ========================================= + if args.input_json is not None: + logging.info(f"Loading input JSON: {args.input_json}") + with open(args.input_json, 'r') as f: + input_records = [json.loads(line) for line in f] + else: + input_records = [{"audio_filepath": args.audio_path, "text": ""}] + + logging.info(f"Found {len(input_records)} records to process") + + os.makedirs(args.output_dir, exist_ok=True) + + if args.output_json: + base_path = args.output_json.rsplit('.', 1)[0] if '.' in args.output_json else args.output_json + output_json_processed = f"{base_path}_processed.json" + output_json_raw = f"{base_path}_raw.json" + else: + output_json_processed = os.path.join(args.output_dir, "output_results_processed.json") + output_json_raw = os.path.join(args.output_dir, "output_results_raw.json") + + logging.info(f"Output will be saved incrementally to:") + logging.info(f" Processed: {output_json_processed}") + logging.info(f" Raw: {output_json_raw}") + output_file_processed = open(output_json_processed, 'w', encoding='utf-8') + output_file_raw = open(output_json_raw, 'w', encoding='utf-8') + + output_records = [] + wer_scores = [] + + try: + for idx, record in enumerate(input_records): + logging.info("\n" + "=" * 70) + logging.info(f"Processing record {idx + 1}/{len(input_records)}") + logging.info("=" * 70) + + audio_path = record.get('audio_filepath') + ground_truth_text = record.get('text', '') + record_system_prompt = record.get('system_prompt', args.system_prompt) + + if not audio_path: + logging.warning(f"Record {idx} missing audio_filepath, skipping...") + continue + + if not os.path.exists(audio_path): + logging.warning(f"Audio file not found: {audio_path}, skipping...") + continue + + logging.info(f" Audio: {audio_path}") + logging.info(f" Ground truth: {ground_truth_text}") + + audio_id = os.path.splitext(os.path.basename(audio_path))[0] + + results = model.inference_realtime_streaming( + audio_path, + num_frames_per_chunk=args.num_frames_per_chunk, + pad_audio_to_sec=args.pad_audio_to_sec, + pad_silence_ratio=args.pad_silence_ratio, + pad_audio_by_sec=args.pad_audio_by_sec, + request_id=f"streaming_request_{idx}", + system_prompt=record_system_prompt, + ) + + pred_asr_text = results['asr_text'][0] if 'asr_text' in results else '' + pred_asr_text_raw = results['asr_text_raw'][0] if 'asr_text_raw' in results else '' + pred_text = results['text'][0] if 'text' in results else '' + pred_text_raw = results['text_raw'][0] if 'text_raw' in results else '' + + try: + cleaned_pred = clean_pred_text(pred_asr_text) + cleaned_gt = clean_pred_text(ground_truth_text) + if cleaned_gt.strip() and cleaned_pred.strip(): + utterance_wer = wer(cleaned_gt, cleaned_pred) + wer_scores.append(utterance_wer) + else: + utterance_wer = None + except Exception as e: + utterance_wer = None + logging.warning(f"Error calculating WER: {e}") + + if utterance_wer is not None: + logging.info(f"WER for utterance {idx + 1}: {utterance_wer:.4f} ({utterance_wer * 100:.2f}%)") + + pred_audio_path = None + if args.decode_audio and 'audio' in results and results['audio'] is not None: + input_basename = os.path.splitext(os.path.basename(audio_path))[0] + audio_filename = f"{idx:04d}_{input_basename}_output.wav" + pred_audio_path = os.path.join(args.output_dir, audio_filename) + + audio_np = results['audio'].float().cpu().numpy().flatten() + + sf.write(pred_audio_path, audio_np, model.target_sample_rate) + logging.info(f"Audio saved: {pred_audio_path}") + + if args.combine_inp_out_audio: + stereo_filename = f"{idx:04d}_{input_basename}_combined.wav" + stereo_path_out = os.path.join(args.output_dir, stereo_filename) + + inp_audio, sr = librosa.load(audio_path, sr=model.target_sample_rate) + + delay_samples = int(args.num_frames_per_chunk * FRAME_SIZE_SEC * model.target_sample_rate) + out_audio_delayed = np.concatenate([np.zeros(delay_samples, dtype=audio_np.dtype), audio_np]) + + max_len = max(len(inp_audio), len(out_audio_delayed)) + inp_audio_padded = np.pad(inp_audio, (0, max_len - len(inp_audio))) + out_audio_padded = np.pad(out_audio_delayed, (0, max_len - len(out_audio_delayed))) + + stereo_audio = np.stack([inp_audio_padded, out_audio_padded], axis=1) + sf.write(stereo_path_out, stereo_audio, model.target_sample_rate) + logging.info(f"Stereo audio saved: {stereo_path_out}") + + result_system_prompt = results.get('system_prompt', '') + + output_record_processed = { + 'id': audio_id, + 'target_text': '', + 'pred_audio': pred_audio_path, + 'src_text': ground_truth_text, + 'pred_src_text': pred_asr_text, + 'pred_text': pred_text, + 'system_prompt': result_system_prompt, + } + + output_record_raw = { + 'id': audio_id, + 'target_text': '', + 'pred_audio': pred_audio_path, + 'src_text': ground_truth_text, + 'pred_src_text': pred_asr_text_raw, + 'pred_text': pred_text_raw, + 'system_prompt': result_system_prompt, + } + + output_records.append(output_record_processed) + + json.dump(output_record_processed, output_file_processed, ensure_ascii=False) + output_file_processed.write('\n') + output_file_processed.flush() + + json.dump(output_record_raw, output_file_raw, ensure_ascii=False) + output_file_raw.write('\n') + output_file_raw.flush() + + logging.info(f"Record {idx + 1} completed and saved") + + finally: + output_file_processed.close() + output_file_raw.close() + + logging.info("\n" + "=" * 70) + logging.info("ALL RESULTS SAVED") + logging.info("=" * 70) + logging.info(f"Results saved to:") + logging.info(f" Processed: {output_json_processed}") + logging.info(f" Raw: {output_json_raw}") + logging.info(f" Processed {len(output_records)}/{len(input_records)} records successfully") + + if wer_scores: + avg_wer = np.mean(wer_scores) + logging.info("\n" + "=" * 70) + logging.info("WER STATISTICS") + logging.info("=" * 70) + logging.info(f" Total utterances with WER: {len(wer_scores)}") + logging.info(f" Average WER: {avg_wer:.4f} ({avg_wer * 100:.2f}%)") + logging.info(f" Min WER: {np.min(wer_scores):.4f} ({np.min(wer_scores) * 100:.2f}%)") + logging.info(f" Max WER: {np.max(wer_scores):.4f} ({np.max(wer_scores) * 100:.2f}%)") + + logging.info("=" * 70) + logging.info("ALL DONE!") + logging.info("=" * 70) + + except Exception as e: + logging.error(f"ERROR during inference: {e}") + import traceback + traceback.print_exc() + return 1 + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) + diff --git a/nemo/collections/speechlm2/inference/model_wrappers/perception_cache.py b/nemo/collections/speechlm2/inference/model_wrappers/perception_cache.py new file mode 100644 index 000000000000..ab7dbe494bf5 --- /dev/null +++ b/nemo/collections/speechlm2/inference/model_wrappers/perception_cache.py @@ -0,0 +1,537 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Cache-aware perception encoder for streaming S2S inference. + +Provides incremental mel-spectrogram encoding with optional CUDA graph +acceleration, so that only new audio needs to be processed each step +instead of re-encoding the entire buffer. +""" + +import copy +import time +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch +from omegaconf import OmegaConf + +from nemo.utils import logging + + +@dataclass +class PerceptionCacheState: + """Cache state for streaming perception inference. + + Holds the cache tensors for the ASR encoder used in the perception module. + This enables cache-aware streaming inference without needing the full audio buffer. + """ + cache_last_channel: Optional[torch.Tensor] = None + cache_last_time: Optional[torch.Tensor] = None + cache_last_channel_len: Optional[torch.Tensor] = None + + def is_initialized(self) -> bool: + """Check if the cache has been initialized.""" + return self.cache_last_channel is not None + + +@dataclass +class PerceptionCUDAGraphState: + """State for CUDA graph-accelerated perception encoder. + + Holds separate graphs for first chunk (different size) and subsequent chunks. + Also holds static buffers for inputs/outputs to enable graph replay. + """ + # CUDA graphs + graph_first: Optional[torch.cuda.CUDAGraph] = None + graph_subsequent: Optional[torch.cuda.CUDAGraph] = None + + # Static input buffers (for copying data before graph replay) + static_mel_first: Optional[torch.Tensor] = None + static_mel_subsequent: Optional[torch.Tensor] = None + static_mel_len_first: Optional[torch.Tensor] = None + static_mel_len_subsequent: Optional[torch.Tensor] = None + + # Static cache input buffers + static_cache_channel_in: Optional[torch.Tensor] = None + static_cache_time_in: Optional[torch.Tensor] = None + static_cache_channel_len_in: Optional[torch.Tensor] = None + + # Static output buffers (results are written here during replay) + static_encoded_first: Optional[torch.Tensor] = None + static_encoded_subsequent: Optional[torch.Tensor] = None + static_encoded_len_first: Optional[torch.Tensor] = None + static_encoded_len_subsequent: Optional[torch.Tensor] = None + + # Static cache output buffers - SEPARATE for first and subsequent graphs + # (each graph writes to its own output tensors during replay) + static_cache_channel_out_first: Optional[torch.Tensor] = None + static_cache_time_out_first: Optional[torch.Tensor] = None + static_cache_channel_len_out_first: Optional[torch.Tensor] = None + static_cache_channel_out_subsequent: Optional[torch.Tensor] = None + static_cache_time_out_subsequent: Optional[torch.Tensor] = None + static_cache_channel_len_out_subsequent: Optional[torch.Tensor] = None + + def is_captured(self) -> bool: + """Check if graphs have been captured.""" + return self.graph_first is not None and self.graph_subsequent is not None + + +class PerceptionCacheManager: + """Manages cache-aware streaming perception encoding with optional CUDA graphs. + + This class encapsulates all perception cache setup, CUDA graph capture, + and the incremental encoding step. It is created by the inference wrapper + when ``use_perception_cache=True``. + """ + + def __init__(self, model, device: torch.device, dtype: torch.dtype, use_cudagraph: bool = False): + self.model = model + self.device = device + self.dtype = dtype + self.use_cudagraph = use_cudagraph + + self.streaming_cfg = None + self.preprocessor = None + self.subsampling_factor = None + self.input_features = None + self.sampling_frames = None + self.cudagraph_state: Optional[PerceptionCUDAGraphState] = None + + def setup(self) -> bool: + """Setup cache-aware streaming for the perception encoder. + + Returns: + True if setup succeeded, False if the encoder doesn't support streaming. + """ + from nemo.collections.asr.parts.mixins.streaming import StreamingEncoder + + perception = self.model.stt_model.perception + encoder = perception.encoder + + if not isinstance(encoder, StreamingEncoder): + logging.warning("Perception encoder does not support streaming. Disabling perception cache.") + return False + + if encoder.streaming_cfg is None: + encoder.setup_streaming_params() + + self.streaming_cfg = encoder.streaming_cfg + + cfg = copy.deepcopy(perception.cfg) + OmegaConf.set_struct(cfg.preprocessor, False) + cfg.preprocessor.dither = 0.0 + cfg.preprocessor.pad_to = 0 + + self.preprocessor = perception.from_config_dict(cfg.preprocessor) + self.preprocessor.to(self.device) + + self.subsampling_factor = encoder.subsampling_factor + self.input_features = encoder._feat_in + + if hasattr(encoder, "pre_encode") and hasattr(encoder.pre_encode, "get_sampling_frames"): + self.sampling_frames = encoder.pre_encode.get_sampling_frames() + else: + self.sampling_frames = None + + logging.info(f"Perception cache setup complete:") + logging.info(f" Streaming config: chunk_size={self.streaming_cfg.chunk_size}, " + f"shift_size={self.streaming_cfg.shift_size}") + logging.info(f" Pre-encode cache size: {self.streaming_cfg.pre_encode_cache_size}") + logging.info(f" Subsampling factor: {self.subsampling_factor}") + + if self.use_cudagraph: + logging.info(f" Setting up CUDA graphs for perception encoder...") + self.capture_cudagraphs() + logging.info(f" CUDA graphs captured") + + return True + + def capture_cudagraphs(self): + """Capture CUDA graphs for perception encoder with both chunk sizes. + + Note: "chunk" in the streaming encoder config (chunk_size, shift_size, etc.) + follows NeMo's cache-aware streaming encoder API and is measured in + mel-spectrogram time-steps, not audio samples or seconds. + """ + encoder = self.model.stt_model.perception.encoder + perception = self.model.stt_model.perception + streaming_cfg = self.streaming_cfg + + if isinstance(streaming_cfg.chunk_size, list): + chunk_size_first = streaming_cfg.chunk_size[0] + chunk_size_subsequent = streaming_cfg.chunk_size[1] + else: + chunk_size_first = streaming_cfg.chunk_size + chunk_size_subsequent = streaming_cfg.chunk_size + + if isinstance(streaming_cfg.pre_encode_cache_size, list): + pre_encode_cache_first = streaming_cfg.pre_encode_cache_size[0] + pre_encode_cache_subsequent = streaming_cfg.pre_encode_cache_size[1] + else: + pre_encode_cache_first = streaming_cfg.pre_encode_cache_size + pre_encode_cache_subsequent = streaming_cfg.pre_encode_cache_size + + mel_len_first = chunk_size_first + pre_encode_cache_first + mel_len_subsequent = chunk_size_subsequent + pre_encode_cache_subsequent + + logging.info(f" CUDA graph mel lengths: first={mel_len_first}, subsequent={mel_len_subsequent}") + + cache_last_channel, cache_last_time, cache_last_channel_len = encoder.get_initial_cache_state( + batch_size=1 + ) + + state = PerceptionCUDAGraphState() + + state.static_mel_first = torch.zeros( + (1, self.input_features, mel_len_first), + dtype=torch.float32, device=self.device + ) + state.static_mel_subsequent = torch.zeros( + (1, self.input_features, mel_len_subsequent), + dtype=torch.float32, device=self.device + ) + state.static_mel_len_first = torch.tensor([mel_len_first], dtype=torch.long, device=self.device) + state.static_mel_len_subsequent = torch.tensor([mel_len_subsequent], dtype=torch.long, device=self.device) + + if cache_last_channel is not None: + state.static_cache_channel_in = cache_last_channel.clone() + if cache_last_time is not None: + state.static_cache_time_in = cache_last_time.clone() + if cache_last_channel_len is not None: + state.static_cache_channel_len_in = cache_last_channel_len.clone() + + logging.info(f" Warming up encoder for CUDA graph capture...") + for _ in range(3): + with torch.no_grad(): + _ = encoder.cache_aware_stream_step( + processed_signal=state.static_mel_first, + processed_signal_length=state.static_mel_len_first, + cache_last_channel=state.static_cache_channel_in.clone() if state.static_cache_channel_in is not None else None, + cache_last_time=state.static_cache_time_in.clone() if state.static_cache_time_in is not None else None, + cache_last_channel_len=state.static_cache_channel_len_in.clone() if state.static_cache_channel_len_in is not None else None, + keep_all_outputs=True, + drop_extra_pre_encoded=0, + ) + _ = encoder.cache_aware_stream_step( + processed_signal=state.static_mel_subsequent, + processed_signal_length=state.static_mel_len_subsequent, + cache_last_channel=state.static_cache_channel_in.clone() if state.static_cache_channel_in is not None else None, + cache_last_time=state.static_cache_time_in.clone() if state.static_cache_time_in is not None else None, + cache_last_channel_len=state.static_cache_channel_len_in.clone() if state.static_cache_channel_len_in is not None else None, + keep_all_outputs=True, + drop_extra_pre_encoded=streaming_cfg.drop_extra_pre_encoded, + ) + torch.cuda.synchronize() + + # Capture graph for FIRST chunk + logging.info(f" Capturing CUDA graph for first chunk (mel_len={mel_len_first})...") + state.graph_first = torch.cuda.CUDAGraph() + + if state.static_cache_channel_in is not None: + state.static_cache_channel_in.copy_(cache_last_channel) + if state.static_cache_time_in is not None: + state.static_cache_time_in.copy_(cache_last_time) + if state.static_cache_channel_len_in is not None: + state.static_cache_channel_len_in.copy_(cache_last_channel_len) + + with torch.cuda.graph(state.graph_first): + ( + encoded_first, + encoded_len_first, + cache_channel_out_first, + cache_time_out_first, + cache_channel_len_out_first, + ) = encoder.cache_aware_stream_step( + processed_signal=state.static_mel_first, + processed_signal_length=state.static_mel_len_first, + cache_last_channel=state.static_cache_channel_in, + cache_last_time=state.static_cache_time_in, + cache_last_channel_len=state.static_cache_channel_len_in, + keep_all_outputs=True, + drop_extra_pre_encoded=0, + ) + encoded_adapted_first, _ = perception.modality_adapter(audio_signal=encoded_first, length=encoded_len_first) + encoded_chunk_first = perception.proj(encoded_adapted_first.transpose(1, 2)) + + state.static_encoded_first = encoded_chunk_first + state.static_encoded_len_first = encoded_len_first + state.static_cache_channel_out_first = cache_channel_out_first + state.static_cache_time_out_first = cache_time_out_first + state.static_cache_channel_len_out_first = cache_channel_len_out_first + + # Capture graph for SUBSEQUENT chunks + logging.info(f" Capturing CUDA graph for subsequent chunks (mel_len={mel_len_subsequent})...") + state.graph_subsequent = torch.cuda.CUDAGraph() + + if state.static_cache_channel_in is not None: + state.static_cache_channel_in.copy_(cache_last_channel) + if state.static_cache_time_in is not None: + state.static_cache_time_in.copy_(cache_last_time) + if state.static_cache_channel_len_in is not None: + state.static_cache_channel_len_in.copy_(cache_last_channel_len) + + with torch.cuda.graph(state.graph_subsequent): + ( + encoded_subsequent, + encoded_len_subsequent, + cache_channel_out_subsequent, + cache_time_out_subsequent, + cache_channel_len_out_subsequent, + ) = encoder.cache_aware_stream_step( + processed_signal=state.static_mel_subsequent, + processed_signal_length=state.static_mel_len_subsequent, + cache_last_channel=state.static_cache_channel_in, + cache_last_time=state.static_cache_time_in, + cache_last_channel_len=state.static_cache_channel_len_in, + keep_all_outputs=True, + drop_extra_pre_encoded=streaming_cfg.drop_extra_pre_encoded, + ) + encoded_adapted_subsequent, _ = perception.modality_adapter(audio_signal=encoded_subsequent, length=encoded_len_subsequent) + encoded_chunk_subsequent = perception.proj(encoded_adapted_subsequent.transpose(1, 2)) + + state.static_encoded_subsequent = encoded_chunk_subsequent + state.static_encoded_len_subsequent = encoded_len_subsequent + state.static_cache_channel_out_subsequent = cache_channel_out_subsequent + state.static_cache_time_out_subsequent = cache_time_out_subsequent + state.static_cache_channel_len_out_subsequent = cache_channel_len_out_subsequent + + self.cudagraph_state = state + logging.info(f" CUDA graphs captured successfully") + + def get_initial_state(self, batch_size: int = 1) -> PerceptionCacheState: + """Get initial cache state for perception encoder.""" + encoder = self.model.stt_model.perception.encoder + cache_last_channel, cache_last_time, cache_last_channel_len = encoder.get_initial_cache_state( + batch_size=batch_size + ) + + return PerceptionCacheState( + cache_last_channel=cache_last_channel, + cache_last_time=cache_last_time, + cache_last_channel_len=cache_last_channel_len, + ) + + def step( + self, + audio_input: torch.Tensor, + frame_idx: int, + num_frames_per_chunk: int, + perception_cache: PerceptionCacheState, + ) -> Tuple[torch.Tensor, PerceptionCacheState]: + """ + Perform cache-aware perception encoding for streaming inference. + + Note: "chunk" in this method (chunk_size, mel_chunk, etc.) follows NeMo's + cache-aware streaming encoder API and is measured in mel-spectrogram time-steps, + not audio samples or seconds. + + This method computes the full mel spectrogram from the audio buffer, then slices + it appropriately based on the frame index. It supports processing multiple + "base steps" in a single call, where each base step processes (lookahead + 1) frames. + + Processing logic per sub-step: + - First sub-step (sub_frame_idx == 0): take first chunk_size_first columns, + prepend zeros for pre_encode_cache + - Subsequent sub-steps (sub_frame_idx > 0): take chunk_size columns starting from + (shift_size_first + (step_number-1)*shift_size), prepend pre_encode_cache_size + columns from mel spec + + The method loops over sub-steps, running the encoder for each and concatenating + the outputs. This allows num_frames_per_chunk to be a multiple of (lookahead + 1). + + Args: + audio_input: Audio buffer tensor [B, T] (full buffer with all samples) + frame_idx: Current frame index in the stream + num_frames_per_chunk: Number of 80ms frames to process. Must be a multiple + of (lookahead + 1), i.e., encoder._cfg.att_context_size[1] + 1 + perception_cache: Current cache state containing encoder caches + + Returns: + Tuple of (encoded_output [B, T_out, D], updated_perception_cache) + where T_out = num_frames_per_chunk (one output frame per input frame) + """ + perception = self.model.stt_model.perception + encoder = perception.encoder + streaming_cfg = self.streaming_cfg + + audio_len = torch.tensor([audio_input.shape[1]], dtype=torch.long, device=self.device) + _t_start_preprocessor = time.time() + processed_signal, _ = self.preprocessor( + input_signal=audio_input, + length=audio_len, + ) + logging.info(f"preprocessor time: {time.time() - _t_start_preprocessor:.3f}s") + + if isinstance(streaming_cfg.chunk_size, list): + chunk_size_first = streaming_cfg.chunk_size[0] + chunk_size = streaming_cfg.chunk_size[1] + else: + chunk_size_first = streaming_cfg.chunk_size + chunk_size = streaming_cfg.chunk_size + + if isinstance(streaming_cfg.shift_size, list): + shift_size_first = streaming_cfg.shift_size[0] + shift_size = streaming_cfg.shift_size[1] + else: + shift_size_first = streaming_cfg.shift_size + shift_size = streaming_cfg.shift_size + + if isinstance(streaming_cfg.pre_encode_cache_size, list): + pre_encode_cache_size_first = streaming_cfg.pre_encode_cache_size[0] + pre_encode_cache_size = streaming_cfg.pre_encode_cache_size[1] + else: + pre_encode_cache_size_first = streaming_cfg.pre_encode_cache_size + pre_encode_cache_size = streaming_cfg.pre_encode_cache_size + + cache_last_channel = perception_cache.cache_last_channel + cache_last_time = perception_cache.cache_last_time + cache_last_channel_len = perception_cache.cache_last_channel_len + + base_step_size = encoder._cfg.att_context_size[1] + 1 + if num_frames_per_chunk % base_step_size != 0: + raise ValueError( + f"num_frames_per_chunk must be a multiple of (lookahead + 1) = {base_step_size}. " + f"Got num_frames_per_chunk={num_frames_per_chunk}" + ) + num_sub_steps = num_frames_per_chunk // base_step_size + + start_time = time.time() + + encoded_chunks = [] + + for sub_step in range(num_sub_steps): + sub_step_start_time = time.time() + + sub_frame_idx = frame_idx + (sub_step * base_step_size) + is_first_sub_step = (sub_frame_idx == 0) + + if is_first_sub_step: + cur_chunk_size = chunk_size_first + cur_pre_encode_cache_size = pre_encode_cache_size_first + drop_extra_pre_encoded = 0 + + mel_chunk = processed_signal[:, :, :cur_chunk_size] + + if cur_pre_encode_cache_size > 0: + zeros_pad = torch.zeros( + (processed_signal.size(0), self.input_features, cur_pre_encode_cache_size), + device=self.device, + dtype=processed_signal.dtype, + ) + mel_chunk = torch.cat([zeros_pad, mel_chunk], dim=-1) + else: + cur_chunk_size = chunk_size + cur_pre_encode_cache_size = pre_encode_cache_size + drop_extra_pre_encoded = streaming_cfg.drop_extra_pre_encoded + + mel_T = processed_signal.shape[-1] + + step_number = sub_frame_idx // base_step_size + chunk_start = shift_size_first + (step_number - 1) * shift_size + chunk_end = chunk_start + cur_chunk_size + + offset = chunk_size - shift_size_first + if chunk_end > mel_T - offset: + sub_steps_remaining = num_sub_steps - 1 - sub_step + chunk_end = mel_T - offset - sub_steps_remaining * shift_size + chunk_start = chunk_end - cur_chunk_size + + main_chunk = processed_signal[:, :, chunk_start:chunk_end] + + cache_start = max(0, chunk_start - cur_pre_encode_cache_size) + cache_mel = processed_signal[:, :, cache_start:chunk_start] + + if cache_mel.shape[-1] < cur_pre_encode_cache_size: + zeros_pad = torch.zeros( + (cache_mel.size(0), cache_mel.size(1), cur_pre_encode_cache_size - cache_mel.shape[-1]), + device=self.device, + dtype=cache_mel.dtype, + ) + cache_mel = torch.cat([zeros_pad, cache_mel], dim=-1) + + mel_chunk = torch.cat([cache_mel, main_chunk], dim=-1) + + chunk_lengths = torch.tensor([mel_chunk.shape[-1]], dtype=torch.long, device=self.device) + + if self.use_cudagraph and self.cudagraph_state is not None and self.cudagraph_state.is_captured(): + graph_state = self.cudagraph_state + + if is_first_sub_step: + graph_state.static_mel_first.copy_(mel_chunk) + else: + graph_state.static_mel_subsequent.copy_(mel_chunk) + + if graph_state.static_cache_channel_in is not None and cache_last_channel is not None: + graph_state.static_cache_channel_in.copy_(cache_last_channel) + if graph_state.static_cache_time_in is not None and cache_last_time is not None: + graph_state.static_cache_time_in.copy_(cache_last_time) + if graph_state.static_cache_channel_len_in is not None and cache_last_channel_len is not None: + graph_state.static_cache_channel_len_in.copy_(cache_last_channel_len) + + if is_first_sub_step: + graph_state.graph_first.replay() + encoded_chunk = graph_state.static_encoded_first.clone() + cache_last_channel = graph_state.static_cache_channel_out_first.clone() if graph_state.static_cache_channel_out_first is not None else None + cache_last_time = graph_state.static_cache_time_out_first.clone() if graph_state.static_cache_time_out_first is not None else None + cache_last_channel_len = graph_state.static_cache_channel_len_out_first.clone() if graph_state.static_cache_channel_len_out_first is not None else None + else: + graph_state.graph_subsequent.replay() + encoded_chunk = graph_state.static_encoded_subsequent.clone() + cache_last_channel = graph_state.static_cache_channel_out_subsequent.clone() if graph_state.static_cache_channel_out_subsequent is not None else None + cache_last_time = graph_state.static_cache_time_out_subsequent.clone() if graph_state.static_cache_time_out_subsequent is not None else None + cache_last_channel_len = graph_state.static_cache_channel_len_out_subsequent.clone() if graph_state.static_cache_channel_len_out_subsequent is not None else None + + else: + ( + encoded, + encoded_len, + cache_last_channel, + cache_last_time, + cache_last_channel_len, + ) = encoder.cache_aware_stream_step( + processed_signal=mel_chunk, + processed_signal_length=chunk_lengths, + cache_last_channel=cache_last_channel, + cache_last_time=cache_last_time, + cache_last_channel_len=cache_last_channel_len, + keep_all_outputs=True, + drop_extra_pre_encoded=drop_extra_pre_encoded, + ) + + modality_adapter = perception.modality_adapter + encoded_adapted, _ = modality_adapter(audio_signal=encoded, length=encoded_len) + + encoded_chunk = perception.proj(encoded_adapted.transpose(1, 2)) + + torch.cuda.synchronize() + logging.info(f" Sub-step {sub_step}/{num_sub_steps} (sub_frame_idx={sub_frame_idx}, first={is_first_sub_step}): {time.time() - sub_step_start_time:.4f}s") + encoded_chunks.append(encoded_chunk) + + if len(encoded_chunks) > 1: + encoded_chunk = torch.cat(encoded_chunks, dim=1) + else: + encoded_chunk = encoded_chunks[0] + + torch.cuda.synchronize() + logging.info(f"Time taken for encoder ({num_sub_steps} sub-steps): {time.time() - start_time}") + + new_perception_cache = PerceptionCacheState( + cache_last_channel=cache_last_channel, + cache_last_time=cache_last_time, + cache_last_channel_len=cache_last_channel_len, + ) + + return encoded_chunk, new_perception_cache diff --git a/nemo/collections/speechlm2/inference/pipelines/__init__.py b/nemo/collections/speechlm2/inference/pipelines/__init__.py new file mode 100644 index 000000000000..9e3fb699d9f6 --- /dev/null +++ b/nemo/collections/speechlm2/inference/pipelines/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/speechlm2/inference/pipelines/s2s_pipeline_interface.py b/nemo/collections/speechlm2/inference/pipelines/s2s_pipeline_interface.py new file mode 100644 index 000000000000..7fe2b7b6143b --- /dev/null +++ b/nemo/collections/speechlm2/inference/pipelines/s2s_pipeline_interface.py @@ -0,0 +1,71 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Any + + +class S2SPipelineInterface: + """Base class for all streaming S2S pipelines. + + This class is intentionally kept minimal and mirrors the behaviour of + ``BasePipeline`` that is used for streaming ASR pipelines. It + provides a small in-memory *state pool* that stores per-stream objects + (cache, running buffers, etc.) required by a concrete pipeline + implementation. Sub-classes are expected to implement + :py:meth:`create_state` to construct a fresh state object. + """ + + def __init__(self) -> None: + # Pool that holds per-stream state, keyed by ``stream_id`` + self._state_pool: Dict[int, Any] = {} + + # ------------------------------------------------------------------ + # State helpers + # ------------------------------------------------------------------ + def get_state(self, stream_id: int): + """Return the state object for *stream_id* or *None* if it does not exist.""" + return self._state_pool.get(stream_id, None) + + def delete_state(self, stream_id: int) -> None: + """Delete the state associated with *stream_id* (noop if missing).""" + if stream_id in self._state_pool: + del self._state_pool[stream_id] + + def create_state(self): # noqa: D401 (keep same signature as recognizers) + """Create and return a *new*, *empty* state object. + + Must be implemented by concrete pipelines. + """ + raise NotImplementedError("`create_state()` must be implemented in a subclass.") + + def get_or_create_state(self, stream_id: int): + """Return existing state for *stream_id* or create a new one via :py:meth:`create_state`.""" + if stream_id not in self._state_pool: + self._state_pool[stream_id] = self.create_state() + return self._state_pool[stream_id] + + # ------------------------------------------------------------------ + # Session helpers – identical to *BasePipeline* + # ------------------------------------------------------------------ + def reset_session(self) -> None: + """Clear the internal *state pool* – effectively resetting the pipeline.""" + self._state_pool.clear() + + def open_session(self) -> None: + """Alias for :py:meth:`reset_session` to start a fresh streaming session.""" + self.reset_session() + + def close_session(self) -> None: + """Alias for :py:meth:`reset_session` to end the current streaming session.""" + self.reset_session() diff --git a/nemo/collections/speechlm2/inference/pipelines/streaming_s2s_pipeline.py b/nemo/collections/speechlm2/inference/pipelines/streaming_s2s_pipeline.py new file mode 100644 index 000000000000..39535e8d118d --- /dev/null +++ b/nemo/collections/speechlm2/inference/pipelines/streaming_s2s_pipeline.py @@ -0,0 +1,763 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import time + +import torch +import librosa +from typing import List, Optional +from torch import Tensor +import soundfile as sf +from omegaconf import DictConfig +import math + +from nemo.collections.asr.inference.streaming.framing.request import Frame +from nemo.collections.asr.inference.utils.enums import RequestType +from nemo.collections.asr.inference.streaming.framing.multi_stream import ContinuousBatchedFrameStreamer +from nemo.collections.asr.inference.streaming.buffering.audio_bufferer import BatchedAudioBufferer +from nemo.collections.asr.inference.utils.progressbar import ProgressBar +from nemo.collections.speechlm2.inference.pipelines.s2s_pipeline_interface import S2SPipelineInterface +from nemo.collections.speechlm2.inference.streaming.state.s2s_state import S2SStreamingState +from nemo.collections.speechlm2.inference.model_wrappers.nemotron_voicechat_inference_wrapper import NemotronVoicechatInferenceWrapper, tokens_to_str_raw +from nemo.collections.speechlm2.models.duplex_s2s_model import tokens_to_str +from nemo.collections.speechlm2.inference.streaming.state.s2s_context_manager import S2SContextManager +from nemo.collections.speechlm2.inference.streaming.framing.s2s_request_options import S2SRequestOptions +from nemo.collections.speechlm2.inference.utils.pipeline_utils import PipelineOutput +from nemo.utils import logging + + +class StreamingS2SPipeline(S2SPipelineInterface): + """ + Streaming S2S pipeline. + """ + + def __init__(self, cfg: DictConfig, s2s_model: NemotronVoicechatInferenceWrapper): + # ------------------------------------------------------------------ + # Model & device + # ------------------------------------------------------------------ + self.s2s_model = s2s_model + self.device = self.s2s_model.device + + # ------------------------------------------------------------------ + # Streaming configuration + # ------------------------------------------------------------------ + self.streaming_cfg = cfg.get("streaming", {}) + self.input_sample_rate = getattr(self.streaming_cfg, "input_sample_rate", 16000) + self.output_sample_rate = getattr(self.streaming_cfg, "output_sample_rate", 22050) + self.batch_size = getattr(self.streaming_cfg, "batch_size", 1) + self.max_len = getattr(self.streaming_cfg, "max_len", 200) + + + # ------------------------------------------------------------------ + # Chunk & buffer sizes + # Terminology: "frame" = 80ms audio unit, "chunk" = 1 or more frames + # A chunk is the amount of audio that is processed per inference step. + # ------------------------------------------------------------------ + self.chunk_size_in_secs = getattr(self.streaming_cfg, "chunk_size_in_secs", 0.08) + # Check if self.chunk_size_in_secs is a multiple of 0.08. + # Because of quirks of floating point arithmetic, the remainder could be either ~0 or ~0.08, + # so we check for both cases. + remainder = self.chunk_size_in_secs % 0.08 + if not (math.isclose(remainder, 0, abs_tol=1e-9) or math.isclose(remainder, 0.08, abs_tol=1e-9)): + raise ValueError(f"Chunk size must be a multiple of 0.08s, but got {self.chunk_size_in_secs}") + + self.num_frames_per_chunk = int(self.chunk_size_in_secs / 0.08) + + # Buffer size determines how much audio is passed to the perception encoder + # Default: 5.68 seconds (71 * 0.08). This is the minimum valid buffer size without the perception cache. + # i.e. att_context_size[0] + att_context_size[1] + 1 frames = 70+0+1 = 71 frames = 5.68 seconds + self.buffer_size_in_secs = getattr(self.streaming_cfg, "buffer_size_in_secs", 71 * 0.08) + + self.att_context_size = getattr(self.streaming_cfg, "att_context_size", [70,0]) + + # ------------------------------------------------------------------ + # bufferer – reused from ASR utilities + # ------------------------------------------------------------------ + self.bufferer = BatchedAudioBufferer( + sample_rate=self.input_sample_rate, + buffer_size_in_secs=self.buffer_size_in_secs, + ) + + # ------------------------------------------------------------------ + # System prompt configuration + # ------------------------------------------------------------------ + s2s_cfg = cfg.get("s2s", {}) + self.system_prompt: Optional[str] = getattr(s2s_cfg, "system_prompt", None) + if self.system_prompt: + logging.info(f"System prompt configured: {self.system_prompt[:100]}{'...' if len(self.system_prompt) > 100 else ''}") + + # Context manager + self.context_manager = S2SContextManager( + s2s_model=self.s2s_model, + num_slots=self.batch_size, + max_len=self.max_len, + ) + + # Output directory for generated files + self.output_dir = getattr(cfg, "output_dir", "./generated") + + # Parse and validate request type early, with a safe default + req_type_cfg = getattr(self.streaming_cfg, "request_type", "frame") + + # Parse and validate the request type; only 'frame' is supported for s2s. + self.request_type = RequestType.from_str(req_type_cfg) + if self.request_type is not RequestType.FRAME: + raise ValueError(f"Request type {self.request_type} is not supported for s2s.") + + self._stream_has_prompt: bool = False + + # ------------------------------------------------------------------ + # Input audio padding (silence appended after real audio) + # ------------------------------------------------------------------ + self.pad_audio_to_sec: float | None = cfg.get("pad_audio_to_sec", None) + self.pad_silence_ratio: float | None = cfg.get("pad_silence_ratio", None) + self.pad_audio_by_sec: float | None = cfg.get("pad_audio_by_sec", None) + if sum(x is not None for x in [self.pad_audio_to_sec, self.pad_silence_ratio, self.pad_audio_by_sec]) > 1: + raise ValueError("Set at most one of: pad_audio_to_sec, pad_silence_ratio, pad_audio_by_sec") + + super().__init__() + + # -------------------------------- ---------------------------------- + # State helpers + # ------------------------------------------------------------------ + def create_state(self) -> S2SStreamingState: + """Create new empty state.""" + num_audio_codebooks = getattr(self.s2s_model.model, "_num_codebooks", 1) + dtype = getattr(self.s2s_model, "compute_dtype", torch.float32) + state = S2SStreamingState( + device=self.device, + dtype=dtype, + max_len=self.max_len, + num_audio_codebooks=num_audio_codebooks, + output_sample_rate=self.output_sample_rate, + ) + return state + + + # ------------------------------------------------------------------ + # Output helpers + # ------------------------------------------------------------------ + def log_output(self, frames: List[Frame], audio_wave: Tensor, ready_feats: List[bool], text_pieces: List[str], asr_text_pieces: List[str] = None): + """Append generated audio waveform and text to per-stream state.""" + for idx, frame in enumerate(frames): + if not ready_feats[idx]: + continue + state = self.get_or_create_state(frame.stream_id) + # audio_wave is [B, S]; take sample idx + sample_audio = audio_wave[idx:idx+1, ...] + # Determine text piece for this index + piece = None + if text_pieces and idx < len(text_pieces): + candidate = text_pieces[idx] + if isinstance(candidate, str) and candidate: + piece = candidate + + # Determine ASR text piece + asr_piece = None + if asr_text_pieces and idx < len(asr_text_pieces): + candidate = asr_text_pieces[idx] + if isinstance(candidate, str) and candidate: + asr_piece = candidate + + state.update_state(sample_audio, output_text=piece, output_asr_text=asr_piece) + + + def inner_generate_step(self, frames: List[Frame], buffers: List[Tensor], left_paddings: List[int], ready_feats: List[bool]): + """Generate speech for chunks in *batch* using a shared ContextManager.""" + if len(frames) == 0: + return + + stream_ids = [f.stream_id for f in frames] + eos_flags = [f.is_last for f in frames] + bos_flags = [f.is_first for f in frames] + + logging.debug(f"stream_ids={stream_ids} bos_flags={bos_flags} eos_flags={eos_flags}") + + if len(frames) != 1: + raise NotImplementedError("NemotronVoicechatInferenceWrapper currently supports batch_size == 1") + + # If this is the first audio frame and prefill was already done via a + # zero-length prefill frame, skip context init -- it's already set up. + # Otherwise (no system prompt), create a fresh context_manager. + has_prompt = False + if bos_flags[0]: + if self._stream_has_prompt: + logging.debug(f"Prefill already done for stream {stream_ids[0]}, skipping context init") + else: + logging.debug(f"No prefill for stream {stream_ids[0]}, creating fresh context_manager") + self.context_manager = S2SContextManager( + s2s_model=self.s2s_model, + num_slots=self.batch_size, + max_len=self.max_len, + ) + + has_prompt = self._stream_has_prompt + self._stream_has_prompt = False + + request_id = self._request_id_for_stream(stream_ids[0]) + + context, _ = self.context_manager.get_context(stream_ids) + + audio_buffer = buffers[0] + if audio_buffer.dim() == 1: + audio_buffer = audio_buffer.unsqueeze(0) + audio_buffer = audio_buffer.to(self.s2s_model.device, dtype=self.s2s_model.dtype) + + # Trim the buffer to exclude left padding (zeros at the beginning before buffer is filled) + left_pad = left_paddings[0] + if left_pad > 0: + audio_buffer = audio_buffer[:, left_pad:] + + result = self.s2s_model.infer_one_step( + audio_input=audio_buffer, + num_frames_per_chunk=self.num_frames_per_chunk, + frame_idx=context.frame_idx, + gen_text=context.gen_text, + audio_toks_buffer=context.audio_toks_buffer, + input_embeds_history=context.input_embeds_history, + dynamic_cache=context.dynamic_cache, + past_key_values=context.past_key_values, + code=context.code, + subword_mask=context.subword_mask, + gen_asr_text=context.gen_asr_text, + gen_function_text=context.gen_function_text, + request_id=request_id, + perception_cache=context.perception_cache, + has_prompt=has_prompt, + codec_cache=context.codec_cache, + ) + + # Persist updated cache & clean finished streams + self.context_manager.update_context(stream_ids, result, self.num_frames_per_chunk) + + # Save full token tensors to state before the context is destroyed, + # so we can run tokens_to_str / tokens_to_str_raw post-hoc. + for stream_id, eos_flag in zip(stream_ids, eos_flags): + if eos_flag: + ctx = self.context_manager.slot_contexts[ + self.context_manager.streamidx2slotidx[stream_id] + ] + if ctx is not None: + state = self.get_or_create_state(stream_id) + state.save_token_tensors(ctx.gen_text, ctx.gen_asr_text, ctx.frame_idx, + gen_function_text=ctx.gen_function_text) + + self.context_manager.reset_slots(stream_ids, eos_flags) + + # Explicitly clean up bufferer and state for finished streams + for stream_id, eos_flag in zip(stream_ids, eos_flags): + if eos_flag: + logging.debug(f"Ending stream {stream_id} - cleaning up bufferer and context") + self.bufferer.rm_bufferer(stream_id) + self._abort_stream_request(stream_id) + # Note: We keep the state in _state_pool until finalization to save audio + # It will be cleaned up in close_session() + + # Log audio and attach text to state + self.log_output(frames, result["decoded_audio_new"], ready_feats, result["predicted_text_strs"], result.get("asr_predicted_text_strs")) + + def prefill_for_new_stream(self, stream_id: int, system_prompt: str | None = None) -> bool: + """Prepare the pipeline for a new stream by resetting context and prefilling the system prompt. + + This is the public API for prefill-only calls (e.g. from the Triton backend) + that need to initialize TTS speaker embeddings and/or inject a system prompt + into the LLM KV cache *without* processing any audio. + + Args: + stream_id: Unique identifier for the new stream. + system_prompt: System prompt text. If *None*, falls back to + the YAML-configured ``self.system_prompt``. + + Returns: + True if a system prompt was prefilled, False otherwise. + """ + t0 = time.time() + if system_prompt is None: + system_prompt = self.system_prompt + + self.context_manager = S2SContextManager( + s2s_model=self.s2s_model, + num_slots=self.batch_size, + max_len=self.max_len, + ) + t_ctx = time.time() + + with torch.no_grad(), torch.inference_mode(): + self._prefill_system_prompt(stream_id, system_prompt) + t_prefill = time.time() + + self._stream_has_prompt = bool(system_prompt) + logging.debug(f"prefill_for_new_stream: context_manager={1000*(t_ctx-t0):.1f}ms, " + f"_prefill_system_prompt={1000*(t_prefill-t_ctx):.1f}ms, " + f"total={1000*(t_prefill-t0):.1f}ms, has_prompt={self._stream_has_prompt}") + return self._stream_has_prompt + + _WARMUP_FALLBACK_PROMPT = "Mock system prompt for warmup." + + def warmup(self, system_prompt: str | None = None) -> None: + """Run a throwaway prefill cycle to warm up the inference engine. + + The very first prefill incurs one-time overhead (e.g. CUDA graph + compilation, memory pool allocation, DynamicCache initialization). + Calling this once during startup moves that cost out of the + critical path so the first real client request is fast. + + The method performs a full prefill (TTS speaker embedding + LLM + system prompt), then aborts the request and resets all pipeline + state so the next real stream starts cleanly. + + Args: + system_prompt: Prompt text to use for warmup. Falls back to + the YAML-configured ``self.system_prompt``, then to a + short fallback string so the LLM prefill path is always + exercised. + """ + prompt = system_prompt if system_prompt is not None else self.system_prompt + if not prompt: + prompt = self._WARMUP_FALLBACK_PROMPT + logging.info(f"No system prompt configured — using fallback prompt for warmup: \"{prompt}\"") + + warmup_stream_id = -1 + + logging.info("Running pipeline warmup prefill...") + t0 = time.time() + + self.prefill_for_new_stream(warmup_stream_id, prompt) + + # Tear down the warmup request so the engine is clean for real traffic + self._abort_stream_request(warmup_stream_id) + self.context_manager.reset() + self._stream_has_prompt = False + + logging.info(f"Pipeline warmup complete in {time.time() - t0:.3f}s") + + def generate_step(self, frames: List[Frame]): + """Main streaming API similar to *transcribe_step* in recognizers. + + If the batch contains a single zero-length first frame with a system + prompt in ``options``, this is treated as a **prefill-only** request: + the context manager and system prompt are initialized but no audio + inference runs. This is the unified protocol used by both the CLI + (``run()``) and the Triton backend. + """ + # Detect prefill-only frame: is_first + zero-length audio + if (len(frames) == 1 + and frames[0].is_first + and frames[0].samples.numel() == 0): + opts = frames[0].options + prompt = None + if opts is not None and hasattr(opts, "system_prompt"): + prompt = opts.system_prompt + self.prefill_for_new_stream(frames[0].stream_id, prompt) + return + + buffers, left_paddings = self.bufferer.update(frames) + ready_feats = [True] * len(frames) + + with torch.no_grad(), torch.inference_mode(): + self.inner_generate_step(frames, buffers, left_paddings, ready_feats) + + # ------------------------------------------------------------------ + # Finalization helpers + # ------------------------------------------------------------------ + def _finalize_and_save_finished_streams( + self, + frames: List[Frame], + audio_filepaths: List[str], + saved_paths_by_stream: dict[int, str], + ) -> None: + """Finalize any streams that ended in this batch and save their audio.""" + for frame in frames: + if frame.is_last: + stream_id = frame.stream_id + state = self.get_or_create_state(stream_id) + + # Flush remaining buffered samples and assemble waveform + if hasattr(state, "finalize"): + state.finalize() + # Concatenate emitted chunks and squeeze (B=1,C=1) to mono waveform + generated_audio = torch.cat(state.speech_frames, dim=-1) + # Ensure 1D mono waveform and float32 dtype for soundfile + if generated_audio.dim() == 3 and generated_audio.size(0) == 1 and generated_audio.size(1) == 1: + generated_audio = generated_audio.squeeze(0).squeeze(0) + elif generated_audio.dim() == 2 and generated_audio.size(0) == 1: + generated_audio = generated_audio.squeeze(0) + generated_audio = generated_audio.to(torch.float32) + + # Build output paths in subdirectories under output_dir + in_path = audio_filepaths[stream_id] + base = os.path.splitext(os.path.basename(in_path))[0] + + wav_dir = os.path.join(self.output_dir, "wav") + stereo_dir = os.path.join(self.output_dir, "stereo") + txt_dir = os.path.join(self.output_dir, "txt") + os.makedirs(wav_dir, exist_ok=True) + os.makedirs(stereo_dir, exist_ok=True) + os.makedirs(txt_dir, exist_ok=True) + + out_path = os.path.join(wav_dir, f"{base}.wav") + + # Write audio to disk + if generated_audio.numel() > 0: + sf.write(out_path, generated_audio.detach().cpu().numpy(), self.output_sample_rate) + + # Also save a stereo file with input (ch0) and output (ch1) + # Load input with librosa (handles mono conversion and resampling) + input_np, _ = librosa.load(in_path, sr=self.output_sample_rate, mono=True) + input_audio = torch.from_numpy(input_np).to(torch.float32) + gen_cpu = generated_audio.detach().cpu().to(input_audio.dtype) + + # Prepend silence to output channel to account for + # the one-chunk processing delay: the server can't + # produce output until it has received a full input chunk. + delay_samples = int(self.chunk_size_in_secs * self.output_sample_rate) + silence = torch.zeros(delay_samples, dtype=gen_cpu.dtype) + gen_cpu = torch.cat([silence, gen_cpu], dim=-1) + + gen_len = int(gen_cpu.shape[-1]) + in_len = int(input_audio.shape[-1]) + max_len = max(gen_len, in_len) + if in_len < max_len: + input_audio = torch.cat([input_audio, torch.zeros(max_len - in_len, dtype=input_audio.dtype)], dim=-1) + if gen_len < max_len: + gen_cpu = torch.cat([gen_cpu, torch.zeros(max_len - gen_len, dtype=gen_cpu.dtype)], dim=-1) + stereo = torch.stack([input_audio, gen_cpu], dim=0).transpose(0, 1) + stereo_path = os.path.join(stereo_dir, f"{base}_input_output.wav") + sf.write(stereo_path, stereo.detach().cpu().numpy(), self.output_sample_rate) + + # Save accumulated text + text_out = state.get_output_text() if hasattr(state, "get_output_text") else "" + if isinstance(text_out, str): + try: + with open(os.path.join(txt_dir, f"{base}.txt"), "w", encoding="utf-8") as f: + f.write(text_out) + except Exception: + pass + + # Save accumulated ASR text + asr_text_out = state.get_output_asr_text() if hasattr(state, "get_output_asr_text") else "" + if isinstance(asr_text_out, str) and asr_text_out: + try: + with open(os.path.join(txt_dir, f"{base}_asr.txt"), "w", encoding="utf-8") as f: + f.write(asr_text_out) + except Exception: + pass + + saved_paths_by_stream[stream_id] = out_path + + # Keep state until outputs are assembled; will be cleared on close_session + + + # ------------------------------------------------------------------ + # Session helpers (extend S2SPipelineInterface) + # ------------------------------------------------------------------ + + def reset_session(self) -> None: + """Reset feature buffer and ContextManager together.""" + for stream_id in list(self.context_manager.streamidx2slotidx.keys()): + self._abort_stream_request(stream_id) + self.bufferer.reset() + self.context_manager.reset() + + super().reset_session() # clears state pool + + # ------------------------------------------------------------------ + # Orchestrator – mirrors recognizers' *run* method + # ------------------------------------------------------------------ + def run( + self, + audio_filepaths: List[str], + options: List[S2SRequestOptions] | None = None, + progress_bar: Optional[ProgressBar] = None, + ) -> PipelineOutput: + """Stream all *audio_filepaths* through the pipeline and save outputs. + + Saves one generated ``.wav`` per input under ``self.output_dir`` and + returns their paths in ``PipelineOutput.texts``. + """ + if progress_bar and not isinstance(progress_bar, ProgressBar): + raise ValueError("progress_bar must be an instance of ProgressBar.") + + if options is None: + options = [S2SRequestOptions(system_prompt=self.system_prompt) for _ in audio_filepaths] + + streamer = ContinuousBatchedFrameStreamer( + n_frames_per_stream=1, + frame_size_in_secs=self.chunk_size_in_secs, + sample_rate=self.input_sample_rate, + batch_size=self.batch_size, + pad_last_frame=True, + ) + + streamer.set_audio_filepaths(audio_filepaths, options) + streamer.set_progress_bar(progress_bar) + + # Ensure output directory exists + os.makedirs(self.output_dir, exist_ok=True) + + # Track saved paths by stream id to preserve input order + saved_paths_by_stream: dict[int, str] = {} + chunk_samples = int(self.chunk_size_in_secs * self.input_sample_rate) + + self.open_session() + for frames in streamer: + # Unified prefill protocol: if the first frame of a new stream + # carries a system prompt, emit a zero-length prefill frame first. + if (len(frames) == 1 + and frames[0].is_first + and frames[0].options is not None + and hasattr(frames[0].options, "system_prompt") + and frames[0].options.system_prompt): + prefill_frame = Frame( + samples=torch.empty(0), + stream_id=frames[0].stream_id, + is_first=True, + is_last=False, + options=frames[0].options, + ) + self.generate_step([prefill_frame]) + + # If padding is configured, intercept last frames so the + # bufferer/context stay alive for the silence-padding phase. + # Padding is generated immediately (same iteration) to avoid + # the next stream's setup destroying this stream's context. + pad_targets: dict[int, float] = {} + if self.pad_audio_to_sec or self.pad_silence_ratio or self.pad_audio_by_sec: + processed_frames = [] + for frame in frames: + if frame.is_last: + elapsed = streamer.elapsed_durations[frame.stream_id] + remaining = self._padding_remaining_secs(elapsed) + if remaining > 0: + processed_frames.append(Frame( + samples=frame.samples, + stream_id=frame.stream_id, + is_first=frame.is_first, + is_last=False, + length=frame.length, + options=frame.options, + )) + pad_targets[frame.stream_id] = remaining + continue + processed_frames.append(frame) + frames = processed_frames + + self.generate_step(frames) + self._finalize_and_save_finished_streams(frames, audio_filepaths, saved_paths_by_stream) + + # Generate silence padding before the next iteration adds a new stream + for stream_id, remaining_secs in pad_targets.items(): + num_pad_frames = max(1, round(remaining_secs / self.chunk_size_in_secs)) + for i in range(num_pad_frames): + is_last = (i == num_pad_frames - 1) + silence_frame = Frame( + samples=torch.zeros(chunk_samples), + stream_id=stream_id, + is_first=False, + is_last=is_last, + length=chunk_samples, + ) + self.generate_step([silence_frame]) + if is_last: + self._finalize_and_save_finished_streams( + [silence_frame], audio_filepaths, saved_paths_by_stream + ) + # Build outputs before closing the session + texts = [] + words = [] + asr_texts = [] + texts_with_timestamps = [] + asr_texts_with_timestamps = [] + raw_texts = [] + raw_asr_texts = [] + + tokenizer = self.s2s_model.tokenizer + pad_id = self.s2s_model.model.stt_model.text_pad_id + + for idx in range(len(audio_filepaths)): + state = self.get_or_create_state(idx) + text_value = state.get_output_text() if hasattr(state, "get_output_text") else "" + if not text_value: + text_value = saved_paths_by_stream.get(idx, "") + texts.append(text_value) + per_stream_words = state.get_output_words() if hasattr(state, "get_output_words") else [] + words.append(per_stream_words) + asr_text_value = state.get_output_asr_text() if hasattr(state, "get_output_asr_text") else "" + asr_texts.append(asr_text_value) + + token_data = state.get_token_tensors() + if token_data is not None: + gen_text, gen_asr_text, total_frames, gen_function_text = token_data + lengths = torch.tensor([total_frames], dtype=torch.long) + texts_with_timestamps.append( + tokens_to_str(gen_text, lengths, tokenizer=tokenizer, pad_id=pad_id, eval_text_turn_taking=True)[0] + ) + asr_texts_with_timestamps.append( + tokens_to_str(gen_asr_text, lengths, tokenizer=tokenizer, pad_id=pad_id, eval_text_turn_taking=True)[0] + ) + raw_texts.append( + tokens_to_str_raw(gen_text, lengths, tokenizer=tokenizer, pad_id=pad_id)[0] + ) + raw_asr_texts.append( + tokens_to_str_raw(gen_asr_text, lengths, tokenizer=tokenizer, pad_id=pad_id)[0] + ) + if gen_function_text is not None: + fc_text = tokens_to_str(gen_function_text, lengths, tokenizer=tokenizer, pad_id=pad_id, eval_text_turn_taking=False)[0] + fc_text_raw = tokens_to_str_raw(gen_function_text, lengths, tokenizer=tokenizer, pad_id=pad_id)[0] + logging.info(f"Function calling channel: {fc_text}") + else: + texts_with_timestamps.append("") + asr_texts_with_timestamps.append("") + raw_texts.append("") + raw_asr_texts.append("") + + self.close_session() + + return PipelineOutput( + texts=texts, + words=words, + asr_texts=asr_texts, + texts_with_timestamps=texts_with_timestamps, + asr_texts_with_timestamps=asr_texts_with_timestamps, + raw_texts=raw_texts, + raw_asr_texts=raw_asr_texts, + ) + + def _prefill_system_prompt(self, stream_id: int, system_prompt: str | None = None) -> Optional[torch.Tensor]: + """Prefill the system prompt for a new stream. + + This prepares the system prompt embeddings and processes them through + the LLM to update the KV cache before audio streaming begins. + Also prefills the TTS model with speaker embeddings when using vLLM EarTTS. + + Args: + stream_id: The stream identifier. + system_prompt: The system prompt text for this stream. If *None*, + TTS prefill still runs (for vLLM EarTTS) but no LLM prompt + is injected. + + Note on TTS prefill codes: + The TTS prefill generates output codes, but these should NOT be used + to initialize context.code for inference. The batch approach uses + first_tts_code_input (INPUT codes from speaker reference) instead. + Using prefill OUTPUT codes causes audio quality issues (mumbling). + + Returns: + Optional[torch.Tensor]: The TTS prefill output codes if vLLM EarTTS prefill + happened, None otherwise. These are returned for logging/debugging but + should NOT be used to update context.code. + """ + request_id = self._request_id_for_stream(stream_id) + engine_type = getattr(self.s2s_model, "engine_type", "native") + tts_output_code = None + + # Prefill TTS with speaker embedding when using vLLM EarTTS + # This initializes the vLLM TTS engine with the speaker context via prompt_token_ids + use_vllm_eartts = "vllm_eartts" in engine_type.lower() + if use_vllm_eartts: + tts_init_inputs = getattr(self.s2s_model, "tts_init_inputs", None) + tts_prompt_token_ids = getattr(self.s2s_model, "tts_prompt_token_ids", None) + if tts_init_inputs is not None and tts_prompt_token_ids is not None: + logging.info(f"Prefilling TTS speaker embedding for stream {stream_id}...") + start_tts_prefill = time.time() + with torch.no_grad(): + # Clone tts_init_inputs to avoid any tensor sharing issues + import copy + tts_inputs_copy = copy.deepcopy(tts_init_inputs) + tts_result = self.s2s_model.model.tts_model.tts_model( + tts_inputs_copy, + request_id=request_id, + prompt_token_ids=tts_prompt_token_ids + ) + # Capture the generated codes to sync context with vLLM state + if hasattr(tts_result, 'codes') and tts_result.codes is not None: + tts_output_code = tts_result.codes.detach().clone() + logging.debug(f"TTS prefill generated codes shape: {tts_output_code.shape}") + logging.info(f"TTS speaker embedding prefilled in {time.time() - start_tts_prefill:.3f}s") + else: + logging.warning("TTS init inputs not available, skipping TTS prefill") + + if not system_prompt: + return tts_output_code + + logging.info(f"Prefilling system prompt for stream {stream_id}...") + start_get_prompt_embeddings = time.time() + prompt_embedded, prompt_len = self.s2s_model._prepare_system_prompt_embeddings(system_prompt) + logging.debug(f"Time taken to get prompt embeddings: {time.time() - start_get_prompt_embeddings:.3f}s") + + if prompt_embedded is None: + logging.warning("System prompt embedding returned None, skipping prefill") + return tts_output_code + + # Check if using vLLM for LLM (matches vllm_llm, vllm_llm_vllm_eartts, etc.) + use_vllm_llm = "vllm_llm" in engine_type.lower() + + if use_vllm_llm: + # For vLLM LLM: prefill all prompt embeddings in one shot + # (decode_steps=0 triggers a single bulk prefill in the vLLM engine) + logging.info(f"Prefilling {prompt_len} prompt embeddings for vLLM LLM...") + start_prefill = time.time() + with torch.no_grad(): + _ = self.s2s_model.model_llm_interface( + prompt_embedded, + request_id=request_id, + decode_steps=0, + prompt_token_ids=None, + ) + logging.info(f"System prompt prefilled ({prompt_len} tokens) in {time.time() - start_prefill:.3f}s") + + else: + context, _ = self.context_manager.get_context([stream_id]) + if context.dynamic_cache is not None: + # Native cache mode: process prompt through LLM to update KV cache + with torch.no_grad(): + llm_cache = context.dynamic_cache + ans = self.s2s_model.model_llm_interface( + prompt_embedded, + cache=llm_cache, + generated_tokens=None, + current_step=0 + ) + context.dynamic_cache = ans.get("cache", llm_cache) + logging.info(f"System prompt processed, cache updated ({prompt_len} tokens)") + else: + # No-cache mode (e.g. Nemotron): add prompt embeddings to history + for t in range(prompt_len): + context.input_embeds_history.append(prompt_embedded[:, t:t+1, :]) + logging.info(f"Added {prompt_len} prompt embeddings to input_embeds_history") + + return tts_output_code + + def _padding_remaining_secs(self, elapsed_secs: float) -> float: + """Return how many seconds of silence padding are still needed.""" + if self.pad_audio_to_sec is not None: + return max(0.0, self.pad_audio_to_sec - elapsed_secs) + if self.pad_silence_ratio is not None: + return elapsed_secs * self.pad_silence_ratio + if self.pad_audio_by_sec is not None: + return self.pad_audio_by_sec + return 0.0 + + def _request_id_for_stream(self, stream_id: int) -> str: + return str(stream_id) + + def _abort_stream_request(self, stream_id: int) -> None: + request_id = self._request_id_for_stream(stream_id) + abort_fn = getattr(self.s2s_model, "abort_request", None) + if callable(abort_fn): + try: + abort_fn(request_id) + except Exception as exc: + logging.warning(f"Failed to abort request {request_id} for stream {stream_id}: {exc}") diff --git a/nemo/collections/speechlm2/inference/streaming/__init__.py b/nemo/collections/speechlm2/inference/streaming/__init__.py new file mode 100644 index 000000000000..9e3fb699d9f6 --- /dev/null +++ b/nemo/collections/speechlm2/inference/streaming/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/speechlm2/inference/streaming/framing/__init__.py b/nemo/collections/speechlm2/inference/streaming/framing/__init__.py new file mode 100644 index 000000000000..9e3fb699d9f6 --- /dev/null +++ b/nemo/collections/speechlm2/inference/streaming/framing/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/speechlm2/inference/streaming/framing/s2s_request_options.py b/nemo/collections/speechlm2/inference/streaming/framing/s2s_request_options.py new file mode 100644 index 000000000000..4bbb222b1149 --- /dev/null +++ b/nemo/collections/speechlm2/inference/streaming/framing/s2s_request_options.py @@ -0,0 +1,27 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + + +@dataclass(slots=True) +class S2SRequestOptions: + """Per-stream options for S2S inference. + + Attached to the first ``Frame`` of each stream via the ``options`` + field so that the pipeline can read per-stream configuration at the + start of every new audio file / Triton sequence. + """ + + system_prompt: str | None = None diff --git a/nemo/collections/speechlm2/inference/streaming/state/__init__.py b/nemo/collections/speechlm2/inference/streaming/state/__init__.py new file mode 100644 index 000000000000..9e3fb699d9f6 --- /dev/null +++ b/nemo/collections/speechlm2/inference/streaming/state/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/speechlm2/inference/streaming/state/s2s_context_manager.py b/nemo/collections/speechlm2/inference/streaming/state/s2s_context_manager.py new file mode 100644 index 000000000000..997975e54373 --- /dev/null +++ b/nemo/collections/speechlm2/inference/streaming/state/s2s_context_manager.py @@ -0,0 +1,292 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from queue import Queue +from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING + +import torch +from transformers import DynamicCache + +from nemo.collections.speechlm2.modules.ear_tts_vae_codec import CausalConv1dCache +from nemo.utils import logging + +if TYPE_CHECKING: + from nemo.collections.speechlm2.inference.model_wrappers.perception_cache import PerceptionCacheState + + +@dataclass +class StreamingRealtimeContext: + frame_idx: int + gen_text: torch.Tensor + gen_asr_text: torch.Tensor + gen_function_text: Optional[torch.Tensor] + audio_toks_buffer: Optional[torch.Tensor] + input_embeds_history: List[torch.Tensor] + dynamic_cache: Optional[DynamicCache] + past_key_values: Any + code: Optional[torch.Tensor] + subword_mask: Optional[torch.Tensor] + perception_cache: Optional["PerceptionCacheState"] = None + codec_cache: Optional[CausalConv1dCache] = None + + +class S2SContextManager: + + def __init__( + self, + s2s_model, + num_slots: int, + max_len: int, + ): + self.s2s_model = s2s_model + self.num_slots = num_slots + + # Detect Nemotron models and disable DynamicCache + # (they require NemotronHHybridDynamicCache which isn't supported yet) + self.cache_disabled = False + stt_model = getattr(self.s2s_model.model, "stt_model", None) + if stt_model is not None: + pretrained_llm = stt_model.cfg.get("pretrained_llm", "") + if "Nemotron" in pretrained_llm: + logging.warning( + f"Detected Nemotron model ({pretrained_llm}). " + "Disabling DynamicCache (Nemotron requires NemotronHHybridDynamicCache which is not yet supported)." + ) + self.cache_disabled = True + self.max_len = max_len + self.device = getattr(self.s2s_model, "device", torch.device("cpu")) + self.dtype = getattr(self.s2s_model, "dtype", torch.float32) + self.text_pad_id = getattr(getattr(self.s2s_model, "model", None), "text_pad_id", 0) + self.codec_token_history_size = int(getattr(self.s2s_model, "codec_token_history_size", 0)) + self.decode_audio = bool(getattr(self.s2s_model, "decode_audio", False)) + self.use_perception_cache = bool(getattr(self.s2s_model, "use_perception_cache", False)) + self.use_codec_cache = bool(getattr(self.s2s_model, "use_codec_cache", True)) + + self.reset() + + def reset(self) -> None: + """Reset all bookkeeping for a new streaming session.""" + self.streamidx2slotidx: Dict[int, int] = {} + self.slotidx2streamidx: Dict[int, int] = {} + self.free_slots = Queue(self.num_slots) + for i in range(self.num_slots): + self.free_slots.put(i) + self.slot_contexts: List[Optional[StreamingRealtimeContext]] = [None] * self.num_slots + + def _create_context(self) -> StreamingRealtimeContext: + """Allocate a fresh context backed by the realtime inference model.""" + gen_text = torch.full( + (1, self.max_len), + fill_value=self.text_pad_id, + device=self.device, + dtype=torch.long, + ) + + gen_asr_text = torch.full( + (1, self.max_len), + fill_value=self.text_pad_id, + device=self.device, + dtype=torch.long, + ) + + stt_model = getattr(self.s2s_model.model, "stt_model", None) + has_function_head = stt_model is not None and getattr(stt_model, "function_head", None) is not None + gen_function_text = None + if has_function_head: + gen_function_text = torch.full( + (1, self.max_len), + fill_value=self.text_pad_id, + device=self.device, + dtype=torch.long, + ) + + dynamic_cache = None if self.cache_disabled else DynamicCache() + audio_toks_buffer: Optional[torch.Tensor] = None + past_key_values: Any = None + code: Optional[torch.Tensor] = None + subword_mask: Optional[torch.Tensor] = None + perception_cache = None + codec_cache = None + + if self.decode_audio and hasattr(getattr(self.s2s_model, "model", None), "tts_model"): + tts_model = self.s2s_model.model.tts_model + if self.use_codec_cache: + # Incremental decode path: CausalConv1dCache maintains all codec + # context internally, so no audio_toks_buffer is needed and + # codec_token_history_size is irrelevant. + codec_cache = CausalConv1dCache() + elif self.codec_token_history_size > 0: + # Sliding-window fallback: allocate silence buffer of + # codec_token_history_size tokens that is re-decoded every step. + silence_tokens_base = tts_model.codec_silence_tokens.detach().clone() + silence_tokens = silence_tokens_base.view(1, 1, -1).expand( + -1, self.codec_token_history_size, -1 + ).contiguous() # contiguous() ensures it's a real copy, not a view + audio_toks_buffer = silence_tokens.to(self.device).clone() + subword_mask = torch.ones((1, self.max_len), device=self.device, dtype=torch.bool) + + if getattr(self.s2s_model, "first_tts_past_key_values_input", None) is not None: + past_key_values = self.s2s_model._clone_cache(self.s2s_model.first_tts_past_key_values_input) + if getattr(self.s2s_model, "first_tts_code_input", None) is not None: + code = self.s2s_model.first_tts_code_input.detach().clone() + + # Initialize perception cache if enabled + if self.use_perception_cache: + mgr = getattr(self.s2s_model, "perception_cache_mgr", None) + if mgr is not None: + perception_cache = mgr.get_initial_state(batch_size=1) + + return StreamingRealtimeContext( + frame_idx=0, + gen_text=gen_text, + gen_asr_text=gen_asr_text, + gen_function_text=gen_function_text, + audio_toks_buffer=audio_toks_buffer, + input_embeds_history=[], + dynamic_cache=dynamic_cache, + past_key_values=past_key_values, + code=code, + subword_mask=subword_mask, + perception_cache=perception_cache, + codec_cache=codec_cache, + ) + + def _ensure_slot(self, stream_id: int) -> int: + if stream_id not in self.streamidx2slotidx: + if self.free_slots.empty(): + # Emergency cleanup: force-release all slots for a fresh start + # This handles cases where previous streams didn't end properly + # (e.g., exceptions, client disconnects, missing is_last=True) + logging.warning(f"No free slots available - forcing cleanup of all {self.num_slots} slots") + orphaned_streams = list(self.slotidx2streamidx.values()) + if orphaned_streams: + logging.warning(f"Orphaned streams being cleaned up: {orphaned_streams}") + for slot_idx in range(self.num_slots): + self.reset_slot(slot_idx) + slot_idx = self.free_slots.get() + # Ensure the slot is completely clean before assigning to new stream + if self.slot_contexts[slot_idx] is not None: + logging.warning(f"Slot {slot_idx} was not properly cleaned. Forcing cleanup.") + self.slot_contexts[slot_idx] = None + self.streamidx2slotidx[stream_id] = slot_idx + self.slotidx2streamidx[slot_idx] = stream_id + return self.streamidx2slotidx[stream_id] + + def reset_slot(self, slot_idx: int) -> None: + """Release a slot back to the pool.""" + if slot_idx < 0 or slot_idx >= self.num_slots: + return + # Set to None to break reference and allow garbage collection + self.slot_contexts[slot_idx] = None + stream_id = self.slotidx2streamidx.get(slot_idx) + if stream_id is not None: + del self.slotidx2streamidx[slot_idx] + del self.streamidx2slotidx[stream_id] + self.free_slots.put(slot_idx) + + def update_context( + self, + stream_ids: List[int], + step_result: Dict[str, Any], + num_frames: int, + ) -> None: + """Persist model outputs back into the cached context.""" + if len(stream_ids) == 0: + return + if len(stream_ids) != 1: + raise NotImplementedError("update_context currently supports batch_size == 1") + + stream_id = stream_ids[0] + slot_idx = self.streamidx2slotidx.get(stream_id) + if slot_idx is None: + raise RuntimeError(f"Stream {stream_id} is not registered in the context manager") + + context = self.slot_contexts[slot_idx] + if context is None: + context = self._create_context() + self.slot_contexts[slot_idx] = context + + start_idx = context.frame_idx + end_idx = start_idx + num_frames + if end_idx > context.gen_text.shape[1]: + raise RuntimeError( + "Context maximum length exceeded. Consider increasing `streaming.max_len` in the configuration." + ) + + predicted_tokens = step_result.get("predicted_text_tokens") + if predicted_tokens is not None: + if predicted_tokens.dim() == 1: + token_slice = predicted_tokens.unsqueeze(0) + else: + token_slice = predicted_tokens[0:1] + context.gen_text[:, start_idx:end_idx] = token_slice.to(context.gen_text.device) + + asr_predicted_tokens = step_result.get("asr_predicted_text_tokens") + if asr_predicted_tokens is not None: + if asr_predicted_tokens.dim() == 1: + asr_token_slice = asr_predicted_tokens.unsqueeze(0) + else: + asr_token_slice = asr_predicted_tokens[0:1] + context.gen_asr_text[:, start_idx:end_idx] = asr_token_slice.to(context.gen_asr_text.device) + + func_predicted_tokens = step_result.get("function_predicted_text_tokens") + if func_predicted_tokens is not None and context.gen_function_text is not None: + if func_predicted_tokens.dim() == 1: + func_token_slice = func_predicted_tokens.unsqueeze(0) + else: + func_token_slice = func_predicted_tokens[0:1] + context.gen_function_text[:, start_idx:end_idx] = func_token_slice.to(context.gen_function_text.device) + + context.frame_idx = end_idx + + if step_result.get("dynamic_cache") is not None: + context.dynamic_cache = step_result["dynamic_cache"] + if "audio_toks_buffer" in step_result: + context.audio_toks_buffer = step_result["audio_toks_buffer"] + if "input_embeds_history" in step_result: + context.input_embeds_history = step_result["input_embeds_history"] + if "past_key_values" in step_result: + context.past_key_values = step_result["past_key_values"] + if "code" in step_result: + context.code = step_result["code"] + if context.subword_mask is not None: + context.subword_mask[:, start_idx:end_idx] = True + if "perception_cache" in step_result and step_result["perception_cache"] is not None: + context.perception_cache = step_result["perception_cache"] + if "codec_cache" in step_result and step_result["codec_cache"] is not None: + context.codec_cache = step_result["codec_cache"] + + def reset_slots(self, stream_ids: List[int], eos_flags: List[bool]) -> None: + """Release contexts for streams that signalled end-of-stream.""" + if len(stream_ids) != len(eos_flags): + raise ValueError("stream_ids and eos_flags must have the same length") + for stream_id, eos_flag in zip(stream_ids, eos_flags): + if eos_flag and stream_id in self.streamidx2slotidx: + self.reset_slot(self.streamidx2slotidx[stream_id]) + + def get_context(self, stream_ids: List[int]) -> Tuple[StreamingRealtimeContext, Dict[int, int]]: + """Return the cached context associated with the provided stream ids.""" + if len(stream_ids) == 0: + return self._create_context(), {} + if len(stream_ids) != 1: + raise NotImplementedError("get_context currently supports batch_size == 1") + + stream_id = stream_ids[0] + slot_idx = self._ensure_slot(stream_id) + + if self.slot_contexts[slot_idx] is None: + self.slot_contexts[slot_idx] = self._create_context() + + return self.slot_contexts[slot_idx], {slot_idx: 0} diff --git a/nemo/collections/speechlm2/inference/streaming/state/s2s_state.py b/nemo/collections/speechlm2/inference/streaming/state/s2s_state.py new file mode 100644 index 000000000000..83ec8f09b439 --- /dev/null +++ b/nemo/collections/speechlm2/inference/streaming/state/s2s_state.py @@ -0,0 +1,146 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import List, Any, Optional + +import torch +from nemo.collections.asr.inference.utils.text_segment import Word + + +@dataclass +class S2SStreamingState: + """ + State for streaming speech generation. + + This dataclass stores streaming tensors and counters used during + incremental generation. It keeps initialization metadata so it can be + reset to a clean state on demand. + """ + # Initialization metadata (required) + device: torch.device + dtype: torch.dtype + max_len: int + num_audio_codebooks: int + output_sample_rate: int + + # Runtime tensors (initialized in __post_init__) + audio_buffer: torch.Tensor = field(init=False) + + # Accumulated text output + output_text_str: str = "" + output_text_tokens: List[str] = field(default_factory=list) + # Accumulated ASR text output + output_asr_text_str: str = "" + output_asr_text_tokens: List[str] = field(default_factory=list) + # Accumulated words with timings + output_words: List[Word] = field(default_factory=list) + # Final token tensors saved from the context before it is destroyed. + # Used for post-hoc tokens_to_str / tokens_to_str_raw conversion. + final_gen_text: Optional[torch.Tensor] = None + final_gen_asr_text: Optional[torch.Tensor] = None + final_total_frames: int = 0 + + def __post_init__(self) -> None: + """Allocate tensors lazily based on provided metadata.""" + with torch.no_grad(): + # Empty 2D buffer: shape (1, 0). Will be appended over time. + self.audio_buffer = torch.empty((1, 0), device=self.device, dtype=self.dtype) + + def reset(self) -> None: + """Reset all tensors and counters to their initial state.""" + with torch.no_grad(): + self.audio_buffer = torch.empty((1, 0), device=self.device, dtype=self.dtype) + self.output_text_str = "" + self.output_text_tokens.clear() + self.output_asr_text_str = "" + self.output_asr_text_tokens.clear() + self.output_words.clear() + self.final_gen_text = None + self.final_gen_asr_text = None + self.final_total_frames = 0 + + def update_state(self, processed_frames: torch.Tensor, output_text_tokens: Any = None, output_text: str | None = None, output_asr_text: str | None = None) -> None: + """Append new audio to the right of the buffer; token/text args are accepted for API compatibility.""" + if processed_frames is None: + return + if not isinstance(processed_frames, torch.Tensor): + raise TypeError("processed_frames must be a torch.Tensor") + with torch.no_grad(): + # Ensure 2D [1, T] layout by flattening extra dims + append_tensor = processed_frames + if append_tensor.dim() > 1: + append_tensor = append_tensor.reshape(1, -1) + elif append_tensor.dim() == 1: + append_tensor = append_tensor.unsqueeze(0) + prior_samples = int(self.audio_buffer.shape[-1]) + appended_samples = int(append_tensor.shape[-1]) + self.audio_buffer = torch.cat([self.audio_buffer, append_tensor.to(self.device, dtype=self.dtype)], dim=-1) + + # Accumulate text output if provided and create a Word with naive timing + if isinstance(output_text, str) and output_text: + self.output_text_tokens.append(output_text) + # Directly concatenate - spacing is already handled by tokenizer (Ġ → space) + self.output_text_str += output_text + try: + if appended_samples > 0 and self.output_sample_rate > 0: + start_t = float(prior_samples) / float(self.output_sample_rate) + end_t = float(prior_samples + appended_samples) / float(self.output_sample_rate) + self.output_words.append(Word(text=output_text, start=start_t, end=end_t, conf=1.0)) + except Exception: + pass + + if isinstance(output_asr_text, str) and output_asr_text: + self.output_asr_text_tokens.append(output_asr_text) + self.output_asr_text_str += output_asr_text + + @property + def speech_frames(self) -> List[torch.Tensor]: + """Backward-compatible view for code expecting a list of chunks.""" + return [self.audio_buffer] + + def get_output_text(self) -> str: + """Return accumulated text as a single string.""" + return self.output_text_str + + def get_output_asr_text(self) -> str: + """Return accumulated ASR text as a single string.""" + return self.output_asr_text_str + + def get_output_words(self) -> List[Word]: + """Return accumulated words with timings.""" + return list(self.output_words) + + def save_token_tensors(self, gen_text: torch.Tensor, gen_asr_text: torch.Tensor, total_frames: int, + gen_function_text: torch.Tensor = None) -> None: + """Snapshot the full token-ID tensors from the context before it is destroyed.""" + with torch.no_grad(): + self.final_gen_text = gen_text[:, :total_frames].clone().cpu() + self.final_gen_asr_text = gen_asr_text[:, :total_frames].clone().cpu() + self.final_total_frames = total_frames + self.final_gen_function_text = ( + gen_function_text[:, :total_frames].clone().cpu() + if gen_function_text is not None else None + ) + + def get_token_tensors(self) -> Optional[tuple]: + """Return (gen_text, gen_asr_text, total_frames[, gen_function_text]) or None if not saved.""" + if self.final_gen_text is None: + return None + return self.final_gen_text, self.final_gen_asr_text, self.final_total_frames, getattr(self, 'final_gen_function_text', None) + + def cleanup_after_response(self) -> None: + """Clear transient audio; keep token workspaces allocated.""" + with torch.no_grad(): + self.audio_buffer = torch.empty((1, 0), device=self.device, dtype=self.dtype) diff --git a/nemo/collections/speechlm2/inference/utils/__init__.py b/nemo/collections/speechlm2/inference/utils/__init__.py new file mode 100644 index 000000000000..9e3fb699d9f6 --- /dev/null +++ b/nemo/collections/speechlm2/inference/utils/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/speechlm2/inference/utils/pipeline_utils.py b/nemo/collections/speechlm2/inference/utils/pipeline_utils.py new file mode 100644 index 000000000000..7022c6783975 --- /dev/null +++ b/nemo/collections/speechlm2/inference/utils/pipeline_utils.py @@ -0,0 +1,62 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from typing import List, Optional + +from nemo.collections.asr.inference.utils.text_segment import Word + + +def clean_pred_text(text: str) -> str: + """Clean prediction text by removing special markers, timestamps, punctuation, and lowercasing. + + Useful for fair WER comparison between predicted and ground-truth text. + """ + if not text: + return "" + text = text.lstrip('^') + text = re.sub(r'', '', text) + text = re.sub(r'<\$[\d.]+\$>', '', text) + text = re.sub(r'<\|[\d.]+\|>', '', text) + text = re.sub(r'', '', text) + text = text.replace('\u0120', ' ') + text = text.lower() + text = re.sub(r'[^\w\s]', '', text) + return ' '.join(text.split()) + + +class PipelineOutput: + """ + Class to store the output of the S2S pipeline. + """ + + def __init__( + self, + texts: Optional[List[str]] = None, + words: Optional[List[List[Word]]] = None, + asr_texts: Optional[List[str]] = None, + texts_with_timestamps: Optional[List[str]] = None, + asr_texts_with_timestamps: Optional[List[str]] = None, + raw_texts: Optional[List[str]] = None, + raw_asr_texts: Optional[List[str]] = None, + ): + if texts is None and words is None: + raise ValueError("At least one of the 'texts' or 'words' should be provided.") + self.texts = texts + self.words = words + self.asr_texts = asr_texts + self.texts_with_timestamps = texts_with_timestamps + self.asr_texts_with_timestamps = asr_texts_with_timestamps + self.raw_texts = raw_texts + self.raw_asr_texts = raw_asr_texts diff --git a/nemo/collections/speechlm2/inference/vllm/__init__.py b/nemo/collections/speechlm2/inference/vllm/__init__.py new file mode 100644 index 000000000000..9e3fb699d9f6 --- /dev/null +++ b/nemo/collections/speechlm2/inference/vllm/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/speechlm2/inference/vllm/scripts/convert_eartts_checkpoint.py b/nemo/collections/speechlm2/inference/vllm/scripts/convert_eartts_checkpoint.py new file mode 100644 index 000000000000..abb722531319 --- /dev/null +++ b/nemo/collections/speechlm2/inference/vllm/scripts/convert_eartts_checkpoint.py @@ -0,0 +1,256 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import json +import argparse + +import torch +from omegaconf import OmegaConf, DictConfig +from safetensors.torch import save_file, load_file +from transformers import AutoConfig +from nemo.utils import logging + +from nemo.collections.speechlm2.models.duplex_ear_tts import DuplexEARTTS + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--config", type=str, required=True) + parser.add_argument("--model", type=str, required=True) + parser.add_argument("--outdir", type=str, required=True) + return parser.parse_args() + + +def convert(outdir, config, model_path): + os.makedirs(outdir, exist_ok=True) + + # load config + with open(config, "r") as f: + config_dict = json.load(f)["model"]["speech_generation"] + cfg = DictConfig(config_dict) + # config modification that is needed to run inference + cfg.model.tts_config.use_unshifthed_prompt = True + cfg.data.add_audio_prompt_after_description = True + cfg.model.tts_config.use_unshifthed_prompt = True + cfg.model.subword_mask_exactly_as_eartts = False + cfg.model.context_hidden_mask_exactly_as_eartts = False + cfg.model.tts_config.disable_eos_prediction = True + cfg.model.inference_force_speech_silence_on_eos = True + cfg.model.use_word_sep_tokenizer = False + cfg.model.num_delay_speech_tokens = 0 + cfg.data.source_sample_rate = 22050 + cfg.data.target_sample_rate = 22050 + cfg.model.pretrained_model = None + + # Compatibility fix: remove 'pretrained_tokenizer_name' from cas_config + # (the new codebase's CharAwareSubwordEncoder no longer accepts this parameter; + # NemotronVoiceChat.__init__ handles this, but we bypass it here) + _pretrained_tokenizer_name = None + if hasattr(cfg.model, "tts_config") and hasattr(cfg.model.tts_config, "cas_config"): + _pretrained_tokenizer_name = cfg.model.tts_config.cas_config.get("pretrained_tokenizer_name", None) + if _pretrained_tokenizer_name is not None: + del cfg.model.tts_config.cas_config.pretrained_tokenizer_name + + model = DuplexEARTTS(OmegaConf.to_container(cfg, resolve=True)).eval() + # get subword encoder vocabs and config + subword_id_to_char_ids = model.tts_model.embed_subword.subword_id_to_char_ids + char_vocab = model.tts_model.embed_subword.char_vocab + # create weights for the embedding layers that convert subword ids to char ids + vocab_size = len(subword_id_to_char_ids) + max_char_len = max(len(char_ids) for char_ids in subword_id_to_char_ids.values()) + hidden_size = cfg.model.tts_config.backbone_config.hidden_size + + # load checkpoint (support both safetensors and pytorch formats) + weights = load_file(model_path) + # select tts model weights, strip off one nested layer + weights = {k[len("tts_model."):]: v for k, v in weights.items() if "tts_model." in k} + + # duplicate weights for rvq embeddings and embed code + rvq_embs_weight = weights["tts_model.rvq_embs"].clone() # 31 x codebook_size x latent_size + rvq_embs_weight_pad = torch.nn.functional.pad(rvq_embs_weight, [0, 0, 0, 1]) # 31 x (codebook_size + 1) x latent_size + embed_code_weight = weights["tts_model.embed_code.weight"].clone() # latent_size x hidden_size + + # ====================== + # embedding module weights + bos_emb = weights["tts_model.bos_emb"] + null_emb = weights["tts_model.null_emb"] + embed_subwords_weight = torch.zeros( + (vocab_size, max_char_len), dtype=bos_emb.dtype, device=bos_emb.device + ) + embed_subwords_mask_weight = torch.zeros( + (vocab_size, max_char_len), dtype=bos_emb.dtype, device=bos_emb.device + ) + for subword_id_str, char_ids_lst in subword_id_to_char_ids.items(): + subword_id = int(subword_id_str) + char_ids = torch.tensor( + char_ids_lst, dtype=bos_emb.dtype, device=bos_emb.device + ) + embed_subwords_weight[subword_id, : len(char_ids)] = char_ids + embed_subwords_mask_weight[subword_id, : len(char_ids)] = 1 + + # create weights for the embedding model that runs outside of the eartts + embedding_module_weights = {} + embedding_module_weights["bos_emb"] = bos_emb + embedding_module_weights["null_emb"] = null_emb + + # embedding transformer has a lot of weights + for key, weight in weights.items(): + if "tts_model.embed_subword" in key: + key = key[len("tts_model.") :] + # bos_eos_emb and subword_flag_emb are moved outside embed_subword + if key.startswith("embed_subword.bos_eos_emb.") or key.startswith("embed_subword.subword_flag_emb."): + key = key[len("embed_subword."):] + embedding_module_weights[key] = weight + for key, weight in weights.items(): + if "tts_model.gated_fusion_audio_text" in key: + key = key[len("tts_model.") :] + embedding_module_weights[key] = weight + if "tts_model.audio_prompt_projection_W" in weights: + embedding_module_weights["audio_prompt_projection_W"] = weights["tts_model.audio_prompt_projection_W"] + embedding_module_weights["embed_subword.embed_subwords.weight"] = ( + embed_subwords_weight + ) + embedding_module_weights["embed_subword.embed_subwords_mask.weight"] = ( + embed_subwords_mask_weight + ) + for i in range(rvq_embs_weight_pad.shape[0]): + embedding_module_weights[f"rvq_embs.{i}.weight"] = rvq_embs_weight_pad[i] + embedding_module_weights["embed_code.weight"] = embed_code_weight + embedding_module_weights = { + f"total_emb.{k}": v for k, v in embedding_module_weights.items() + } + + # ====================== + # gemma backbone weights + backbone_module_weights = {k[len("tts_model."):]: v for k, v in weights.items() if k.startswith("tts_model.backbone.")} + backbone_module_weights["backbone.embed_tokens.weight"] = torch.randn(1, hidden_size, dtype=bos_emb.dtype, device=bos_emb.device) + + # ====================== + # sampler weights + used_keys = ["rvq_embs", "embed_code", "mog_head"] + sampler_weights = {k[len("tts_model."):]: v for k, v in weights.items() if any(k.startswith(f"tts_model.{key}") for key in used_keys)} + sampler_weights = {"sampler." + k: v for k, v in sampler_weights.items()} + + # combine embedding module and backbone module weights + weights = {**embedding_module_weights, **backbone_module_weights, **sampler_weights} + weights = {"model." + k: v for k, v in weights.items()} + + # save weights + safetensors_path = os.path.join(outdir, "model.safetensors") + save_file(weights, safetensors_path) + logging.info("Saved weights for vllm model") + weight_map = {name: "model.safetensors" for name in weights.keys()} + index = { + "metadata": { + "total_size": sum(w.numel() * w.element_size() for w in weights.values()) + }, + "weight_map": weight_map, + } + index_path = os.path.join(outdir, "model.safetensors.index.json") + with open(index_path, "w") as f: + json.dump(index, f, indent=2) + logging.info("Saved model index") + + # save config.json + flat_config = {"architectures": ["EarTTSForCausalLM"], "model_type": "eartts"} + # not using vocab size of the backbone model + flat_config["vocab_size"] = 1 + + # Parse backbone config exactly as NeMo does to get all defaults from transformers + backbone_type = cfg.model.tts_config.get("backbone_type", None) + backbone_config_dict = OmegaConf.to_container( + cfg.model.tts_config.backbone_config, resolve=True + ) if cfg.model.tts_config.get("backbone_config") else {} + + # Create AutoConfig the same way NeMo does - this fills in all defaults + parsed_backbone_config = AutoConfig.for_model(backbone_type, **backbone_config_dict) + + # Store the backbone type for vllm to use + flat_config["backbone_type"] = backbone_type + + # Forward all backbone configs from the parsed AutoConfig (includes defaults) + for key in [ + "hidden_size", + "intermediate_size", + "num_hidden_layers", + "num_attention_heads", + "num_key_value_heads", + "head_dim", + "max_position_embeddings", + "rope_theta", + "rope_local_base_freq", + "sliding_window", + "layer_types", + ]: + if hasattr(parsed_backbone_config, key): + value = getattr(parsed_backbone_config, key) + # convert to list if it's a tuple or other iterable (except str) + if hasattr(value, '__iter__') and not isinstance(value, (str, dict)): + value = list(value) + flat_config[key] = value + # forward overall configs + for key in ["latent_size", "codebook_size", "num_quantizers", "exponent"]: + flat_config[key] = cfg.model.tts_config[key] + # forward mog head configs + for key in ["num_layers", "low_rank", "num_predictions", "min_log_std", "eps"]: + flat_config[f"mog_{key}"] = cfg.model.tts_config.mog_head_config[key] + + # forward inference configs (with name mapping for vLLM model) + # num_iter is hardcoded to 8 in native model's _get_generation_config + flat_config["num_iter"] = 8 + flat_config["noise_scale"] = cfg.model.get("inference_noise_scale", 0.8) + flat_config["top_p_or_k"] = cfg.model.get("inference_top_p_or_k", 0.8) + flat_config["guidance_scale"] = cfg.model.get("inference_guidance_scale", 0.5) + + # configuration of the embedding module + flat_config["emb_backbone_config"] = OmegaConf.to_container( + cfg.model.tts_config.cas_config.backbone_config, resolve=True + ) + flat_config["emb_backbone_type"] = cfg.model.tts_config.cas_config.backbone_type + flat_config["emb_vocab_size"] = vocab_size + flat_config["emb_char_vocab_size"] = len(char_vocab) + flat_config["max_char_len"] = max_char_len + + # configuration of flag embeddings + flat_config["pretrained_tokenizer_name"] = _pretrained_tokenizer_name + flat_config["use_subword_flag_emb"] = cfg.model.tts_config.use_subword_flag_emb + flat_config["use_bos_eos_emb"] = cfg.model.tts_config.use_bos_eos_emb + flat_config["use_gated_fusion_for_text_audio"] = cfg.model.tts_config.use_gated_fusion_for_text_audio + flat_config["use_audio_prompt_frozen_projection"] = cfg.model.tts_config.use_audio_prompt_frozen_projection + # hardcode enabling guidance so emb is created and application + # of cfg is captured into a cuda graph + flat_config["enable_guidance"] = True + + # configuring custom inputs/outputs + flat_config["custom_input_specs"] = [ + { + "name": "acoustic_tokens", + "dim": flat_config["num_quantizers"], + "dtype": "int32", + }, + {"name": "text_tokens", "dtype": "int32"}, + {"name": "text_mask"}, + {"name": "bos_mask"}, + ] + flat_config["custom_outputs"] = ["acoustic_tokens"] + + with open(os.path.join(outdir, "config.json"), "w") as f: + json.dump(flat_config, f, indent=2) + logging.info("Saved vllm config") + + +if __name__ == "__main__": + args = parse_args() + convert(args.outdir, args.config, args.model) diff --git a/nemo/collections/speechlm2/inference/vllm/scripts/convert_nemotronllm_checkpoint.py b/nemo/collections/speechlm2/inference/vllm/scripts/convert_nemotronllm_checkpoint.py new file mode 100644 index 000000000000..47d2463cebf0 --- /dev/null +++ b/nemo/collections/speechlm2/inference/vllm/scripts/convert_nemotronllm_checkpoint.py @@ -0,0 +1,261 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Convert NeMo STT checkpoint to HuggingFace format compatible with vLLM. + +This script extracts weights from a NeMo checkpoint that has the structure: +- stt_model.llm.layers.* +- stt_model.lm_head.* +- stt_model.asr_head.* +- stt_model.embed_asr_tokens.* +- stt_model.embed_tokens.* + +And converts them to HuggingFace format that can be loaded by vLLM with the +custom WeightsMapper defined in nemotron_h.py. +""" + +import argparse +import json +import os +from pathlib import Path +from typing import Dict, List, Optional + +import torch +from safetensors.torch import load_file, save_file +from transformers import AutoConfig, AutoTokenizer +from nemo.utils import logging + + +def load_checkpoint(checkpoint_path: str) -> Dict[str, torch.Tensor]: + """ + Load checkpoint from safetensors or PyTorch format. + + Args: + checkpoint_path: Path to the checkpoint file (.safetensors or .pt/.pth) + + Returns: + Dictionary of tensor names to tensors + """ + if os.path.isdir(checkpoint_path): + checkpoint_path = os.path.join(checkpoint_path, "model.safetensors") + + if checkpoint_path.endswith('.safetensors'): + logging.info(f"Loading safetensors from {checkpoint_path}") + return load_file(checkpoint_path) + else: + logging.info(f"Loading PyTorch checkpoint from {checkpoint_path}") + ckpt = torch.load(checkpoint_path, map_location='cpu') + # Handle different checkpoint formats + if 'state_dict' in ckpt: + return ckpt['state_dict'] + elif 'model' in ckpt: + return ckpt['model'] + else: + return ckpt + + +def filter_tensors( + state_dict: Dict[str, torch.Tensor], + prefixes_to_keep: List[str] +) -> Dict[str, torch.Tensor]: + """ + Filter tensors to keep only those with specified prefixes. + + Args: + state_dict: Full state dictionary + prefixes_to_keep: List of prefixes to keep (e.g., ["stt_model.llm", "stt_model.asr_head"]) + + Returns: + Filtered state dictionary + """ + filtered_dict = {} + for name, tensor in state_dict.items(): + if any(name.startswith(prefix) for prefix in prefixes_to_keep): + filtered_dict[name] = tensor + logging.debug(f"Keeping: {name} with shape {tensor.shape}") + else: + logging.debug(f"Skipping: {name}") + + logging.info(f"Total tensors kept: {len(filtered_dict)}") + return filtered_dict + + +def convert_nemo_to_hf_format( + checkpoint_path: str, + output_dir: str, + config_path: Optional[str] = None, + pretrained_llm: Optional[str] = None, + tensors_to_keep: Optional[List[str]] = None, + dtype: str = "float32", +) -> None: + """ + Convert NeMo STT checkpoint to HuggingFace format. + + Args: + checkpoint_path: Path to the NeMo checkpoint (.safetensors or .pt) + output_dir: Directory to save the converted checkpoint + config_path: Path to config.json (if None, will look in same dir as checkpoint) + pretrained_llm: HuggingFace model name to get base config from + tensors_to_keep: List of tensor prefixes to keep (default: all stt_model.* tensors) + dtype: Data type for tensors ("float32", "float16", "bfloat16") + """ + # Default prefixes to keep + if tensors_to_keep is None: + tensors_to_keep = [ + "stt_model.llm", + "stt_model.lm_head", + "stt_model.asr_head", + "stt_model.embed_asr_tokens", + "stt_model.embed_tokens", + ] + + # Load config to get pretrained_llm if not provided + if config_path is None: + ckpt_dir = checkpoint_path if os.path.isdir(checkpoint_path) else os.path.dirname(checkpoint_path) + config_path = os.path.join(ckpt_dir, "config.json") + + if os.path.exists(config_path): + logging.info(f"Loading config from {config_path}") + with open(config_path, "r") as f: + config = json.load(f) + + try: + pretrained_llm = config["model"]["stt"]["model"]["pretrained_llm"] + logging.info(f"Found pretrained_llm in config: {pretrained_llm}") + except KeyError: + if pretrained_llm is None: + raise ValueError( + "Could not find pretrained_llm in config and none provided via argument" + ) + else: + if pretrained_llm is None: + raise ValueError( + f"Config file not found at {config_path} and pretrained_llm not provided" + ) + + # Create output directory + output_path = Path(output_dir) + output_path.mkdir(parents=True, exist_ok=True) + + # Load base config from pretrained model + logging.info(f"Loading base config from {pretrained_llm}") + base_config = AutoConfig.from_pretrained(pretrained_llm, trust_remote_code=True) + + # Modify config for custom inputs/outputs + custom_config = { + "custom_input_specs": [ + { + "name": "combined_embeds", + "dtype": dtype, + "dim": base_config.hidden_size + } + ], + "custom_outputs": ["text_logits", "asr_tokens", "asr_logits"] + } + base_config.update(custom_config) + + + # Load tokenizer from pretrained model + logging.info(f"Loading tokenizer from {pretrained_llm}") + tokenizer = AutoTokenizer.from_pretrained(pretrained_llm, trust_remote_code=True) + + # Load checkpoint + logging.info(f"Loading checkpoint from {checkpoint_path}") + state_dict = load_checkpoint(checkpoint_path) + + # Filter tensors + logging.info(f"Filtering tensors to keep prefixes: {tensors_to_keep}") + filtered_state_dict = filter_tensors(state_dict, tensors_to_keep) + + if len(filtered_state_dict) == 0: + raise ValueError( + f"No tensors found with prefixes {tensors_to_keep}. " + f"Available prefixes: {set(k.split('.')[0] for k in state_dict.keys())}" + ) + + # Save tensors + output_model_path = output_path / "model.safetensors" + logging.info(f"Saving tensors to {output_model_path}") + save_file(filtered_state_dict, str(output_model_path)) + + # Save config + output_config_path = output_path / "config.json" + logging.info(f"Saving config to {output_config_path}") + base_config.save_pretrained(str(output_path)) + + # Save tokenizer + logging.info(f"Saving tokenizer to {output_path}") + tokenizer.save_pretrained(str(output_path)) + + logging.info(f"Conversion completed successfully! Output saved to: {output_path}") + + +def main(): + parser = argparse.ArgumentParser( + description="Convert NeMo STT checkpoint to HuggingFace format for vLLM" + ) + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to NeMo checkpoint file (.safetensors or .pt/.pth)", + ) + parser.add_argument( + "--output-dir", + type=str, + required=True, + help="Directory to save converted checkpoint", + ) + parser.add_argument( + "--config", + type=str, + default=None, + help="Path to config.json (default: same directory as checkpoint)", + ) + parser.add_argument( + "--pretrained-llm", + type=str, + default=None, + help="HuggingFace model name to use as base (default: read from config)", + ) + parser.add_argument( + "--tensors-to-keep", + type=str, + nargs="+", + default=None, + help="Tensor prefixes to keep (default: all stt_model.* backbone llm related tensors)", + ) + parser.add_argument( + "--dtype", + type=str, + default="float32", + choices=["float32", "float16", "bfloat16", "fp32", "fp16", "bf16"], + help="Target dtype for tensors (default: float32)", + ) + + args = parser.parse_args() + + convert_nemo_to_hf_format( + checkpoint_path=args.checkpoint, + output_dir=args.output_dir, + config_path=args.config, + pretrained_llm=args.pretrained_llm, + tensors_to_keep=args.tensors_to_keep, + dtype=args.dtype, + ) + + +if __name__ == "__main__": + main() diff --git a/nemo/collections/speechlm2/inference/vllm/streaming_llm_engine.py b/nemo/collections/speechlm2/inference/vllm/streaming_llm_engine.py new file mode 100644 index 000000000000..8384dd112ef7 --- /dev/null +++ b/nemo/collections/speechlm2/inference/vllm/streaming_llm_engine.py @@ -0,0 +1,480 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Speech Streaming Engine Wrapper +A clean wrapper for streaming speech-to-speech generation with custom embeddings. +""" + +import os +import json +import torch +import asyncio +from typing import Optional, Dict, Any, AsyncGenerator, Tuple +from dataclasses import dataclass +from enum import Enum + +from vllm.v1.engine.async_llm import AsyncLLM +from vllm import SamplingParams +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.config.model import CustomInputSpec +from vllm.attention.selector import _cached_get_attn_backend + +from nemo.utils import logging + +class StreamStatus(Enum): + IDLE = "idle" + ACTIVE = "active" + ABORTED = "aborted" + FINISHED = "finished" + + +@dataclass +class GenerationResult: + token_id: int + is_finished: bool + custom_outputs: Optional[Dict[str, torch.Tensor]] = None + finish_reason: Optional[str] = None + total_tokens: int = 0 + + +@dataclass +class RequestState: + """State for a single generation request.""" + request_id: str + status: StreamStatus + generated_tokens: list + generation_iterator: Optional[AsyncGenerator] = None + + +class LLMStreamingEngine: + """ + A wrapper for vLLM AsyncLLM engine that enables: + - Easy initialization with speech model configuration + - Start/stop streaming with custom embeddings + - Generate one token at a time + - Abort ongoing generation + """ + + def __init__( + self, + model_path: str = "/ws/ckpt/converted", + max_model_len: int = 10240, + gpu_memory_utilization: float = 0.8, + trust_remote_code: bool = True, + dtype: str = "bfloat16", + skip_tokenizer_init: bool = False, + **sampling_kwargs + ): + """ + Initialize the Speech Streaming Engine. + + Args: + model_path: Path to the speech model (default: "/ws/ckpt/converted") + max_model_len: Maximum sequence length (default: 10240) + gpu_memory_utilization: GPU memory utilization ratio (default: 0.8) + trust_remote_code: Whether to trust remote code (default: True) + dtype: Data type for embeddings (default: "bfloat16") + **sampling_kwargs: Additional sampling parameters (max_tokens, temperature, top_p, top_k, seed, stop, stop_token_ids, ignore_eos) + """ + self.model_path = model_path + self.max_model_len = max_model_len + self.gpu_memory_utilization = gpu_memory_utilization + self.trust_remote_code = trust_remote_code + self.dtype = dtype + self.skip_tokenizer_init = skip_tokenizer_init + + # Engine state + self.engine: Optional[AsyncLLM] = None + + # Request state tracking - supports multiple concurrent requests + self.requests: Dict[str, RequestState] = {} + + # Default sampling parameters + default_sampling = { + "max_tokens": 100000, # Set very high to prevent stopping - use abort to stop explicitly + "temperature": 0.0, + "top_p": 0.9, + "top_k": 50, + "seed": None, + "stop": [], + "stop_token_ids": [], + "ignore_eos": True, + } + default_sampling.update(sampling_kwargs) + self.sampling_params = SamplingParams(**default_sampling) + + logging.info(f"LLMStreamingEngine initialized for model: {model_path}") + + async def initialize(self): + """Initialize the vLLM engine with custom input specifications.""" + if self.engine is not None: + logging.info("Engine already initialized!") + return + + logging.info("Initializing vLLM engine...") + + # Create engine arguments + + engine_args = AsyncEngineArgs( + model=self.model_path, + max_model_len=self.max_model_len, + max_num_batched_tokens=768, + gpu_memory_utilization=self.gpu_memory_utilization, + trust_remote_code=self.trust_remote_code, + mamba_ssm_cache_dtype="float32", + dtype=self.dtype, + skip_tokenizer_init=self.skip_tokenizer_init, + enable_prefix_caching=False + ) + + # please custom input/output specs in model config file + # Create engine config and add custom input specs + vllm_config = engine_args.create_engine_config() + self.custom_input_specs = vllm_config.model_config.custom_input_specs + + # Initialize the engine + self.engine = AsyncLLM.from_vllm_config(vllm_config) + + logging.info("Engine initialized with custom input specs:") + for spec in self.custom_input_specs: + logging.info(f" - {spec}") + + def _get_safe_prompt_tokens(self, length: int = 10) -> list[int]: + """Generate safe prompt tokens that won't cause immediate EOS.""" + if self.engine and hasattr(self.engine, 'tokenizer') and self.engine.tokenizer: + bos_id = getattr(self.engine.tokenizer, 'bos_token_id', 1) + else: + bos_id = 1 + + # Mix of BOS + safe alphanumeric tokens + safe_tokens = [bos_id] + list(range(50, 59)) # tokens 50-58 are usually safe + return (safe_tokens * ((length // len(safe_tokens)) + 1))[:length] + + async def start_generation( + self, + request_id: str = "speech_stream" + ) -> bool: + """ + Start a new streaming generation session. + + Args: + request_id: Unique identifier for this generation request + + Returns: + bool: True if generation started successfully + """ + if self.engine is None: + raise RuntimeError("Engine not initialized! Call initialize() first.") + + # Check if request already exists + if request_id in self.requests: + existing_state = self.requests[request_id] + if existing_state.status == StreamStatus.ACTIVE: + logging.warning(f"Request {request_id} is already active. Aborting it first.") + await self.abort_generation(request_id) + + logging.info(f"Starting generation session with request_id: {request_id}") + + # Create new request state + self.requests[request_id] = RequestState( + request_id=request_id, + status=StreamStatus.ACTIVE, + generated_tokens=[], + generation_iterator=None # Will be created on first generate_next_token call + ) + return True + + async def generate_next_token(self, input_tensors: list[torch.Tensor], + prompt_token_ids: Optional[list[int]] = None, + request_id: str = "speech_stream") -> Optional[GenerationResult]: + """ + Generate the next token using the provided input embedding. + + Args: + input_tensors: List of tensors for generating the next token + prompt_token_ids: Optional list of token IDs for the system prompt + request_id: Unique identifier for this generation request + + Returns: + GenerationResult or None if generation is finished/aborted + """ + if request_id not in self.requests: + raise RuntimeError(f"Request {request_id} not found. Call start_generation() first.") + + request_state = self.requests[request_id] + + if request_state.status != StreamStatus.ACTIVE: + logging.warning(f"Generation not active for request {request_id} (status: {request_state.status})") + return None + + assert len(input_tensors) == len(self.custom_input_specs), f"Expected {len(self.custom_input_specs)} input tensors, got {len(input_tensors)}" + + if self.engine is None: + raise RuntimeError("Engine not initialized") + + custom_inputs = {} + max_length = 1 + for i, spec in enumerate(self.custom_input_specs): + input_dtype = spec.dtype + if input_dtype is None: + input_dtype = "float32" # Default dtype + if spec.dim !=None and spec.dim != input_tensors[i].shape[-1]: + raise ValueError(f"Input tensor dimension mismatch for {spec.name}: expected {spec.dim}, got {input_tensors[i].shape[-1]}") + custom_inputs[spec.name] = input_tensors[i].to(dtype=getattr(torch, input_dtype)).cpu() + max_length = max(max_length, input_tensors[i].shape[0]) + + try: + # If this is the first call, initialize the generation + if request_state.generation_iterator is None: + # Create initial inputs with a single safe prompt token + # this will not be used for generation, just to initialize the model state + prompt_tokens = prompt_token_ids if prompt_token_ids is not None else self._get_safe_prompt_tokens(max_length) + assert len(prompt_tokens) == max_length, f"Prompt tokens length {len(prompt_tokens)} does not match input length {max_length}" + inputs = { + "prompt_token_ids": prompt_tokens, + "custom_inputs": custom_inputs + } + + logging.info(f"Initializing generation for request {request_id} with first embedding") + + # Start generation + request_state.generation_iterator = self.engine.generate( + inputs, + self.sampling_params, + request_id=request_id + ) + + # If this is not the first call, append the current embedding for processing + elif request_state.generation_iterator is not None and len(request_state.generated_tokens) > 0: + try: + await self.engine.append_request( + request_id=request_id, + custom_inputs=custom_inputs + ) + except ValueError as e: + if "not found" in str(e): + logging.warning(f"Request {request_id} was removed from vLLM engine. Marking as finished.") + request_state.status = StreamStatus.FINISHED + return None + else: + raise RuntimeError(f"Error appending to request {request_id}: {e}") + + # Get next output from the generation + output = await request_state.generation_iterator.__anext__() + + # Extract new tokens + current_tokens = output.outputs[0].token_ids + if len(current_tokens) > len(request_state.generated_tokens): + new_tokens = current_tokens[len(request_state.generated_tokens):] + assert len(new_tokens) == 1, f"Expected exactly one new token, got {len(new_tokens)}" + new_tokens = current_tokens[-1:] + request_state.generated_tokens.extend(new_tokens) + + # Get the latest token + latest_token = new_tokens[-1] + + # Check if finished + if output.finished: + request_state.status = StreamStatus.FINISHED + finish_reason = output.outputs[0].finish_reason + logging.warning(f"Request {request_id} finished after {len(request_state.generated_tokens)} tokens. Reason: {finish_reason}") + return GenerationResult( + token_id=latest_token, + custom_outputs=output.outputs[0].custom_outputs if hasattr(output.outputs[0], 'custom_outputs') else None, + is_finished=True, + finish_reason=finish_reason, + total_tokens=len(request_state.generated_tokens) + ) + else: + return GenerationResult( + token_id=latest_token, + custom_outputs=output.outputs[0].custom_outputs if hasattr(output.outputs[0], 'custom_outputs') else None, + is_finished=False, + total_tokens=len(request_state.generated_tokens) + ) + else: + # No new tokens generated + if output.finished: + logging.warning("No new tokens but finished!") + logging.warning(output.outputs[0].finish_reason) + request_state.status = StreamStatus.FINISHED + return None + + except StopAsyncIteration: + # Generation ended + request_state.status = StreamStatus.FINISHED + return None + except Exception as e: + logging.error(f"Error in generate_next_token for request {request_id}: {e}") + request_state.status = StreamStatus.FINISHED + return None + + async def abort_generation(self, request_id: str = "speech_stream") -> bool: + """ + Abort a specific generation request. + + Args: + request_id: Unique identifier for the generation request to abort + + Returns: + bool: True if abort was successful + """ + if request_id not in self.requests: + logging.warning(f"Request {request_id} not found") + return False + + request_state = self.requests[request_id] + + if request_state.status != StreamStatus.ACTIVE: + logging.info(f"Request {request_id} is {request_state.status.value}, cleaning up state") + # Just remove the state, no need to abort + del self.requests[request_id] + return True + + try: + await self.engine.abort(request_id) + request_state.status = StreamStatus.ABORTED + del self.requests[request_id] + logging.info(f"Aborted generation for request: {request_id}") + return True + except Exception as e: + logging.error(f"Error aborting generation for request {request_id}: {e}") + return False + + async def shutdown(self): + """Shutdown the engine and cleanup resources.""" + # Abort all active requests + for request_id, request_state in list(self.requests.items()): + if request_state.status == StreamStatus.ACTIVE: + await self.abort_generation(request_id) + + if self.engine is not None: + logging.info("Shutting down engine...") + self.engine.shutdown() + self.engine = None + logging.info("Engine shutdown complete.") + + # Clear all request states + self.requests.clear() + + def get_status(self, request_id: Optional[str] = None) -> Dict[str, Any]: + """Get current status information. + + Args: + request_id: If provided, return status for specific request. + If None, return status for all requests. + + Returns: + Status information dictionary + """ + if request_id is not None: + if request_id not in self.requests: + return {"error": f"Request {request_id} not found"} + + request_state = self.requests[request_id] + return { + "request_id": request_id, + "status": request_state.status.value, + "tokens_generated": len(request_state.generated_tokens), + "latest_tokens": request_state.generated_tokens[-5:] if request_state.generated_tokens else [] + } + else: + # Return summary of all requests + return { + "total_requests": len(self.requests), + "requests": { + rid: { + "status": state.status.value, + "tokens_generated": len(state.generated_tokens) + } + for rid, state in self.requests.items() + } + } + + async def __aenter__(self): + """Async context manager entry.""" + await self.initialize() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Async context manager exit.""" + await self.shutdown() + + +class EARTTSStreamingEngine(LLMStreamingEngine): + """ + A specialized streaming engine for EARTTS models. + Inherits from LLMStreamingEngine and sets EARTTS-specific configurations. + """ + def __init__(self, **kwargs): + super().__init__(**kwargs) + guidance_scale = self._read_guidance_scale_from_config() + default_sampling = { + "max_tokens": 100000, # Set very high to prevent stopping - use abort to stop explicitly + "temperature": 0.0, + "skip_sampling": True, + "ignore_eos": True, + "guidance_scale": guidance_scale, + } + self.sampling_params = SamplingParams(**default_sampling) + logging.info(f"EARTTSStreamingEngine initialized (guidance_scale={guidance_scale}).") + + def _read_guidance_scale_from_config(self) -> float: + """Read guidance_scale from the converted vLLM model's config.json.""" + config_path = os.path.join(self.model_path, "config.json") + if os.path.isfile(config_path): + with open(config_path, "r") as f: + cfg = json.load(f) + value = cfg.get("guidance_scale", None) + if value is not None: + logging.info(f"Read guidance_scale={value} from {config_path}") + return float(value) + logging.warning( + f"guidance_scale not found in {config_path}, using default 0.5. " + ) + return 0.5 + + async def initialize(self): + # Force TRITON_ATTN backend for EarTTS + os.environ["VLLM_ATTENTION_BACKEND"] = "TRITON_ATTN" + # TF32 matmul precision to match TTS training ("medium"). + # torch.set_float32_matmul_precision is process-local and does NOT + # propagate to vLLM's spawned worker processes; this CUDA-level env + # var is inherited by child processes. + os.environ["NVIDIA_TF32_OVERRIDE"] = "1" + _cached_get_attn_backend.cache_clear() + await super().initialize() + os.environ.pop("VLLM_ATTENTION_BACKEND", None) + os.environ.pop("NVIDIA_TF32_OVERRIDE", None) + _cached_get_attn_backend.cache_clear() + + +def create_engine(engine_type: str = "llm", **kwargs) -> LLMStreamingEngine: + """ + Factory function to create a streaming engine instance. + + Args: + engine_type: Type of the engine ("eartts" or "llm", default: "llm") + **kwargs: Additional arguments for engine initialization (model_path, max_model_len, gpu_memory_utilization, trust_remote_code, dtype, and sampling parameters) + Returns: + An instance of LLMStreamingEngine or its subclass + """ + + if engine_type == "eartts": + return EARTTSStreamingEngine(**kwargs) + elif engine_type == "llm": + return LLMStreamingEngine(**kwargs) + else: + raise ValueError(f"Unsupported engine_type: {engine_type}") \ No newline at end of file diff --git a/nemo/collections/speechlm2/inference/vllm/vllm_patch.py b/nemo/collections/speechlm2/inference/vllm/vllm_patch.py new file mode 100644 index 000000000000..35322235e586 --- /dev/null +++ b/nemo/collections/speechlm2/inference/vllm/vllm_patch.py @@ -0,0 +1,59 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +@torch.no_grad() +def patched_infer_codes_one_step( + self, + current_subword_id, + prev_subword_id, + current_subword_mask, + prev_audio_tokens, + past_key_values, + guidance_enabled=True, + generation_config=None, + ignore_eos_flag_stop=True, + request_id=None, # change signature to include request_id +): + if self.cfg.tts_config.context_hidden_size is not None: + # get context_hidden_state it is always one step behind current_subword_id + # for the first step uses the last step from warmup + context_hidden_state = self.embed_tokens(prev_subword_id) + else: + context_hidden_state = None + + # force silence as next token + if self.cfg.get('inference_force_speech_silence_on_eos', True): + silence_codes = self.codec_silence_tokens.view(1, 1, -1).expand(prev_audio_tokens.shape) + prev_audio_tokens = torch.where( + current_subword_id.unsqueeze(-1) == self.text_eos_id, + silence_codes, # silence + prev_audio_tokens, # keep original + ) + # get subword_ids + inputs = { + "code": prev_audio_tokens, + "context_hidden_state": context_hidden_state, + "subword_ids": current_subword_id, + "subword_mask": current_subword_mask, + "past_key_values": past_key_values, + "use_cache": True, + "guidance_enabled": guidance_enabled, + "generation_config": generation_config, + "ignore_eos_flag_stop": ignore_eos_flag_stop, + "request_id": request_id, # pass request_id to the model + } + outputs = self.tts_model(**inputs) + return outputs["codes"], outputs["past_key_values"] \ No newline at end of file From 98da1ad4cec3a3ed3fa073ef4e0ef243efd1e6cc Mon Sep 17 00:00:00 2001 From: Elena Rastorgueva Date: Thu, 12 Mar 2026 17:55:40 +0000 Subject: [PATCH 02/40] add on: asr_logits boosts, speaker embedding, fc head Signed-off-by: Elena Rastorgueva --- .../conf/s2s_streaming.yaml | 3 +- .../nemotron_voicechat_inference_wrapper.py | 31 +++++++++++-------- .../pipelines/streaming_s2s_pipeline.py | 2 +- .../speechlm2/models/duplex_stt_model.py | 20 ++++++++++++ 4 files changed, 41 insertions(+), 15 deletions(-) diff --git a/examples/speechlm2/nemo_inference_pipelines/conf/s2s_streaming.yaml b/examples/speechlm2/nemo_inference_pipelines/conf/s2s_streaming.yaml index bded9501448f..f066b989f5f2 100644 --- a/examples/speechlm2/nemo_inference_pipelines/conf/s2s_streaming.yaml +++ b/examples/speechlm2/nemo_inference_pipelines/conf/s2s_streaming.yaml @@ -28,7 +28,8 @@ pipeline_type: s2s_streaming s2s: model_path: ??? llm_checkpoint_path: ??? - speaker_reference: ??? + speaker_reference: null + speaker_name: null engine_type: ??? # Engine type: 'native' or 'vllm_llm' or 'vllm_eartts' or 'vllm_llm_vllm_eartts' vllm_llm_config: model_path: ${s2s.model_path} # Inherits from s2s.model_path diff --git a/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py b/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py index 6aabcaa41a15..d7ab874b3c70 100755 --- a/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py +++ b/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py @@ -44,7 +44,7 @@ from nemo.collections.speechlm2.models.nemotron_voicechat import NemotronVoiceChat -from nemo.collections.speechlm2.models.duplex_s2s_model import tokens_to_str +from nemo.collections.speechlm2.parts.text_utils import tokens_to_str from nemo.collections.speechlm2.parts.precision import fp32_precision from nemo.collections.audio.parts.utils.transforms import resample from nemo.collections.speechlm2.inference.model_wrappers.model_factory import create_model @@ -190,8 +190,9 @@ def __init__(self, model_cfg: DictConfig): ) self.speaker_reference = model_cfg.get("speaker_reference") - if self.decode_audio and not self.speaker_reference: - raise ValueError("`model_cfg.speaker_reference` must be provided when decode_audio is enabled.") + self.speaker_name = model_cfg.get("speaker_name", None) + if self.decode_audio and not self.speaker_reference and not self.speaker_name: + raise ValueError("`model_cfg.speaker_reference` or `model_cfg.speaker_name` must be provided when decode_audio is enabled.") self.tts_system_prompt = model_cfg.get("tts_system_prompt", None) logging.info(f"TTS system prompt: {self.tts_system_prompt}") @@ -572,7 +573,7 @@ def _initialize_model(self): stt = self.model.stt_model asr_boost_map = { "inference_user_pad_boost": stt.text_pad_id, - "inference_user_bos_boost": stt.user_bos_id, + "inference_user_bos_boost": stt.text_bos_id, "inference_user_eos_boost": stt.text_eos_id, } for cfg_key, token_id in asr_boost_map.items(): @@ -741,18 +742,22 @@ def _prepare_tts_initial_state(self): logging.info("Preparing TTS warmup state...") - with fp32_precision(): - speaker_audio, speaker_sr = torchaudio.load(self.speaker_reference) - speaker_audio = resample(speaker_audio, speaker_sr, self.model.tts_model.target_sample_rate) - - speaker_audio = speaker_audio.to(self.device) - speaker_audio_lens = torch.tensor([speaker_audio.size(1)], device=self.device).long() + if self.speaker_name is not None: + logging.info(f"Using registered speaker name: {self.speaker_name}") + speaker_audio = None + speaker_audio_lens = None + else: + with fp32_precision(): + speaker_audio, speaker_sr = torchaudio.load(self.speaker_reference) + speaker_audio = resample(speaker_audio, speaker_sr, self.model.tts_model.target_sample_rate) + speaker_audio = speaker_audio.to(self.device) + speaker_audio_lens = torch.tensor([speaker_audio.size(1)], device=self.device).long() - # init tts_model self.model.tts_model.set_init_inputs( speaker_audio=speaker_audio, speaker_audio_lens=speaker_audio_lens, system_prompt=self.tts_system_prompt, + speaker_name=self.speaker_name, ) init_inputs = self.model.tts_model.get_init_inputs(B=1) @@ -1211,7 +1216,7 @@ def _maybe_apply_forced_turn_taking(self, t, gen_text, gen_asr): # Require that the pad window starts after a non-pad token if has_pad_window and pad_lookback_start > 0: token_before_window = gen_asr[batch_idx, pad_lookback_start - 1] - has_pad_window = (token_before_window != self.model.stt_model.text_pad_id) and (token_before_window != self.model.stt_model.user_bos_id) + has_pad_window = (token_before_window != self.model.stt_model.text_pad_id) and (token_before_window != self.model.stt_model.text_bos_id) elif has_pad_window and pad_lookback_start == 0: # If the pad window starts at position 0, it doesn't meet the requirement has_pad_window = False @@ -1222,7 +1227,7 @@ def _maybe_apply_forced_turn_taking(self, t, gen_text, gen_asr): logging.info(f"Forced turn-taking at frame {t}: inserted agent BOS (reason: pad window)") # ASR BOS → insert agent EOS if not present in window - elif current_asr_token == self.model.stt_model.user_bos_id: + elif current_asr_token == self.model.stt_model.text_bos_id: if not (agent_text_window == self.model.stt_model.text_eos_id).any(): gen_text[batch_idx, t] = self.model.stt_model.text_eos_id logging.info(f"Forced turn-taking at frame {t}: inserted agent EOS (reason: user started speaking)") diff --git a/nemo/collections/speechlm2/inference/pipelines/streaming_s2s_pipeline.py b/nemo/collections/speechlm2/inference/pipelines/streaming_s2s_pipeline.py index 39535e8d118d..8c0c4d8c3cef 100644 --- a/nemo/collections/speechlm2/inference/pipelines/streaming_s2s_pipeline.py +++ b/nemo/collections/speechlm2/inference/pipelines/streaming_s2s_pipeline.py @@ -31,7 +31,7 @@ from nemo.collections.speechlm2.inference.pipelines.s2s_pipeline_interface import S2SPipelineInterface from nemo.collections.speechlm2.inference.streaming.state.s2s_state import S2SStreamingState from nemo.collections.speechlm2.inference.model_wrappers.nemotron_voicechat_inference_wrapper import NemotronVoicechatInferenceWrapper, tokens_to_str_raw -from nemo.collections.speechlm2.models.duplex_s2s_model import tokens_to_str +from nemo.collections.speechlm2.parts.text_utils import tokens_to_str from nemo.collections.speechlm2.inference.streaming.state.s2s_context_manager import S2SContextManager from nemo.collections.speechlm2.inference.streaming.framing.s2s_request_options import S2SRequestOptions from nemo.collections.speechlm2.inference.utils.pipeline_utils import PipelineOutput diff --git a/nemo/collections/speechlm2/models/duplex_stt_model.py b/nemo/collections/speechlm2/models/duplex_stt_model.py index d15becb2cbc8..9d49a5e6b254 100644 --- a/nemo/collections/speechlm2/models/duplex_stt_model.py +++ b/nemo/collections/speechlm2/models/duplex_stt_model.py @@ -109,6 +109,13 @@ def __init__(self, cfg: dict) -> None: self.asr_head = copy.deepcopy(self.lm_head) self.embed_asr_tokens = copy.deepcopy(self.embed_tokens) + self.use_function_head = self.cfg.get("use_function_head", False) + if self.use_function_head: + self.function_head = copy.deepcopy(self.lm_head) + logging.info("[Function Calling] Initialized function_head (deep copy of lm_head)") + else: + self.function_head = None + maybe_install_lora(self) # Load the pretrained streaming ASR model @@ -191,9 +198,22 @@ def forward( if self.cfg.get("inference_eos_boost", None): text_logits[:, :, self.text_eos_id] += self.cfg.inference_eos_boost + if self.predict_user_text: + if self.cfg.get("inference_user_pad_boost", None): + asr_logits[:, :, self.text_pad_id] += self.cfg.inference_user_pad_boost + if self.cfg.get("inference_user_bos_boost", None): + asr_logits[:, :, self.text_bos_id] += self.cfg.inference_user_bos_boost + if self.cfg.get("inference_user_eos_boost", None): + asr_logits[:, :, self.text_eos_id] += self.cfg.inference_user_eos_boost + ans = {"text_logits": text_logits} if self.predict_user_text: ans["asr_logits"] = asr_logits + if self.function_head is not None: + function_in = out['last_hidden_state'] + if function_in.dtype != self.function_head.weight.dtype: + function_in = function_in.to(self.function_head.weight.dtype) + ans["function_logits"] = self.function_head(function_in) if cache is not None: if 'Nemotron' in self.cfg.pretrained_llm: From 52813dea8185e09aa200d37618ce1cc4964b3f53 Mon Sep 17 00:00:00 2001 From: Elena Rastorgueva Date: Thu, 12 Mar 2026 23:40:32 +0000 Subject: [PATCH 03/40] add use_llm_cache option, will use HybridMambaAttentionDynamicCache, with patches Signed-off-by: Elena Rastorgueva --- .../conf/s2s_streaming.yaml | 1 + .../inference/model_wrappers/model_factory.py | 7 +- .../nemotron_voicechat_inference_wrapper.py | 39 ++++- .../pipelines/streaming_s2s_pipeline.py | 7 +- .../streaming/state/s2s_context_manager.py | 43 +++-- .../speechlm2/models/duplex_stt_model.py | 150 +++++++++++++++++- 6 files changed, 219 insertions(+), 28 deletions(-) diff --git a/examples/speechlm2/nemo_inference_pipelines/conf/s2s_streaming.yaml b/examples/speechlm2/nemo_inference_pipelines/conf/s2s_streaming.yaml index f066b989f5f2..9a8ad1bbe643 100644 --- a/examples/speechlm2/nemo_inference_pipelines/conf/s2s_streaming.yaml +++ b/examples/speechlm2/nemo_inference_pipelines/conf/s2s_streaming.yaml @@ -64,6 +64,7 @@ s2s: use_perception_cudagraph: true # Enable CUDA graph-accelerated perception encoder use_codec_cache: true # Incremental codec decode to remove clicking sounds and wasted computation # (when true, codec_token_history_size is unused) + use_llm_cache: true # Use KV cache for the STT LLM (DynamicCache or HybridMambaAttentionDynamicCache) # Deterministic inference (native engine only). Ensures identical results across # runs by disabling FlashAttention and forcing deterministic CUDA algorithms. diff --git a/nemo/collections/speechlm2/inference/model_wrappers/model_factory.py b/nemo/collections/speechlm2/inference/model_wrappers/model_factory.py index 505b60a57e56..fc082fb35ced 100644 --- a/nemo/collections/speechlm2/inference/model_wrappers/model_factory.py +++ b/nemo/collections/speechlm2/inference/model_wrappers/model_factory.py @@ -836,6 +836,7 @@ def __call__( self, input_embeds: torch.Tensor, cache: Optional[Any] = None, + cache_position: Optional[torch.Tensor] = None, generated_tokens: Optional[torch.Tensor] = None, current_step: int = 0, **kwargs @@ -845,7 +846,8 @@ def __call__( Args: input_embeds: Input embeddings [batch, seq_len, hidden_dim] - cache: Optional DynamicCache for transformers + cache: Optional DynamicCache or HybridMambaAttentionDynamicCache + cache_position: Optional position tensor for Nemotron models generated_tokens: Previously generated tokens [batch, num_generated]. Required for repetition_penalty. If None, creates empty tensor. current_step: Current decoding step. Used for repetition penalty. @@ -854,8 +856,7 @@ def __call__( Returns: Dictionary with 'predicted_token', 'asr_predicted_token', and 'cache' """ - # Call the underlying model - result = self.model.stt_model(input_embeds, cache=cache, **kwargs) + result = self.model.stt_model(input_embeds, cache=cache, cache_position=cache_position, **kwargs) # Ensure consistent return format if not isinstance(result, dict): diff --git a/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py b/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py index d7ab874b3c70..bb9da19ff321 100755 --- a/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py +++ b/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py @@ -255,6 +255,11 @@ def __init__(self, model_cfg: DictConfig): f"will be ignored (context is maintained incrementally by the codec cache)." ) + # LLM KV cache: when enabled, uses DynamicCache (standard) or + # HybridMambaAttentionDynamicCache (Nemotron) for incremental decoding. + # When disabled, falls back to full-history reprocessing each step. + self.use_llm_cache = bool(model_cfg.get("use_llm_cache", True)) + # Perception cache configuration self.use_perception_cache = bool(model_cfg.get("use_perception_cache", False)) use_perception_cudagraph = bool(model_cfg.get("use_perception_cudagraph", False)) @@ -852,7 +857,8 @@ def infer_one_step(self, request_id: Optional[str] = None, perception_cache: Optional[PerceptionCacheState] = None, has_prompt: bool = False, - codec_cache=None): + codec_cache=None, + cache_position_offset: int = 0): # Set up effective request ID for vLLM streaming effective_request_id = request_id or self.request_id @@ -964,9 +970,13 @@ def infer_one_step(self, current_step=current_frame_idx ) else: + cache_pos = torch.tensor( + [cache_position_offset + frame_offset], device=self.device + ) ans = self.model_llm_interface( current_input_emb, cache=dynamic_cache, + cache_position=cache_pos, generated_tokens=gen_text, current_step=current_frame_idx ) @@ -1150,6 +1160,7 @@ def infer_one_step(self, 'code': code, 'perception_cache': perception_cache, 'codec_cache': codec_cache, + 'cache_position_offset': cache_position_offset + num_frames_per_chunk if use_cache else cache_position_offset, } if self.model.stt_model.function_head is not None: result['function_predicted_text_tokens'] = function_predicted_tokens @@ -1343,10 +1354,10 @@ def inference_realtime_streaming(self, audio_path: str, num_frames_per_chunk: in # convert audio signal to tensor audio_signal_tensor = torch.tensor(audio_signal, dtype=self.dtype, device=self.device).unsqueeze(0) - # Check if Nemotron (no cache support) - use_cache = 'Nemotron' not in self.model.stt_model.cfg.pretrained_llm + use_cache = self.use_llm_cache + is_nemotron = 'Nemotron' in self.model.stt_model.cfg.pretrained_llm logging.info(f"Model: {self.model.stt_model.cfg.pretrained_llm}") - logging.info(f" Use cache: {use_cache}") + logging.info(f" Use LLM cache: {use_cache}, is_nemotron: {is_nemotron}") # Initialize buffer and state audio_buffer = torch.zeros(1, buffer_size_samples, dtype=self.dtype, device=self.device) @@ -1354,10 +1365,14 @@ def inference_realtime_streaming(self, audio_path: str, num_frames_per_chunk: in # Initialize LLM cache if use_cache: - llm_cache = DynamicCache() + if is_nemotron: + llm_cache = self.model.stt_model._create_nemotron_cache(batch_size=1) + else: + llm_cache = DynamicCache() else: llm_cache = None - input_embeds_history = [] # For no-cache mode + input_embeds_history = [] + cache_position_offset = 0 # Process system prompt if provided (before streaming audio) prompt_embedded = None @@ -1399,9 +1414,11 @@ def inference_realtime_streaming(self, audio_path: str, num_frames_per_chunk: in elif prompt_embedded is not None and use_cache: # For cache mode: process prompt through LLM to update cache with torch.no_grad(): - ans = self.model.stt_model(prompt_embedded, cache=llm_cache) + cache_pos = torch.arange(prompt_len, device=self.device) + ans = self.model.stt_model(prompt_embedded, cache=llm_cache, cache_position=cache_pos) llm_cache = ans.get("cache", llm_cache) - logging.info(f" System prompt processed, cache updated") + cache_position_offset = prompt_len + logging.info(f" System prompt processed, cache updated (offset={cache_position_offset})") # Initialize TTS code = None @@ -1499,6 +1516,7 @@ def inference_realtime_streaming(self, audio_path: str, num_frames_per_chunk: in perception_cache=perception_cache, has_prompt=(prompt_len > 0), codec_cache=codec_cache, + cache_position_offset=cache_position_offset, ) # handle results from infer_one_step @@ -1507,6 +1525,7 @@ def inference_realtime_streaming(self, audio_path: str, num_frames_per_chunk: in gen_function_text[:, frame_idx + fi] = result['function_predicted_text_tokens'][:, fi] input_embeds_history = result['input_embeds_history'] llm_cache = result['dynamic_cache'] + cache_position_offset = result.get('cache_position_offset', cache_position_offset) if self.use_perception_cache: perception_cache = result.get('perception_cache', perception_cache) if self.decode_audio: @@ -1635,6 +1654,9 @@ def main(): help="Enable cache-aware streaming for perception encoder") parser.add_argument("--use_perception_cudagraph", action="store_true", help="Use CUDA graphs for perception encoder (requires --use_perception_cache)") + # LLM KV cache argument + parser.add_argument("--use_llm_cache", action="store_true", + help="Use KV cache for the STT LLM (DynamicCache or HybridMambaAttentionDynamicCache for Nemotron)") # Codec streaming cache argument parser.add_argument("--use_codec_cache", action="store_true", help="Enable incremental codec decode to remove clicking sounds and wasted inference computation (recommended)") @@ -1719,6 +1741,7 @@ def main(): "deterministic": bool(args.deterministic), "use_perception_cache": bool(args.use_perception_cache), "use_perception_cudagraph": bool(args.use_perception_cudagraph), + "use_llm_cache": bool(args.use_llm_cache), "use_codec_cache": bool(args.use_codec_cache), "top_p": args.top_p, "repetition_penalty": args.repetition_penalty, diff --git a/nemo/collections/speechlm2/inference/pipelines/streaming_s2s_pipeline.py b/nemo/collections/speechlm2/inference/pipelines/streaming_s2s_pipeline.py index 8c0c4d8c3cef..ed49815246a5 100644 --- a/nemo/collections/speechlm2/inference/pipelines/streaming_s2s_pipeline.py +++ b/nemo/collections/speechlm2/inference/pipelines/streaming_s2s_pipeline.py @@ -237,6 +237,7 @@ def inner_generate_step(self, frames: List[Frame], buffers: List[Tensor], left_p perception_cache=context.perception_cache, has_prompt=has_prompt, codec_cache=context.codec_cache, + cache_position_offset=context.cache_position_offset, ) # Persist updated cache & clean finished streams @@ -723,17 +724,19 @@ def _prefill_system_prompt(self, stream_id: int, system_prompt: str | None = Non if context.dynamic_cache is not None: # Native cache mode: process prompt through LLM to update KV cache with torch.no_grad(): + cache_pos = torch.arange(prompt_len, device=self.s2s_model.device) llm_cache = context.dynamic_cache ans = self.s2s_model.model_llm_interface( prompt_embedded, cache=llm_cache, + cache_position=cache_pos, generated_tokens=None, current_step=0 ) context.dynamic_cache = ans.get("cache", llm_cache) - logging.info(f"System prompt processed, cache updated ({prompt_len} tokens)") + context.cache_position_offset = prompt_len + logging.info(f"System prompt processed, cache updated ({prompt_len} tokens, offset={prompt_len})") else: - # No-cache mode (e.g. Nemotron): add prompt embeddings to history for t in range(prompt_len): context.input_embeds_history.append(prompt_embedded[:, t:t+1, :]) logging.info(f"Added {prompt_len} prompt embeddings to input_embeds_history") diff --git a/nemo/collections/speechlm2/inference/streaming/state/s2s_context_manager.py b/nemo/collections/speechlm2/inference/streaming/state/s2s_context_manager.py index 997975e54373..beaef0990a75 100644 --- a/nemo/collections/speechlm2/inference/streaming/state/s2s_context_manager.py +++ b/nemo/collections/speechlm2/inference/streaming/state/s2s_context_manager.py @@ -34,12 +34,13 @@ class StreamingRealtimeContext: gen_function_text: Optional[torch.Tensor] audio_toks_buffer: Optional[torch.Tensor] input_embeds_history: List[torch.Tensor] - dynamic_cache: Optional[DynamicCache] + dynamic_cache: Any # DynamicCache or HybridMambaAttentionDynamicCache past_key_values: Any code: Optional[torch.Tensor] subword_mask: Optional[torch.Tensor] perception_cache: Optional["PerceptionCacheState"] = None codec_cache: Optional[CausalConv1dCache] = None + cache_position_offset: int = 0 class S2SContextManager: @@ -53,18 +54,6 @@ def __init__( self.s2s_model = s2s_model self.num_slots = num_slots - # Detect Nemotron models and disable DynamicCache - # (they require NemotronHHybridDynamicCache which isn't supported yet) - self.cache_disabled = False - stt_model = getattr(self.s2s_model.model, "stt_model", None) - if stt_model is not None: - pretrained_llm = stt_model.cfg.get("pretrained_llm", "") - if "Nemotron" in pretrained_llm: - logging.warning( - f"Detected Nemotron model ({pretrained_llm}). " - "Disabling DynamicCache (Nemotron requires NemotronHHybridDynamicCache which is not yet supported)." - ) - self.cache_disabled = True self.max_len = max_len self.device = getattr(self.s2s_model, "device", torch.device("cpu")) self.dtype = getattr(self.s2s_model, "dtype", torch.float32) @@ -73,6 +62,24 @@ def __init__( self.decode_audio = bool(getattr(self.s2s_model, "decode_audio", False)) self.use_perception_cache = bool(getattr(self.s2s_model, "use_perception_cache", False)) self.use_codec_cache = bool(getattr(self.s2s_model, "use_codec_cache", True)) + self.use_llm_cache = bool(getattr(self.s2s_model, "use_llm_cache", True)) + + self.is_nemotron = False + stt_model = getattr(self.s2s_model.model, "stt_model", None) + if stt_model is not None: + pretrained_llm = stt_model.cfg.get("pretrained_llm", "") + if "Nemotron" in pretrained_llm: + self.is_nemotron = True + if self.use_llm_cache: + logging.info( + f"Detected Nemotron model ({pretrained_llm}). " + "Will use HybridMambaAttentionDynamicCache for KV caching." + ) + else: + logging.info( + f"Detected Nemotron model ({pretrained_llm}). " + "LLM cache is disabled (use_llm_cache=False)." + ) self.reset() @@ -112,7 +119,13 @@ def _create_context(self) -> StreamingRealtimeContext: dtype=torch.long, ) - dynamic_cache = None if self.cache_disabled else DynamicCache() + if not self.use_llm_cache: + dynamic_cache = None + elif self.is_nemotron: + stt_model = getattr(self.s2s_model.model, "stt_model", None) + dynamic_cache = stt_model._create_nemotron_cache(batch_size=1) + else: + dynamic_cache = DynamicCache() audio_toks_buffer: Optional[torch.Tensor] = None past_key_values: Any = None code: Optional[torch.Tensor] = None @@ -267,6 +280,8 @@ def update_context( context.perception_cache = step_result["perception_cache"] if "codec_cache" in step_result and step_result["codec_cache"] is not None: context.codec_cache = step_result["codec_cache"] + if "cache_position_offset" in step_result: + context.cache_position_offset = step_result["cache_position_offset"] def reset_slots(self, stream_ids: List[int], eos_flags: List[bool]) -> None: """Release contexts for streams that signalled end-of-stream.""" diff --git a/nemo/collections/speechlm2/models/duplex_stt_model.py b/nemo/collections/speechlm2/models/duplex_stt_model.py index 9d49a5e6b254..f6f0919e2749 100644 --- a/nemo/collections/speechlm2/models/duplex_stt_model.py +++ b/nemo/collections/speechlm2/models/duplex_stt_model.py @@ -59,7 +59,7 @@ def maybe_rename_llm_kwargs_for_nemotron(kwargs: dict, model_cfg) -> dict: return kwargs cache = kwargs.pop("past_key_values") if cache is not None: - cache_key = model_cfg.get("cache_key", "past_key_values") + cache_key = model_cfg.get("cache_key", "cache_params") kwargs[cache_key] = cache return kwargs @@ -174,12 +174,15 @@ def forward( self, input_embeds: Tensor, cache=None, + cache_position=None, ) -> dict[str, Tensor]: """ Text prediction only (audio_loss_weight=0). """ kwargs = dict(inputs_embeds=input_embeds, past_key_values=cache, use_cache=cache is not None, return_dict=True) kwargs = maybe_rename_llm_kwargs_for_nemotron(kwargs, self.cfg) + if cache_position is not None: + kwargs["cache_position"] = cache_position out = self.llm(**kwargs) B, T = input_embeds.shape[:2] @@ -748,3 +751,148 @@ def load_state_dict(self, state_dict, strict: bool = True): logging.info("Error loading model state_dict !! Retrying with partial initialization!") model_dict = set_model_dict_for_partial_init(state_dict, self.state_dict()) return super().load_state_dict(model_dict, strict=False) + + def _create_nemotron_cache(self, batch_size: int = 1): + """Create a HybridMambaAttentionDynamicCache for Nemotron hybrid Mamba2/Attention models.""" + import importlib + cache_cls = None + llm = self.llm + if hasattr(llm, '_orig_mod'): + llm = llm._orig_mod + model_module = type(llm).__module__ + if model_module: + mod = importlib.import_module(model_module) + cache_cls = getattr(mod, "HybridMambaAttentionDynamicCache", None) + if cache_cls is None: + raise RuntimeError( + "Could not find HybridMambaAttentionDynamicCache in the Nemotron model's module. " + "Ensure the model was loaded with trust_remote_code=True." + ) + + # Newer transformers defines key_cache/value_cache as read-only + # properties on DynamicCache, but the Nemotron __init__ tries to set + # them as regular attributes. Create a patched subclass that shadows + # the properties so the assignment succeeds. + needs_patch = any( + isinstance(cls.__dict__.get(attr), property) + for cls in cache_cls.__mro__ + for attr in ('key_cache', 'value_cache') + ) + if needs_patch: + patched = type(cache_cls.__name__, (cache_cls,), { + 'key_cache': None, + 'value_cache': None, + }) + else: + patched = cache_cls + + config = llm.config + cache = patched( + config=config, + batch_size=batch_size, + dtype=llm.dtype if hasattr(llm, 'dtype') else torch.float32, + device=self.device, + ) + if not hasattr(cache, 'conv_kernel_size'): + cache.conv_kernel_size = config.conv_kernel + + intermediate_size = config.mamba_num_heads * config.mamba_head_dim + conv_dim = intermediate_size + 2 * config.n_groups * config.ssm_state_size + if conv_dim != intermediate_size: + conv_kernel = config.conv_kernel + dtype = llm.dtype if hasattr(llm, 'dtype') else torch.float32 + for i, pattern in enumerate(config.hybrid_override_pattern): + if pattern == "M" and i < len(cache.conv_states): + cache.conv_states[i] = torch.zeros( + batch_size, conv_dim, conv_kernel, device=self.device, dtype=dtype, + ) + + self._patch_nemotron_cache_bugs(cache) + self._patch_nemotron_block_forward(llm) + return cache + + @staticmethod + def _patch_nemotron_block_forward(llm): + """Patch NemotronHBlock.forward to pass cache and mask to attention layers. + + The upstream HF code only passes cache_position to the attention mixer, + omitting past_key_value and attention_mask. Without this fix, attention + layers never see the KV cache and only attend to the current token. + """ + import types + + if hasattr(llm, '_orig_mod'): + llm = llm._orig_mod + + layers = getattr(llm, 'layers', None) + if layers is None: + return + + def _patched_block_forward( + self, + hidden_states, + cache_params=None, + cache_position=None, + attention_mask=None, + ): + with torch.cuda.stream(torch.cuda.default_stream(hidden_states.device)): + residual = hidden_states + hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype)) + if self.residual_in_fp32: + residual = residual.to(torch.float32) + + if self.block_type == "mamba": + hidden_states = self.mixer( + hidden_states, cache_params=cache_params, cache_position=cache_position + ) + elif self.block_type == "attention": + hidden_states = self.mixer( + hidden_states, + past_key_value=cache_params, + attention_mask=attention_mask, + cache_position=cache_position, + ) + hidden_states = hidden_states[0] + elif self.block_type == "mlp": + hidden_states = self.mixer(hidden_states) + else: + raise ValueError(f"Invalid block_type: {self.block_type}") + + hidden_states = residual + hidden_states + return hidden_states + + patched_count = 0 + for layer in layers: + block_type = getattr(layer, 'block_type', None) + if block_type == "attention": + layer.forward = types.MethodType(_patched_block_forward, layer) + patched_count += 1 + + if patched_count > 0: + logging.info(f"Patched {patched_count} NemotronHBlock attention layers to pass KV cache") + + @staticmethod + def _patch_nemotron_cache_bugs(cache): + """Patch bugs in HybridMambaAttentionDynamicCache from the HF model code. + + The upstream code references self.conv_states.device and self.ssm_states.device, + but these are Python lists, not tensors. We monkey-patch the affected methods. + """ + import types + + def update_conv_state(self, layer_idx, new_conv_state, cache_init=False): + if cache_init: + self.conv_states[layer_idx] = new_conv_state.to(self.conv_states[layer_idx].device) + else: + self.conv_states[layer_idx] = self.conv_states[layer_idx].roll(shifts=-1, dims=-1) + self.conv_states[layer_idx][:, :, -1] = new_conv_state[:, 0, :].to( + self.conv_states[layer_idx].device + ) + return self.conv_states[layer_idx] + + def update_ssm_state(self, layer_idx, new_ssm_state): + self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states[layer_idx].device) + return self.ssm_states[layer_idx] + + cache.update_conv_state = types.MethodType(update_conv_state, cache) + cache.update_ssm_state = types.MethodType(update_ssm_state, cache) From a65916a18020221fb8774cfe39aa02f0656c0e4b Mon Sep 17 00:00:00 2001 From: Elena Rastorgueva Date: Fri, 13 Mar 2026 02:57:07 +0000 Subject: [PATCH 04/40] add tts inference speedups: vectorize depthsum, precompute rvq schedule, optional torch.compile & subword cache Signed-off-by: Elena Rastorgueva --- .../conf/s2s_streaming.yaml | 4 ++ .../nemotron_voicechat_inference_wrapper.py | 29 +++++++++++ .../speechlm2/modules/ear_tts_model.py | 51 ++++++++++++++----- 3 files changed, 71 insertions(+), 13 deletions(-) diff --git a/examples/speechlm2/nemo_inference_pipelines/conf/s2s_streaming.yaml b/examples/speechlm2/nemo_inference_pipelines/conf/s2s_streaming.yaml index 9a8ad1bbe643..1fdaf5998637 100644 --- a/examples/speechlm2/nemo_inference_pipelines/conf/s2s_streaming.yaml +++ b/examples/speechlm2/nemo_inference_pipelines/conf/s2s_streaming.yaml @@ -66,6 +66,10 @@ s2s: # (when true, codec_token_history_size is unused) use_llm_cache: true # Use KV cache for the STT LLM (DynamicCache or HybridMambaAttentionDynamicCache) + # TTS speedup flags (default to false; enable to speed up native inference) + use_tts_torch_compile: false # Compile TTS backbone with torch.compile (mode='default') + use_tts_subword_cache: false # Cache CharAwareSubwordEncoder embeddings (skip backbone for repeated tokens) + # Deterministic inference (native engine only). Ensures identical results across # runs by disabling FlashAttention and forcing deterministic CUDA algorithms. # Trade-offs: slower inference, might be worse results than non-deterministic mode, since diff --git a/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py b/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py index bb9da19ff321..b41c17a1a366 100755 --- a/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py +++ b/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py @@ -534,6 +534,25 @@ def _initialize_model(self): # and anyway - when sticking to "native", saw no difference in output # with and without this call #self.model.on_train_epoch_start() + + # torch.compile for native TTS backbone + use_tts_torch_compile = bool(self.model_cfg.get("use_tts_torch_compile", False)) + if use_tts_torch_compile and not self.use_vllm_eartts and hasattr(self.model, 'tts_model'): + tts_backbone = getattr(self.model.tts_model, 'tts_model', None) + if tts_backbone is not None and hasattr(tts_backbone, 'backbone'): + logging.info("Compiling TTS backbone with torch.compile(mode='default')...") + tts_backbone.backbone = torch.compile(tts_backbone.backbone, mode="default") + logging.info(" TTS backbone compiled") + + # Inject TTS speedup flags into the TTS model config so ear_tts_model.py can read them + tts_inner = getattr(self.model.tts_model, 'tts_model', None) if hasattr(self.model, 'tts_model') else None + if tts_inner is not None and hasattr(tts_inner, 'config'): + if bool(self.model_cfg.get("use_tts_subword_cache", False)): + OmegaConf.update(tts_inner.config, "use_tts_subword_cache", True) + logging.info("TTS speedup enabled: use_tts_subword_cache") + if hasattr(tts_inner, 'embed_subword') and tts_inner.embed_subword is not None and hasattr(tts_inner.embed_subword, 'use_tts_subword_cache'): + tts_inner.embed_subword.use_tts_subword_cache = True + self.tokenizer = self.model.stt_model.tokenizer @@ -1661,6 +1680,14 @@ def main(): parser.add_argument("--use_codec_cache", action="store_true", help="Enable incremental codec decode to remove clicking sounds and wasted inference computation (recommended)") + # torch.compile for native inference + parser.add_argument("--use_tts_torch_compile", action="store_true", + help="Compile TTS backbone with torch.compile for faster native inference (mode='default')") + + # TTS model speedup flags (applied inside ear_tts_model.py) + parser.add_argument("--use_tts_subword_cache", action="store_true", + help="Cache CharAwareSubwordEncoder embeddings at inference time (skip backbone for repeated tokens)") + # vLLM arguments parser.add_argument("--engine_type", type=str, default="native", choices=["native", "vllm_llm", "vllm_eartts", "vllm_llm_vllm_eartts"], help="Engine type for inference (default: native)") @@ -1743,6 +1770,8 @@ def main(): "use_perception_cudagraph": bool(args.use_perception_cudagraph), "use_llm_cache": bool(args.use_llm_cache), "use_codec_cache": bool(args.use_codec_cache), + "use_tts_torch_compile": bool(args.use_tts_torch_compile), + "use_tts_subword_cache": bool(args.use_tts_subword_cache), "top_p": args.top_p, "repetition_penalty": args.repetition_penalty, "temperature": args.temperature, diff --git a/nemo/collections/speechlm2/modules/ear_tts_model.py b/nemo/collections/speechlm2/modules/ear_tts_model.py index 879416915cc7..4be39de8889c 100644 --- a/nemo/collections/speechlm2/modules/ear_tts_model.py +++ b/nemo/collections/speechlm2/modules/ear_tts_model.py @@ -903,6 +903,8 @@ def __init__( if self.use_bos_eos_emb: self.bos_eos_emb = BOSEOSEmbedding(tokenizer, self.hidden_size) + self.use_tts_subword_cache = False + def prepare_inputs(self, subword_ids: Tensor, padding_mask: Tensor) -> tuple[Tensor, Tensor]: """ Converts a batch of subword IDs into a padded batch of character IDs. @@ -937,6 +939,10 @@ def forward(self, subword_ids: Tensor, subword_mask: Tensor | None = None) -> Te """ Performs the forward pass to get character-aware subword embeddings. + When use_tts_subword_cache is True and the module is in eval mode, a + per-subword-ID cache skips the expensive char encoding + backbone + + pooling path for previously seen tokens. + Args: subword_ids (Tensor): A tensor of subword IDs. Shape: `[batch, seq_len]`. subword_mask (Tensor | None): A boolean mask for padding. Defaults to None. @@ -947,6 +953,19 @@ def forward(self, subword_ids: Tensor, subword_mask: Tensor | None = None) -> Te if subword_mask is None: subword_mask = torch.ones_like(subword_ids, dtype=torch.bool) + # Inference cache: return cached embeddings if all valid IDs have been seen + if not self.training and self.use_tts_subword_cache: + if not hasattr(self, '_inference_cache'): + self._inference_cache = {} + valid_ids = torch.masked_select(subword_ids, subword_mask).tolist() + if all(sid in self._inference_cache for sid in valid_ids): + cached = torch.stack([self._inference_cache[sid] for sid in valid_ids]) + out = torch.zeros( + subword_ids.shape + (cached.size(-1),), device=subword_ids.device, dtype=cached.dtype + ) + out[subword_mask] = cached + return out + # 1. Convert subword IDs to character IDs char_ids, char_lengths = self.prepare_inputs(subword_ids, subword_mask) @@ -978,6 +997,12 @@ def forward(self, subword_ids: Tensor, subword_mask: Tensor | None = None) -> Te if self.use_bos_eos_emb: subword_embeds = self.bos_eos_emb(subword_embeds, subword_ids) + # Cache results for future lookups + if not self.training and self.use_tts_subword_cache: + valid_embeds = subword_embeds[subword_mask].detach() + for idx, sid in enumerate(valid_ids): + self._inference_cache[sid] = valid_embeds[idx] + return subword_embeds @@ -1146,15 +1171,13 @@ def depthsum_embedding(self, code: Tensor) -> Tensor: ret: [b, t, h] """ b, t, d = code.size() - _, v, h = self.rvq_embs.size() - device = code.device - - ret = torch.zeros((b, t, h), device=device) embs = F.pad(self.rvq_embs, [0, 0, 0, 1]) - for i in range(d): - emb = embs[i] - ret = ret + F.embedding(code[..., i], emb) - return ret + v_padded = embs.shape[1] + offsets = torch.arange(d, device=code.device).view(1, 1, d) * v_padded + flat_indices = (code + offsets).reshape(b * t * d) + flat_embs = embs.reshape(d * v_padded, -1) + gathered = F.embedding(flat_indices, flat_embs).reshape(b, t, d, -1) + return gathered.sum(dim=2) def prepare_training_inputs(self, code: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: """Prepares masked and dropped-out versions of the code for training.""" @@ -1638,11 +1661,13 @@ def generate_step( num_maskings = torch.ceil(masking_rates * self.config.num_quantizers).long() ks = num_maskings - F.pad(num_maskings[1:], [0, 0, 0, 1]) + ks_list = ks.squeeze(-1).tolist() # 4. Iteratively unmask the continuous part of the code cnt = 0 - for i, k in enumerate(ks): - if torch.all(k == 0): + for i, k_val in enumerate(ks_list): + k_val = int(k_val) + if k_val == 0: continue # Prepare input for the MoG head @@ -1652,7 +1677,7 @@ def generate_step( mog_input_embeds = self.embed_code(self.depthsum_embedding(code)) if self.config.random_target_masking: - mog_input_embeds += self.embed_target_mask(cnt + k - 1) + mog_input_embeds += self.embed_target_mask(cnt + k_val - 1) if guidance_scale_i > 0.0: mog_input_embeds = torch.cat( [mog_input_embeds + hidden_states, mog_input_embeds + uncond_hidden_states], 0 @@ -1666,8 +1691,8 @@ def generate_step( top_p_or_k=top_p_or_k_i, ) z = mog_mu + torch.exp(mog_logs) * torch.randn_like(mog_mu) * noise_scale_i - code = depthsum_encoding_step(self.rvq_embs, z, code, cnt, k[0].item()) - cnt += k[0].item() + code = depthsum_encoding_step(self.rvq_embs, z, code, cnt, k_val) + cnt += k_val return code, lm_logits, eos_flag def load_state_dict(self, state_dict, strict: bool = True): From 2b217535f38fb861c5c39f124e5926e3f9dc35f6 Mon Sep 17 00:00:00 2001 From: Elena Rastorgueva Date: Fri, 13 Mar 2026 21:04:12 +0000 Subject: [PATCH 05/40] allow using speaker_latent with vllm (need to update vllm eartts.py) Signed-off-by: Elena Rastorgueva --- .../inference/model_wrappers/model_factory.py | 27 +++++++++++++ .../nemotron_voicechat_inference_wrapper.py | 16 ++++---- .../streaming/state/s2s_context_manager.py | 3 +- .../vllm/scripts/convert_eartts_checkpoint.py | 40 ++++++++++++++----- 4 files changed, 68 insertions(+), 18 deletions(-) diff --git a/nemo/collections/speechlm2/inference/model_wrappers/model_factory.py b/nemo/collections/speechlm2/inference/model_wrappers/model_factory.py index fc082fb35ced..a9af52b984c9 100644 --- a/nemo/collections/speechlm2/inference/model_wrappers/model_factory.py +++ b/nemo/collections/speechlm2/inference/model_wrappers/model_factory.py @@ -637,6 +637,7 @@ def __init__(self, **kwargs): **kwargs: Arguments passed to the VllmLLMModel constructor """ super().__init__(**kwargs) + self._speaker_latent_dim = None logging.info("VllmEARTTSModel initialized with EARTTS-specific settings.") def _convert_ckpt(self, save_path: str): @@ -760,6 +761,32 @@ async def _process_inputs_to_outputs( bos_mask = torch.full_like(current_subword_id, 1e-20, dtype=getattr(torch, self._dtype)) input_tensors.append(bos_mask) + # Pass speaker_latent: the pre-extracted speaker embedding. + # During prefill with speaker_name: audio_prompt_lantent is [1, T, hidden_size] + # During decode or speaker_reference: pass zeros so the model falls back + # to computing the latent from acoustic tokens. + if "audio_prompt_lantent" in inputs and inputs["audio_prompt_lantent"] is not None: + speaker_latent = inputs["audio_prompt_lantent"].squeeze(0) # T x hidden_size + self._speaker_latent_dim = speaker_latent.shape[-1] + input_tensors.append(speaker_latent.to(dtype=getattr(torch, self._dtype))) + else: + if self._speaker_latent_dim is None: + # Read hidden_size from the converted model config + import json as _json + dir_name = os.path.basename(os.path.normpath(self.model_path)) + converted_config_path = os.path.join("/tmp", dir_name + "_vllm_converted_eartts", "config.json") + if os.path.exists(converted_config_path): + with open(converted_config_path) as _f: + self._speaker_latent_dim = _json.load(_f)["hidden_size"] + else: + raise RuntimeError( + f"Cannot determine speaker_latent_dim: converted config not found at {converted_config_path}. " + "Run a prefill with audio_prompt_lantent first, or ensure the converted checkpoint exists." + ) + num_tokens = codes.shape[0] + speaker_latent = torch.zeros(num_tokens, self._speaker_latent_dim, dtype=getattr(torch, self._dtype)) + input_tensors.append(speaker_latent) + result = await self.engine.generate_next_token(input_tensors, prompt_token_ids=prompt_token_ids, request_id=request_id) acoustic_tokens = result.custom_outputs["acoustic_tokens"] # T x 31 step_acoustic_tokens = acoustic_tokens[-1:] # 1 x 31 diff --git a/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py b/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py index b41c17a1a366..76206348b0df 100755 --- a/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py +++ b/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py @@ -1382,15 +1382,15 @@ def inference_realtime_streaming(self, audio_path: str, num_frames_per_chunk: in audio_buffer = torch.zeros(1, buffer_size_samples, dtype=self.dtype, device=self.device) buffer_fill_level = 0 # How many samples currently in buffer - # Initialize LLM cache - if use_cache: - if is_nemotron: - llm_cache = self.model.stt_model._create_nemotron_cache(batch_size=1) - else: - llm_cache = DynamicCache() - else: + # Initialize LLM cache (skip for vLLM -- it manages its own KV cache) + if not use_cache or self.use_vllm_llm: llm_cache = None - input_embeds_history = [] + if not use_cache: + input_embeds_history = [] + elif is_nemotron: + llm_cache = self.model.stt_model._create_nemotron_cache(batch_size=1) + else: + llm_cache = DynamicCache() cache_position_offset = 0 # Process system prompt if provided (before streaming audio) diff --git a/nemo/collections/speechlm2/inference/streaming/state/s2s_context_manager.py b/nemo/collections/speechlm2/inference/streaming/state/s2s_context_manager.py index beaef0990a75..9d0282eb7951 100644 --- a/nemo/collections/speechlm2/inference/streaming/state/s2s_context_manager.py +++ b/nemo/collections/speechlm2/inference/streaming/state/s2s_context_manager.py @@ -119,7 +119,8 @@ def _create_context(self) -> StreamingRealtimeContext: dtype=torch.long, ) - if not self.use_llm_cache: + use_vllm_llm = bool(getattr(self.s2s_model, "use_vllm_llm", False)) + if not self.use_llm_cache or use_vllm_llm: dynamic_cache = None elif self.is_nemotron: stt_model = getattr(self.s2s_model.model, "stt_model", None) diff --git a/nemo/collections/speechlm2/inference/vllm/scripts/convert_eartts_checkpoint.py b/nemo/collections/speechlm2/inference/vllm/scripts/convert_eartts_checkpoint.py index abb722531319..620f6bf20d66 100644 --- a/nemo/collections/speechlm2/inference/vllm/scripts/convert_eartts_checkpoint.py +++ b/nemo/collections/speechlm2/inference/vllm/scripts/convert_eartts_checkpoint.py @@ -38,7 +38,8 @@ def convert(outdir, config, model_path): # load config with open(config, "r") as f: - config_dict = json.load(f)["model"]["speech_generation"] + full_config = json.load(f) + config_dict = full_config["model"]["speech_generation"] cfg = DictConfig(config_dict) # config modification that is needed to run inference cfg.model.tts_config.use_unshifthed_prompt = True @@ -54,14 +55,15 @@ def convert(outdir, config, model_path): cfg.data.target_sample_rate = 22050 cfg.model.pretrained_model = None - # Compatibility fix: remove 'pretrained_tokenizer_name' from cas_config - # (the new codebase's CharAwareSubwordEncoder no longer accepts this parameter; - # NemotronVoiceChat.__init__ handles this, but we bypass it here) - _pretrained_tokenizer_name = None - if hasattr(cfg.model, "tts_config") and hasattr(cfg.model.tts_config, "cas_config"): - _pretrained_tokenizer_name = cfg.model.tts_config.cas_config.get("pretrained_tokenizer_name", None) - if _pretrained_tokenizer_name is not None: - del cfg.model.tts_config.cas_config.pretrained_tokenizer_name + # Resolve tokenizer name from the STT config (same tokenizer is shared by TTS). + _pretrained_tokenizer_name = ( + full_config.get("model", {}).get("stt", {}).get("model", {}).get("pretrained_llm", None) + ) + if _pretrained_tokenizer_name is None: + raise ValueError( + "Cannot determine tokenizer: 'pretrained_llm' not found in " + "config.json -> model -> stt -> model. Check the checkpoint." + ) model = DuplexEARTTS(OmegaConf.to_container(cfg, resolve=True)).eval() # get subword encoder vocabs and config @@ -243,6 +245,7 @@ def convert(outdir, config, model_path): {"name": "text_tokens", "dtype": "int32"}, {"name": "text_mask"}, {"name": "bos_mask"}, + {"name": "speaker_latent", "dim": flat_config["hidden_size"]}, ] flat_config["custom_outputs"] = ["acoustic_tokens"] @@ -250,6 +253,25 @@ def convert(outdir, config, model_path): json.dump(flat_config, f, indent=2) logging.info("Saved vllm config") + # Extract and save pre-computed speaker latents (audio_prompt_latents.*) + # from the NeMo checkpoint so they can be used at inference time. + speaker_latents_dir = os.path.join(outdir, "speaker_latents") + all_weights = load_file(model_path) + found_latents = False + for key, tensor in all_weights.items(): + if "audio_prompt_latents." in key: + speaker_name = key.split("audio_prompt_latents.")[-1] + os.makedirs(speaker_latents_dir, exist_ok=True) + latent_path = os.path.join(speaker_latents_dir, f"{speaker_name}.pt") + torch.save(tensor, latent_path) + logging.info(f"Saved speaker latent '{speaker_name}' to {latent_path} (shape={tensor.shape})") + found_latents = True + if not found_latents: + logging.warning( + "No audio_prompt_latents found in checkpoint. " + "speaker_name will not work unless latents are added." + ) + if __name__ == "__main__": args = parse_args() From 11abaa0f708c364cf41dd6fc548e43543e7c2ea2 Mon Sep 17 00:00:00 2001 From: Elena Rastorgueva Date: Fri, 13 Mar 2026 23:04:46 +0000 Subject: [PATCH 06/40] add flag for speaker_name if doing standalone inference Signed-off-by: Elena Rastorgueva --- .../model_wrappers/nemotron_voicechat_inference_wrapper.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py b/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py index 76206348b0df..632b1b11103f 100755 --- a/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py +++ b/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py @@ -1651,8 +1651,10 @@ def main(): help="Append silence equal to this ratio of the original audio duration (e.g. 0.2 = 20%% extra)") parser.add_argument("--pad_audio_by_sec", type=float, default=None, help="Append this many seconds of extra silence after the audio") - parser.add_argument("--speaker_reference", type=str, required=True, + parser.add_argument("--speaker_reference", type=str, default=None, help="Path to speaker reference audio file") + parser.add_argument("--speaker_name", type=str, default=None, + help="Name of a registered speaker whose latent is cached in the checkpoint") parser.add_argument("--buffer_size_frames", type=int, default=DEFAULT_BUFFER_SIZE_FRAMES, help=f"Size of audio buffer in frames (each frame = 80ms, default: {DEFAULT_BUFFER_SIZE_FRAMES})") parser.add_argument("--num_frames_per_chunk", type=int, default=DEFAULT_NUM_FRAMES_PER_CHUNK, @@ -1751,6 +1753,8 @@ def main(): if sum(x is not None for x in [args.pad_audio_to_sec, args.pad_silence_ratio, args.pad_audio_by_sec]) > 1: raise ValueError("Set at most one of: --pad_audio_to_sec, --pad_silence_ratio, --pad_audio_by_sec") + if args.speaker_reference is None and args.speaker_name is None: + parser.error("At least one of --speaker_reference or --speaker_name must be provided") if not math.isfinite(args.temperature) or args.temperature < 0.0: parser.error(f"--temperature must be a finite value >= 0.0, got {args.temperature}") @@ -1762,6 +1766,7 @@ def main(): "model_path": args.model_path, "llm_checkpoint_path": args.llm_checkpoint_path, "speaker_reference": args.speaker_reference, + "speaker_name": args.speaker_name, "buffer_size_frames": args.buffer_size_frames, "decode_audio": bool(args.decode_audio), "engine_type": args.engine_type, From 0b506b21c1f0d08e6af8600b49ac71126a1a6247 Mon Sep 17 00:00:00 2001 From: Elena Rastorgueva Date: Thu, 19 Mar 2026 03:57:11 +0000 Subject: [PATCH 07/40] remove standalone code path; add parity check for offline vs streaming - adjusted infer_one_step code so operations will match offline Signed-off-by: Elena Rastorgueva --- .../s2s_streaming_infer.py | 7 +- .../voicechat/1/infer_streaming.py | 5 +- .../nemotron_voicechat_parity_harness.py | 746 ++++++++++++++ .../inference/model_wrappers/model_factory.py | 6 + .../nemotron_voicechat_inference_wrapper.py | 975 +++--------------- .../pipelines/streaming_s2s_pipeline.py | 39 + .../streaming/state/s2s_context_manager.py | 134 +-- .../inference/utils/pipeline_utils.py | 14 + 8 files changed, 958 insertions(+), 968 deletions(-) create mode 100644 examples/speechlm2/nemotron_voicechat_parity_harness.py diff --git a/examples/speechlm2/nemo_inference_pipelines/s2s_streaming_infer.py b/examples/speechlm2/nemo_inference_pipelines/s2s_streaming_infer.py index 44a1979a1123..693e3ac6c5ad 100644 --- a/examples/speechlm2/nemo_inference_pipelines/s2s_streaming_infer.py +++ b/examples/speechlm2/nemo_inference_pipelines/s2s_streaming_infer.py @@ -142,11 +142,10 @@ def dump_output( """ Dump inference results to output_processed.json and output_raw.json. - output_processed.json uses the same schema as the standalone wrapper's - output_results_processed.json (timestamps in pred_text via <|t|> / <$t$>). + output_processed.json uses the canonical S2S processed-output schema + (timestamps in pred_text via <|t|> / <$t$>). - output_raw.json preserves all tokens including (pad tokens), - matching the standalone wrapper's output_results_raw.json. + output_raw.json preserves all tokens including (pad tokens). CTM files are still written for per-word audio-sample-based timing. diff --git a/examples/speechlm2/nemo_inference_pipelines/triton/model_repo_s2s/voicechat/1/infer_streaming.py b/examples/speechlm2/nemo_inference_pipelines/triton/model_repo_s2s/voicechat/1/infer_streaming.py index bf59203366ae..7fdb449f0bec 100644 --- a/examples/speechlm2/nemo_inference_pipelines/triton/model_repo_s2s/voicechat/1/infer_streaming.py +++ b/examples/speechlm2/nemo_inference_pipelines/triton/model_repo_s2s/voicechat/1/infer_streaming.py @@ -38,10 +38,10 @@ class TritonPythonModel: def _resolve_env_overrides(self, cfg): """Resolve ??? placeholders in the config from environment variables. - + This allows start_triton.sh to control model paths and settings via env vars, while sharing the same s2s_streaming.yaml used by the CLI. - + Env var mapping (cfg key -> env var, default): s2s.model_path -> S2S_MODEL_PATH (required) s2s.llm_checkpoint_path -> S2S_LLM_CHECKPOINT_PATH (required) @@ -69,7 +69,6 @@ def _resolve_env_overrides(self, cfg): for cfg_key, (env_var, default) in env_overrides.items(): val = os.environ.get(env_var) if val is not None: - # Cast to match the default's type (e.g. "0.08" -> float) if default is not None and isinstance(default, bool): val = val.lower() in ("true", "1", "yes") elif default is not None and isinstance(default, float): diff --git a/examples/speechlm2/nemotron_voicechat_parity_harness.py b/examples/speechlm2/nemotron_voicechat_parity_harness.py new file mode 100644 index 000000000000..582a3be8b722 --- /dev/null +++ b/examples/speechlm2/nemotron_voicechat_parity_harness.py @@ -0,0 +1,746 @@ +from __future__ import annotations + +import argparse +import json +import math +import os +from pathlib import Path +import tempfile +from typing import Any + +import librosa +import soundfile as sf +import torch +from omegaconf import MISSING, OmegaConf + +from nemo.collections.speechlm2.inference.factory.s2s_pipeline_builder import S2SPipelineBuilder +from nemo.collections.speechlm2.inference.model_wrappers.nemotron_voicechat_inference_wrapper import ( + FRAME_SIZE_SAMPLES, + NemotronVoicechatInferenceWrapper, +) +from nemo.collections.speechlm2.inference.streaming.framing.s2s_request_options import S2SRequestOptions +def _bool_arg(parser: argparse.ArgumentParser, name: str, help_text: str) -> None: + parser.add_argument(name, action=argparse.BooleanOptionalAction, default=None, help=help_text) + + +def _default_s2s_streaming_config_path() -> str: + repo_root = Path(__file__).resolve().parents[2] + return str(repo_root / "examples" / "speechlm2" / "nemo_inference_pipelines" / "conf" / "s2s_streaming.yaml") + + +def _load_s2s_inference_config(config_path: str | None = None): + path = config_path or _default_s2s_streaming_config_path() + cfg = OmegaConf.load(path) + for key, value in { + "audio_file": "", + "output_dir": "./generated", + "s2s.model_path": None, + "s2s.llm_checkpoint_path": None, + "s2s.decode_audio": True, + "s2s.engine_type": "native", + "s2s.system_prompt": None, + "streaming.chunk_size_in_secs": FRAME_SIZE_SAMPLES / 16000.0, + "streaming.buffer_size_in_secs": 71 * (FRAME_SIZE_SAMPLES / 16000.0), + }.items(): + if OmegaConf.select(cfg, key, default=MISSING) is MISSING: + OmegaConf.update(cfg, key, value, force_add=True) + return cfg + + +def _apply_inference_overrides(cfg, overrides: dict[str, Any]): + for key, value in overrides.items(): + if value is not None: + OmegaConf.update(cfg, key, value, force_add=True) + return cfg + + +def _load_audio_tensor(audio_path: str, sample_rate: int, device: torch.device, dtype: torch.dtype): + audio_np, _ = librosa.load(audio_path, sr=sample_rate) + audio = torch.tensor(audio_np, device=device, dtype=dtype).unsqueeze(0) + audio_lens = torch.tensor([audio.shape[1]], device=device, dtype=torch.long) + return audio, audio_lens + + +def _build_prompt_token_ids(tokenizer, system_prompt: str | None) -> list[int]: + if not system_prompt or not system_prompt.strip(): + return [] + return [tokenizer.bos_id] + tokenizer.text_to_ids(system_prompt) + [tokenizer.eos_id] + + +def _resolve_num_frames_per_chunk(args, total_frames: int) -> int: + if args.num_frames_per_chunk is not None: + value = int(args.num_frames_per_chunk) + elif args.chunk_size_in_secs is not None: + value = int(round(float(args.chunk_size_in_secs) / (FRAME_SIZE_SAMPLES / 16000.0))) + else: + value = total_frames + + if value < 1: + raise ValueError(f"num_frames_per_chunk must be >= 1, got {value}") + return value + + +def _apply_deterministic_runtime_settings(enabled: bool) -> None: + if not enabled: + return + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + torch.backends.cuda.enable_flash_sdp(False) + torch.backends.cuda.enable_mem_efficient_sdp(False) + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + torch.set_float32_matmul_precision("medium") + torch.use_deterministic_algorithms(True, warn_only=False) + + + +def _compute_min_buffer_frames(wrapper, num_frames_per_chunk: int) -> int: + att_context_size = wrapper.model.stt_model.perception.encoder._cfg.att_context_size + if wrapper.use_perception_cache: + return num_frames_per_chunk * (att_context_size[1] + 1) + 2 + return att_context_size[0] + att_context_size[1] + 1 + + +def _compute_min_buffer_frames_from_cfg(cfg, num_frames_per_chunk: int) -> int: + att_context_size = cfg.streaming.get("att_context_size", [70, 0]) + if cfg.s2s.get("use_perception_cache", False): + return num_frames_per_chunk * (att_context_size[1] + 1) + 2 + return att_context_size[0] + att_context_size[1] + 1 + + +def _first_diff(a: torch.Tensor, b: torch.Tensor) -> int | None: + a = a.detach().cpu() + b = b.detach().cpu() + if a.shape != b.shape: + return 0 + diff = (a != b).flatten() + if not diff.any(): + return None + return int(diff.nonzero(as_tuple=False)[0].item()) + + +def _prefix_compare(a: torch.Tensor, b: torch.Tensor) -> tuple[int | None, bool | None, int | None]: + if not isinstance(a, torch.Tensor) or not isinstance(b, torch.Tensor): + return None, None, None + a = a.detach().cpu() + b = b.detach().cpu() + if a.dim() != b.dim(): + return None, None, None + + prefix_len = min(a.shape[-1], b.shape[-1]) + if prefix_len == 0: + return 0, True, None + + a_prefix = a[..., :prefix_len] + b_prefix = b[..., :prefix_len] + if torch.equal(a_prefix, b_prefix): + return prefix_len, True, None + + diff = (a_prefix != b_prefix).flatten() + first_diff = int(diff.nonzero(as_tuple=False)[0].item()) + return prefix_len, False, first_diff + + +def _prefix_tensor_diff(a: torch.Tensor | None, b: torch.Tensor | None) -> dict[str, Any] | None: + if not isinstance(a, torch.Tensor) or not isinstance(b, torch.Tensor): + return None + a = a.detach().cpu() + b = b.detach().cpu() + if a.dim() != b.dim(): + return None + prefix_len = min(a.shape[1], b.shape[1]) if a.dim() >= 2 else min(a.shape[0], b.shape[0]) + if prefix_len <= 0: + return {"prefix_len": 0, "match": True, "max_abs_diff": 0.0, "mean_abs_diff": 0.0} + if a.dim() == 2: + a_prefix = a[:, :prefix_len] + b_prefix = b[:, :prefix_len] + else: + a_prefix = a[:, :prefix_len, ...] + b_prefix = b[:, :prefix_len, ...] + diff = (a_prefix - b_prefix).abs() + reduce_dims = tuple(i for i in range(diff.dim()) if i != 1) + if reduce_dims: + per_step_max = diff.amax(dim=reduce_dims) + else: + per_step_max = diff + first_step_diff_index = None + differing_steps = (per_step_max > 0).nonzero(as_tuple=False) + if differing_steps.numel() > 0: + first_step_diff_index = int(differing_steps[0].item()) + return { + "prefix_len": prefix_len, + "match": bool(torch.equal(a_prefix, b_prefix)), + "max_abs_diff": float(diff.max().item()), + "mean_abs_diff": float(diff.mean().item()), + "first_step_diff_index": first_step_diff_index, + } + + +def _dtype_name(value) -> str | None: + if value is None: + return None + if isinstance(value, torch.Tensor): + return str(value.dtype) + if isinstance(value, torch.dtype): + return str(value) + return str(value) + + +def _module_param_dtype(module) -> str | None: + if module is None: + return None + try: + return str(next(module.parameters()).dtype) + except StopIteration: + return None + except Exception: + return None + + +def _collect_model_dtypes(wrapper: NemotronVoicechatInferenceWrapper) -> dict[str, Any]: + stt_model = wrapper.model.stt_model + return { + "wrapper_dtype": _dtype_name(wrapper.dtype), + "llm_dtype": _module_param_dtype(getattr(stt_model, "llm", None)), + "lm_head_dtype": _module_param_dtype(getattr(stt_model, "lm_head", None)), + "asr_head_dtype": _module_param_dtype(getattr(stt_model, "asr_head", None)), + "embed_tokens_dtype": _module_param_dtype(getattr(stt_model, "embed_tokens", None)), + "embed_asr_tokens_dtype": _module_param_dtype(getattr(stt_model, "embed_asr_tokens", None)), + "perception_dtype": _module_param_dtype(getattr(stt_model, "perception", None)), + "tts_dtype": _module_param_dtype(getattr(wrapper.model, "tts_model", None)), + } + + +def _tensor_summary_diff(a: torch.Tensor | None, b: torch.Tensor | None) -> dict[str, Any] | None: + if not isinstance(a, torch.Tensor) or not isinstance(b, torch.Tensor): + return None + if a.shape != b.shape: + return {"shape_a": list(a.shape), "shape_b": list(b.shape), "match": False} + diff = (a - b).abs() + return { + "shape": list(a.shape), + "match": bool(torch.equal(a, b)), + "max_abs_diff": float(diff.max().item()), + "mean_abs_diff": float(diff.mean().item()), + } + + +def _step_component_diagnostics( + wrapper: NemotronVoicechatInferenceWrapper, + offline_debug: dict[str, Any], + incremental_debug: dict[str, Any], +) -> dict[str, Any] | None: + input_embed_diff = _prefix_tensor_diff(offline_debug.get("input_embeds"), incremental_debug.get("input_embeds")) + if input_embed_diff is None: + return {"status": "missing_input_embed_diff"} + step_idx = input_embed_diff.get("first_step_diff_index") + if step_idx is None: + return {"status": "no_input_embed_drift"} + if step_idx == 0: + return {"first_step_diff_index": 0, "note": "Drift starts at step 0; component breakdown not specialized."} + + offline_source = offline_debug.get("source_encoded") + incremental_source = incremental_debug.get("source_encoded") + selected_indices = incremental_debug.get("selected_frame_indices") or [] + offline_tokens = offline_debug.get("gen_text") + offline_asr = offline_debug.get("gen_asr") + incremental_tokens = incremental_debug.get("gen_text") + incremental_asr = incremental_debug.get("gen_asr") + required = { + "offline_source_encoded": offline_source, + "incremental_source_encoded": incremental_source, + "offline_gen_text": offline_tokens, + "offline_gen_asr": offline_asr, + "incremental_gen_text": incremental_tokens, + "incremental_gen_asr": incremental_asr, + } + missing = [name for name, value in required.items() if not isinstance(value, torch.Tensor)] + if missing: + return { + "status": "missing_tensors", + "first_step_diff_index": step_idx, + "missing": missing, + } + if step_idx >= len(selected_indices): + return { + "status": "selected_index_out_of_range", + "first_step_diff_index": step_idx, + "selected_frame_indices_len": len(selected_indices), + } + + stt_model = wrapper.model.stt_model + source_frame_offline = offline_source[:, step_idx : step_idx + 1, :] + source_frame_incremental = incremental_source[:, selected_indices[step_idx] : selected_indices[step_idx] + 1, :] + prev_offline_text = offline_tokens[:, step_idx - 1] + prev_incremental_text = incremental_tokens[:, step_idx - 1] + prev_offline_asr = offline_asr[:, step_idx - 1] + prev_incremental_asr = incremental_asr[:, step_idx - 1] + + offline_text_emb = stt_model.embed_tokens(prev_offline_text.to(wrapper.device)).detach().cpu() + incremental_text_emb = stt_model.embed_tokens(prev_incremental_text.to(wrapper.device)).detach().cpu() + offline_asr_emb = stt_model.embed_asr_tokens(prev_offline_asr.to(wrapper.device)).detach().cpu() + incremental_asr_emb = stt_model.embed_asr_tokens(prev_incremental_asr.to(wrapper.device)).detach().cpu() + + text_weight = stt_model.cfg.get("duplex_text_channel_weight", 1.0) + asr_weight = stt_model.cfg.get("duplex_asr_text_weight", 1.0) + offline_last_emb = offline_text_emb * text_weight + offline_asr_emb * asr_weight + incremental_last_emb = incremental_text_emb * text_weight + incremental_asr_emb * asr_weight + + offline_input = offline_debug["input_embeds"][:, step_idx : step_idx + 1, :] + incremental_input = incremental_debug["input_embeds"][:, step_idx : step_idx + 1, :] + + offline_style = source_frame_incremental.detach().cpu().clone() + offline_style += incremental_last_emb.unsqueeze(1) + incremental_style = source_frame_incremental.detach().cpu().clone() + incremental_style += (incremental_text_emb * text_weight).unsqueeze(1) + incremental_style += (incremental_asr_emb * asr_weight).unsqueeze(1) + + return { + "status": "ok", + "first_step_diff_index": step_idx, + "selected_frame_index": selected_indices[step_idx], + "prev_text_token_equal": bool(torch.equal(prev_offline_text.cpu(), prev_incremental_text.cpu())), + "prev_asr_token_equal": bool(torch.equal(prev_offline_asr.cpu(), prev_incremental_asr.cpu())), + "source_frame_diff": _tensor_summary_diff(source_frame_offline.cpu(), source_frame_incremental.cpu()), + "text_embedding_diff": _tensor_summary_diff(offline_text_emb, incremental_text_emb), + "asr_embedding_diff": _tensor_summary_diff(offline_asr_emb, incremental_asr_emb), + "last_emb_diff": _tensor_summary_diff(offline_last_emb, incremental_last_emb), + "offline_input_vs_incremental_input": _tensor_summary_diff(offline_input.cpu(), incremental_input.cpu()), + "offline_input_vs_offline_style_rebuild": _tensor_summary_diff(offline_input.cpu(), offline_style), + "incremental_input_vs_offline_style_rebuild": _tensor_summary_diff(incremental_input.cpu(), offline_style), + "offline_input_vs_incremental_style_rebuild": _tensor_summary_diff(offline_input.cpu(), incremental_style), + "incremental_input_vs_incremental_style_rebuild": _tensor_summary_diff(incremental_input.cpu(), incremental_style), + } + + +def _compare_debug_outputs(offline_debug: dict[str, Any] | None, incremental_debug: dict[str, Any] | None) -> dict[str, Any] | None: + if offline_debug is None or incremental_debug is None: + return None + + offline_encoder = offline_debug.get("source_encoded") + incremental_encoder = incremental_debug.get("source_encoded") + selected_indices = incremental_debug.get("selected_frame_indices") or [] + selected_incremental = None + selected_prefix = None + if isinstance(incremental_encoder, torch.Tensor) and selected_indices: + selected_incremental = incremental_encoder[:, selected_indices, :] + if isinstance(offline_encoder, torch.Tensor) and selected_incremental is not None: + prefix_len = min(offline_encoder.shape[1], selected_incremental.shape[1]) + selected_prefix = _prefix_tensor_diff( + offline_encoder[:, :prefix_len, :], + selected_incremental[:, :prefix_len, :], + ) + + report = { + "offline_tensor_dtypes": { + "source_encoded": _dtype_name(offline_debug.get("source_encoded")), + "input_embeds": _dtype_name(offline_debug.get("input_embeds")), + "text_logits": _dtype_name(offline_debug.get("text_logits")), + "asr_logits": _dtype_name(offline_debug.get("asr_logits")), + }, + "incremental_tensor_dtypes": { + "source_encoded": _dtype_name(incremental_debug.get("source_encoded")), + "input_embeds": _dtype_name(incremental_debug.get("input_embeds")), + "text_logits": _dtype_name(incremental_debug.get("text_logits")), + "asr_logits": _dtype_name(incremental_debug.get("asr_logits")), + }, + "offline_source_encoded_shape": list(offline_encoder.shape) if isinstance(offline_encoder, torch.Tensor) else None, + "incremental_source_encoded_shape": list(incremental_encoder.shape) if isinstance(incremental_encoder, torch.Tensor) else None, + "incremental_selected_frame_indices": selected_indices, + "selected_encoder_prefix": selected_prefix, + "offline_input_embeds_shape": list(offline_debug["input_embeds"].shape) + if isinstance(offline_debug.get("input_embeds"), torch.Tensor) + else None, + "incremental_input_embeds_shape": list(incremental_debug["input_embeds"].shape) + if isinstance(incremental_debug.get("input_embeds"), torch.Tensor) + else None, + "input_embeds_prefix": _prefix_tensor_diff(offline_debug.get("input_embeds"), incremental_debug.get("input_embeds")), + "offline_text_logits_shape": list(offline_debug["text_logits"].shape) + if isinstance(offline_debug.get("text_logits"), torch.Tensor) + else None, + "incremental_text_logits_shape": list(incremental_debug["text_logits"].shape) + if isinstance(incremental_debug.get("text_logits"), torch.Tensor) + else None, + "text_logits_prefix": _prefix_tensor_diff(offline_debug.get("text_logits"), incremental_debug.get("text_logits")), + "offline_asr_logits_shape": list(offline_debug["asr_logits"].shape) + if isinstance(offline_debug.get("asr_logits"), torch.Tensor) + else None, + "incremental_asr_logits_shape": list(incremental_debug["asr_logits"].shape) + if isinstance(incremental_debug.get("asr_logits"), torch.Tensor) + else None, + "asr_logits_prefix": _prefix_tensor_diff(offline_debug.get("asr_logits"), incremental_debug.get("asr_logits")), + } + report["step_component_diagnostics"] = None + return report + + +def _compare_outputs(offline: dict[str, Any], incremental: dict[str, Any]) -> dict[str, Any]: + offline_tokens = offline.get("tokens_text") + incremental_tokens = incremental.get("tokens_text") + offline_asr = offline.get("tokens_text_src") + incremental_asr = incremental.get("asr_tokens") + token_prefix_len, token_prefix_match, token_prefix_first_diff = _prefix_compare(offline_tokens, incremental_tokens) + asr_prefix_len, asr_prefix_match, asr_prefix_first_diff = _prefix_compare(offline_asr, incremental_asr) + + token_match = ( + isinstance(offline_tokens, torch.Tensor) + and isinstance(incremental_tokens, torch.Tensor) + and offline_tokens.shape == incremental_tokens.shape + and torch.equal(offline_tokens.detach().cpu(), incremental_tokens.detach().cpu()) + ) + asr_token_match = None + if isinstance(offline_asr, torch.Tensor) and isinstance(incremental_asr, torch.Tensor): + asr_token_match = offline_asr.shape == incremental_asr.shape and torch.equal( + offline_asr.detach().cpu(), incremental_asr.detach().cpu() + ) + + offline_audio_len = offline.get("audio_len") + incremental_audio = incremental.get("audio") + audio_sample_count_equal = None + if offline_audio_len is not None and incremental_audio is not None: + expected = int(offline_audio_len[0].item()) + got = int(incremental_audio.shape[-1]) + audio_sample_count_equal = expected == got + + report = { + "offline_text": offline.get("text", [""])[0], + "incremental_text": incremental.get("text", [""])[0], + "text_equal": offline.get("text", [""])[0] == incremental.get("text", [""])[0], + "offline_asr_text": (offline.get("src_text") or [""])[0] if offline.get("src_text") is not None else None, + "incremental_asr_text": (incremental.get("asr_text") or [""])[0] if incremental.get("asr_text") is not None else None, + "asr_text_equal": ( + offline.get("src_text") is not None + and incremental.get("asr_text") is not None + and offline["src_text"][0] == incremental["asr_text"][0] + ), + "offline_token_shape": list(offline_tokens.shape) if isinstance(offline_tokens, torch.Tensor) else None, + "incremental_token_shape": list(incremental_tokens.shape) if isinstance(incremental_tokens, torch.Tensor) else None, + "token_match": token_match, + "token_first_diff_index": _first_diff(offline_tokens, incremental_tokens) + if isinstance(offline_tokens, torch.Tensor) and isinstance(incremental_tokens, torch.Tensor) + else None, + "token_prefix_len": token_prefix_len, + "token_prefix_match": token_prefix_match, + "token_prefix_first_diff_index": token_prefix_first_diff, + "offline_asr_token_shape": list(offline_asr.shape) if isinstance(offline_asr, torch.Tensor) else None, + "incremental_asr_token_shape": list(incremental_asr.shape) if isinstance(incremental_asr, torch.Tensor) else None, + "asr_token_match": asr_token_match, + "asr_token_first_diff_index": _first_diff(offline_asr, incremental_asr) + if isinstance(offline_asr, torch.Tensor) and isinstance(incremental_asr, torch.Tensor) + else None, + "asr_token_prefix_len": asr_prefix_len, + "asr_token_prefix_match": asr_prefix_match, + "asr_token_prefix_first_diff_index": asr_prefix_first_diff, + "audio_sample_count_equal": audio_sample_count_equal, + } + return report + + +def _merge_incremental_debug_steps(steps: list[dict[str, Any]]) -> dict[str, Any]: + """Merge per-step debug dicts from the pipeline into a single dict matching offline debug format.""" + if not steps: + return {} + all_source_encoded = [s["source_encoded"] for s in steps if s.get("source_encoded") is not None] + all_input_embeds = [s["input_embeds"] for s in steps if s.get("input_embeds") is not None] + all_text_logits = [s["text_logits"] for s in steps if s.get("text_logits") is not None] + all_asr_logits = [s["asr_logits"] for s in steps if s.get("asr_logits") is not None] + all_gen_text = [s["gen_text"] for s in steps if s.get("gen_text") is not None] + all_gen_asr = [s["gen_asr"] for s in steps if s.get("gen_asr") is not None] + selected_frame_indices = [] + for s in steps: + selected_frame_indices.extend(s.get("selected_frame_indices", [])) + return { + "source_encoded": all_source_encoded[-1] if all_source_encoded else None, + "input_embeds": torch.cat(all_input_embeds, dim=1) if all_input_embeds else None, + "gen_text": all_gen_text[-1] if all_gen_text else None, + "gen_asr": all_gen_asr[-1] if all_gen_asr else None, + "text_logits": torch.cat(all_text_logits, dim=1) if all_text_logits else None, + "asr_logits": torch.cat(all_asr_logits, dim=1) if all_asr_logits else None, + "selected_frame_indices": selected_frame_indices, + } + + +def _collect_offline_debug( + wrapper: NemotronVoicechatInferenceWrapper, + audio: torch.Tensor, + audio_lens: torch.Tensor, + prompt_tokens: torch.Tensor | None, + prompt_token_lens: torch.Tensor | None, +) -> dict[str, Any]: + buffer_len = audio_lens.to(device=wrapper.device, dtype=torch.long) + source_encoded, _, _ = wrapper.model.stt_model.perception( + input_signal=audio, + input_signal_length=buffer_len, + return_encoder_emb=True, + ) + source_encoded = source_encoded.to(wrapper.dtype) + + inference_state = wrapper.model.stt_model.streaming_inference._init_inference( + audio, + audio_lens, + 0, + prompt_tokens, + prompt_token_lens, + ) + ans, inference_state = wrapper.model.stt_model.streaming_inference._step_zero(inference_state) + text_logits = [ans["text_logits"][:, -1].detach().cpu()] + asr_logits = [ans["asr_logits"][:, -1].detach().cpu()] if "asr_logits" in ans else [] + T = inference_state["T"] + for t in range(1, T): + ans = wrapper.model.stt_model.streaming_inference._step_inference(t, inference_state, ans) + text_logits.append(ans["text_logits"][:, -1].detach().cpu()) + if "asr_logits" in ans: + asr_logits.append(ans["asr_logits"][:, -1].detach().cpu()) + + return { + "source_encoded": source_encoded.detach().cpu(), + "input_embeds": inference_state["input_embeds"].detach().cpu(), + "gen_text": inference_state["gen_text"].detach().cpu(), + "gen_asr": inference_state["gen_asr"].detach().cpu() if inference_state.get("gen_asr") is not None else None, + "text_logits": torch.stack(text_logits, dim=1), + "asr_logits": torch.stack(asr_logits, dim=1) if asr_logits else None, + } + + +def run_parity_harness(args) -> dict[str, Any]: + inference_cfg = _load_s2s_inference_config(args.config_path) + + if args.strict_runtime_parity and args.tts_system_prompt: + raise ValueError( + "Strict offline/incremental parity does not currently support `tts_system_prompt`, " + "because offline_inference has no equivalent string prompt API for TTS conditioning." + ) + + overrides = { + "s2s.model_path": args.model_path, + "s2s.llm_checkpoint_path": args.llm_checkpoint_path, + "s2s.speaker_reference": args.speaker_reference, + "s2s.speaker_name": args.speaker_name, + "s2s.compute_dtype": args.compute_dtype, + "s2s.decode_audio": args.decode_audio, + "s2s.system_prompt": args.system_prompt, + "s2s.tts_system_prompt": args.tts_system_prompt, + "s2s.engine_type": args.engine_type, + "s2s.use_perception_cache": args.use_perception_cache, + "s2s.use_perception_cudagraph": args.use_perception_cudagraph, + "s2s.use_llm_cache": args.use_llm_cache, + "s2s.use_codec_cache": args.use_codec_cache, + "s2s.deterministic": args.deterministic, + "s2s.top_p": args.top_p, + "s2s.repetition_penalty": args.repetition_penalty, + "s2s.temperature": args.temperature, + } + + if args.strict_runtime_parity: + strict_defaults = { + "s2s.engine_type": "native", + "s2s.compute_dtype": "float32", + "s2s.use_perception_cache": False, + "s2s.use_perception_cudagraph": False, + "s2s.use_llm_cache": False, + "s2s.use_codec_cache": False, + "s2s.deterministic": True, + "s2s.top_p": 1.0, + "s2s.repetition_penalty": 1.0, + "s2s.temperature": 0.0, + } + for key, value in strict_defaults.items(): + overrides[key] = value if overrides.get(key) is None else overrides[key] + + inference_cfg = _apply_inference_overrides(inference_cfg, overrides) + _apply_deterministic_runtime_settings(bool(inference_cfg.s2s.get("deterministic", False))) + + input_sample_rate = int(inference_cfg.streaming.get("input_sample_rate", 16000)) + audio_np, _ = librosa.load(args.audio_path, sr=input_sample_rate) + total_samples = len(audio_np) + total_frames = int(math.ceil(total_samples / FRAME_SIZE_SAMPLES)) + num_frames_per_chunk = _resolve_num_frames_per_chunk(args, total_frames) + chunk_size_in_secs = num_frames_per_chunk * (FRAME_SIZE_SAMPLES / float(input_sample_rate)) + buffer_size_frames = max(num_frames_per_chunk, _compute_min_buffer_frames_from_cfg(inference_cfg, num_frames_per_chunk)) + + with tempfile.TemporaryDirectory(prefix="voicechat-parity-") as tmpdir: + inference_cfg = _apply_inference_overrides( + inference_cfg, + { + "output_dir": tmpdir, + "streaming.chunk_size_in_secs": chunk_size_in_secs, + "streaming.buffer_size_in_secs": buffer_size_frames * (FRAME_SIZE_SAMPLES / float(input_sample_rate)), + }, + ) + pipeline = S2SPipelineBuilder.build_pipeline(inference_cfg) + do_collect_debug = args.collect_debug if args.collect_debug is not None else bool(args.strict_runtime_parity) + pipeline.collect_debug = do_collect_debug + wrapper = pipeline.s2s_model + + audio, audio_lens = _load_audio_tensor( + args.audio_path, + sample_rate=wrapper.model.source_sample_rate, + device=wrapper.device, + dtype=wrapper.dtype, + ) + + prompt_tokens = None + prompt_token_lens = None + if inference_cfg.s2s.get("system_prompt"): + prompt_token_ids = _build_prompt_token_ids(wrapper.tokenizer, inference_cfg.s2s.system_prompt) + prompt_tokens = torch.tensor(prompt_token_ids, device=wrapper.device, dtype=torch.long).unsqueeze(0) + prompt_token_lens = torch.tensor([len(prompt_token_ids)], device=wrapper.device, dtype=torch.long) + + offline = wrapper.model.offline_inference( + input_signal=audio, + input_signal_lens=audio_lens, + prompt_tokens=prompt_tokens, + prompt_token_lens=prompt_token_lens, + decode_audio=bool(inference_cfg.s2s.get("decode_audio", True)), + ) + offline_debug = _collect_offline_debug( + wrapper, + audio=audio, + audio_lens=audio_lens, + prompt_tokens=prompt_tokens, + prompt_token_lens=prompt_token_lens, + ) + pipeline_output = pipeline.run( + [args.audio_path], + options=[S2SRequestOptions(system_prompt=inference_cfg.s2s.get("system_prompt"))], + ) + + incremental_audio = None + incremental_audio_path = None + audio_sample_count_equal = None + if getattr(pipeline_output, "audio_filepaths", None): + incremental_audio_path = pipeline_output.audio_filepaths[0] + if incremental_audio_path: + incremental_audio, _ = sf.read(incremental_audio_path) + incremental_audio = torch.tensor(incremental_audio).reshape(1, -1) + incremental = { + "text": [pipeline_output.texts_with_timestamps[0] if pipeline_output.texts_with_timestamps else pipeline_output.texts[0]], + "asr_text": [pipeline_output.asr_texts_with_timestamps[0] if pipeline_output.asr_texts_with_timestamps else pipeline_output.asr_texts[0]], + "tokens_text": pipeline_output.token_texts[0] if pipeline_output.token_texts else None, + "asr_tokens": pipeline_output.token_asr_texts[0] if pipeline_output.token_asr_texts else None, + "audio": incremental_audio, + } + + incremental_debug = None + if pipeline_output.debug_data and pipeline_output.debug_data[0]: + incremental_debug = _merge_incremental_debug_steps(pipeline_output.debug_data[0]) + + debug_comparison = _compare_debug_outputs(offline_debug, incremental_debug) if incremental_debug else None + + report = { + "audio_path": args.audio_path, + "total_samples": int(total_samples), + "total_frames": total_frames, + "num_frames_per_chunk": num_frames_per_chunk, + "buffer_size_frames": buffer_size_frames, + "strict_runtime_parity": bool(args.strict_runtime_parity), + "engine_type": inference_cfg.s2s.get("engine_type"), + "use_perception_cache": bool(inference_cfg.s2s.get("use_perception_cache", False)), + "use_llm_cache": bool(inference_cfg.s2s.get("use_llm_cache", False)), + "use_codec_cache": bool(inference_cfg.s2s.get("use_codec_cache", False)), + "deterministic": bool(inference_cfg.s2s.get("deterministic", False)), + "model_dtypes": _collect_model_dtypes(wrapper), + "comparison": _compare_outputs(offline, incremental), + "debug_comparison": debug_comparison, + "debug": { + "incremental_mode": "pipeline", + "incremental_audio_filepath": incremental_audio_path, + "offline_debug": offline_debug, + "incremental_debug": incremental_debug, + }, + } + + if args.output_json: + output_path = Path(args.output_json) + output_path.parent.mkdir(parents=True, exist_ok=True) + serializable_report = {k: v for k, v in report.items() if k != "debug"} + with output_path.open("w", encoding="utf-8") as f: + json.dump(serializable_report, f, indent=2, ensure_ascii=False) + + if args.strict_runtime_parity: + comparison = report["comparison"] + failed = [] + if not comparison["text_equal"]: + failed.append("text") + if comparison["token_match"] is False: + failed.append("tokens") + if comparison["asr_token_match"] is False: + failed.append("asr_tokens") + if comparison["audio_sample_count_equal"] is False: + failed.append("audio_length") + if failed: + raise AssertionError( + "Offline/incremental parity failed for: " + + ", ".join(failed) + + f". Report: {json.dumps({k: v for k, v in report.items() if k != 'debug'}, ensure_ascii=False)}" + ) + + return report + + +def build_argparser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description="Compare Nemotron VoiceChat offline inference against incremental decoding with one full-audio chunk." + ) + parser.add_argument("--model_path", type=str, required=True, help="Path to S2S/TTS checkpoint directory.") + parser.add_argument("--llm_checkpoint_path", type=str, required=True, help="Path to LLM/perception checkpoint directory.") + parser.add_argument("--audio_path", type=str, required=True, help="Audio file to compare across both paths.") + parser.add_argument("--speaker_reference", type=str, default=None, help="Speaker reference audio path.") + parser.add_argument("--speaker_name", type=str, default=None, help="Registered speaker name.") + parser.add_argument("--config_path", type=str, default=None, help="Optional path to s2s_streaming.yaml.") + parser.add_argument("--system_prompt", type=str, default=None, help="Optional system prompt.") + parser.add_argument("--tts_system_prompt", type=str, default=None, help="Optional TTS system prompt.") + parser.add_argument( + "--num_frames_per_chunk", + type=int, + default=None, + help="Override incremental chunk size in 80ms frames. If unset, defaults to full audio length.", + ) + parser.add_argument( + "--chunk_size_in_secs", + type=float, + default=None, + help="Override incremental chunk size in seconds. If set, converted to 80ms frames. If unset, defaults to full audio length.", + ) + parser.add_argument("--engine_type", type=str, default=None, help="Override engine type.") + parser.add_argument("--compute_dtype", type=str, default=None, help="Override compute dtype (for example: float32, bfloat16).") + _bool_arg(parser, "--decode_audio", "Whether to decode waveform outputs.") + _bool_arg(parser, "--use_perception_cache", "Override perception cache usage.") + _bool_arg(parser, "--use_perception_cudagraph", "Override perception CUDA-graph usage.") + _bool_arg(parser, "--use_llm_cache", "Override LLM cache usage.") + _bool_arg(parser, "--use_codec_cache", "Override codec cache usage.") + _bool_arg(parser, "--deterministic", "Override deterministic mode.") + parser.add_argument("--top_p", type=float, default=None, help="Override top-p.") + parser.add_argument("--repetition_penalty", type=float, default=None, help="Override repetition penalty.") + parser.add_argument("--temperature", type=float, default=None, help="Override temperature.") + parser.add_argument("--output_json", type=str, default=None, help="Optional JSON report path.") + parser.add_argument( + "--strict_runtime_parity", + action=argparse.BooleanOptionalAction, + default=True, + help=( + "When enabled, force a strict native/deterministic parity profile and raise if text/token/audio-length " + "comparisons differ." + ), + ) + _bool_arg( + parser, + "--collect_debug", + "Collect per-step encoder outputs and logits for comparison. " + "Stores tensors on CPU each step; disable for long audio to avoid OOM.", + ) + return parser + + +def main() -> int: + args = build_argparser().parse_args() + report = run_parity_harness(args) + printable = {k: v for k, v in report.items() if k != "debug"} + print(json.dumps(printable, indent=2, ensure_ascii=False)) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/nemo/collections/speechlm2/inference/model_wrappers/model_factory.py b/nemo/collections/speechlm2/inference/model_wrappers/model_factory.py index a9af52b984c9..cdf9d2ab613a 100644 --- a/nemo/collections/speechlm2/inference/model_wrappers/model_factory.py +++ b/nemo/collections/speechlm2/inference/model_wrappers/model_factory.py @@ -866,6 +866,7 @@ def __call__( cache_position: Optional[torch.Tensor] = None, generated_tokens: Optional[torch.Tensor] = None, current_step: int = 0, + return_logits: bool = False, **kwargs ) -> Dict[str, Any]: """ @@ -916,6 +917,11 @@ def __call__( "asr_predicted_token": asr_predicted_token, "cache": result.get("cache", None), } + if return_logits: + ans["text_logits"] = result["text_logits"] + ans["asr_logits"] = result.get("asr_logits") + if "function_logits" in result: + ans["function_logits"] = result["function_logits"] if "function_logits" in result: ans["function_predicted_token"] = result["function_logits"][:, -1].argmax(dim=-1) return ans diff --git a/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py b/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py index 632b1b11103f..7de38b990586 100755 --- a/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py +++ b/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py @@ -13,27 +13,22 @@ # limitations under the License. import torch -import yaml from omegaconf import OmegaConf, DictConfig -import numpy as np -import librosa import time -from transformers import DynamicCache import re import os import sys -import argparse -import math import torchaudio import functools from dataclasses import dataclass from typing import Optional, Tuple from nemo.utils import logging -from jiwer import wer import gc import types +from transformers import DynamicCache + # Set environment variables (use existing env vars if set, otherwise use defaults) _default_cache = "/tmp/cache" @@ -47,6 +42,7 @@ from nemo.collections.speechlm2.parts.text_utils import tokens_to_str from nemo.collections.speechlm2.parts.precision import fp32_precision from nemo.collections.audio.parts.utils.transforms import resample +from nemo.collections.speechlm2.modules.ear_tts_vae_codec import CausalConv1dCache from nemo.collections.speechlm2.inference.model_wrappers.model_factory import create_model from nemo.collections.speechlm2.inference.model_wrappers.perception_cache import ( PerceptionCacheState, @@ -118,7 +114,7 @@ def __init__(self, model_cfg: DictConfig): Initialize the model for realtime streaming inference. Args: - model_cfg (DictConfig): Configuration describing the model paths and runtime parameters. + model_cfg (DictConfig): Configuration describing the model paths and inference parameters. """ if model_cfg is None: raise ValueError("model_cfg must be provided") @@ -678,66 +674,32 @@ def _prepare_system_prompt_embeddings( self, system_prompt: str, ) -> Tuple[Optional[torch.Tensor], int]: - """ - Prepare system prompt embeddings consistent with offline_inference. - - In offline_inference, prompt embeddings are structured as: - - Position 0: prompt_token_emb + bos_emb + asr_bos - - Position t > 0: prompt_token_emb + pad_emb + pad_asr - - Args: - system_prompt: The system prompt text - - Returns: - Tuple of (prompt_embedded [1, prompt_len, H], prompt_length) - Returns (None, 0) if system_prompt is empty - """ - if not system_prompt or not system_prompt.strip(): return None, 0 - logging.info(f"Preparing system prompt: {system_prompt[:100]}...") - - # Step 1: Tokenize the prompt - # Format: [bos] + text_tokens + [eos] (consistent with collate_system_prompt) - prompt_token_ids = ( - [self.tokenizer.bos_id] + - self.tokenizer.text_to_ids(system_prompt) + - [self.tokenizer.eos_id] - ) - prompt_tokens = torch.tensor(prompt_token_ids, dtype=torch.long, device=self.device).unsqueeze(0) # [1, prompt_len] + prompt_token_ids = self._build_prompt_token_ids(system_prompt) + prompt_tokens = torch.tensor(prompt_token_ids, dtype=torch.long, device=self.device).unsqueeze(0) + prompt_embedded = self.model.stt_model.embed_tokens(prompt_tokens).to(dtype=self.dtype) prompt_len = prompt_tokens.shape[1] - logging.info(f" Prompt length: {prompt_len} tokens") - - # Step 2: Embed the prompt tokens (this acts as the "audio channel" for prompt positions) - prompt_embedded = self.model.stt_model.embed_tokens(prompt_tokens) # [1, prompt_len, H] - prompt_embedded = prompt_embedded.to(dtype=self.dtype) - - # Step 3: Add pad embeddings for text and ASR channels (for positions t > 0) - # In offline_inference, prompt positions use gen_text[:, t-1] = pad_id pad_id = self.model.stt_model.text_pad_id pad_token = torch.full((1,), fill_value=pad_id, device=self.device, dtype=torch.long) - pad_emb = self.model.stt_model.embed_tokens(pad_token).to(dtype=self.dtype) # [1, H] - pad_asr_emb = self.model.stt_model.embed_asr_tokens(pad_token).to(dtype=self.dtype) # [1, H] + pad_emb = self.model.stt_model.embed_tokens(pad_token).to(dtype=self.dtype) + pad_asr_emb = self.model.stt_model.embed_asr_tokens(pad_token).to(dtype=self.dtype) - # For positions t > 0, add pad embeddings (simulating gen_text[:, t-1] = pad_id) has_fc = self.model.stt_model.function_head is not None if prompt_len > 1: prompt_embedded[:, 1:, :] += pad_emb prompt_embedded[:, 1:, :] += pad_asr_emb if has_fc: - prompt_embedded[:, 1:, :] += pad_emb # FC channel also uses pad at t > 0 + prompt_embedded[:, 1:, :] += pad_emb - # Step 4: For position 0, add BOS embeddings - bos_emb = self._get_bos_embedding() # [1, H] - asr_bos_emb = self._get_asr_bos_embedding() # [1, H] + bos_emb = self._get_bos_embedding() + asr_bos_emb = self._get_asr_bos_embedding() prompt_embedded[:, 0, :] += bos_emb.squeeze(0) prompt_embedded[:, 0, :] += asr_bos_emb.squeeze(0) if has_fc: - prompt_embedded[:, 0, :] += pad_emb.squeeze(0) # FC channel uses pad at t=0 - - logging.info(f" System prompt embeddings prepared: shape {prompt_embedded.shape}") + prompt_embedded[:, 0, :] += pad_emb.squeeze(0) return prompt_embedded, prompt_len @@ -751,13 +713,50 @@ def _clone_cache(self, cache): return type(cache)(self._clone_cache(x) for x in cache) if isinstance(cache, dict): return {k: self._clone_cache(v) for k, v in cache.items()} - # Handle complex objects (e.g., DynamicCache with __dict__ attributes) - # Use deepcopy to ensure complete isolation between streams if hasattr(cache, '__dict__'): import copy return copy.deepcopy(cache) return cache + def _build_prompt_token_ids(self, system_prompt: str | None) -> list[int]: + if not system_prompt or not system_prompt.strip(): + return [] + return [self.tokenizer.bos_id] + self.tokenizer.text_to_ids(system_prompt) + [self.tokenizer.eos_id] + + def _create_generation_workspace(self, max_len: int): + stt_model = self.model.stt_model + gen_text = torch.full((1, max_len), stt_model.text_pad_id, device=self.device, dtype=torch.long) + gen_asr_text = torch.full((1, max_len), stt_model.text_pad_id, device=self.device, dtype=torch.long) + gen_function_text = None + if getattr(stt_model, "function_head", None) is not None: + gen_function_text = torch.full((1, max_len), stt_model.text_pad_id, device=self.device, dtype=torch.long) + return gen_text, gen_asr_text, gen_function_text + + def _create_llm_cache(self): + if not self.use_llm_cache or self.use_vllm_llm: + return None + pretrained_llm = str(self.model.stt_model.cfg.get("pretrained_llm", "")) + if "Nemotron" in pretrained_llm: + return self.model.stt_model._create_nemotron_cache(batch_size=1) + return DynamicCache() + + def _create_codec_state(self, max_len: int): + if not self.decode_audio or not hasattr(self.model, "tts_model"): + return None, None, None + + audio_toks_buffer = None + codec_cache = None + if self.use_codec_cache: + codec_cache = CausalConv1dCache() + elif self.codec_token_history_size > 0: + silence_tokens = self.model.tts_model.codec_silence_tokens.detach().clone() + audio_toks_buffer = silence_tokens.view(1, 1, -1).expand( + 1, self.codec_token_history_size, -1 + ).contiguous().to(self.device) + + subword_mask = torch.ones((1, max_len), device=self.device, dtype=torch.bool) + return audio_toks_buffer, subword_mask, codec_cache + def _prepare_tts_initial_state(self): if not self.decode_audio: return @@ -803,9 +802,6 @@ def _prepare_tts_initial_state(self): outputs = self.model.tts_model.tts_model(**init_inputs) code = init_inputs["code"][:, -1:] - # code, _, _ = self.model.tts_model.tts_model.generate_step( - # outputs.hidden_states[:, -1:], **self.generation_config - # ) self.first_context_subword_id = init_inputs["subword_ids"][:, -1].unsqueeze(-1) self.first_tts_code_input = code.detach().clone() @@ -814,51 +810,35 @@ def _prepare_tts_initial_state(self): logging.info("TTS warmup state prepared") - def _update_audio_buffer(self, audio_buffer, buffer_fill_level, new_audio, buffer_size_samples): - """ - Append incoming samples to the sliding-window buffer and produce the view used for inference. - - Parameters: - audio_buffer (torch.Tensor): Tensor of shape `[1, buffer_size_samples]` holding the latest audio samples. - buffer_fill_level (int): Number of valid samples currently stored in `audio_buffer`. - new_audio (torch.Tensor): Incoming samples of shape `[1, slice_n_samples]` for the current step. - buffer_size_samples (int): Total capacity of the buffer in samples. - - Returns: - Tuple[torch.Tensor, int, torch.Tensor]: - - Updated `audio_buffer` containing the newest samples (always capped to `buffer_size_samples`). - - Updated `buffer_fill_level`, reflecting how many contiguous samples are valid. - - `current_buffer`, a view over the valid portion of the buffer used for the model input. - - Notes: - `audio_buffer` always retains the last `buffer_size_samples` samples even when overfilled, - whereas `current_buffer` may be shorter during the initial warm-up phase when the buffer - is not yet full. - """ - if new_audio.shape[1] == 0: - current_buffer = audio_buffer[:, :buffer_fill_level] - return audio_buffer, buffer_fill_level, current_buffer - - remaining = new_audio - - if buffer_fill_level < buffer_size_samples and remaining.shape[1] > 0: - warmup_take = min(buffer_size_samples - buffer_fill_level, remaining.shape[1]) - if warmup_take > 0: - audio_buffer[:, buffer_fill_level:buffer_fill_level + warmup_take] = remaining[:, :warmup_take] - buffer_fill_level += warmup_take - remaining = remaining[:, warmup_take:] - - if remaining.shape[1] > 0: - if remaining.shape[1] >= buffer_size_samples: - audio_buffer = remaining[:, -buffer_size_samples:] - else: - audio_buffer = torch.cat([ - audio_buffer[:, remaining.shape[1]:], - remaining - ], dim=1) - buffer_fill_level = buffer_size_samples - current_buffer = audio_buffer if buffer_fill_level == buffer_size_samples else audio_buffer[:, :buffer_fill_level] - return audio_buffer, buffer_fill_level, current_buffer + def create_decode_state(self, max_len: int): + gen_text, gen_asr_text, gen_function_text = self._create_generation_workspace(max_len) + llm_cache = self._create_llm_cache() + audio_toks_buffer, subword_mask, codec_cache = self._create_codec_state(max_len) + perception_cache = None + if self.use_perception_cache and self.perception_cache_mgr is not None: + perception_cache = self.perception_cache_mgr.get_initial_state(batch_size=1) + + past_key_values = None + code = None + if self.decode_audio and self.first_tts_code_input is not None: + past_key_values = self._clone_cache(self.first_tts_past_key_values_input) + code = self.first_tts_code_input.detach().clone() + + return { + "frame_idx": 0, + "gen_text": gen_text, + "gen_asr_text": gen_asr_text, + "gen_function_text": gen_function_text, + "audio_toks_buffer": audio_toks_buffer, + "input_embeds_history": [], + "dynamic_cache": llm_cache, + "past_key_values": past_key_values, + "code": code, + "subword_mask": subword_mask, + "perception_cache": perception_cache, + "codec_cache": codec_cache, + "cache_position_offset": 0, + } def infer_one_step(self, audio_input, @@ -877,7 +857,8 @@ def infer_one_step(self, perception_cache: Optional[PerceptionCacheState] = None, has_prompt: bool = False, codec_cache=None, - cache_position_offset: int = 0): + cache_position_offset: int = 0, + return_debug: bool = False): # Set up effective request ID for vLLM streaming effective_request_id = request_id or self.request_id @@ -889,6 +870,10 @@ def infer_one_step(self, predicted_tokens = torch.empty((batch_size, num_frames_per_chunk), dtype=gen_text.dtype, device=gen_text.device) asr_predicted_tokens = torch.empty((batch_size, num_frames_per_chunk), dtype=gen_text.dtype, device=gen_text.device) function_predicted_tokens = torch.empty((batch_size, num_frames_per_chunk), dtype=gen_text.dtype, device=gen_text.device) + debug_text_logits = [] + debug_asr_logits = [] + debug_input_embeds = [] + selected_frame_indices = [] # Do "perception" step outside the for-loop start_perception = time.time() @@ -942,16 +927,22 @@ def infer_one_step(self, current_frame_idx = frame_idx + frame_offset current_frame_index = base_frame_index + frame_offset current_frame_index = min(current_frame_index, total_encoded_frames - 1) + selected_frame_indices.append(current_frame_index) current_frame_embedding = source_encoded[:, current_frame_index:current_frame_index + 1, :] current_input_emb = current_frame_embedding.clone() + current_input_emb *= self.model.stt_model.cfg.get("duplex_user_channel_weight", 1.0) has_fc = gen_function_text is not None if current_frame_idx == 0 and not has_prompt: # Only add BOS if there's no prompt (BOS is already in prompt's position 0) - current_input_emb += self._get_bos_embedding() - current_input_emb += self._get_asr_bos_embedding() + current_input_emb += self._get_bos_embedding() * self.model.stt_model.cfg.get( + "duplex_text_channel_weight", 1.0 + ) + current_input_emb += self._get_asr_bos_embedding() * self.model.stt_model.cfg.get( + "duplex_asr_text_weight", 1.0 + ) if has_fc: pad_id = self.model.stt_model.text_pad_id fc_pad_token = torch.full((1,), fill_value=pad_id, device=self.device, dtype=torch.long) @@ -969,13 +960,18 @@ def infer_one_step(self, current_input_emb += self.model.stt_model.embed_tokens(pad_token).to(dtype=self.dtype) else: # t > 0: add embeddings from model's own predictions at t-1 - last_token_emb = self.model.stt_model.embed_tokens(gen_text[:, current_frame_idx - 1]) - current_input_emb += last_token_emb - last_asr_token_emb = self.model.stt_model.embed_asr_tokens(gen_asr_text[:, current_frame_idx - 1]) - current_input_emb += last_asr_token_emb + last_token_emb = self.model.stt_model.embed_tokens( + gen_text[:, current_frame_idx - 1] + ) * self.model.stt_model.cfg.get("duplex_text_channel_weight", 1.0) + last_asr_token_emb = self.model.stt_model.embed_asr_tokens( + gen_asr_text[:, current_frame_idx - 1] + ) * self.model.stt_model.cfg.get("duplex_asr_text_weight", 1.0) + current_input_emb += last_token_emb + last_asr_token_emb if has_fc: last_fc_token_emb = self.model.stt_model.embed_tokens(gen_function_text[:, current_frame_idx - 1]) current_input_emb += last_fc_token_emb.to(dtype=self.dtype) + if return_debug: + debug_input_embeds.append(current_input_emb.detach().cpu()) start_stt_model = time.time() @@ -997,7 +993,8 @@ def infer_one_step(self, cache=dynamic_cache, cache_position=cache_pos, generated_tokens=gen_text, - current_step=current_frame_idx + current_step=current_frame_idx, + return_logits=return_debug, ) dynamic_cache = ans["cache"] else: @@ -1007,7 +1004,8 @@ def infer_one_step(self, full_input_embeds, cache=None, generated_tokens=gen_text, - current_step=current_frame_idx + current_step=current_frame_idx, + return_logits=return_debug, ) torch.cuda.synchronize() @@ -1016,6 +1014,10 @@ def infer_one_step(self, predicted_token = ans["predicted_token"] asr_predicted_token = ans["asr_predicted_token"] + if return_debug and "text_logits" in ans: + debug_text_logits.append(ans["text_logits"][:, -1].detach().cpu()) + if return_debug and "asr_logits" in ans and ans["asr_logits"] is not None: + debug_asr_logits.append(ans["asr_logits"][:, -1].detach().cpu()) gen_text[:, current_frame_idx] = predicted_token predicted_tokens[:, frame_offset] = predicted_token @@ -1183,6 +1185,16 @@ def infer_one_step(self, } if self.model.stt_model.function_head is not None: result['function_predicted_text_tokens'] = function_predicted_tokens + if return_debug: + result["debug"] = { + "source_encoded": source_encoded.detach().cpu(), + "selected_frame_indices": selected_frame_indices, + "input_embeds": torch.cat(debug_input_embeds, dim=1) if debug_input_embeds else None, + "gen_text": gen_text.detach().cpu(), + "gen_asr": gen_asr_text.detach().cpu() if gen_asr_text is not None else None, + "text_logits": torch.stack(debug_text_logits, dim=1) if debug_text_logits else None, + "asr_logits": torch.stack(debug_asr_logits, dim=1) if debug_asr_logits else None, + } return result def abort_request(self, request_id: Optional[str]) -> bool: @@ -1262,746 +1274,11 @@ def _maybe_apply_forced_turn_taking(self, t, gen_text, gen_asr): gen_text[batch_idx, t] = self.model.stt_model.text_eos_id logging.info(f"Forced turn-taking at frame {t}: inserted agent EOS (reason: user started speaking)") - @torch.no_grad() - def inference_realtime_streaming(self, audio_path: str, num_frames_per_chunk: int = None, request_id: Optional[str] = None, pad_audio_to_sec: Optional[float] = None, pad_silence_ratio: Optional[float] = None, pad_audio_by_sec: Optional[float] = None, system_prompt: Optional[str] = None): - """ - Perform realtime streaming inference simulating microphone capture. - - Args: - audio_path: Path to input audio file (simulates microphone input) - num_frames_per_chunk: Number of frames to process per inference step (default: 1) - request_id: Optional request ID for vLLM streaming - pad_audio_to_sec: Optional duration to pad audio to (in seconds) - pad_silence_ratio: Optional ratio of original duration to append as silence (e.g. 0.2 = 20%) - pad_audio_by_sec: Optional fixed number of extra seconds of silence to append - system_prompt: Optional system prompt to provide context to the model - - Returns: - Dictionary with 'text', 'tokens_text', 'tokens_audio', 'audio', 'audio_len', 'system_prompt' - """ - # Use provided value or default - if num_frames_per_chunk is None: - num_frames_per_chunk = DEFAULT_NUM_FRAMES_PER_CHUNK - if num_frames_per_chunk < 1: - raise ValueError("num_frames_per_chunk must be at least 1") - start_time = time.time() - - logging.info("\n" + "=" * 70) - logging.info("STARTING REALTIME STREAMING INFERENCE") - logging.info("=" * 70) - - # Set up request ID for vLLM streaming - stream_request_id = request_id or self.request_id - - buffer_size_frames = int(self.model_cfg.get("buffer_size_frames", DEFAULT_BUFFER_SIZE_FRAMES)) - buffer_size_samples = buffer_size_frames * FRAME_SIZE_SAMPLES - if num_frames_per_chunk > buffer_size_frames: - raise ValueError( - f"num_frames_per_chunk ({num_frames_per_chunk}) must be " - f"less than or equal to buffer_size_frames ({buffer_size_frames})." - ) - - att_context_size = self.model.stt_model.perception.encoder._cfg.att_context_size - if self.use_perception_cache: - min_buffer = num_frames_per_chunk * (att_context_size[1] + 1) + 2 - reason = ( - f"must be >= num_frames_per_chunk * (att_context_size[1] + 1) + 2 = " - f"{num_frames_per_chunk} * ({att_context_size[1]} + 1) + 2 = {min_buffer} " - f"when using perception cache (+2 to minimize windowing artifacts)" - ) - else: - min_buffer = att_context_size[0] + att_context_size[1] + 1 - reason = ( - f"must be >= att_context_size[0] + att_context_size[1] + 1 = " - f"{att_context_size[0]} + {att_context_size[1]} + 1 = {min_buffer} " - f"without perception cache" - ) - if buffer_size_frames < min_buffer: - raise ValueError( - f"buffer_size_frames ({buffer_size_frames}) is too small: {reason}." - ) - if self.decode_audio and not self.use_codec_cache and num_frames_per_chunk > self.codec_token_history_size: - raise ValueError( - f"num_frames_per_chunk ({num_frames_per_chunk}) must be " - f"<= codec_token_history_size ({self.codec_token_history_size}) when decode_audio=True " - f"and use_codec_cache=False. " - f"Either reduce num_frames_per_chunk, increase codec_token_history_size, or enable use_codec_cache." - ) - logging.info(f"Buffer size: {buffer_size_frames} frames ({buffer_size_frames * FRAME_SIZE_SEC}s)") - logging.info(f"Frames per inference step: {num_frames_per_chunk}") - - # Load audio file (simulating microphone stream) - logging.info(f"Loading audio file: {audio_path}") - audio_signal, sr = librosa.load(audio_path, sr=SAMPLE_RATE) - total_samples = len(audio_signal) - total_duration = total_samples / SAMPLE_RATE - - logging.info(f" Total duration: {total_duration:.2f}s") - logging.info(f" Total samples: {total_samples}") - - # Optionally pad audio (at most one of these is set; enforced by caller) - if pad_audio_to_sec is not None and pad_audio_to_sec > total_duration: - target_samples = int(pad_audio_to_sec * SAMPLE_RATE) - audio_signal = np.pad(audio_signal, (0, target_samples - total_samples), mode='constant') - total_samples = len(audio_signal) - logging.info(f" Padded to {pad_audio_to_sec:.2f}s ({total_samples} samples)") - elif pad_silence_ratio is not None: - extra_samples = int(total_duration * pad_silence_ratio * SAMPLE_RATE) - audio_signal = np.pad(audio_signal, (0, extra_samples), mode='constant') - total_samples = len(audio_signal) - logging.info(f" Padded with {pad_silence_ratio*100:.1f}% extra silence ({extra_samples} samples)") - elif pad_audio_by_sec is not None: - extra_samples = int(pad_audio_by_sec * SAMPLE_RATE) - audio_signal = np.pad(audio_signal, (0, extra_samples), mode='constant') - total_samples = len(audio_signal) - logging.info(f" Padded with {pad_audio_by_sec:.2f}s extra silence ({extra_samples} samples)") - - # derive num_inference_steps - total_frames_maybe = int(np.ceil(total_samples / FRAME_SIZE_SAMPLES)) # "maybe" because we might need to add padding - num_inference_steps = (total_frames_maybe // num_frames_per_chunk) - if total_frames_maybe % num_frames_per_chunk != 0: - num_inference_steps += 1 - total_frames = num_inference_steps * num_frames_per_chunk - - # pad audio signal so that it is divisible by num_inference_steps - padded_total_samples = num_inference_steps * num_frames_per_chunk * FRAME_SIZE_SAMPLES - if padded_total_samples > total_samples: - audio_signal = np.pad(audio_signal, (0, padded_total_samples - total_samples), mode='constant') - logging.info(f" Padded to: {padded_total_samples} samples") - logging.info(f" {num_frames_per_chunk=} => {total_frames=}, {num_inference_steps=}") - - # convert audio signal to tensor - audio_signal_tensor = torch.tensor(audio_signal, dtype=self.dtype, device=self.device).unsqueeze(0) - - use_cache = self.use_llm_cache - is_nemotron = 'Nemotron' in self.model.stt_model.cfg.pretrained_llm - logging.info(f"Model: {self.model.stt_model.cfg.pretrained_llm}") - logging.info(f" Use LLM cache: {use_cache}, is_nemotron: {is_nemotron}") - - # Initialize buffer and state - audio_buffer = torch.zeros(1, buffer_size_samples, dtype=self.dtype, device=self.device) - buffer_fill_level = 0 # How many samples currently in buffer - - # Initialize LLM cache (skip for vLLM -- it manages its own KV cache) - if not use_cache or self.use_vllm_llm: - llm_cache = None - if not use_cache: - input_embeds_history = [] - elif is_nemotron: - llm_cache = self.model.stt_model._create_nemotron_cache(batch_size=1) - else: - llm_cache = DynamicCache() - cache_position_offset = 0 - - # Process system prompt if provided (before streaming audio) - prompt_embedded = None - prompt_len = 0 - - if system_prompt: - start_get_prompt_embeddings = time.time() - prompt_embedded, prompt_len = self._prepare_system_prompt_embeddings(system_prompt) - logging.info(f"Time taken to get prompt embeddings: {time.time() - start_get_prompt_embeddings:.3f}s") - if prompt_embedded is not None and "vllm" in self.engine_type.lower(): - # Prepare token IDs for the prompt - prompt_token_ids = ( - [self.tokenizer.bos_id] + - self.tokenizer.text_to_ids(system_prompt) + - [self.tokenizer.eos_id] - ) - - # For vLLM mode: use efficient BATCH prefill (~20x faster than sequential) - logging.info(f" Batch prefilling {prompt_len} prompt embeddings...") - start_batch_prefill = time.time() - with torch.no_grad(): - success = self.model_llm_interface( - prompt_embedded, - request_id=stream_request_id, - decode_steps=0, - prompt_token_ids=prompt_token_ids, - ) - logging.info(f"Time taken to batch prefill stt model: {time.time() - start_batch_prefill:.3f}s") - if success: - logging.info(f" System prompt prefilled ({prompt_len} tokens)") - else: - raise RuntimeError("vLLM batch prefill for system prompt failed.") - elif prompt_embedded is not None and not use_cache: - # For no-cache mode (Nemotron): add prompt embeddings to history - # Split into individual frames for consistent processing - for t in range(prompt_len): - input_embeds_history.append(prompt_embedded[:, t:t+1, :]) - logging.info(f" Added {prompt_len} prompt embeddings to input_embeds_history") - elif prompt_embedded is not None and use_cache: - # For cache mode: process prompt through LLM to update cache - with torch.no_grad(): - cache_pos = torch.arange(prompt_len, device=self.device) - ans = self.model.stt_model(prompt_embedded, cache=llm_cache, cache_position=cache_pos) - llm_cache = ans.get("cache", llm_cache) - cache_position_offset = prompt_len - logging.info(f" System prompt processed, cache updated (offset={cache_position_offset})") - - # Initialize TTS - code = None - past_key_values = None - subword_mask = None - audio_toks_buffer = None - if self.decode_audio and hasattr(self.model, 'tts_model'): - - # Sliding-window buffer is only needed when codec_cache is off - if not self.use_codec_cache: - audio_toks_buffer = self.model.tts_model.codec_silence_tokens.view(1, 1, -1).expand( - -1, self.codec_token_history_size, -1 - ).to(self.device) - - if ( - self.first_context_subword_id is None - or self.generation_config is None - or self.first_tts_code_input is None - or self.first_tts_past_key_values_input is None - ) and not self.use_vllm_eartts: - raise RuntimeError("TTS warmup state was not prepared during initialization.") - - if not self.use_vllm_eartts: - past_key_values = self._clone_cache(self.first_tts_past_key_values_input) - code = self.first_tts_code_input.detach().clone() - else: - start_batch_prefill = time.time() - logging.info(f" Batch prefilling TTS model with speaker embedding...") - # use speaker embedding to prefill EarTTS's vLLM - tts_result = self.model.tts_model.tts_model( - self.tts_init_inputs, - request_id=stream_request_id, - prompt_token_ids=self.tts_prompt_token_ids - ) - code = self.first_tts_code_input.detach().clone() - past_key_values = None - logging.info(f"Time taken to batch prefill tts model: {time.time() - start_batch_prefill:.3f}s") - # Initialize subword_mask for vLLM path as well - subword_mask = torch.ones(1, total_frames, device=self.device, dtype=torch.bool) - logging.info(f"TTS initialized") - - # Initialize perception cache if enabled - perception_cache = None - if self.use_perception_cache: - perception_cache = self.perception_cache_mgr.get_initial_state(batch_size=1) - logging.info(f"Perception cache initialized") - - # Initialize codec streaming cache to remove clicking sounds and wasted inference computation - codec_cache = None - if self.decode_audio and self.use_codec_cache: - from nemo.collections.speechlm2.modules.ear_tts_vae_codec import CausalConv1dCache - codec_cache = CausalConv1dCache() - logging.info(f"Codec streaming cache initialized") - - gen_text = torch.full((1, total_frames), self.model.stt_model.text_pad_id, device=self.device, dtype=torch.long) - gen_asr_text = torch.full((1, total_frames), self.model.stt_model.text_pad_id, device=self.device, dtype=torch.long) - has_function_head = self.model.stt_model.function_head is not None - if has_function_head: - gen_function_text = torch.full((1, total_frames), self.model.stt_model.text_pad_id, device=self.device, dtype=torch.long) - - # initialize list to which we will append generated audio segments - audio_segments = [] - - logging.info("\n" + "=" * 70) - logging.info("STARTING FRAME-BY-FRAME PROCESSING") - logging.info("=" * 70) - - # frame_idx corresponds to index of the first frame passed to infer_one_step - # (we need this distinction in the case that num_frames_per_chunk > 1) - frame_idx = 0 - while frame_idx < total_frames: - slice_start = frame_idx * FRAME_SIZE_SAMPLES - slice_n_samples = num_frames_per_chunk * FRAME_SIZE_SAMPLES - slice_end = slice_start + slice_n_samples - new_audio = audio_signal_tensor[:, slice_start:slice_end] - - audio_buffer, buffer_fill_level, current_buffer = self._update_audio_buffer( - audio_buffer, buffer_fill_level, new_audio, buffer_size_samples - ) - - result = self.infer_one_step( - audio_input=current_buffer, - num_frames_per_chunk=num_frames_per_chunk, - frame_idx=frame_idx, - gen_text=gen_text, - audio_toks_buffer=audio_toks_buffer if self.decode_audio else None, - input_embeds_history=input_embeds_history if not use_cache else [], - dynamic_cache=llm_cache if use_cache else None, - past_key_values=past_key_values if self.decode_audio else None, - code=code if self.decode_audio else None, - subword_mask=subword_mask if self.decode_audio else None, - gen_asr_text=gen_asr_text, - gen_function_text=gen_function_text if has_function_head else None, - request_id=stream_request_id, - perception_cache=perception_cache, - has_prompt=(prompt_len > 0), - codec_cache=codec_cache, - cache_position_offset=cache_position_offset, - ) - - # handle results from infer_one_step - if has_function_head and 'function_predicted_text_tokens' in result: - for fi in range(num_frames_per_chunk): - gen_function_text[:, frame_idx + fi] = result['function_predicted_text_tokens'][:, fi] - input_embeds_history = result['input_embeds_history'] - llm_cache = result['dynamic_cache'] - cache_position_offset = result.get('cache_position_offset', cache_position_offset) - if self.use_perception_cache: - perception_cache = result.get('perception_cache', perception_cache) - if self.decode_audio: - audio_toks_buffer = result['audio_toks_buffer'] - decoded_audio_new = result['decoded_audio_new'] - if decoded_audio_new is not None: - audio_segments.append(decoded_audio_new) - - past_key_values = result['past_key_values'] - code = result['code'] - codec_cache = result.get('codec_cache', codec_cache) - else: - decoded_audio_new = None - - if frame_idx % 10 == 0 or frame_idx < 3 or gen_text[:, frame_idx].item() == self.model.stt_model.text_eos_id: - token_str = self.tokenizer.ids_to_text([gen_text[0, frame_idx].item()]) - buffer_status = f"{buffer_fill_level}/{buffer_size_samples}" if buffer_fill_level < buffer_size_samples else "FULL" - special_label = "" - if gen_text[0, frame_idx].item() == self.model.stt_model.text_bos_id: - special_label = " [BOS]" - elif gen_text[0, frame_idx].item() == self.model.stt_model.text_eos_id: - special_label = " [EOS]" - elif gen_text[0, frame_idx].item() == self.model.stt_model.text_pad_id: - special_label = " [PAD]" - logging.info(f"Frame {frame_idx:3d}/{total_frames} | Buffer: {buffer_status:20s} | Token: {gen_text[0, frame_idx].item():5d}{special_label} | '{token_str}'") - - frame_idx += num_frames_per_chunk - - # Prepare results - elapsed_time = time.time() - start_time - logging.info("\n" + "=" * 70) - logging.info("STREAMING INFERENCE COMPLETED") - logging.info("=" * 70) - logging.info(f"Total time: {elapsed_time:.2f}s") - logging.info(f"Audio duration: {total_duration:.2f}s") - logging.info(f"RTF (Real-Time Factor): {elapsed_time / total_duration:.2f}x") - logging.info(f"Processed frames: {total_frames}") - - # Trim to actual length - # TODO: this is currently redundant since we iterate over all frames in the while loop - gen_text = gen_text[:, :total_frames] - gen_asr_text = gen_asr_text[:, :total_frames] - - # Decode text - lengths = torch.tensor([total_frames], dtype=torch.long, device=self.device) - text_output = tokens_to_str(gen_text, lengths, tokenizer=self.tokenizer, pad_id=self.model.stt_model.text_pad_id, eval_text_turn_taking=True) - - # Decode ASR text - asr_text_output = tokens_to_str(gen_asr_text, lengths, tokenizer=self.tokenizer, pad_id=self.model.stt_model.text_pad_id, eval_text_turn_taking=True) - - # Also create raw versions with kept for comparison - text_output_raw = tokens_to_str_raw(gen_text, lengths, tokenizer=self.tokenizer, pad_id=self.model.stt_model.text_pad_id) - asr_text_output_raw = tokens_to_str_raw(gen_asr_text, lengths, tokenizer=self.tokenizer, pad_id=self.model.stt_model.text_pad_id) - - logging.info(f"Generated text: {text_output[0]}") - logging.info(f"Generated ASR text: {asr_text_output[0]}") - - # Decode function calling channel - if has_function_head: - gen_function_text = gen_function_text[:, :total_frames] - function_text_output = tokens_to_str(gen_function_text, lengths, tokenizer=self.tokenizer, pad_id=self.model.stt_model.text_pad_id, eval_text_turn_taking=False) - function_text_output_raw = tokens_to_str_raw(gen_function_text, lengths, tokenizer=self.tokenizer, pad_id=self.model.stt_model.text_pad_id) - logging.info(f"Generated function text: {function_text_output[0]}") - - ans = { - "text": text_output, - "text_raw": text_output_raw, - "tokens_text": gen_text, - "tokens_len": lengths, - "audio": torch.cat(audio_segments, dim=-1) if audio_segments else None, - "asr_text": asr_text_output, - "asr_text_raw": asr_text_output_raw, - "asr_tokens": gen_asr_text, - "system_prompt": system_prompt if system_prompt else "", - } - if has_function_head: - ans["function_text"] = function_text_output - ans["function_text_raw"] = function_text_output_raw - ans["function_tokens"] = gen_function_text - - if self.use_vllm_llm or self.use_vllm_eartts: - self.abort_request(stream_request_id) - - return ans - - def main(): - parser = argparse.ArgumentParser(description="Realtime Streaming Inference") - parser.add_argument("--model_path", type=str, required=True, - help="Path to eartts's checkpoint with TTS (HF format)") - parser.add_argument("--llm_checkpoint_path", type=str, required=True, - help="Path to checkpoint with LLM/perception (HF format)") - parser.add_argument("--audio_path", type=str, default=None, - help="Path to input audio file (for single-file mode)") - parser.add_argument("--input_json", type=str, default=None, - help="Path to input JSON file containing list of records with audio_filepath and text fields (for batch mode)") - parser.add_argument("--output_json", type=str, default=None, - help="Path to output JSON file with predictions") - parser.add_argument("--output_dir", type=str, default="output_streaming", - help="Output directory for audio files and JSON results") - parser.add_argument("--pad_audio_to_sec", type=float, default=None, - help="Pad audio to this duration in seconds (useful for consistent buffer behavior)") - parser.add_argument("--pad_silence_ratio", type=float, default=None, - help="Append silence equal to this ratio of the original audio duration (e.g. 0.2 = 20%% extra)") - parser.add_argument("--pad_audio_by_sec", type=float, default=None, - help="Append this many seconds of extra silence after the audio") - parser.add_argument("--speaker_reference", type=str, default=None, - help="Path to speaker reference audio file") - parser.add_argument("--speaker_name", type=str, default=None, - help="Name of a registered speaker whose latent is cached in the checkpoint") - parser.add_argument("--buffer_size_frames", type=int, default=DEFAULT_BUFFER_SIZE_FRAMES, - help=f"Size of audio buffer in frames (each frame = 80ms, default: {DEFAULT_BUFFER_SIZE_FRAMES})") - parser.add_argument("--num_frames_per_chunk", type=int, default=DEFAULT_NUM_FRAMES_PER_CHUNK, - help="Number of frames per inference step (default: 1)") - parser.add_argument("--decode_audio", action="store_true", - help="Whether to decode audio") - parser.add_argument("--combine_inp_out_audio", action="store_true", - help="Whether to combine input and output audio into a stereo file") - - # Deterministic inference - parser.add_argument("--deterministic", action="store_true", - help="Enable fully deterministic inference (disables FlashAttention, forces deterministic " - "CUDA algorithms). Useful for reproducible benchmarking. Not compatible with vLLM engines. " - "Note: results may differ slightly from non-deterministic mode due to different compute path.") - - # Perception cache argument - parser.add_argument("--use_perception_cache", action="store_true", - help="Enable cache-aware streaming for perception encoder") - parser.add_argument("--use_perception_cudagraph", action="store_true", - help="Use CUDA graphs for perception encoder (requires --use_perception_cache)") - # LLM KV cache argument - parser.add_argument("--use_llm_cache", action="store_true", - help="Use KV cache for the STT LLM (DynamicCache or HybridMambaAttentionDynamicCache for Nemotron)") - # Codec streaming cache argument - parser.add_argument("--use_codec_cache", action="store_true", - help="Enable incremental codec decode to remove clicking sounds and wasted inference computation (recommended)") - - # torch.compile for native inference - parser.add_argument("--use_tts_torch_compile", action="store_true", - help="Compile TTS backbone with torch.compile for faster native inference (mode='default')") - - # TTS model speedup flags (applied inside ear_tts_model.py) - parser.add_argument("--use_tts_subword_cache", action="store_true", - help="Cache CharAwareSubwordEncoder embeddings at inference time (skip backbone for repeated tokens)") - - # vLLM arguments - parser.add_argument("--engine_type", type=str, default="native", choices=["native", "vllm_llm", "vllm_eartts", "vllm_llm_vllm_eartts"], - help="Engine type for inference (default: native)") - parser.add_argument("--vllm_llm_engine_path", type=str, default=None, - help="Path to vLLM-compatible model checkpoint if the path not exists, it will be auto-converted") - parser.add_argument("--vllm_max_model_len", type=int, default=768, - help="Maximum sequence length for vLLM (default: 768)") - parser.add_argument("--vllm_gpu_memory_utilization", type=float, nargs='+', default=[0.4], - help="GPU memory utilization for vLLM. Single value shared by both engines; two values assign to LLM and TTS respectively.") - parser.add_argument("--vllm_llm_dtype", type=str, default="bfloat16", - help="Data type for vLLM (default: bfloat16)") - - # vLLM EarTTS arguments - parser.add_argument("--vllm_eartts_engine_path", type=str, default=None, - help="Path to vLLM-compatible EarTTS model checkpoint if the path not exists, it will be auto-converted") - parser.add_argument("--vllm_eartts_dtype", type=str, default="float32", - help="Data type for vLLM (default: float32)") - - # Sampling parameters - parser.add_argument("--top_p", type=float, default=1.0, - help="Top-p (nucleus) sampling threshold. 1.0 disables it (greedy). Default: 1.0") - parser.add_argument("--repetition_penalty", type=float, default=1.0, - help="Repetition penalty for generated tokens. 1.0 disables it. Default: 1.0. Recommended: 1.2") - parser.add_argument("--temperature", type=float, default=1.0, - help="Temperature for sampling. 1.0 = no change, <1.0 = sharper, >1.0 = flatter, 0.0 = greedy. Default: 1.0") - - # Turn-taking - parser.add_argument("--force_turn_taking", action="store_true", - help="Enable forced turn-taking based on ASR channel tokens") - parser.add_argument("--force_turn_taking_threshold", type=int, default=40, - help="Number of lookback steps for turn-taking detection (default: 40)") - parser.add_argument("--force_turn_taking_pad_window", type=int, default=25, - help="Number of consecutive ASR pad tokens to trigger turn-taking (default: 25)") - - # Inference logit boosts - parser.add_argument("--inference_pad_boost", type=float, default=None, - help="Boost for agent pad logit at inference time") - parser.add_argument("--inference_bos_boost", type=float, default=None, - help="Boost for agent BOS logit at inference time") - parser.add_argument("--inference_eos_boost", type=float, default=None, - help="Boost for agent EOS logit at inference time") - parser.add_argument("--inference_user_pad_boost", type=float, default=None, - help="Boost for ASR pad logit at inference time") - parser.add_argument("--inference_user_bos_boost", type=float, default=None, - help="Boost for ASR BOS logit at inference time") - parser.add_argument("--inference_user_eos_boost", type=float, default=None, - help="Boost for ASR EOS logit at inference time") - - # System prompt - parser.add_argument("--system_prompt", type=str, default=None, - help="System prompt to provide context to the model. Can also be specified per-record in input JSON.") - parser.add_argument("--tts_system_prompt", type=str, default=None, - help="System prompt for EARTTS model.") - args = parser.parse_args() - - # Validate arguments: either audio_path OR input_json must be provided - if args.audio_path is None and args.input_json is None: - parser.error("Either --audio_path (single-file mode) or --input_json (batch mode) must be provided") - if args.audio_path is not None and args.input_json is not None: - parser.error("Cannot use both --audio_path and --input_json at the same time") - - if sum(x is not None for x in [args.pad_audio_to_sec, args.pad_silence_ratio, args.pad_audio_by_sec]) > 1: - raise ValueError("Set at most one of: --pad_audio_to_sec, --pad_silence_ratio, --pad_audio_by_sec") - if args.speaker_reference is None and args.speaker_name is None: - parser.error("At least one of --speaker_reference or --speaker_name must be provided") - if not math.isfinite(args.temperature) or args.temperature < 0.0: - parser.error(f"--temperature must be a finite value >= 0.0, got {args.temperature}") - - try: - import json - import soundfile as sf - - model_cfg_dict = { - "model_path": args.model_path, - "llm_checkpoint_path": args.llm_checkpoint_path, - "speaker_reference": args.speaker_reference, - "speaker_name": args.speaker_name, - "buffer_size_frames": args.buffer_size_frames, - "decode_audio": bool(args.decode_audio), - "engine_type": args.engine_type, - "deterministic": bool(args.deterministic), - "use_perception_cache": bool(args.use_perception_cache), - "use_perception_cudagraph": bool(args.use_perception_cudagraph), - "use_llm_cache": bool(args.use_llm_cache), - "use_codec_cache": bool(args.use_codec_cache), - "use_tts_torch_compile": bool(args.use_tts_torch_compile), - "use_tts_subword_cache": bool(args.use_tts_subword_cache), - "top_p": args.top_p, - "repetition_penalty": args.repetition_penalty, - "temperature": args.temperature, - "tts_system_prompt": args.tts_system_prompt, - "force_turn_taking": args.force_turn_taking, - "force_turn_taking_threshold": args.force_turn_taking_threshold, - "force_turn_taking_pad_window": args.force_turn_taking_pad_window, - "inference_pad_boost": args.inference_pad_boost, - "inference_bos_boost": args.inference_bos_boost, - "inference_eos_boost": args.inference_eos_boost, - "inference_user_pad_boost": args.inference_user_pad_boost, - "inference_user_bos_boost": args.inference_user_bos_boost, - "inference_user_eos_boost": args.inference_user_eos_boost, - } - - # Pop GPU memory utilization values: first for LLM, second (or same) for TTS - _gpu_mem = list(args.vllm_gpu_memory_utilization) - gpu_mem_llm = _gpu_mem.pop(0) - gpu_mem_tts = _gpu_mem.pop(0) if _gpu_mem else gpu_mem_llm - - # Add vLLM configuration if using vLLM engine - if "vllm_llm" in args.engine_type: - model_cfg_dict["vllm_llm_config"] = { - "model_path": args.model_path, - "max_model_len": args.vllm_max_model_len, - "gpu_memory_utilization": gpu_mem_llm, - "dtype": args.vllm_llm_dtype, - "engine_path": args.vllm_llm_engine_path, # Will auto-convert if needed - "pretrained_llm": args.llm_checkpoint_path, - } - - if "vllm_eartts" in args.engine_type: - model_cfg_dict["vllm_tts_config"] = { - "model_path": args.model_path, # we use exactly the same whole duplexs2s ckpt - "max_model_len": args.vllm_max_model_len, - "gpu_memory_utilization": gpu_mem_tts, - "dtype": args.vllm_eartts_dtype, - "engine_path": args.vllm_eartts_engine_path, - "pretrained_llm": None, - "skip_tokenizer_init": True - } - - model_cfg = OmegaConf.create(model_cfg_dict) - - model = NemotronVoicechatInferenceWrapper(model_cfg=model_cfg) - - # ========================================= - # Load input records (from JSON manifest or single audio file) - # ========================================= - if args.input_json is not None: - logging.info(f"Loading input JSON: {args.input_json}") - with open(args.input_json, 'r') as f: - input_records = [json.loads(line) for line in f] - else: - input_records = [{"audio_filepath": args.audio_path, "text": ""}] - - logging.info(f"Found {len(input_records)} records to process") - - os.makedirs(args.output_dir, exist_ok=True) - - if args.output_json: - base_path = args.output_json.rsplit('.', 1)[0] if '.' in args.output_json else args.output_json - output_json_processed = f"{base_path}_processed.json" - output_json_raw = f"{base_path}_raw.json" - else: - output_json_processed = os.path.join(args.output_dir, "output_results_processed.json") - output_json_raw = os.path.join(args.output_dir, "output_results_raw.json") - - logging.info(f"Output will be saved incrementally to:") - logging.info(f" Processed: {output_json_processed}") - logging.info(f" Raw: {output_json_raw}") - output_file_processed = open(output_json_processed, 'w', encoding='utf-8') - output_file_raw = open(output_json_raw, 'w', encoding='utf-8') - - output_records = [] - wer_scores = [] - - try: - for idx, record in enumerate(input_records): - logging.info("\n" + "=" * 70) - logging.info(f"Processing record {idx + 1}/{len(input_records)}") - logging.info("=" * 70) - - audio_path = record.get('audio_filepath') - ground_truth_text = record.get('text', '') - record_system_prompt = record.get('system_prompt', args.system_prompt) - - if not audio_path: - logging.warning(f"Record {idx} missing audio_filepath, skipping...") - continue - - if not os.path.exists(audio_path): - logging.warning(f"Audio file not found: {audio_path}, skipping...") - continue - - logging.info(f" Audio: {audio_path}") - logging.info(f" Ground truth: {ground_truth_text}") - - audio_id = os.path.splitext(os.path.basename(audio_path))[0] - - results = model.inference_realtime_streaming( - audio_path, - num_frames_per_chunk=args.num_frames_per_chunk, - pad_audio_to_sec=args.pad_audio_to_sec, - pad_silence_ratio=args.pad_silence_ratio, - pad_audio_by_sec=args.pad_audio_by_sec, - request_id=f"streaming_request_{idx}", - system_prompt=record_system_prompt, - ) - - pred_asr_text = results['asr_text'][0] if 'asr_text' in results else '' - pred_asr_text_raw = results['asr_text_raw'][0] if 'asr_text_raw' in results else '' - pred_text = results['text'][0] if 'text' in results else '' - pred_text_raw = results['text_raw'][0] if 'text_raw' in results else '' - - try: - cleaned_pred = clean_pred_text(pred_asr_text) - cleaned_gt = clean_pred_text(ground_truth_text) - if cleaned_gt.strip() and cleaned_pred.strip(): - utterance_wer = wer(cleaned_gt, cleaned_pred) - wer_scores.append(utterance_wer) - else: - utterance_wer = None - except Exception as e: - utterance_wer = None - logging.warning(f"Error calculating WER: {e}") - - if utterance_wer is not None: - logging.info(f"WER for utterance {idx + 1}: {utterance_wer:.4f} ({utterance_wer * 100:.2f}%)") - - pred_audio_path = None - if args.decode_audio and 'audio' in results and results['audio'] is not None: - input_basename = os.path.splitext(os.path.basename(audio_path))[0] - audio_filename = f"{idx:04d}_{input_basename}_output.wav" - pred_audio_path = os.path.join(args.output_dir, audio_filename) - - audio_np = results['audio'].float().cpu().numpy().flatten() - - sf.write(pred_audio_path, audio_np, model.target_sample_rate) - logging.info(f"Audio saved: {pred_audio_path}") - - if args.combine_inp_out_audio: - stereo_filename = f"{idx:04d}_{input_basename}_combined.wav" - stereo_path_out = os.path.join(args.output_dir, stereo_filename) - - inp_audio, sr = librosa.load(audio_path, sr=model.target_sample_rate) - - delay_samples = int(args.num_frames_per_chunk * FRAME_SIZE_SEC * model.target_sample_rate) - out_audio_delayed = np.concatenate([np.zeros(delay_samples, dtype=audio_np.dtype), audio_np]) - - max_len = max(len(inp_audio), len(out_audio_delayed)) - inp_audio_padded = np.pad(inp_audio, (0, max_len - len(inp_audio))) - out_audio_padded = np.pad(out_audio_delayed, (0, max_len - len(out_audio_delayed))) - - stereo_audio = np.stack([inp_audio_padded, out_audio_padded], axis=1) - sf.write(stereo_path_out, stereo_audio, model.target_sample_rate) - logging.info(f"Stereo audio saved: {stereo_path_out}") - - result_system_prompt = results.get('system_prompt', '') - - output_record_processed = { - 'id': audio_id, - 'target_text': '', - 'pred_audio': pred_audio_path, - 'src_text': ground_truth_text, - 'pred_src_text': pred_asr_text, - 'pred_text': pred_text, - 'system_prompt': result_system_prompt, - } - - output_record_raw = { - 'id': audio_id, - 'target_text': '', - 'pred_audio': pred_audio_path, - 'src_text': ground_truth_text, - 'pred_src_text': pred_asr_text_raw, - 'pred_text': pred_text_raw, - 'system_prompt': result_system_prompt, - } - - output_records.append(output_record_processed) - - json.dump(output_record_processed, output_file_processed, ensure_ascii=False) - output_file_processed.write('\n') - output_file_processed.flush() - - json.dump(output_record_raw, output_file_raw, ensure_ascii=False) - output_file_raw.write('\n') - output_file_raw.flush() - - logging.info(f"Record {idx + 1} completed and saved") - - finally: - output_file_processed.close() - output_file_raw.close() - - logging.info("\n" + "=" * 70) - logging.info("ALL RESULTS SAVED") - logging.info("=" * 70) - logging.info(f"Results saved to:") - logging.info(f" Processed: {output_json_processed}") - logging.info(f" Raw: {output_json_raw}") - logging.info(f" Processed {len(output_records)}/{len(input_records)} records successfully") - - if wer_scores: - avg_wer = np.mean(wer_scores) - logging.info("\n" + "=" * 70) - logging.info("WER STATISTICS") - logging.info("=" * 70) - logging.info(f" Total utterances with WER: {len(wer_scores)}") - logging.info(f" Average WER: {avg_wer:.4f} ({avg_wer * 100:.2f}%)") - logging.info(f" Min WER: {np.min(wer_scores):.4f} ({np.min(wer_scores) * 100:.2f}%)") - logging.info(f" Max WER: {np.max(wer_scores):.4f} ({np.max(wer_scores) * 100:.2f}%)") - - logging.info("=" * 70) - logging.info("ALL DONE!") - logging.info("=" * 70) - - except Exception as e: - logging.error(f"ERROR during inference: {e}") - import traceback - traceback.print_exc() - return 1 - - return 0 + raise RuntimeError( + "This module cannot be called directly. " + "Use examples/speechlm2/nemo_inference_pipelines/s2s_streaming_infer.py instead." + ) if __name__ == "__main__": diff --git a/nemo/collections/speechlm2/inference/pipelines/streaming_s2s_pipeline.py b/nemo/collections/speechlm2/inference/pipelines/streaming_s2s_pipeline.py index ed49815246a5..1dba0d838c64 100644 --- a/nemo/collections/speechlm2/inference/pipelines/streaming_s2s_pipeline.py +++ b/nemo/collections/speechlm2/inference/pipelines/streaming_s2s_pipeline.py @@ -49,6 +49,7 @@ def __init__(self, cfg: DictConfig, s2s_model: NemotronVoicechatInferenceWrapper # ------------------------------------------------------------------ self.s2s_model = s2s_model self.device = self.s2s_model.device + self.collect_debug = False # ------------------------------------------------------------------ # Streaming configuration @@ -58,6 +59,11 @@ def __init__(self, cfg: DictConfig, s2s_model: NemotronVoicechatInferenceWrapper self.output_sample_rate = getattr(self.streaming_cfg, "output_sample_rate", 22050) self.batch_size = getattr(self.streaming_cfg, "batch_size", 1) self.max_len = getattr(self.streaming_cfg, "max_len", 200) + if self.batch_size != 1: + raise ValueError( + "StreamingS2SPipeline currently supports only single-stream inference " + "(streaming.batch_size must be 1)." + ) # ------------------------------------------------------------------ @@ -238,8 +244,15 @@ def inner_generate_step(self, frames: List[Frame], buffers: List[Tensor], left_p has_prompt=has_prompt, codec_cache=context.codec_cache, cache_position_offset=context.cache_position_offset, + return_debug=self.collect_debug, ) + if self.collect_debug and "debug" in result: + state = self.get_or_create_state(stream_ids[0]) + if not hasattr(state, "debug_steps"): + state.debug_steps = [] + state.debug_steps.append(result["debug"]) + # Persist updated cache & clean finished streams self.context_manager.update_context(stream_ids, result, self.num_frames_per_chunk) @@ -583,6 +596,11 @@ def run( asr_texts_with_timestamps = [] raw_texts = [] raw_asr_texts = [] + token_texts = [] + token_asr_texts = [] + token_function_texts = [] + token_lengths = [] + audio_paths = [] tokenizer = self.s2s_model.tokenizer pad_id = self.s2s_model.model.stt_model.text_pad_id @@ -593,6 +611,7 @@ def run( if not text_value: text_value = saved_paths_by_stream.get(idx, "") texts.append(text_value) + audio_paths.append(saved_paths_by_stream.get(idx)) per_stream_words = state.get_output_words() if hasattr(state, "get_output_words") else [] words.append(per_stream_words) asr_text_value = state.get_output_asr_text() if hasattr(state, "get_output_asr_text") else "" @@ -601,6 +620,10 @@ def run( token_data = state.get_token_tensors() if token_data is not None: gen_text, gen_asr_text, total_frames, gen_function_text = token_data + token_texts.append(gen_text) + token_asr_texts.append(gen_asr_text) + token_function_texts.append(gen_function_text) + token_lengths.append(total_frames) lengths = torch.tensor([total_frames], dtype=torch.long) texts_with_timestamps.append( tokens_to_str(gen_text, lengths, tokenizer=tokenizer, pad_id=pad_id, eval_text_turn_taking=True)[0] @@ -619,11 +642,21 @@ def run( fc_text_raw = tokens_to_str_raw(gen_function_text, lengths, tokenizer=tokenizer, pad_id=pad_id)[0] logging.info(f"Function calling channel: {fc_text}") else: + token_texts.append(None) + token_asr_texts.append(None) + token_function_texts.append(None) + token_lengths.append(None) texts_with_timestamps.append("") asr_texts_with_timestamps.append("") raw_texts.append("") raw_asr_texts.append("") + debug_data = [] + if self.collect_debug: + for idx in range(len(audio_filepaths)): + state = self.get_or_create_state(idx) + debug_data.append(getattr(state, "debug_steps", [])) + self.close_session() return PipelineOutput( @@ -634,6 +667,12 @@ def run( asr_texts_with_timestamps=asr_texts_with_timestamps, raw_texts=raw_texts, raw_asr_texts=raw_asr_texts, + token_texts=token_texts, + token_asr_texts=token_asr_texts, + token_function_texts=token_function_texts, + token_lengths=token_lengths, + audio_filepaths=audio_paths, + debug_data=debug_data if debug_data else None, ) def _prefill_system_prompt(self, stream_id: int, system_prompt: str | None = None) -> Optional[torch.Tensor]: diff --git a/nemo/collections/speechlm2/inference/streaming/state/s2s_context_manager.py b/nemo/collections/speechlm2/inference/streaming/state/s2s_context_manager.py index 9d0282eb7951..da829e8c634d 100644 --- a/nemo/collections/speechlm2/inference/streaming/state/s2s_context_manager.py +++ b/nemo/collections/speechlm2/inference/streaming/state/s2s_context_manager.py @@ -17,9 +17,6 @@ from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING import torch -from transformers import DynamicCache - -from nemo.collections.speechlm2.modules.ear_tts_vae_codec import CausalConv1dCache from nemo.utils import logging if TYPE_CHECKING: @@ -27,7 +24,7 @@ @dataclass -class StreamingRealtimeContext: +class StreamingDecodeState: frame_idx: int gen_text: torch.Tensor gen_asr_text: torch.Tensor @@ -39,7 +36,7 @@ class StreamingRealtimeContext: code: Optional[torch.Tensor] subword_mask: Optional[torch.Tensor] perception_cache: Optional["PerceptionCacheState"] = None - codec_cache: Optional[CausalConv1dCache] = None + codec_cache: Any = None cache_position_offset: int = 0 @@ -57,29 +54,6 @@ def __init__( self.max_len = max_len self.device = getattr(self.s2s_model, "device", torch.device("cpu")) self.dtype = getattr(self.s2s_model, "dtype", torch.float32) - self.text_pad_id = getattr(getattr(self.s2s_model, "model", None), "text_pad_id", 0) - self.codec_token_history_size = int(getattr(self.s2s_model, "codec_token_history_size", 0)) - self.decode_audio = bool(getattr(self.s2s_model, "decode_audio", False)) - self.use_perception_cache = bool(getattr(self.s2s_model, "use_perception_cache", False)) - self.use_codec_cache = bool(getattr(self.s2s_model, "use_codec_cache", True)) - self.use_llm_cache = bool(getattr(self.s2s_model, "use_llm_cache", True)) - - self.is_nemotron = False - stt_model = getattr(self.s2s_model.model, "stt_model", None) - if stt_model is not None: - pretrained_llm = stt_model.cfg.get("pretrained_llm", "") - if "Nemotron" in pretrained_llm: - self.is_nemotron = True - if self.use_llm_cache: - logging.info( - f"Detected Nemotron model ({pretrained_llm}). " - "Will use HybridMambaAttentionDynamicCache for KV caching." - ) - else: - logging.info( - f"Detected Nemotron model ({pretrained_llm}). " - "LLM cache is disabled (use_llm_cache=False)." - ) self.reset() @@ -90,91 +64,27 @@ def reset(self) -> None: self.free_slots = Queue(self.num_slots) for i in range(self.num_slots): self.free_slots.put(i) - self.slot_contexts: List[Optional[StreamingRealtimeContext]] = [None] * self.num_slots + self.slot_contexts: List[Optional[StreamingDecodeState]] = [None] * self.num_slots - def _create_context(self) -> StreamingRealtimeContext: + def _create_context(self) -> StreamingDecodeState: """Allocate a fresh context backed by the realtime inference model.""" - gen_text = torch.full( - (1, self.max_len), - fill_value=self.text_pad_id, - device=self.device, - dtype=torch.long, - ) - - gen_asr_text = torch.full( - (1, self.max_len), - fill_value=self.text_pad_id, - device=self.device, - dtype=torch.long, - ) - - stt_model = getattr(self.s2s_model.model, "stt_model", None) - has_function_head = stt_model is not None and getattr(stt_model, "function_head", None) is not None - gen_function_text = None - if has_function_head: - gen_function_text = torch.full( - (1, self.max_len), - fill_value=self.text_pad_id, - device=self.device, - dtype=torch.long, - ) - - use_vllm_llm = bool(getattr(self.s2s_model, "use_vllm_llm", False)) - if not self.use_llm_cache or use_vllm_llm: - dynamic_cache = None - elif self.is_nemotron: - stt_model = getattr(self.s2s_model.model, "stt_model", None) - dynamic_cache = stt_model._create_nemotron_cache(batch_size=1) - else: - dynamic_cache = DynamicCache() - audio_toks_buffer: Optional[torch.Tensor] = None - past_key_values: Any = None - code: Optional[torch.Tensor] = None - subword_mask: Optional[torch.Tensor] = None - perception_cache = None - codec_cache = None - - if self.decode_audio and hasattr(getattr(self.s2s_model, "model", None), "tts_model"): - tts_model = self.s2s_model.model.tts_model - if self.use_codec_cache: - # Incremental decode path: CausalConv1dCache maintains all codec - # context internally, so no audio_toks_buffer is needed and - # codec_token_history_size is irrelevant. - codec_cache = CausalConv1dCache() - elif self.codec_token_history_size > 0: - # Sliding-window fallback: allocate silence buffer of - # codec_token_history_size tokens that is re-decoded every step. - silence_tokens_base = tts_model.codec_silence_tokens.detach().clone() - silence_tokens = silence_tokens_base.view(1, 1, -1).expand( - -1, self.codec_token_history_size, -1 - ).contiguous() # contiguous() ensures it's a real copy, not a view - audio_toks_buffer = silence_tokens.to(self.device).clone() - subword_mask = torch.ones((1, self.max_len), device=self.device, dtype=torch.bool) - - if getattr(self.s2s_model, "first_tts_past_key_values_input", None) is not None: - past_key_values = self.s2s_model._clone_cache(self.s2s_model.first_tts_past_key_values_input) - if getattr(self.s2s_model, "first_tts_code_input", None) is not None: - code = self.s2s_model.first_tts_code_input.detach().clone() - - # Initialize perception cache if enabled - if self.use_perception_cache: - mgr = getattr(self.s2s_model, "perception_cache_mgr", None) - if mgr is not None: - perception_cache = mgr.get_initial_state(batch_size=1) - - return StreamingRealtimeContext( - frame_idx=0, - gen_text=gen_text, - gen_asr_text=gen_asr_text, - gen_function_text=gen_function_text, - audio_toks_buffer=audio_toks_buffer, - input_embeds_history=[], - dynamic_cache=dynamic_cache, - past_key_values=past_key_values, - code=code, - subword_mask=subword_mask, - perception_cache=perception_cache, - codec_cache=codec_cache, + if not hasattr(self.s2s_model, "create_decode_state"): + raise RuntimeError("s2s_model must provide create_decode_state(max_len)") + decode_state = self.s2s_model.create_decode_state(self.max_len) + return StreamingDecodeState( + frame_idx=decode_state["frame_idx"], + gen_text=decode_state["gen_text"], + gen_asr_text=decode_state["gen_asr_text"], + gen_function_text=decode_state["gen_function_text"], + audio_toks_buffer=decode_state["audio_toks_buffer"], + input_embeds_history=decode_state["input_embeds_history"], + dynamic_cache=decode_state["dynamic_cache"], + past_key_values=decode_state["past_key_values"], + code=decode_state["code"], + subword_mask=decode_state["subword_mask"], + perception_cache=decode_state["perception_cache"], + codec_cache=decode_state["codec_cache"], + cache_position_offset=decode_state["cache_position_offset"], ) def _ensure_slot(self, stream_id: int) -> int: @@ -292,7 +202,7 @@ def reset_slots(self, stream_ids: List[int], eos_flags: List[bool]) -> None: if eos_flag and stream_id in self.streamidx2slotidx: self.reset_slot(self.streamidx2slotidx[stream_id]) - def get_context(self, stream_ids: List[int]) -> Tuple[StreamingRealtimeContext, Dict[int, int]]: + def get_context(self, stream_ids: List[int]) -> Tuple[StreamingDecodeState, Dict[int, int]]: """Return the cached context associated with the provided stream ids.""" if len(stream_ids) == 0: return self._create_context(), {} diff --git a/nemo/collections/speechlm2/inference/utils/pipeline_utils.py b/nemo/collections/speechlm2/inference/utils/pipeline_utils.py index 7022c6783975..4ad5634172f2 100644 --- a/nemo/collections/speechlm2/inference/utils/pipeline_utils.py +++ b/nemo/collections/speechlm2/inference/utils/pipeline_utils.py @@ -15,6 +15,8 @@ import re from typing import List, Optional +import torch + from nemo.collections.asr.inference.utils.text_segment import Word @@ -50,6 +52,12 @@ def __init__( asr_texts_with_timestamps: Optional[List[str]] = None, raw_texts: Optional[List[str]] = None, raw_asr_texts: Optional[List[str]] = None, + token_texts: Optional[List[torch.Tensor | None]] = None, + token_asr_texts: Optional[List[torch.Tensor | None]] = None, + token_function_texts: Optional[List[torch.Tensor | None]] = None, + token_lengths: Optional[List[int | None]] = None, + audio_filepaths: Optional[List[str | None]] = None, + debug_data: Optional[List[list]] = None, ): if texts is None and words is None: raise ValueError("At least one of the 'texts' or 'words' should be provided.") @@ -60,3 +68,9 @@ def __init__( self.asr_texts_with_timestamps = asr_texts_with_timestamps self.raw_texts = raw_texts self.raw_asr_texts = raw_asr_texts + self.token_texts = token_texts + self.token_asr_texts = token_asr_texts + self.token_function_texts = token_function_texts + self.token_lengths = token_lengths + self.audio_filepaths = audio_filepaths + self.debug_data = debug_data From a7c61d9ceb0e5498761c72739198e6f6e353c53f Mon Sep 17 00:00:00 2001 From: Elena Rastorgueva Date: Thu, 19 Mar 2026 19:45:45 +0000 Subject: [PATCH 08/40] skip pretrained ASR/LLM downloads in from_pretrained; simplify inference wrapper loading Signed-off-by: Elena Rastorgueva --- .../conf/s2s_streaming.yaml | 3 +- .../nemotron_voicechat_inference_wrapper.py | 183 +++--------------- .../speechlm2/models/duplex_stt_model.py | 30 ++- .../speechlm2/models/nemotron_voicechat.py | 53 +++++ .../collections/speechlm2/parts/pretrained.py | 36 ++++ 5 files changed, 140 insertions(+), 165 deletions(-) diff --git a/examples/speechlm2/nemo_inference_pipelines/conf/s2s_streaming.yaml b/examples/speechlm2/nemo_inference_pipelines/conf/s2s_streaming.yaml index 1fdaf5998637..9f31b11bcd25 100644 --- a/examples/speechlm2/nemo_inference_pipelines/conf/s2s_streaming.yaml +++ b/examples/speechlm2/nemo_inference_pipelines/conf/s2s_streaming.yaml @@ -27,7 +27,6 @@ pipeline_type: s2s_streaming # S2S model block s2s: model_path: ??? - llm_checkpoint_path: ??? speaker_reference: null speaker_name: null engine_type: ??? # Engine type: 'native' or 'vllm_llm' or 'vllm_eartts' or 'vllm_llm_vllm_eartts' @@ -37,7 +36,7 @@ s2s: gpu_memory_utilization: 0.35 # GPU memory utilization (0.0-1.0) dtype: bfloat16 # Data type for vLLM inference engine_path: null # Path to vLLM engine (null = auto-convert if needed) - pretrained_llm: ${s2s.llm_checkpoint_path} # Inherits from s2s.llm_checkpoint_path + pretrained_llm: ${s2s.model_path} vllm_tts_config: model_path: ${s2s.vllm_llm_config.model_path} # Inherits from s2s.model_path diff --git a/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py b/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py index 7de38b990586..4ad77f8be83c 100755 --- a/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py +++ b/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py @@ -172,10 +172,6 @@ def __init__(self, model_cfg: DictConfig): if not self.model_path: raise ValueError("`model_cfg.model_path` must be provided.") - self.llm_checkpoint_path = model_cfg.get("llm_checkpoint_path") - if not self.llm_checkpoint_path: - raise ValueError("`model_cfg.llm_checkpoint_path` must be provided.") - self.decode_audio = bool(model_cfg.get("decode_audio", True)) # Number of past codec tokens kept in the sliding-window decode buffer. # Only used when use_codec_cache=False (the fallback path). When the @@ -331,168 +327,35 @@ def _samples_per_audio_output_frame(self): samples = int(float(rate) * FRAME_SIZE_SEC) return samples - def _load_and_merge_configs(self): - """Load and merge configurations from both nano and eartts checkpoints.""" - logging.info("Loading and merging configurations...") - - # Load nano's config (for LLM, perception) - nano_config_file = os.path.join(self.llm_checkpoint_path, "config.json") - logging.info(f" Loading nano config: {nano_config_file}") - with open(nano_config_file, 'r') as f: - import json - nano_cfg_dict = json.load(f) - nano_cfg = DictConfig(nano_cfg_dict) - - # Load eartts's config (for TTS) - eartts_config_file = os.path.join(self.model_path, "config.json") - logging.info(f" Loading eartts config: {eartts_config_file}") - with open(eartts_config_file, 'r') as f: - eartts_cfg_dict = json.load(f) - eartts_cfg = DictConfig(eartts_cfg_dict) - - # Start with nano's config as base - merged_cfg = nano_cfg - - # Override TTS-related parts with eartts's config - logging.info(" Merging: Using nano's config for LLM/perception, eartts's for TTS") - if 'model' in eartts_cfg and 'speech_generation' in eartts_cfg.model: - merged_cfg.model.speech_generation = eartts_cfg.model.speech_generation - logging.info(" TTS config from eartts") - - # Set speaker reference - if 'model' not in merged_cfg: - merged_cfg.model = {} - merged_cfg.model.inference_speaker_reference = self.speaker_reference - - # Ensure data section has correct sample rates - if 'data' not in merged_cfg: - merged_cfg.data = eartts_cfg.data - - logging.info(f" Final config:") - logging.info(f" - pretrained_llm: {merged_cfg.model.stt.model.pretrained_llm}") - logging.info(f" - perception.d_model: {merged_cfg.model.stt.model.perception.modality_adapter.d_model}") - logging.info(f" - speech_generation: {'present' if 'speech_generation' in merged_cfg.model else 'missing'}") - - return merged_cfg - def _initialize_model(self): - """Initialize the NemotronVoiceChat with hybrid loading.""" - from safetensors.torch import load_file - from nemo.collections.speechlm2.parts.pretrained import set_model_dict_for_partial_init - - logging.info("Initializing model with hybrid loading strategy...") - - - # Step 1: Load and merge configs - cfg = self._load_and_merge_configs() - - # Step 2: DO NOT set pretrained_s2s_model - we'll load weights manually - cfg.model.stt.model.pretrained_s2s_model = None - cfg.model.speech_generation.model.pretrained_model = None - - # Convert to dict for model initialization - cfg_dict = OmegaConf.to_container(cfg, resolve=True) - - # Step 3: Initialize model structure + """Initialize the NemotronVoiceChat model from an HF checkpoint.""" logging.info("Initializing model structure...") start_DuplexS2S_init = time.time() - self.model = NemotronVoiceChat(cfg_dict) + + self.model = NemotronVoiceChat.from_pretrained( + self.model_path, + local_files_only=True, + ) logging.info(f"Time taken to initialize NemotronVoiceChat: {time.time() - start_DuplexS2S_init} seconds") logging.info(" Model structure initialized") - # Step 4: Load nano's checkpoint (LLM + perception) - if self.llm_checkpoint_path is not None: - logging.info("Loading LLM + perception:") - logging.info(f" Path: {self.llm_checkpoint_path}") - - nano_state_dict = load_file(os.path.join(self.llm_checkpoint_path, "model.safetensors")) - - # Filter to non-TTS weights - tts_keys = ['tts_model.', 'speech_generation.'] - - # If using vLLM for LLM, also exclude LLM weights to save memory - # vLLM will load its own copy of the LLM - if self.use_vllm_llm: - llm_keys = ['stt_model.llm.'] - exclude_keys = tts_keys + llm_keys - logging.info(f" Using vLLM - excluding LLM weights from nano checkpoint") - else: - exclude_keys = tts_keys - - nano_filtered = {k: v for k, v in nano_state_dict.items() - if not any(k.startswith(prefix) for prefix in exclude_keys)} - - logging.info(f" Loading {len(nano_filtered)} parameters (excluded: {exclude_keys})...") - - # Free the full state dict immediately to save CPU memory - del nano_state_dict - gc.collect() - - nano_filtered = set_model_dict_for_partial_init(nano_filtered, self.model.state_dict()) - missing, unexpected = self.model.load_state_dict(nano_filtered, strict=False) - - # Free filtered dict - del nano_filtered + if self.use_vllm_eartts: + # Use object.__setattr__ to bypass PyTorch's module registration + # since VllmEARTTSModel is not a torch.nn.Module + del self.model.tts_model.tts_model gc.collect() - - missing_non_excluded = [k for k in missing if not any(k.startswith(prefix) for prefix in exclude_keys)] - unexpected_non_excluded = [k for k in unexpected if not any(k.startswith(prefix) for prefix in exclude_keys)] - - if missing_non_excluded: - logging.info(f" {len(missing_non_excluded)} keys missing (might be OK)") - if unexpected_non_excluded: - logging.info(f" {len(unexpected_non_excluded)} unexpected keys") - - # Step 5: Load eartts's checkpoint (TTS only) - if self.model_path is not None: - logging.info("Loading TTS checkpoint:") - logging.info(f" Path: {self.model_path}") - - eartts_state_dict = load_file(os.path.join(self.model_path, "model.safetensors")) - - # Filter to only TTS weights - tts_keys_filter = ['tts_model.'] - eartts_tts_only = {k: v for k, v in eartts_state_dict.items() - if any(k.startswith(prefix) for prefix in tts_keys_filter)} - - logging.info(f" Loading {len(eartts_tts_only)} TTS parameters...") - - start_tts_load_state_dict = time.time() - missing, unexpected = self.model.load_state_dict(eartts_tts_only, strict=False) - logging.info(f"Time taken to load TTS state dict: {time.time() - start_tts_load_state_dict} seconds") - - missing_tts = [k for k in missing if any(k.startswith(prefix) for prefix in tts_keys_filter)] - unexpected_tts = [k for k in unexpected if any(k.startswith(prefix) for prefix in tts_keys_filter)] - - if missing_tts: - logging.info(f" {len(missing_tts)} TTS keys missing") - for mk in missing_tts: - logging.info(f" missing: {mk}") - if unexpected_tts: - logging.info(f" {len(unexpected_tts)} unexpected TTS keys") - - if self.use_vllm_eartts: - # gonna convert and load vllm eartts engine - # Use object.__setattr__ to bypass PyTorch's module registration - # since VllmEARTTSModel is not a torch.nn.Module - del self.model.tts_model.tts_model - gc.collect() - torch.cuda.empty_cache() - torch.cuda.synchronize() - object.__setattr__( - self.model.tts_model, - 'tts_model', - create_model( - model=self.model_path, - engine_type="vllm_eartts", - vllm_config=self.vllm_tts_config) - ) - from nemo.collections.speechlm2.inference.vllm.vllm_patch import patched_infer_codes_one_step - self.model.tts_model.infer_codes_one_step = types.MethodType(patched_infer_codes_one_step, self.model.tts_model) - - logging.info(f" eartts checkpoint loaded (TTS only)") - - logging.info("\nHybrid loading completed!") + torch.cuda.empty_cache() + torch.cuda.synchronize() + object.__setattr__( + self.model.tts_model, + 'tts_model', + create_model( + model=self.model_path, + engine_type="vllm_eartts", + vllm_config=self.vllm_tts_config) + ) + from nemo.collections.speechlm2.inference.vllm.vllm_patch import patched_infer_codes_one_step + self.model.tts_model.infer_codes_one_step = types.MethodType(patched_infer_codes_one_step, self.model.tts_model) # If using vLLM for LLM, delete native LLM BEFORE moving to device to save memory if self.use_vllm_llm: @@ -579,7 +442,7 @@ def _initialize_model(self): if self.use_vllm_llm: logging.info("\nWrapping model with VllmLLMModel interface...") if self.vllm_llm_config is None: - raise ValueError("vllm_llm_config must be provided when engine_type contains'vllm_llm'") + raise ValueError("vllm_llm_config must be provided when engine_type contains 'vllm_llm'") # LLM already deleted above, just ensure cleanup gc.collect() diff --git a/nemo/collections/speechlm2/models/duplex_stt_model.py b/nemo/collections/speechlm2/models/duplex_stt_model.py index f6f0919e2749..65092bf92da0 100644 --- a/nemo/collections/speechlm2/models/duplex_stt_model.py +++ b/nemo/collections/speechlm2/models/duplex_stt_model.py @@ -14,6 +14,9 @@ import copy import os import re +import warnings +from pathlib import Path +from typing import Optional, Union import torch from lightning import LightningModule @@ -45,6 +48,7 @@ from nemo.collections.speechlm2.parts.pretrained import ( load_pretrained_hf, maybe_load_pretrained_models, + resolve_pretrained_config, set_model_dict_for_partial_init, setup_speech_encoder, ) @@ -79,16 +83,18 @@ def __init__(self, cfg: dict) -> None: self.predict_user_text = self.cfg.get("predict_user_text", False) + pretrained_weights, tokenizer_path = resolve_pretrained_config(self.cfg) + # Load LLM first llm = load_pretrained_hf( self.cfg.pretrained_llm, - pretrained_weights=self.cfg.pretrained_weights, + pretrained_weights=pretrained_weights, trust_remote_code=self.cfg.get("trust_remote_code", True), ).train() # Initialize tokenizer with optional special tokens from config self.tokenizer = AutoTokenizer( - self.cfg.pretrained_llm, + tokenizer_path, use_fast=True, bos_token=self.cfg.get("bos_token", None), eos_token=self.cfg.get("eos_token", None), @@ -119,7 +125,7 @@ def __init__(self, cfg: dict) -> None: maybe_install_lora(self) # Load the pretrained streaming ASR model - setup_speech_encoder(self, pretrained_weights=self.cfg.pretrained_weights) + setup_speech_encoder(self, pretrained_weights=pretrained_weights) maybe_load_pretrained_models(self) @@ -129,6 +135,24 @@ def __init__(self, cfg: dict) -> None: # Initialize streaming inference engine self.streaming_inference = DuplexSTTStreamingInference(self) + def save_pretrained( + self, + save_directory: Union[str, Path], + **kwargs, + ) -> Optional[str]: + """Save model and also export LLM artifacts (config + tokenizer) for offline inference.""" + result = super().save_pretrained(save_directory, **kwargs) + + try: + llm_dir = Path(save_directory) / "llm_artifacts" + llm_dir.mkdir(parents=True, exist_ok=True) + self.tokenizer.tokenizer.save_pretrained(str(llm_dir)) + logging.info(f"Saved LLM tokenizer to {llm_dir}") + except Exception as e: + warnings.warn(f"Failed to save LLM tokenizer: {e}. Inference will fall back to downloading from HF.") + + return result + @property def text_vocab_size(self): """Return the size of the text tokenizer.""" diff --git a/nemo/collections/speechlm2/models/nemotron_voicechat.py b/nemo/collections/speechlm2/models/nemotron_voicechat.py index 8c849a5c1548..7631a7d22d99 100644 --- a/nemo/collections/speechlm2/models/nemotron_voicechat.py +++ b/nemo/collections/speechlm2/models/nemotron_voicechat.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import gc +import json import os +import warnings from pathlib import Path from typing import Optional, Union @@ -131,6 +133,39 @@ def __init__(self, cfg: dict) -> None: self._use_fsdp = False self._use_tp = False + def save_pretrained( + self, + save_directory: Union[str, Path], + **kwargs, + ) -> Optional[str]: + """Save model and export LLM artifacts (tokenizer + perception config) for offline inference.""" + result = super().save_pretrained(save_directory, **kwargs) + + # Save tokenizer for offline loading + try: + llm_dir = Path(save_directory) / "llm_artifacts" + llm_dir.mkdir(parents=True, exist_ok=True) + self.stt_model.tokenizer.tokenizer.save_pretrained(str(llm_dir)) + logging.info(f"Saved LLM tokenizer to {llm_dir}") + except Exception as e: + warnings.warn(f"Failed to save LLM tokenizer: {e}. Inference will fall back to downloading from HF.") + + # Save full perception config at the top level of config.json so that + # resolve_pretrained_config() can skip pretrained ASR/LLM downloads. + try: + config_path = Path(save_directory) / "config.json" + if config_path.exists(): + with open(config_path) as f: + config = json.load(f) + config["perception"] = OmegaConf.to_container(self.stt_model.cfg.perception, resolve=True) + with open(config_path, "w") as f: + json.dump(config, f, indent=2) + logging.info(f"Saved perception config to {config_path}") + except Exception as e: + warnings.warn(f"Failed to save perception config: {e}") + + return result + def init_from_model_from_ckpt(self, checkpoint_path): if checkpoint_path is not None: checkpoint_state = torch.load(checkpoint_path, weights_only=False, map_location='cpu')['state_dict'] @@ -183,6 +218,24 @@ def _from_pretrained( # Skip loading child module weights natively model_kwargs['cfg']['pretrained_weights'] = False + # Propagate pretrained_weights=False into nested configs so child + # modules skip downloading pretrained ASR, LLM, and codec models. + cfg = model_kwargs['cfg'] + try: + stt_model_cfg = cfg['model']['stt']['model'] + stt_model_cfg['pretrained_weights'] = False + if 'perception' in cfg: + stt_model_cfg['perception'] = cfg['perception'] + logging.info("Injected saved perception config into STT model config") + except (KeyError, TypeError): + logging.warning("Could not propagate pretrained_weights=False into nested STT config") + try: + tts_model_cfg = cfg['model']['speech_generation']['model'] + tts_model_cfg['pretrained_model'] = None + tts_model_cfg['pretrained_codec_model'] = None + except (KeyError, TypeError): + pass + # Instantiate the empty model skeleton model = cls(model_kwargs['cfg']) diff --git a/nemo/collections/speechlm2/parts/pretrained.py b/nemo/collections/speechlm2/parts/pretrained.py index bc158ce24c42..7aa61398218a 100644 --- a/nemo/collections/speechlm2/parts/pretrained.py +++ b/nemo/collections/speechlm2/parts/pretrained.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import json from contextlib import contextmanager from pathlib import Path from typing import Dict @@ -65,6 +66,41 @@ def load_pretrained_hf( return AutoModelForCausalLM.from_config(config, torch_dtype=dtype, trust_remote_code=trust_remote_code) +def resolve_pretrained_config(cfg): + """Resolve pretrained config when pretrained_s2s_model points to an HF checkpoint. + + When the HF checkpoint contains a config.json with perception config, this function: + - Sets pretrained_weights to False (weights will be loaded from pretrained_s2s_model) + - Loads the perception config from the HF checkpoint into cfg + - Resolves the tokenizer path to local llm_artifacts if available + + Args: + cfg: DictConfig with model configuration (modified in-place for perception config). + + Returns: + Tuple of (pretrained_weights, tokenizer_path). + """ + tokenizer_path = cfg.pretrained_llm + pretrained_weights = cfg.pretrained_weights + pretrained_s2s = cfg.get("pretrained_s2s_model", None) + if pretrained_s2s is not None: + hf_config_path = Path(pretrained_s2s) / "config.json" + if hf_config_path.exists(): + with open(hf_config_path) as f: + hf_config = json.load(f) + if "perception" in hf_config: + pretrained_weights = False + with open_dict(cfg): + cfg.perception = hf_config["perception"] + logging.info(f"Loaded perception config from {hf_config_path}, skipping pretrained downloads") + # Use local tokenizer if available + llm_artifacts_dir = Path(pretrained_s2s) / "llm_artifacts" + if llm_artifacts_dir.is_dir(): + tokenizer_path = str(llm_artifacts_dir) + logging.info(f"Using local tokenizer from {llm_artifacts_dir}") + return pretrained_weights, tokenizer_path + + @contextmanager def move_embedding(model): """Temporarily restores the embedding layer into HF LLM. Supports LoRA models.""" From 40475f0dfabb1f2a23cc6713b0216cb093b742cb Mon Sep 17 00:00:00 2001 From: Elena Rastorgueva Date: Thu, 19 Mar 2026 19:46:51 +0000 Subject: [PATCH 09/40] quickfix for parity harness regarding speaker name / reference in tts Signed-off-by: Elena Rastorgueva --- examples/speechlm2/nemotron_voicechat_parity_harness.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/examples/speechlm2/nemotron_voicechat_parity_harness.py b/examples/speechlm2/nemotron_voicechat_parity_harness.py index 582a3be8b722..db7083b3fb53 100644 --- a/examples/speechlm2/nemotron_voicechat_parity_harness.py +++ b/examples/speechlm2/nemotron_voicechat_parity_harness.py @@ -588,6 +588,11 @@ def run_parity_harness(args) -> dict[str, Any]: prompt_tokens = torch.tensor(prompt_token_ids, device=wrapper.device, dtype=torch.long).unsqueeze(0) prompt_token_lens = torch.tensor([len(prompt_token_ids)], device=wrapper.device, dtype=torch.long) + if wrapper.speaker_name is not None: + OmegaConf.update(wrapper.model.cfg, "inference_speaker_name", wrapper.speaker_name, force_add=True) + elif wrapper.speaker_reference: + OmegaConf.update(wrapper.model.cfg, "inference_speaker_reference", wrapper.speaker_reference, force_add=True) + offline = wrapper.model.offline_inference( input_signal=audio, input_signal_lens=audio_lens, From 1449183e361730b90be7575754b4483cabecce0c Mon Sep 17 00:00:00 2001 From: Elena Rastorgueva Date: Thu, 19 Mar 2026 23:01:43 +0000 Subject: [PATCH 10/40] speed up model loading: use meta device, dont get codec silence tokens which will be ignored anyway Signed-off-by: Elena Rastorgueva --- .../speechlm2/models/duplex_ear_tts.py | 9 +++- .../speechlm2/models/duplex_stt_model.py | 1 + .../speechlm2/models/nemotron_voicechat.py | 43 +++++++++++++++++-- .../collections/speechlm2/parts/pretrained.py | 11 ++++- 4 files changed, 58 insertions(+), 6 deletions(-) diff --git a/nemo/collections/speechlm2/models/duplex_ear_tts.py b/nemo/collections/speechlm2/models/duplex_ear_tts.py index e1e982f71a63..4a9aa1f49830 100644 --- a/nemo/collections/speechlm2/models/duplex_ear_tts.py +++ b/nemo/collections/speechlm2/models/duplex_ear_tts.py @@ -111,8 +111,13 @@ def __init__(self, cfg: dict) -> None: # compute samples per frame self.source_samples_per_frame = int(self.source_sample_rate * cfg.data.frame_length) - # get codec silence tokens - codec_silence_tokens = self.get_codec_silence_frame() + # get codec silence tokens (skip when codec has random weights — the + # buffer will be overwritten from the checkpoint) + if self.cfg.get('pretrained_codec_model', None) is not None: + codec_silence_tokens = self.get_codec_silence_frame() + else: + num_q = self.tts_model.config.num_quantizers + codec_silence_tokens = torch.zeros(num_q, dtype=torch.long) self.register_buffer("codec_silence_tokens", codec_silence_tokens) # cached for quicker audio decoding diff --git a/nemo/collections/speechlm2/models/duplex_stt_model.py b/nemo/collections/speechlm2/models/duplex_stt_model.py index 65092bf92da0..3e9607d15456 100644 --- a/nemo/collections/speechlm2/models/duplex_stt_model.py +++ b/nemo/collections/speechlm2/models/duplex_stt_model.py @@ -90,6 +90,7 @@ def __init__(self, cfg: dict) -> None: self.cfg.pretrained_llm, pretrained_weights=pretrained_weights, trust_remote_code=self.cfg.get("trust_remote_code", True), + use_meta_device=self.cfg.get("use_meta_device", False), ).train() # Initialize tokenizer with optional special tokens from config diff --git a/nemo/collections/speechlm2/models/nemotron_voicechat.py b/nemo/collections/speechlm2/models/nemotron_voicechat.py index 7631a7d22d99..84a3af2e96df 100644 --- a/nemo/collections/speechlm2/models/nemotron_voicechat.py +++ b/nemo/collections/speechlm2/models/nemotron_voicechat.py @@ -122,8 +122,10 @@ def __init__(self, cfg: dict) -> None: # Load Duplex TTS model self.tts_model = DuplexEARTTS(OmegaConf.to_container(self.cfg.speech_generation, resolve=True)) - # reset silence tokens to avoid inference issues - self.tts_model.codec_silence_tokens = self.tts_model.get_codec_silence_frame() + # reset silence tokens to avoid inference issues (skip when codec + # has random weights — the buffer will be loaded from checkpoint) + if self.tts_model.cfg.get('pretrained_codec_model', None) is not None: + self.tts_model.codec_silence_tokens = self.tts_model.get_codec_silence_frame() self.target_fps = self.tts_model.target_fps # compute source fps self.source_fps = self.source_sample_rate / ( @@ -224,6 +226,7 @@ def _from_pretrained( try: stt_model_cfg = cfg['model']['stt']['model'] stt_model_cfg['pretrained_weights'] = False + stt_model_cfg['use_meta_device'] = True if 'perception' in cfg: stt_model_cfg['perception'] = cfg['perception'] logging.info("Injected saved perception config into STT model config") @@ -257,12 +260,32 @@ def _from_pretrained( if resolved_weights_file is None: raise RuntimeError(f"Missing model.safetensors file for {model_id=}") - # Stream the weights safely using your custom memory-efficient loader! + # Stream the weights from safetensors ckpt_dir = os.path.dirname(resolved_weights_file) model.init_from_safetensors_ckpt(ckpt_dir) return model + def _replace_tensor(self, full_key, value): + """Replace a parameter or buffer on its parent module. + + Meta-device tensors (from torch.device('meta')) have no storage, so the + usual ``target.data.copy_(tensor)`` raises an error. Instead we walk + the module tree to find the parent and swap the entry directly in + ``module._parameters`` or ``module._buffers``. + """ + parts = full_key.split(".") + module = self + for part in parts[:-1]: + module = getattr(module, part) + name = parts[-1] + if name in module._parameters: + module._parameters[name] = torch.nn.Parameter( + value, requires_grad=module._parameters[name].requires_grad + ) + elif name in module._buffers: + module._buffers[name] = value + def init_from_safetensors_ckpt(self, ckpt_path, prefix=""): """ Memory-efficient streaming safetensors loader with dynamic @@ -304,6 +327,8 @@ def init_from_safetensors_ckpt(self, ckpt_path, prefix=""): if target.shape != tensor.shape: logging.warning(f"Shape mismatch for {key}: " f"model {target.shape} vs ckpt {tensor.shape}") + elif target.is_meta: + self._replace_tensor(prefix + key, tensor) else: target.data.copy_(tensor) @@ -316,6 +341,8 @@ def init_from_safetensors_ckpt(self, ckpt_path, prefix=""): logging.warning( f"Buffer shape mismatch for {key}: " f"model {target.shape} vs ckpt {tensor.shape}" ) + elif target.is_meta: + self._replace_tensor(prefix + key, tensor) else: target.data.copy_(tensor) @@ -334,6 +361,16 @@ def init_from_safetensors_ckpt(self, ckpt_path, prefix=""): if missing_keys: logging.warning(f"{len(missing_keys)} keys in checkpoint not found in model") + # Fail if any parameters/buffers are still on meta device — this means + # the checkpoint is missing weights the model requires. + meta_remaining = [n for n, p in self.named_parameters() if p.is_meta] + meta_remaining += [n for n, b in self.named_buffers() if b.is_meta] + if meta_remaining: + raise RuntimeError( + f"{len(meta_remaining)} tensors still on meta device after checkpoint load " + f"(missing from checkpoint): {meta_remaining[:20]}" + ) + gc.collect() def training_step(self, batch: dict, batch_idx: int): diff --git a/nemo/collections/speechlm2/parts/pretrained.py b/nemo/collections/speechlm2/parts/pretrained.py index 7aa61398218a..626ac25bebdd 100644 --- a/nemo/collections/speechlm2/parts/pretrained.py +++ b/nemo/collections/speechlm2/parts/pretrained.py @@ -43,7 +43,11 @@ def load_pretrained_nemo(cls, model_path_or_name: str): def load_pretrained_hf( - model_path_or_name: str, pretrained_weights: bool = True, dtype=torch.float32, trust_remote_code: bool = False + model_path_or_name: str, + pretrained_weights: bool = True, + dtype=torch.float32, + trust_remote_code: bool = False, + use_meta_device: bool = False, ): """ Load pretrained HuggingFace AutoModelForCausalLM. @@ -56,6 +60,8 @@ def load_pretrained_hf( pretrained_weights: Whether to load pretrained weights (True) or random init (False) dtype: Data type for the model trust_remote_code: Whether to trust remote code when loading model (needed for some models like Nemotron) + use_meta_device: If True, create the model on the meta device (no memory allocation). + The caller must handle materializing meta tensors from a checkpoint. """ if pretrained_weights: return AutoModelForCausalLM.from_pretrained( @@ -63,6 +69,9 @@ def load_pretrained_hf( ) else: config = AutoConfig.from_pretrained(model_path_or_name, trust_remote_code=trust_remote_code) + if use_meta_device: + with torch.device('meta'): + return AutoModelForCausalLM.from_config(config, torch_dtype=dtype, trust_remote_code=trust_remote_code) return AutoModelForCausalLM.from_config(config, torch_dtype=dtype, trust_remote_code=trust_remote_code) From ef06833a68d1be163595eec098195c25cfde3b13 Mon Sep 17 00:00:00 2001 From: Elena Rastorgueva Date: Fri, 20 Mar 2026 17:42:31 +0000 Subject: [PATCH 11/40] normalize indentation to 4-space Signed-off-by: Elena Rastorgueva --- .../pipelines/streaming_s2s_pipeline.py | 1527 ++++++++--------- .../streaming/state/s2s_context_manager.py | 374 ++-- .../inference/streaming/state/s2s_state.py | 246 +-- 3 files changed, 1073 insertions(+), 1074 deletions(-) diff --git a/nemo/collections/speechlm2/inference/pipelines/streaming_s2s_pipeline.py b/nemo/collections/speechlm2/inference/pipelines/streaming_s2s_pipeline.py index 1dba0d838c64..b7cbb06a34cb 100644 --- a/nemo/collections/speechlm2/inference/pipelines/streaming_s2s_pipeline.py +++ b/nemo/collections/speechlm2/inference/pipelines/streaming_s2s_pipeline.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import os import time @@ -39,767 +40,765 @@ class StreamingS2SPipeline(S2SPipelineInterface): - """ - Streaming S2S pipeline. - """ - - def __init__(self, cfg: DictConfig, s2s_model: NemotronVoicechatInferenceWrapper): - # ------------------------------------------------------------------ - # Model & device - # ------------------------------------------------------------------ - self.s2s_model = s2s_model - self.device = self.s2s_model.device - self.collect_debug = False - - # ------------------------------------------------------------------ - # Streaming configuration - # ------------------------------------------------------------------ - self.streaming_cfg = cfg.get("streaming", {}) - self.input_sample_rate = getattr(self.streaming_cfg, "input_sample_rate", 16000) - self.output_sample_rate = getattr(self.streaming_cfg, "output_sample_rate", 22050) - self.batch_size = getattr(self.streaming_cfg, "batch_size", 1) - self.max_len = getattr(self.streaming_cfg, "max_len", 200) - if self.batch_size != 1: - raise ValueError( - "StreamingS2SPipeline currently supports only single-stream inference " - "(streaming.batch_size must be 1)." - ) - - - # ------------------------------------------------------------------ - # Chunk & buffer sizes - # Terminology: "frame" = 80ms audio unit, "chunk" = 1 or more frames - # A chunk is the amount of audio that is processed per inference step. - # ------------------------------------------------------------------ - self.chunk_size_in_secs = getattr(self.streaming_cfg, "chunk_size_in_secs", 0.08) - # Check if self.chunk_size_in_secs is a multiple of 0.08. - # Because of quirks of floating point arithmetic, the remainder could be either ~0 or ~0.08, - # so we check for both cases. - remainder = self.chunk_size_in_secs % 0.08 - if not (math.isclose(remainder, 0, abs_tol=1e-9) or math.isclose(remainder, 0.08, abs_tol=1e-9)): - raise ValueError(f"Chunk size must be a multiple of 0.08s, but got {self.chunk_size_in_secs}") - - self.num_frames_per_chunk = int(self.chunk_size_in_secs / 0.08) - - # Buffer size determines how much audio is passed to the perception encoder - # Default: 5.68 seconds (71 * 0.08). This is the minimum valid buffer size without the perception cache. - # i.e. att_context_size[0] + att_context_size[1] + 1 frames = 70+0+1 = 71 frames = 5.68 seconds - self.buffer_size_in_secs = getattr(self.streaming_cfg, "buffer_size_in_secs", 71 * 0.08) - - self.att_context_size = getattr(self.streaming_cfg, "att_context_size", [70,0]) - - # ------------------------------------------------------------------ - # bufferer – reused from ASR utilities - # ------------------------------------------------------------------ - self.bufferer = BatchedAudioBufferer( - sample_rate=self.input_sample_rate, - buffer_size_in_secs=self.buffer_size_in_secs, - ) - - # ------------------------------------------------------------------ - # System prompt configuration - # ------------------------------------------------------------------ - s2s_cfg = cfg.get("s2s", {}) - self.system_prompt: Optional[str] = getattr(s2s_cfg, "system_prompt", None) - if self.system_prompt: - logging.info(f"System prompt configured: {self.system_prompt[:100]}{'...' if len(self.system_prompt) > 100 else ''}") - - # Context manager - self.context_manager = S2SContextManager( - s2s_model=self.s2s_model, - num_slots=self.batch_size, - max_len=self.max_len, - ) - - # Output directory for generated files - self.output_dir = getattr(cfg, "output_dir", "./generated") - - # Parse and validate request type early, with a safe default - req_type_cfg = getattr(self.streaming_cfg, "request_type", "frame") - - # Parse and validate the request type; only 'frame' is supported for s2s. - self.request_type = RequestType.from_str(req_type_cfg) - if self.request_type is not RequestType.FRAME: - raise ValueError(f"Request type {self.request_type} is not supported for s2s.") - - self._stream_has_prompt: bool = False - - # ------------------------------------------------------------------ - # Input audio padding (silence appended after real audio) - # ------------------------------------------------------------------ - self.pad_audio_to_sec: float | None = cfg.get("pad_audio_to_sec", None) - self.pad_silence_ratio: float | None = cfg.get("pad_silence_ratio", None) - self.pad_audio_by_sec: float | None = cfg.get("pad_audio_by_sec", None) - if sum(x is not None for x in [self.pad_audio_to_sec, self.pad_silence_ratio, self.pad_audio_by_sec]) > 1: - raise ValueError("Set at most one of: pad_audio_to_sec, pad_silence_ratio, pad_audio_by_sec") - - super().__init__() - - # -------------------------------- ---------------------------------- - # State helpers - # ------------------------------------------------------------------ - def create_state(self) -> S2SStreamingState: - """Create new empty state.""" - num_audio_codebooks = getattr(self.s2s_model.model, "_num_codebooks", 1) - dtype = getattr(self.s2s_model, "compute_dtype", torch.float32) - state = S2SStreamingState( - device=self.device, - dtype=dtype, - max_len=self.max_len, - num_audio_codebooks=num_audio_codebooks, - output_sample_rate=self.output_sample_rate, - ) - return state - - - # ------------------------------------------------------------------ - # Output helpers - # ------------------------------------------------------------------ - def log_output(self, frames: List[Frame], audio_wave: Tensor, ready_feats: List[bool], text_pieces: List[str], asr_text_pieces: List[str] = None): - """Append generated audio waveform and text to per-stream state.""" - for idx, frame in enumerate(frames): - if not ready_feats[idx]: - continue - state = self.get_or_create_state(frame.stream_id) - # audio_wave is [B, S]; take sample idx - sample_audio = audio_wave[idx:idx+1, ...] - # Determine text piece for this index - piece = None - if text_pieces and idx < len(text_pieces): - candidate = text_pieces[idx] - if isinstance(candidate, str) and candidate: - piece = candidate - - # Determine ASR text piece - asr_piece = None - if asr_text_pieces and idx < len(asr_text_pieces): - candidate = asr_text_pieces[idx] - if isinstance(candidate, str) and candidate: - asr_piece = candidate - - state.update_state(sample_audio, output_text=piece, output_asr_text=asr_piece) - - - def inner_generate_step(self, frames: List[Frame], buffers: List[Tensor], left_paddings: List[int], ready_feats: List[bool]): - """Generate speech for chunks in *batch* using a shared ContextManager.""" - if len(frames) == 0: - return - - stream_ids = [f.stream_id for f in frames] - eos_flags = [f.is_last for f in frames] - bos_flags = [f.is_first for f in frames] - - logging.debug(f"stream_ids={stream_ids} bos_flags={bos_flags} eos_flags={eos_flags}") - - if len(frames) != 1: - raise NotImplementedError("NemotronVoicechatInferenceWrapper currently supports batch_size == 1") - - # If this is the first audio frame and prefill was already done via a - # zero-length prefill frame, skip context init -- it's already set up. - # Otherwise (no system prompt), create a fresh context_manager. - has_prompt = False - if bos_flags[0]: - if self._stream_has_prompt: - logging.debug(f"Prefill already done for stream {stream_ids[0]}, skipping context init") - else: - logging.debug(f"No prefill for stream {stream_ids[0]}, creating fresh context_manager") - self.context_manager = S2SContextManager( - s2s_model=self.s2s_model, - num_slots=self.batch_size, - max_len=self.max_len, - ) - - has_prompt = self._stream_has_prompt - self._stream_has_prompt = False - - request_id = self._request_id_for_stream(stream_ids[0]) - - context, _ = self.context_manager.get_context(stream_ids) - - audio_buffer = buffers[0] - if audio_buffer.dim() == 1: - audio_buffer = audio_buffer.unsqueeze(0) - audio_buffer = audio_buffer.to(self.s2s_model.device, dtype=self.s2s_model.dtype) - - # Trim the buffer to exclude left padding (zeros at the beginning before buffer is filled) - left_pad = left_paddings[0] - if left_pad > 0: - audio_buffer = audio_buffer[:, left_pad:] - - result = self.s2s_model.infer_one_step( - audio_input=audio_buffer, - num_frames_per_chunk=self.num_frames_per_chunk, - frame_idx=context.frame_idx, - gen_text=context.gen_text, - audio_toks_buffer=context.audio_toks_buffer, - input_embeds_history=context.input_embeds_history, - dynamic_cache=context.dynamic_cache, - past_key_values=context.past_key_values, - code=context.code, - subword_mask=context.subword_mask, - gen_asr_text=context.gen_asr_text, - gen_function_text=context.gen_function_text, - request_id=request_id, - perception_cache=context.perception_cache, - has_prompt=has_prompt, - codec_cache=context.codec_cache, - cache_position_offset=context.cache_position_offset, - return_debug=self.collect_debug, - ) - - if self.collect_debug and "debug" in result: - state = self.get_or_create_state(stream_ids[0]) - if not hasattr(state, "debug_steps"): - state.debug_steps = [] - state.debug_steps.append(result["debug"]) - - # Persist updated cache & clean finished streams - self.context_manager.update_context(stream_ids, result, self.num_frames_per_chunk) - - # Save full token tensors to state before the context is destroyed, - # so we can run tokens_to_str / tokens_to_str_raw post-hoc. - for stream_id, eos_flag in zip(stream_ids, eos_flags): - if eos_flag: - ctx = self.context_manager.slot_contexts[ - self.context_manager.streamidx2slotidx[stream_id] - ] - if ctx is not None: - state = self.get_or_create_state(stream_id) - state.save_token_tensors(ctx.gen_text, ctx.gen_asr_text, ctx.frame_idx, - gen_function_text=ctx.gen_function_text) - - self.context_manager.reset_slots(stream_ids, eos_flags) - - # Explicitly clean up bufferer and state for finished streams - for stream_id, eos_flag in zip(stream_ids, eos_flags): - if eos_flag: - logging.debug(f"Ending stream {stream_id} - cleaning up bufferer and context") - self.bufferer.rm_bufferer(stream_id) - self._abort_stream_request(stream_id) - # Note: We keep the state in _state_pool until finalization to save audio - # It will be cleaned up in close_session() - - # Log audio and attach text to state - self.log_output(frames, result["decoded_audio_new"], ready_feats, result["predicted_text_strs"], result.get("asr_predicted_text_strs")) - - def prefill_for_new_stream(self, stream_id: int, system_prompt: str | None = None) -> bool: - """Prepare the pipeline for a new stream by resetting context and prefilling the system prompt. - - This is the public API for prefill-only calls (e.g. from the Triton backend) - that need to initialize TTS speaker embeddings and/or inject a system prompt - into the LLM KV cache *without* processing any audio. - - Args: - stream_id: Unique identifier for the new stream. - system_prompt: System prompt text. If *None*, falls back to - the YAML-configured ``self.system_prompt``. - - Returns: - True if a system prompt was prefilled, False otherwise. - """ - t0 = time.time() - if system_prompt is None: - system_prompt = self.system_prompt - - self.context_manager = S2SContextManager( - s2s_model=self.s2s_model, - num_slots=self.batch_size, - max_len=self.max_len, - ) - t_ctx = time.time() - - with torch.no_grad(), torch.inference_mode(): - self._prefill_system_prompt(stream_id, system_prompt) - t_prefill = time.time() - - self._stream_has_prompt = bool(system_prompt) - logging.debug(f"prefill_for_new_stream: context_manager={1000*(t_ctx-t0):.1f}ms, " - f"_prefill_system_prompt={1000*(t_prefill-t_ctx):.1f}ms, " - f"total={1000*(t_prefill-t0):.1f}ms, has_prompt={self._stream_has_prompt}") - return self._stream_has_prompt - - _WARMUP_FALLBACK_PROMPT = "Mock system prompt for warmup." - - def warmup(self, system_prompt: str | None = None) -> None: - """Run a throwaway prefill cycle to warm up the inference engine. - - The very first prefill incurs one-time overhead (e.g. CUDA graph - compilation, memory pool allocation, DynamicCache initialization). - Calling this once during startup moves that cost out of the - critical path so the first real client request is fast. - - The method performs a full prefill (TTS speaker embedding + LLM - system prompt), then aborts the request and resets all pipeline - state so the next real stream starts cleanly. - - Args: - system_prompt: Prompt text to use for warmup. Falls back to - the YAML-configured ``self.system_prompt``, then to a - short fallback string so the LLM prefill path is always - exercised. - """ - prompt = system_prompt if system_prompt is not None else self.system_prompt - if not prompt: - prompt = self._WARMUP_FALLBACK_PROMPT - logging.info(f"No system prompt configured — using fallback prompt for warmup: \"{prompt}\"") - - warmup_stream_id = -1 - - logging.info("Running pipeline warmup prefill...") - t0 = time.time() - - self.prefill_for_new_stream(warmup_stream_id, prompt) - - # Tear down the warmup request so the engine is clean for real traffic - self._abort_stream_request(warmup_stream_id) - self.context_manager.reset() - self._stream_has_prompt = False - - logging.info(f"Pipeline warmup complete in {time.time() - t0:.3f}s") - - def generate_step(self, frames: List[Frame]): - """Main streaming API similar to *transcribe_step* in recognizers. - - If the batch contains a single zero-length first frame with a system - prompt in ``options``, this is treated as a **prefill-only** request: - the context manager and system prompt are initialized but no audio - inference runs. This is the unified protocol used by both the CLI - (``run()``) and the Triton backend. - """ - # Detect prefill-only frame: is_first + zero-length audio - if (len(frames) == 1 - and frames[0].is_first - and frames[0].samples.numel() == 0): - opts = frames[0].options - prompt = None - if opts is not None and hasattr(opts, "system_prompt"): - prompt = opts.system_prompt - self.prefill_for_new_stream(frames[0].stream_id, prompt) - return - - buffers, left_paddings = self.bufferer.update(frames) - ready_feats = [True] * len(frames) - - with torch.no_grad(), torch.inference_mode(): - self.inner_generate_step(frames, buffers, left_paddings, ready_feats) - - # ------------------------------------------------------------------ - # Finalization helpers - # ------------------------------------------------------------------ - def _finalize_and_save_finished_streams( - self, - frames: List[Frame], - audio_filepaths: List[str], - saved_paths_by_stream: dict[int, str], - ) -> None: - """Finalize any streams that ended in this batch and save their audio.""" - for frame in frames: - if frame.is_last: - stream_id = frame.stream_id - state = self.get_or_create_state(stream_id) - - # Flush remaining buffered samples and assemble waveform - if hasattr(state, "finalize"): - state.finalize() - # Concatenate emitted chunks and squeeze (B=1,C=1) to mono waveform - generated_audio = torch.cat(state.speech_frames, dim=-1) - # Ensure 1D mono waveform and float32 dtype for soundfile - if generated_audio.dim() == 3 and generated_audio.size(0) == 1 and generated_audio.size(1) == 1: - generated_audio = generated_audio.squeeze(0).squeeze(0) - elif generated_audio.dim() == 2 and generated_audio.size(0) == 1: - generated_audio = generated_audio.squeeze(0) - generated_audio = generated_audio.to(torch.float32) - - # Build output paths in subdirectories under output_dir - in_path = audio_filepaths[stream_id] - base = os.path.splitext(os.path.basename(in_path))[0] - - wav_dir = os.path.join(self.output_dir, "wav") - stereo_dir = os.path.join(self.output_dir, "stereo") - txt_dir = os.path.join(self.output_dir, "txt") - os.makedirs(wav_dir, exist_ok=True) - os.makedirs(stereo_dir, exist_ok=True) - os.makedirs(txt_dir, exist_ok=True) - - out_path = os.path.join(wav_dir, f"{base}.wav") - - # Write audio to disk - if generated_audio.numel() > 0: - sf.write(out_path, generated_audio.detach().cpu().numpy(), self.output_sample_rate) - - # Also save a stereo file with input (ch0) and output (ch1) - # Load input with librosa (handles mono conversion and resampling) - input_np, _ = librosa.load(in_path, sr=self.output_sample_rate, mono=True) - input_audio = torch.from_numpy(input_np).to(torch.float32) - gen_cpu = generated_audio.detach().cpu().to(input_audio.dtype) - - # Prepend silence to output channel to account for - # the one-chunk processing delay: the server can't - # produce output until it has received a full input chunk. - delay_samples = int(self.chunk_size_in_secs * self.output_sample_rate) - silence = torch.zeros(delay_samples, dtype=gen_cpu.dtype) - gen_cpu = torch.cat([silence, gen_cpu], dim=-1) - - gen_len = int(gen_cpu.shape[-1]) - in_len = int(input_audio.shape[-1]) - max_len = max(gen_len, in_len) - if in_len < max_len: - input_audio = torch.cat([input_audio, torch.zeros(max_len - in_len, dtype=input_audio.dtype)], dim=-1) - if gen_len < max_len: - gen_cpu = torch.cat([gen_cpu, torch.zeros(max_len - gen_len, dtype=gen_cpu.dtype)], dim=-1) - stereo = torch.stack([input_audio, gen_cpu], dim=0).transpose(0, 1) - stereo_path = os.path.join(stereo_dir, f"{base}_input_output.wav") - sf.write(stereo_path, stereo.detach().cpu().numpy(), self.output_sample_rate) - - # Save accumulated text - text_out = state.get_output_text() if hasattr(state, "get_output_text") else "" - if isinstance(text_out, str): - try: - with open(os.path.join(txt_dir, f"{base}.txt"), "w", encoding="utf-8") as f: - f.write(text_out) - except Exception: - pass - - # Save accumulated ASR text - asr_text_out = state.get_output_asr_text() if hasattr(state, "get_output_asr_text") else "" - if isinstance(asr_text_out, str) and asr_text_out: - try: - with open(os.path.join(txt_dir, f"{base}_asr.txt"), "w", encoding="utf-8") as f: - f.write(asr_text_out) - except Exception: - pass - - saved_paths_by_stream[stream_id] = out_path - - # Keep state until outputs are assembled; will be cleared on close_session - - - # ------------------------------------------------------------------ - # Session helpers (extend S2SPipelineInterface) - # ------------------------------------------------------------------ - - def reset_session(self) -> None: - """Reset feature buffer and ContextManager together.""" - for stream_id in list(self.context_manager.streamidx2slotidx.keys()): - self._abort_stream_request(stream_id) - self.bufferer.reset() - self.context_manager.reset() - - super().reset_session() # clears state pool - - # ------------------------------------------------------------------ - # Orchestrator – mirrors recognizers' *run* method - # ------------------------------------------------------------------ - def run( - self, - audio_filepaths: List[str], - options: List[S2SRequestOptions] | None = None, - progress_bar: Optional[ProgressBar] = None, - ) -> PipelineOutput: - """Stream all *audio_filepaths* through the pipeline and save outputs. - - Saves one generated ``.wav`` per input under ``self.output_dir`` and - returns their paths in ``PipelineOutput.texts``. - """ - if progress_bar and not isinstance(progress_bar, ProgressBar): - raise ValueError("progress_bar must be an instance of ProgressBar.") - - if options is None: - options = [S2SRequestOptions(system_prompt=self.system_prompt) for _ in audio_filepaths] - - streamer = ContinuousBatchedFrameStreamer( - n_frames_per_stream=1, - frame_size_in_secs=self.chunk_size_in_secs, - sample_rate=self.input_sample_rate, - batch_size=self.batch_size, - pad_last_frame=True, - ) - - streamer.set_audio_filepaths(audio_filepaths, options) - streamer.set_progress_bar(progress_bar) - - # Ensure output directory exists - os.makedirs(self.output_dir, exist_ok=True) - - # Track saved paths by stream id to preserve input order - saved_paths_by_stream: dict[int, str] = {} - chunk_samples = int(self.chunk_size_in_secs * self.input_sample_rate) - - self.open_session() - for frames in streamer: - # Unified prefill protocol: if the first frame of a new stream - # carries a system prompt, emit a zero-length prefill frame first. - if (len(frames) == 1 - and frames[0].is_first - and frames[0].options is not None - and hasattr(frames[0].options, "system_prompt") - and frames[0].options.system_prompt): - prefill_frame = Frame( - samples=torch.empty(0), - stream_id=frames[0].stream_id, - is_first=True, - is_last=False, - options=frames[0].options, - ) - self.generate_step([prefill_frame]) - - # If padding is configured, intercept last frames so the - # bufferer/context stay alive for the silence-padding phase. - # Padding is generated immediately (same iteration) to avoid - # the next stream's setup destroying this stream's context. - pad_targets: dict[int, float] = {} - if self.pad_audio_to_sec or self.pad_silence_ratio or self.pad_audio_by_sec: - processed_frames = [] - for frame in frames: - if frame.is_last: - elapsed = streamer.elapsed_durations[frame.stream_id] - remaining = self._padding_remaining_secs(elapsed) - if remaining > 0: - processed_frames.append(Frame( - samples=frame.samples, - stream_id=frame.stream_id, - is_first=frame.is_first, - is_last=False, - length=frame.length, - options=frame.options, - )) - pad_targets[frame.stream_id] = remaining - continue - processed_frames.append(frame) - frames = processed_frames - - self.generate_step(frames) - self._finalize_and_save_finished_streams(frames, audio_filepaths, saved_paths_by_stream) - - # Generate silence padding before the next iteration adds a new stream - for stream_id, remaining_secs in pad_targets.items(): - num_pad_frames = max(1, round(remaining_secs / self.chunk_size_in_secs)) - for i in range(num_pad_frames): - is_last = (i == num_pad_frames - 1) - silence_frame = Frame( - samples=torch.zeros(chunk_samples), - stream_id=stream_id, - is_first=False, - is_last=is_last, - length=chunk_samples, - ) - self.generate_step([silence_frame]) - if is_last: - self._finalize_and_save_finished_streams( - [silence_frame], audio_filepaths, saved_paths_by_stream - ) - # Build outputs before closing the session - texts = [] - words = [] - asr_texts = [] - texts_with_timestamps = [] - asr_texts_with_timestamps = [] - raw_texts = [] - raw_asr_texts = [] - token_texts = [] - token_asr_texts = [] - token_function_texts = [] - token_lengths = [] - audio_paths = [] - - tokenizer = self.s2s_model.tokenizer - pad_id = self.s2s_model.model.stt_model.text_pad_id - - for idx in range(len(audio_filepaths)): - state = self.get_or_create_state(idx) - text_value = state.get_output_text() if hasattr(state, "get_output_text") else "" - if not text_value: - text_value = saved_paths_by_stream.get(idx, "") - texts.append(text_value) - audio_paths.append(saved_paths_by_stream.get(idx)) - per_stream_words = state.get_output_words() if hasattr(state, "get_output_words") else [] - words.append(per_stream_words) - asr_text_value = state.get_output_asr_text() if hasattr(state, "get_output_asr_text") else "" - asr_texts.append(asr_text_value) - - token_data = state.get_token_tensors() - if token_data is not None: - gen_text, gen_asr_text, total_frames, gen_function_text = token_data - token_texts.append(gen_text) - token_asr_texts.append(gen_asr_text) - token_function_texts.append(gen_function_text) - token_lengths.append(total_frames) - lengths = torch.tensor([total_frames], dtype=torch.long) - texts_with_timestamps.append( - tokens_to_str(gen_text, lengths, tokenizer=tokenizer, pad_id=pad_id, eval_text_turn_taking=True)[0] - ) - asr_texts_with_timestamps.append( - tokens_to_str(gen_asr_text, lengths, tokenizer=tokenizer, pad_id=pad_id, eval_text_turn_taking=True)[0] - ) - raw_texts.append( - tokens_to_str_raw(gen_text, lengths, tokenizer=tokenizer, pad_id=pad_id)[0] - ) - raw_asr_texts.append( - tokens_to_str_raw(gen_asr_text, lengths, tokenizer=tokenizer, pad_id=pad_id)[0] - ) - if gen_function_text is not None: - fc_text = tokens_to_str(gen_function_text, lengths, tokenizer=tokenizer, pad_id=pad_id, eval_text_turn_taking=False)[0] - fc_text_raw = tokens_to_str_raw(gen_function_text, lengths, tokenizer=tokenizer, pad_id=pad_id)[0] - logging.info(f"Function calling channel: {fc_text}") - else: - token_texts.append(None) - token_asr_texts.append(None) - token_function_texts.append(None) - token_lengths.append(None) - texts_with_timestamps.append("") - asr_texts_with_timestamps.append("") - raw_texts.append("") - raw_asr_texts.append("") - - debug_data = [] - if self.collect_debug: - for idx in range(len(audio_filepaths)): - state = self.get_or_create_state(idx) - debug_data.append(getattr(state, "debug_steps", [])) - - self.close_session() - - return PipelineOutput( - texts=texts, - words=words, - asr_texts=asr_texts, - texts_with_timestamps=texts_with_timestamps, - asr_texts_with_timestamps=asr_texts_with_timestamps, - raw_texts=raw_texts, - raw_asr_texts=raw_asr_texts, - token_texts=token_texts, - token_asr_texts=token_asr_texts, - token_function_texts=token_function_texts, - token_lengths=token_lengths, - audio_filepaths=audio_paths, - debug_data=debug_data if debug_data else None, - ) - - def _prefill_system_prompt(self, stream_id: int, system_prompt: str | None = None) -> Optional[torch.Tensor]: - """Prefill the system prompt for a new stream. - - This prepares the system prompt embeddings and processes them through - the LLM to update the KV cache before audio streaming begins. - Also prefills the TTS model with speaker embeddings when using vLLM EarTTS. - - Args: - stream_id: The stream identifier. - system_prompt: The system prompt text for this stream. If *None*, - TTS prefill still runs (for vLLM EarTTS) but no LLM prompt - is injected. - - Note on TTS prefill codes: - The TTS prefill generates output codes, but these should NOT be used - to initialize context.code for inference. The batch approach uses - first_tts_code_input (INPUT codes from speaker reference) instead. - Using prefill OUTPUT codes causes audio quality issues (mumbling). - - Returns: - Optional[torch.Tensor]: The TTS prefill output codes if vLLM EarTTS prefill - happened, None otherwise. These are returned for logging/debugging but - should NOT be used to update context.code. - """ - request_id = self._request_id_for_stream(stream_id) - engine_type = getattr(self.s2s_model, "engine_type", "native") - tts_output_code = None - - # Prefill TTS with speaker embedding when using vLLM EarTTS - # This initializes the vLLM TTS engine with the speaker context via prompt_token_ids - use_vllm_eartts = "vllm_eartts" in engine_type.lower() - if use_vllm_eartts: - tts_init_inputs = getattr(self.s2s_model, "tts_init_inputs", None) - tts_prompt_token_ids = getattr(self.s2s_model, "tts_prompt_token_ids", None) - if tts_init_inputs is not None and tts_prompt_token_ids is not None: - logging.info(f"Prefilling TTS speaker embedding for stream {stream_id}...") - start_tts_prefill = time.time() - with torch.no_grad(): - # Clone tts_init_inputs to avoid any tensor sharing issues - import copy - tts_inputs_copy = copy.deepcopy(tts_init_inputs) - tts_result = self.s2s_model.model.tts_model.tts_model( - tts_inputs_copy, - request_id=request_id, - prompt_token_ids=tts_prompt_token_ids - ) - # Capture the generated codes to sync context with vLLM state - if hasattr(tts_result, 'codes') and tts_result.codes is not None: - tts_output_code = tts_result.codes.detach().clone() - logging.debug(f"TTS prefill generated codes shape: {tts_output_code.shape}") - logging.info(f"TTS speaker embedding prefilled in {time.time() - start_tts_prefill:.3f}s") - else: - logging.warning("TTS init inputs not available, skipping TTS prefill") - - if not system_prompt: - return tts_output_code - - logging.info(f"Prefilling system prompt for stream {stream_id}...") - start_get_prompt_embeddings = time.time() - prompt_embedded, prompt_len = self.s2s_model._prepare_system_prompt_embeddings(system_prompt) - logging.debug(f"Time taken to get prompt embeddings: {time.time() - start_get_prompt_embeddings:.3f}s") - - if prompt_embedded is None: - logging.warning("System prompt embedding returned None, skipping prefill") - return tts_output_code - - # Check if using vLLM for LLM (matches vllm_llm, vllm_llm_vllm_eartts, etc.) - use_vllm_llm = "vllm_llm" in engine_type.lower() - - if use_vllm_llm: - # For vLLM LLM: prefill all prompt embeddings in one shot - # (decode_steps=0 triggers a single bulk prefill in the vLLM engine) - logging.info(f"Prefilling {prompt_len} prompt embeddings for vLLM LLM...") - start_prefill = time.time() - with torch.no_grad(): - _ = self.s2s_model.model_llm_interface( - prompt_embedded, - request_id=request_id, - decode_steps=0, - prompt_token_ids=None, - ) - logging.info(f"System prompt prefilled ({prompt_len} tokens) in {time.time() - start_prefill:.3f}s") - - else: - context, _ = self.context_manager.get_context([stream_id]) - if context.dynamic_cache is not None: - # Native cache mode: process prompt through LLM to update KV cache - with torch.no_grad(): - cache_pos = torch.arange(prompt_len, device=self.s2s_model.device) - llm_cache = context.dynamic_cache - ans = self.s2s_model.model_llm_interface( - prompt_embedded, - cache=llm_cache, - cache_position=cache_pos, - generated_tokens=None, - current_step=0 - ) - context.dynamic_cache = ans.get("cache", llm_cache) - context.cache_position_offset = prompt_len - logging.info(f"System prompt processed, cache updated ({prompt_len} tokens, offset={prompt_len})") - else: - for t in range(prompt_len): - context.input_embeds_history.append(prompt_embedded[:, t:t+1, :]) - logging.info(f"Added {prompt_len} prompt embeddings to input_embeds_history") - - return tts_output_code - - def _padding_remaining_secs(self, elapsed_secs: float) -> float: - """Return how many seconds of silence padding are still needed.""" - if self.pad_audio_to_sec is not None: - return max(0.0, self.pad_audio_to_sec - elapsed_secs) - if self.pad_silence_ratio is not None: - return elapsed_secs * self.pad_silence_ratio - if self.pad_audio_by_sec is not None: - return self.pad_audio_by_sec - return 0.0 - - def _request_id_for_stream(self, stream_id: int) -> str: - return str(stream_id) - - def _abort_stream_request(self, stream_id: int) -> None: - request_id = self._request_id_for_stream(stream_id) - abort_fn = getattr(self.s2s_model, "abort_request", None) - if callable(abort_fn): - try: - abort_fn(request_id) - except Exception as exc: - logging.warning(f"Failed to abort request {request_id} for stream {stream_id}: {exc}") + """ + Streaming S2S pipeline. + """ + + def __init__(self, cfg: DictConfig, s2s_model: NemotronVoicechatInferenceWrapper): + # ------------------------------------------------------------------ + # Model & device + # ------------------------------------------------------------------ + self.s2s_model = s2s_model + self.device = self.s2s_model.device + self.collect_debug = False + + # ------------------------------------------------------------------ + # Streaming configuration + # ------------------------------------------------------------------ + self.streaming_cfg = cfg.get("streaming", {}) + self.input_sample_rate = getattr(self.streaming_cfg, "input_sample_rate", 16000) + self.output_sample_rate = getattr(self.streaming_cfg, "output_sample_rate", 22050) + self.batch_size = getattr(self.streaming_cfg, "batch_size", 1) + self.max_len = getattr(self.streaming_cfg, "max_len", 200) + if self.batch_size != 1: + raise ValueError( + "StreamingS2SPipeline currently supports only single-stream inference " + "(streaming.batch_size must be 1)." + ) + + + # ------------------------------------------------------------------ + # Chunk & buffer sizes + # Terminology: "frame" = 80ms audio unit, "chunk" = 1 or more frames + # A chunk is the amount of audio that is processed per inference step. + # ------------------------------------------------------------------ + self.chunk_size_in_secs = getattr(self.streaming_cfg, "chunk_size_in_secs", 0.08) + # Check if self.chunk_size_in_secs is a multiple of 0.08. + # Because of quirks of floating point arithmetic, the remainder could be either ~0 or ~0.08, + # so we check for both cases. + remainder = self.chunk_size_in_secs % 0.08 + if not (math.isclose(remainder, 0, abs_tol=1e-9) or math.isclose(remainder, 0.08, abs_tol=1e-9)): + raise ValueError(f"Chunk size must be a multiple of 0.08s, but got {self.chunk_size_in_secs}") + + self.num_frames_per_chunk = int(self.chunk_size_in_secs / 0.08) + + # Buffer size determines how much audio is passed to the perception encoder + # Default: 5.68 seconds (71 * 0.08). This is the minimum valid buffer size without the perception cache. + # i.e. att_context_size[0] + att_context_size[1] + 1 frames = 70+0+1 = 71 frames = 5.68 seconds + self.buffer_size_in_secs = getattr(self.streaming_cfg, "buffer_size_in_secs", 71 * 0.08) + + self.att_context_size = getattr(self.streaming_cfg, "att_context_size", [70,0]) + + # ------------------------------------------------------------------ + # bufferer – reused from ASR utilities + # ------------------------------------------------------------------ + self.bufferer = BatchedAudioBufferer( + sample_rate=self.input_sample_rate, + buffer_size_in_secs=self.buffer_size_in_secs, + ) + + # ------------------------------------------------------------------ + # System prompt configuration + # ------------------------------------------------------------------ + s2s_cfg = cfg.get("s2s", {}) + self.system_prompt: Optional[str] = getattr(s2s_cfg, "system_prompt", None) + if self.system_prompt: + logging.info(f"System prompt configured: {self.system_prompt[:100]}{'...' if len(self.system_prompt) > 100 else ''}") + + # Context manager + self.context_manager = S2SContextManager( + s2s_model=self.s2s_model, + num_slots=self.batch_size, + max_len=self.max_len, + ) + + # Output directory for generated files + self.output_dir = getattr(cfg, "output_dir", "./generated") + + # Parse and validate request type early, with a safe default + req_type_cfg = getattr(self.streaming_cfg, "request_type", "frame") + + # Parse and validate the request type; only 'frame' is supported for s2s. + self.request_type = RequestType.from_str(req_type_cfg) + if self.request_type is not RequestType.FRAME: + raise ValueError(f"Request type {self.request_type} is not supported for s2s.") + + self._stream_has_prompt: bool = False + + # ------------------------------------------------------------------ + # Input audio padding (silence appended after real audio) + # ------------------------------------------------------------------ + self.pad_audio_to_sec: float | None = cfg.get("pad_audio_to_sec", None) + self.pad_silence_ratio: float | None = cfg.get("pad_silence_ratio", None) + self.pad_audio_by_sec: float | None = cfg.get("pad_audio_by_sec", None) + if sum(x is not None for x in [self.pad_audio_to_sec, self.pad_silence_ratio, self.pad_audio_by_sec]) > 1: + raise ValueError("Set at most one of: pad_audio_to_sec, pad_silence_ratio, pad_audio_by_sec") + + super().__init__() + + # -------------------------------- ---------------------------------- + # State helpers + # ------------------------------------------------------------------ + def create_state(self) -> S2SStreamingState: + """Create new empty state.""" + num_audio_codebooks = getattr(self.s2s_model.model, "_num_codebooks", 1) + dtype = getattr(self.s2s_model, "compute_dtype", torch.float32) + state = S2SStreamingState( + device=self.device, + dtype=dtype, + max_len=self.max_len, + num_audio_codebooks=num_audio_codebooks, + output_sample_rate=self.output_sample_rate, + ) + return state + + + # ------------------------------------------------------------------ + # Output helpers + # ------------------------------------------------------------------ + def log_output(self, frames: List[Frame], audio_wave: Tensor, ready_feats: List[bool], text_pieces: List[str], asr_text_pieces: List[str] = None): + """Append generated audio waveform and text to per-stream state.""" + for idx, frame in enumerate(frames): + if not ready_feats[idx]: + continue + state = self.get_or_create_state(frame.stream_id) + # audio_wave is [B, S]; take sample idx + sample_audio = audio_wave[idx:idx+1, ...] + # Determine text piece for this index + piece = None + if text_pieces and idx < len(text_pieces): + candidate = text_pieces[idx] + if isinstance(candidate, str) and candidate: + piece = candidate + + # Determine ASR text piece + asr_piece = None + if asr_text_pieces and idx < len(asr_text_pieces): + candidate = asr_text_pieces[idx] + if isinstance(candidate, str) and candidate: + asr_piece = candidate + + state.update_state(sample_audio, output_text=piece, output_asr_text=asr_piece) + + + def inner_generate_step(self, frames: List[Frame], buffers: List[Tensor], left_paddings: List[int], ready_feats: List[bool]): + """Generate speech for chunks in *batch* using a shared ContextManager.""" + if len(frames) == 0: + return + + stream_ids = [f.stream_id for f in frames] + eos_flags = [f.is_last for f in frames] + bos_flags = [f.is_first for f in frames] + + logging.debug(f"stream_ids={stream_ids} bos_flags={bos_flags} eos_flags={eos_flags}") + + if len(frames) != 1: + raise NotImplementedError("NemotronVoicechatInferenceWrapper currently supports batch_size == 1") + + # If this is the first audio frame and prefill was already done via a + # zero-length prefill frame, skip context init -- it's already set up. + # Otherwise (no system prompt), create a fresh context_manager. + has_prompt = False + if bos_flags[0]: + if self._stream_has_prompt: + logging.debug(f"Prefill already done for stream {stream_ids[0]}, skipping context init") + else: + logging.debug(f"No prefill for stream {stream_ids[0]}, creating fresh context_manager") + self.context_manager = S2SContextManager( + s2s_model=self.s2s_model, + num_slots=self.batch_size, + max_len=self.max_len, + ) + + has_prompt = self._stream_has_prompt + self._stream_has_prompt = False + + request_id = self._request_id_for_stream(stream_ids[0]) + + context, _ = self.context_manager.get_context(stream_ids) + + audio_buffer = buffers[0] + if audio_buffer.dim() == 1: + audio_buffer = audio_buffer.unsqueeze(0) + audio_buffer = audio_buffer.to(self.s2s_model.device, dtype=self.s2s_model.dtype) + + # Trim the buffer to exclude left padding (zeros at the beginning before buffer is filled) + left_pad = left_paddings[0] + if left_pad > 0: + audio_buffer = audio_buffer[:, left_pad:] + + result = self.s2s_model.infer_one_step( + audio_input=audio_buffer, + num_frames_per_chunk=self.num_frames_per_chunk, + frame_idx=context.frame_idx, + gen_text=context.gen_text, + audio_toks_buffer=context.audio_toks_buffer, + input_embeds_history=context.input_embeds_history, + dynamic_cache=context.dynamic_cache, + past_key_values=context.past_key_values, + code=context.code, + subword_mask=context.subword_mask, + gen_asr_text=context.gen_asr_text, + gen_function_text=context.gen_function_text, + request_id=request_id, + perception_cache=context.perception_cache, + has_prompt=has_prompt, + codec_cache=context.codec_cache, + cache_position_offset=context.cache_position_offset, + return_debug=self.collect_debug, + ) + + if self.collect_debug and "debug" in result: + state = self.get_or_create_state(stream_ids[0]) + if not hasattr(state, "debug_steps"): + state.debug_steps = [] + state.debug_steps.append(result["debug"]) + + # Persist updated cache & clean finished streams + self.context_manager.update_context(stream_ids, result, self.num_frames_per_chunk) + + # Save full token tensors to state before the context is destroyed, + # so we can run tokens_to_str / tokens_to_str_raw post-hoc. + for stream_id, eos_flag in zip(stream_ids, eos_flags): + if eos_flag: + ctx = self.context_manager.slot_contexts[ + self.context_manager.streamidx2slotidx[stream_id] + ] + if ctx is not None: + state = self.get_or_create_state(stream_id) + state.save_token_tensors(ctx.gen_text, ctx.gen_asr_text, ctx.frame_idx, + gen_function_text=ctx.gen_function_text) + + self.context_manager.reset_slots(stream_ids, eos_flags) + + # Explicitly clean up bufferer and state for finished streams + for stream_id, eos_flag in zip(stream_ids, eos_flags): + if eos_flag: + logging.debug(f"Ending stream {stream_id} - cleaning up bufferer and context") + self.bufferer.rm_bufferer(stream_id) + self._abort_stream_request(stream_id) + # Note: We keep the state in _state_pool until finalization to save audio + # It will be cleaned up in close_session() + + # Log audio and attach text to state + self.log_output(frames, result["decoded_audio_new"], ready_feats, result["predicted_text_strs"], result.get("asr_predicted_text_strs")) + + def prefill_for_new_stream(self, stream_id: int, system_prompt: str | None = None) -> bool: + """Prepare the pipeline for a new stream by resetting context and prefilling the system prompt. + + This is the public API for prefill-only calls (e.g. from the Triton backend) + that need to initialize TTS speaker embeddings and/or inject a system prompt + into the LLM KV cache *without* processing any audio. + + Args: + stream_id: Unique identifier for the new stream. + system_prompt: System prompt text. If *None*, falls back to + the YAML-configured ``self.system_prompt``. + + Returns: + True if a system prompt was prefilled, False otherwise. + """ + t0 = time.time() + if system_prompt is None: + system_prompt = self.system_prompt + + self.context_manager = S2SContextManager( + s2s_model=self.s2s_model, + num_slots=self.batch_size, + max_len=self.max_len, + ) + t_ctx = time.time() + + with torch.no_grad(), torch.inference_mode(): + self._prefill_system_prompt(stream_id, system_prompt) + t_prefill = time.time() + + self._stream_has_prompt = bool(system_prompt) + logging.debug(f"prefill_for_new_stream: context_manager={1000*(t_ctx-t0):.1f}ms, " + f"_prefill_system_prompt={1000*(t_prefill-t_ctx):.1f}ms, " + f"total={1000*(t_prefill-t0):.1f}ms, has_prompt={self._stream_has_prompt}") + return self._stream_has_prompt + + _WARMUP_FALLBACK_PROMPT = "Mock system prompt for warmup." + + def warmup(self, system_prompt: str | None = None) -> None: + """Run a throwaway prefill cycle to warm up the inference engine. + + The very first prefill incurs one-time overhead (e.g. CUDA graph + compilation, memory pool allocation, DynamicCache initialization). + Calling this once during startup moves that cost out of the + critical path so the first real client request is fast. + + The method performs a full prefill (TTS speaker embedding + LLM + system prompt), then aborts the request and resets all pipeline + state so the next real stream starts cleanly. + + Args: + system_prompt: Prompt text to use for warmup. Falls back to + the YAML-configured ``self.system_prompt``, then to a + short fallback string so the LLM prefill path is always + exercised. + """ + prompt = system_prompt if system_prompt is not None else self.system_prompt + if not prompt: + prompt = self._WARMUP_FALLBACK_PROMPT + logging.info(f"No system prompt configured — using fallback prompt for warmup: \"{prompt}\"") + + warmup_stream_id = -1 + + logging.info("Running pipeline warmup prefill...") + t0 = time.time() + + self.prefill_for_new_stream(warmup_stream_id, prompt) + + # Tear down the warmup request so the engine is clean for real traffic + self._abort_stream_request(warmup_stream_id) + self.context_manager.reset() + self._stream_has_prompt = False + + logging.info(f"Pipeline warmup complete in {time.time() - t0:.3f}s") + + def generate_step(self, frames: List[Frame]): + """Main streaming API similar to *transcribe_step* in recognizers. + + If the batch contains a single zero-length first frame with a system + prompt in ``options``, this is treated as a **prefill-only** request: + the context manager and system prompt are initialized but no audio + inference runs. This is the unified protocol used by both the CLI + (``run()``) and the Triton backend. + """ + # Detect prefill-only frame: is_first + zero-length audio + if (len(frames) == 1 + and frames[0].is_first + and frames[0].samples.numel() == 0): + opts = frames[0].options + prompt = None + if opts is not None and hasattr(opts, "system_prompt"): + prompt = opts.system_prompt + self.prefill_for_new_stream(frames[0].stream_id, prompt) + return + + buffers, left_paddings = self.bufferer.update(frames) + ready_feats = [True] * len(frames) + + with torch.no_grad(), torch.inference_mode(): + self.inner_generate_step(frames, buffers, left_paddings, ready_feats) + + # ------------------------------------------------------------------ + # Finalization helpers + # ------------------------------------------------------------------ + def _finalize_and_save_finished_streams( + self, + frames: List[Frame], + audio_filepaths: List[str], + saved_paths_by_stream: dict[int, str], + ) -> None: + """Finalize any streams that ended in this batch and save their audio.""" + for frame in frames: + if frame.is_last: + stream_id = frame.stream_id + state = self.get_or_create_state(stream_id) + + # Flush remaining buffered samples and assemble waveform + if hasattr(state, "finalize"): + state.finalize() + # Concatenate emitted chunks and squeeze (B=1,C=1) to mono waveform + generated_audio = torch.cat(state.speech_frames, dim=-1) + # Ensure 1D mono waveform and float32 dtype for soundfile + if generated_audio.dim() == 3 and generated_audio.size(0) == 1 and generated_audio.size(1) == 1: + generated_audio = generated_audio.squeeze(0).squeeze(0) + elif generated_audio.dim() == 2 and generated_audio.size(0) == 1: + generated_audio = generated_audio.squeeze(0) + generated_audio = generated_audio.to(torch.float32) + + # Build output paths in subdirectories under output_dir + in_path = audio_filepaths[stream_id] + base = os.path.splitext(os.path.basename(in_path))[0] + + wav_dir = os.path.join(self.output_dir, "wav") + stereo_dir = os.path.join(self.output_dir, "stereo") + txt_dir = os.path.join(self.output_dir, "txt") + os.makedirs(wav_dir, exist_ok=True) + os.makedirs(stereo_dir, exist_ok=True) + os.makedirs(txt_dir, exist_ok=True) + + out_path = os.path.join(wav_dir, f"{base}.wav") + + # Write audio to disk + if generated_audio.numel() > 0: + sf.write(out_path, generated_audio.detach().cpu().numpy(), self.output_sample_rate) + + # Also save a stereo file with input (ch0) and output (ch1) + # Load input with librosa (handles mono conversion and resampling) + input_np, _ = librosa.load(in_path, sr=self.output_sample_rate, mono=True) + input_audio = torch.from_numpy(input_np).to(torch.float32) + gen_cpu = generated_audio.detach().cpu().to(input_audio.dtype) + + # Prepend silence to output channel to account for + # the one-chunk processing delay: the server can't + # produce output until it has received a full input chunk. + delay_samples = int(self.chunk_size_in_secs * self.output_sample_rate) + silence = torch.zeros(delay_samples, dtype=gen_cpu.dtype) + gen_cpu = torch.cat([silence, gen_cpu], dim=-1) + + gen_len = int(gen_cpu.shape[-1]) + in_len = int(input_audio.shape[-1]) + max_len = max(gen_len, in_len) + if in_len < max_len: + input_audio = torch.cat([input_audio, torch.zeros(max_len - in_len, dtype=input_audio.dtype)], dim=-1) + if gen_len < max_len: + gen_cpu = torch.cat([gen_cpu, torch.zeros(max_len - gen_len, dtype=gen_cpu.dtype)], dim=-1) + stereo = torch.stack([input_audio, gen_cpu], dim=0).transpose(0, 1) + stereo_path = os.path.join(stereo_dir, f"{base}_input_output.wav") + sf.write(stereo_path, stereo.detach().cpu().numpy(), self.output_sample_rate) + + # Save accumulated text + text_out = state.get_output_text() if hasattr(state, "get_output_text") else "" + if isinstance(text_out, str): + try: + with open(os.path.join(txt_dir, f"{base}.txt"), "w", encoding="utf-8") as f: + f.write(text_out) + except Exception: + pass + + # Save accumulated ASR text + asr_text_out = state.get_output_asr_text() if hasattr(state, "get_output_asr_text") else "" + if isinstance(asr_text_out, str) and asr_text_out: + try: + with open(os.path.join(txt_dir, f"{base}_asr.txt"), "w", encoding="utf-8") as f: + f.write(asr_text_out) + except Exception: + pass + + saved_paths_by_stream[stream_id] = out_path + + # Keep state until outputs are assembled; will be cleared on close_session + + + # ------------------------------------------------------------------ + # Session helpers (extend S2SPipelineInterface) + # ------------------------------------------------------------------ + + def reset_session(self) -> None: + """Reset feature buffer and ContextManager together.""" + for stream_id in list(self.context_manager.streamidx2slotidx.keys()): + self._abort_stream_request(stream_id) + self.bufferer.reset() + self.context_manager.reset() + + super().reset_session() # clears state pool + + # ------------------------------------------------------------------ + # Orchestrator – mirrors recognizers' *run* method + # ------------------------------------------------------------------ + def run( + self, + audio_filepaths: List[str], + options: List[S2SRequestOptions] | None = None, + progress_bar: Optional[ProgressBar] = None, + ) -> PipelineOutput: + """Stream all *audio_filepaths* through the pipeline and save outputs. + + Saves one generated ``.wav`` per input under ``self.output_dir`` and + returns their paths in ``PipelineOutput.texts``. + """ + if progress_bar and not isinstance(progress_bar, ProgressBar): + raise ValueError("progress_bar must be an instance of ProgressBar.") + + if options is None: + options = [S2SRequestOptions(system_prompt=self.system_prompt) for _ in audio_filepaths] + + streamer = ContinuousBatchedFrameStreamer( + n_frames_per_stream=1, + frame_size_in_secs=self.chunk_size_in_secs, + sample_rate=self.input_sample_rate, + batch_size=self.batch_size, + pad_last_frame=True, + ) + + streamer.set_audio_filepaths(audio_filepaths, options) + streamer.set_progress_bar(progress_bar) + + # Ensure output directory exists + os.makedirs(self.output_dir, exist_ok=True) + + # Track saved paths by stream id to preserve input order + saved_paths_by_stream: dict[int, str] = {} + chunk_samples = int(self.chunk_size_in_secs * self.input_sample_rate) + + self.open_session() + for frames in streamer: + # Unified prefill protocol: if the first frame of a new stream + # carries a system prompt, emit a zero-length prefill frame first. + if (len(frames) == 1 + and frames[0].is_first + and frames[0].options is not None + and hasattr(frames[0].options, "system_prompt") + and frames[0].options.system_prompt): + prefill_frame = Frame( + samples=torch.empty(0), + stream_id=frames[0].stream_id, + is_first=True, + is_last=False, + options=frames[0].options, + ) + self.generate_step([prefill_frame]) + + # If padding is configured, intercept last frames so the + # bufferer/context stay alive for the silence-padding phase. + # Padding is generated immediately (same iteration) to avoid + # the next stream's setup destroying this stream's context. + pad_targets: dict[int, float] = {} + if self.pad_audio_to_sec or self.pad_silence_ratio or self.pad_audio_by_sec: + processed_frames = [] + for frame in frames: + if frame.is_last: + elapsed = streamer.elapsed_durations[frame.stream_id] + remaining = self._padding_remaining_secs(elapsed) + if remaining > 0: + processed_frames.append(Frame( + samples=frame.samples, + stream_id=frame.stream_id, + is_first=frame.is_first, + is_last=False, + length=frame.length, + options=frame.options, + )) + pad_targets[frame.stream_id] = remaining + continue + processed_frames.append(frame) + frames = processed_frames + + self.generate_step(frames) + self._finalize_and_save_finished_streams(frames, audio_filepaths, saved_paths_by_stream) + + # Generate silence padding before the next iteration adds a new stream + for stream_id, remaining_secs in pad_targets.items(): + num_pad_frames = max(1, round(remaining_secs / self.chunk_size_in_secs)) + for i in range(num_pad_frames): + is_last = (i == num_pad_frames - 1) + silence_frame = Frame( + samples=torch.zeros(chunk_samples), + stream_id=stream_id, + is_first=False, + is_last=is_last, + length=chunk_samples, + ) + self.generate_step([silence_frame]) + if is_last: + self._finalize_and_save_finished_streams( + [silence_frame], audio_filepaths, saved_paths_by_stream + ) + # Build outputs before closing the session + texts = [] + words = [] + asr_texts = [] + texts_with_timestamps = [] + asr_texts_with_timestamps = [] + raw_texts = [] + raw_asr_texts = [] + token_texts = [] + token_asr_texts = [] + token_function_texts = [] + token_lengths = [] + audio_paths = [] + + tokenizer = self.s2s_model.tokenizer + pad_id = self.s2s_model.model.stt_model.text_pad_id + + for idx in range(len(audio_filepaths)): + state = self.get_or_create_state(idx) + text_value = state.get_output_text() if hasattr(state, "get_output_text") else "" + if not text_value: + text_value = saved_paths_by_stream.get(idx, "") + texts.append(text_value) + audio_paths.append(saved_paths_by_stream.get(idx)) + per_stream_words = state.get_output_words() if hasattr(state, "get_output_words") else [] + words.append(per_stream_words) + asr_text_value = state.get_output_asr_text() if hasattr(state, "get_output_asr_text") else "" + asr_texts.append(asr_text_value) + + token_data = state.get_token_tensors() + if token_data is not None: + gen_text, gen_asr_text, total_frames, gen_function_text = token_data + token_texts.append(gen_text) + token_asr_texts.append(gen_asr_text) + token_function_texts.append(gen_function_text) + token_lengths.append(total_frames) + lengths = torch.tensor([total_frames], dtype=torch.long) + texts_with_timestamps.append( + tokens_to_str(gen_text, lengths, tokenizer=tokenizer, pad_id=pad_id, eval_text_turn_taking=True)[0] + ) + asr_texts_with_timestamps.append( + tokens_to_str(gen_asr_text, lengths, tokenizer=tokenizer, pad_id=pad_id, eval_text_turn_taking=True)[0] + ) + raw_texts.append( + tokens_to_str_raw(gen_text, lengths, tokenizer=tokenizer, pad_id=pad_id)[0] + ) + raw_asr_texts.append( + tokens_to_str_raw(gen_asr_text, lengths, tokenizer=tokenizer, pad_id=pad_id)[0] + ) + if gen_function_text is not None: + fc_text = tokens_to_str(gen_function_text, lengths, tokenizer=tokenizer, pad_id=pad_id, eval_text_turn_taking=False)[0] + fc_text_raw = tokens_to_str_raw(gen_function_text, lengths, tokenizer=tokenizer, pad_id=pad_id)[0] + logging.info(f"Function calling channel: {fc_text}") + else: + token_texts.append(None) + token_asr_texts.append(None) + token_function_texts.append(None) + token_lengths.append(None) + texts_with_timestamps.append("") + asr_texts_with_timestamps.append("") + raw_texts.append("") + raw_asr_texts.append("") + + debug_data = [] + if self.collect_debug: + for idx in range(len(audio_filepaths)): + state = self.get_or_create_state(idx) + debug_data.append(getattr(state, "debug_steps", [])) + + self.close_session() + + return PipelineOutput( + texts=texts, + words=words, + asr_texts=asr_texts, + texts_with_timestamps=texts_with_timestamps, + asr_texts_with_timestamps=asr_texts_with_timestamps, + raw_texts=raw_texts, + raw_asr_texts=raw_asr_texts, + token_texts=token_texts, + token_asr_texts=token_asr_texts, + token_function_texts=token_function_texts, + token_lengths=token_lengths, + audio_filepaths=audio_paths, + debug_data=debug_data if debug_data else None, + ) + + def _prefill_system_prompt(self, stream_id: int, system_prompt: str | None = None) -> Optional[torch.Tensor]: + """Prefill the system prompt for a new stream. + + This prepares the system prompt embeddings and processes them through + the LLM to update the KV cache before audio streaming begins. + Also prefills the TTS model with speaker embeddings when using vLLM EarTTS. + + Args: + stream_id: The stream identifier. + system_prompt: The system prompt text for this stream. If *None*, + TTS prefill still runs (for vLLM EarTTS) but no LLM prompt + is injected. + + Note on TTS prefill codes: + The TTS prefill generates output codes, but these should NOT be used + to initialize context.code for inference. The batch approach uses + first_tts_code_input (INPUT codes from speaker reference) instead. + Using prefill OUTPUT codes causes audio quality issues (mumbling). + + Returns: + Optional[torch.Tensor]: The TTS prefill output codes if vLLM EarTTS prefill + happened, None otherwise. These are returned for logging/debugging but + should NOT be used to update context.code. + """ + request_id = self._request_id_for_stream(stream_id) + engine_type = getattr(self.s2s_model, "engine_type", "native") + tts_output_code = None + + # Prefill TTS with speaker embedding when using vLLM EarTTS + # This initializes the vLLM TTS engine with the speaker context via prompt_token_ids + use_vllm_eartts = "vllm_eartts" in engine_type.lower() + if use_vllm_eartts: + tts_init_inputs = getattr(self.s2s_model, "tts_init_inputs", None) + tts_prompt_token_ids = getattr(self.s2s_model, "tts_prompt_token_ids", None) + if tts_init_inputs is not None and tts_prompt_token_ids is not None: + logging.info(f"Prefilling TTS speaker embedding for stream {stream_id}...") + start_tts_prefill = time.time() + with torch.no_grad(): + tts_inputs_copy = copy.deepcopy(tts_init_inputs) + tts_result = self.s2s_model.model.tts_model.tts_model( + tts_inputs_copy, + request_id=request_id, + prompt_token_ids=tts_prompt_token_ids + ) + # Capture the generated codes to sync context with vLLM state + if hasattr(tts_result, 'codes') and tts_result.codes is not None: + tts_output_code = tts_result.codes.detach().clone() + logging.debug(f"TTS prefill generated codes shape: {tts_output_code.shape}") + logging.info(f"TTS speaker embedding prefilled in {time.time() - start_tts_prefill:.3f}s") + else: + logging.warning("TTS init inputs not available, skipping TTS prefill") + + if not system_prompt: + return tts_output_code + + logging.info(f"Prefilling system prompt for stream {stream_id}...") + start_get_prompt_embeddings = time.time() + prompt_embedded, prompt_len = self.s2s_model._prepare_system_prompt_embeddings(system_prompt) + logging.debug(f"Time taken to get prompt embeddings: {time.time() - start_get_prompt_embeddings:.3f}s") + + if prompt_embedded is None: + logging.warning("System prompt embedding returned None, skipping prefill") + return tts_output_code + + # Check if using vLLM for LLM (matches vllm_llm, vllm_llm_vllm_eartts, etc.) + use_vllm_llm = "vllm_llm" in engine_type.lower() + + if use_vllm_llm: + # For vLLM LLM: prefill all prompt embeddings in one shot + # (decode_steps=0 triggers a single bulk prefill in the vLLM engine) + logging.info(f"Prefilling {prompt_len} prompt embeddings for vLLM LLM...") + start_prefill = time.time() + with torch.no_grad(): + _ = self.s2s_model.model_llm_interface( + prompt_embedded, + request_id=request_id, + decode_steps=0, + prompt_token_ids=None, + ) + logging.info(f"System prompt prefilled ({prompt_len} tokens) in {time.time() - start_prefill:.3f}s") + + else: + context, _ = self.context_manager.get_context([stream_id]) + if context.dynamic_cache is not None: + # Native cache mode: process prompt through LLM to update KV cache + with torch.no_grad(): + cache_pos = torch.arange(prompt_len, device=self.s2s_model.device) + llm_cache = context.dynamic_cache + ans = self.s2s_model.model_llm_interface( + prompt_embedded, + cache=llm_cache, + cache_position=cache_pos, + generated_tokens=None, + current_step=0 + ) + context.dynamic_cache = ans.get("cache", llm_cache) + context.cache_position_offset = prompt_len + logging.info(f"System prompt processed, cache updated ({prompt_len} tokens, offset={prompt_len})") + else: + for t in range(prompt_len): + context.input_embeds_history.append(prompt_embedded[:, t:t+1, :]) + logging.info(f"Added {prompt_len} prompt embeddings to input_embeds_history") + + return tts_output_code + + def _padding_remaining_secs(self, elapsed_secs: float) -> float: + """Return how many seconds of silence padding are still needed.""" + if self.pad_audio_to_sec is not None: + return max(0.0, self.pad_audio_to_sec - elapsed_secs) + if self.pad_silence_ratio is not None: + return elapsed_secs * self.pad_silence_ratio + if self.pad_audio_by_sec is not None: + return self.pad_audio_by_sec + return 0.0 + + def _request_id_for_stream(self, stream_id: int) -> str: + return str(stream_id) + + def _abort_stream_request(self, stream_id: int) -> None: + request_id = self._request_id_for_stream(stream_id) + abort_fn = getattr(self.s2s_model, "abort_request", None) + if callable(abort_fn): + try: + abort_fn(request_id) + except Exception as exc: + logging.warning(f"Failed to abort request {request_id} for stream {stream_id}: {exc}") diff --git a/nemo/collections/speechlm2/inference/streaming/state/s2s_context_manager.py b/nemo/collections/speechlm2/inference/streaming/state/s2s_context_manager.py index da829e8c634d..0cf1ce8633a7 100644 --- a/nemo/collections/speechlm2/inference/streaming/state/s2s_context_manager.py +++ b/nemo/collections/speechlm2/inference/streaming/state/s2s_context_manager.py @@ -25,194 +25,194 @@ @dataclass class StreamingDecodeState: - frame_idx: int - gen_text: torch.Tensor - gen_asr_text: torch.Tensor - gen_function_text: Optional[torch.Tensor] - audio_toks_buffer: Optional[torch.Tensor] - input_embeds_history: List[torch.Tensor] - dynamic_cache: Any # DynamicCache or HybridMambaAttentionDynamicCache - past_key_values: Any - code: Optional[torch.Tensor] - subword_mask: Optional[torch.Tensor] - perception_cache: Optional["PerceptionCacheState"] = None - codec_cache: Any = None - cache_position_offset: int = 0 + frame_idx: int + gen_text: torch.Tensor + gen_asr_text: torch.Tensor + gen_function_text: Optional[torch.Tensor] + audio_toks_buffer: Optional[torch.Tensor] + input_embeds_history: List[torch.Tensor] + dynamic_cache: Any # DynamicCache or HybridMambaAttentionDynamicCache + past_key_values: Any + code: Optional[torch.Tensor] + subword_mask: Optional[torch.Tensor] + perception_cache: Optional["PerceptionCacheState"] = None + codec_cache: Any = None + cache_position_offset: int = 0 class S2SContextManager: - def __init__( - self, - s2s_model, - num_slots: int, - max_len: int, - ): - self.s2s_model = s2s_model - self.num_slots = num_slots - - self.max_len = max_len - self.device = getattr(self.s2s_model, "device", torch.device("cpu")) - self.dtype = getattr(self.s2s_model, "dtype", torch.float32) - - self.reset() - - def reset(self) -> None: - """Reset all bookkeeping for a new streaming session.""" - self.streamidx2slotidx: Dict[int, int] = {} - self.slotidx2streamidx: Dict[int, int] = {} - self.free_slots = Queue(self.num_slots) - for i in range(self.num_slots): - self.free_slots.put(i) - self.slot_contexts: List[Optional[StreamingDecodeState]] = [None] * self.num_slots - - def _create_context(self) -> StreamingDecodeState: - """Allocate a fresh context backed by the realtime inference model.""" - if not hasattr(self.s2s_model, "create_decode_state"): - raise RuntimeError("s2s_model must provide create_decode_state(max_len)") - decode_state = self.s2s_model.create_decode_state(self.max_len) - return StreamingDecodeState( - frame_idx=decode_state["frame_idx"], - gen_text=decode_state["gen_text"], - gen_asr_text=decode_state["gen_asr_text"], - gen_function_text=decode_state["gen_function_text"], - audio_toks_buffer=decode_state["audio_toks_buffer"], - input_embeds_history=decode_state["input_embeds_history"], - dynamic_cache=decode_state["dynamic_cache"], - past_key_values=decode_state["past_key_values"], - code=decode_state["code"], - subword_mask=decode_state["subword_mask"], - perception_cache=decode_state["perception_cache"], - codec_cache=decode_state["codec_cache"], - cache_position_offset=decode_state["cache_position_offset"], - ) - - def _ensure_slot(self, stream_id: int) -> int: - if stream_id not in self.streamidx2slotidx: - if self.free_slots.empty(): - # Emergency cleanup: force-release all slots for a fresh start - # This handles cases where previous streams didn't end properly - # (e.g., exceptions, client disconnects, missing is_last=True) - logging.warning(f"No free slots available - forcing cleanup of all {self.num_slots} slots") - orphaned_streams = list(self.slotidx2streamidx.values()) - if orphaned_streams: - logging.warning(f"Orphaned streams being cleaned up: {orphaned_streams}") - for slot_idx in range(self.num_slots): - self.reset_slot(slot_idx) - slot_idx = self.free_slots.get() - # Ensure the slot is completely clean before assigning to new stream - if self.slot_contexts[slot_idx] is not None: - logging.warning(f"Slot {slot_idx} was not properly cleaned. Forcing cleanup.") - self.slot_contexts[slot_idx] = None - self.streamidx2slotidx[stream_id] = slot_idx - self.slotidx2streamidx[slot_idx] = stream_id - return self.streamidx2slotidx[stream_id] - - def reset_slot(self, slot_idx: int) -> None: - """Release a slot back to the pool.""" - if slot_idx < 0 or slot_idx >= self.num_slots: - return - # Set to None to break reference and allow garbage collection - self.slot_contexts[slot_idx] = None - stream_id = self.slotidx2streamidx.get(slot_idx) - if stream_id is not None: - del self.slotidx2streamidx[slot_idx] - del self.streamidx2slotidx[stream_id] - self.free_slots.put(slot_idx) - - def update_context( - self, - stream_ids: List[int], - step_result: Dict[str, Any], - num_frames: int, - ) -> None: - """Persist model outputs back into the cached context.""" - if len(stream_ids) == 0: - return - if len(stream_ids) != 1: - raise NotImplementedError("update_context currently supports batch_size == 1") - - stream_id = stream_ids[0] - slot_idx = self.streamidx2slotidx.get(stream_id) - if slot_idx is None: - raise RuntimeError(f"Stream {stream_id} is not registered in the context manager") - - context = self.slot_contexts[slot_idx] - if context is None: - context = self._create_context() - self.slot_contexts[slot_idx] = context - - start_idx = context.frame_idx - end_idx = start_idx + num_frames - if end_idx > context.gen_text.shape[1]: - raise RuntimeError( - "Context maximum length exceeded. Consider increasing `streaming.max_len` in the configuration." - ) - - predicted_tokens = step_result.get("predicted_text_tokens") - if predicted_tokens is not None: - if predicted_tokens.dim() == 1: - token_slice = predicted_tokens.unsqueeze(0) - else: - token_slice = predicted_tokens[0:1] - context.gen_text[:, start_idx:end_idx] = token_slice.to(context.gen_text.device) - - asr_predicted_tokens = step_result.get("asr_predicted_text_tokens") - if asr_predicted_tokens is not None: - if asr_predicted_tokens.dim() == 1: - asr_token_slice = asr_predicted_tokens.unsqueeze(0) - else: - asr_token_slice = asr_predicted_tokens[0:1] - context.gen_asr_text[:, start_idx:end_idx] = asr_token_slice.to(context.gen_asr_text.device) - - func_predicted_tokens = step_result.get("function_predicted_text_tokens") - if func_predicted_tokens is not None and context.gen_function_text is not None: - if func_predicted_tokens.dim() == 1: - func_token_slice = func_predicted_tokens.unsqueeze(0) - else: - func_token_slice = func_predicted_tokens[0:1] - context.gen_function_text[:, start_idx:end_idx] = func_token_slice.to(context.gen_function_text.device) - - context.frame_idx = end_idx - - if step_result.get("dynamic_cache") is not None: - context.dynamic_cache = step_result["dynamic_cache"] - if "audio_toks_buffer" in step_result: - context.audio_toks_buffer = step_result["audio_toks_buffer"] - if "input_embeds_history" in step_result: - context.input_embeds_history = step_result["input_embeds_history"] - if "past_key_values" in step_result: - context.past_key_values = step_result["past_key_values"] - if "code" in step_result: - context.code = step_result["code"] - if context.subword_mask is not None: - context.subword_mask[:, start_idx:end_idx] = True - if "perception_cache" in step_result and step_result["perception_cache"] is not None: - context.perception_cache = step_result["perception_cache"] - if "codec_cache" in step_result and step_result["codec_cache"] is not None: - context.codec_cache = step_result["codec_cache"] - if "cache_position_offset" in step_result: - context.cache_position_offset = step_result["cache_position_offset"] - - def reset_slots(self, stream_ids: List[int], eos_flags: List[bool]) -> None: - """Release contexts for streams that signalled end-of-stream.""" - if len(stream_ids) != len(eos_flags): - raise ValueError("stream_ids and eos_flags must have the same length") - for stream_id, eos_flag in zip(stream_ids, eos_flags): - if eos_flag and stream_id in self.streamidx2slotidx: - self.reset_slot(self.streamidx2slotidx[stream_id]) - - def get_context(self, stream_ids: List[int]) -> Tuple[StreamingDecodeState, Dict[int, int]]: - """Return the cached context associated with the provided stream ids.""" - if len(stream_ids) == 0: - return self._create_context(), {} - if len(stream_ids) != 1: - raise NotImplementedError("get_context currently supports batch_size == 1") - - stream_id = stream_ids[0] - slot_idx = self._ensure_slot(stream_id) - - if self.slot_contexts[slot_idx] is None: - self.slot_contexts[slot_idx] = self._create_context() - - return self.slot_contexts[slot_idx], {slot_idx: 0} + def __init__( + self, + s2s_model, + num_slots: int, + max_len: int, + ): + self.s2s_model = s2s_model + self.num_slots = num_slots + + self.max_len = max_len + self.device = getattr(self.s2s_model, "device", torch.device("cpu")) + self.dtype = getattr(self.s2s_model, "dtype", torch.float32) + + self.reset() + + def reset(self) -> None: + """Reset all bookkeeping for a new streaming session.""" + self.streamidx2slotidx: Dict[int, int] = {} + self.slotidx2streamidx: Dict[int, int] = {} + self.free_slots = Queue(self.num_slots) + for i in range(self.num_slots): + self.free_slots.put(i) + self.slot_contexts: List[Optional[StreamingDecodeState]] = [None] * self.num_slots + + def _create_context(self) -> StreamingDecodeState: + """Allocate a fresh context backed by the realtime inference model.""" + if not hasattr(self.s2s_model, "create_decode_state"): + raise RuntimeError("s2s_model must provide create_decode_state(max_len)") + decode_state = self.s2s_model.create_decode_state(self.max_len) + return StreamingDecodeState( + frame_idx=decode_state["frame_idx"], + gen_text=decode_state["gen_text"], + gen_asr_text=decode_state["gen_asr_text"], + gen_function_text=decode_state["gen_function_text"], + audio_toks_buffer=decode_state["audio_toks_buffer"], + input_embeds_history=decode_state["input_embeds_history"], + dynamic_cache=decode_state["dynamic_cache"], + past_key_values=decode_state["past_key_values"], + code=decode_state["code"], + subword_mask=decode_state["subword_mask"], + perception_cache=decode_state["perception_cache"], + codec_cache=decode_state["codec_cache"], + cache_position_offset=decode_state["cache_position_offset"], + ) + + def _ensure_slot(self, stream_id: int) -> int: + if stream_id not in self.streamidx2slotidx: + if self.free_slots.empty(): + # Emergency cleanup: force-release all slots for a fresh start + # This handles cases where previous streams didn't end properly + # (e.g., exceptions, client disconnects, missing is_last=True) + logging.warning(f"No free slots available - forcing cleanup of all {self.num_slots} slots") + orphaned_streams = list(self.slotidx2streamidx.values()) + if orphaned_streams: + logging.warning(f"Orphaned streams being cleaned up: {orphaned_streams}") + for slot_idx in range(self.num_slots): + self.reset_slot(slot_idx) + slot_idx = self.free_slots.get() + # Ensure the slot is completely clean before assigning to new stream + if self.slot_contexts[slot_idx] is not None: + logging.warning(f"Slot {slot_idx} was not properly cleaned. Forcing cleanup.") + self.slot_contexts[slot_idx] = None + self.streamidx2slotidx[stream_id] = slot_idx + self.slotidx2streamidx[slot_idx] = stream_id + return self.streamidx2slotidx[stream_id] + + def reset_slot(self, slot_idx: int) -> None: + """Release a slot back to the pool.""" + if slot_idx < 0 or slot_idx >= self.num_slots: + return + # Set to None to break reference and allow garbage collection + self.slot_contexts[slot_idx] = None + stream_id = self.slotidx2streamidx.get(slot_idx) + if stream_id is not None: + del self.slotidx2streamidx[slot_idx] + del self.streamidx2slotidx[stream_id] + self.free_slots.put(slot_idx) + + def update_context( + self, + stream_ids: List[int], + step_result: Dict[str, Any], + num_frames: int, + ) -> None: + """Persist model outputs back into the cached context.""" + if len(stream_ids) == 0: + return + if len(stream_ids) != 1: + raise NotImplementedError("update_context currently supports batch_size == 1") + + stream_id = stream_ids[0] + slot_idx = self.streamidx2slotidx.get(stream_id) + if slot_idx is None: + raise RuntimeError(f"Stream {stream_id} is not registered in the context manager") + + context = self.slot_contexts[slot_idx] + if context is None: + context = self._create_context() + self.slot_contexts[slot_idx] = context + + start_idx = context.frame_idx + end_idx = start_idx + num_frames + if end_idx > context.gen_text.shape[1]: + raise RuntimeError( + "Context maximum length exceeded. Consider increasing `streaming.max_len` in the configuration." + ) + + predicted_tokens = step_result.get("predicted_text_tokens") + if predicted_tokens is not None: + if predicted_tokens.dim() == 1: + token_slice = predicted_tokens.unsqueeze(0) + else: + token_slice = predicted_tokens[0:1] + context.gen_text[:, start_idx:end_idx] = token_slice.to(context.gen_text.device) + + asr_predicted_tokens = step_result.get("asr_predicted_text_tokens") + if asr_predicted_tokens is not None: + if asr_predicted_tokens.dim() == 1: + asr_token_slice = asr_predicted_tokens.unsqueeze(0) + else: + asr_token_slice = asr_predicted_tokens[0:1] + context.gen_asr_text[:, start_idx:end_idx] = asr_token_slice.to(context.gen_asr_text.device) + + func_predicted_tokens = step_result.get("function_predicted_text_tokens") + if func_predicted_tokens is not None and context.gen_function_text is not None: + if func_predicted_tokens.dim() == 1: + func_token_slice = func_predicted_tokens.unsqueeze(0) + else: + func_token_slice = func_predicted_tokens[0:1] + context.gen_function_text[:, start_idx:end_idx] = func_token_slice.to(context.gen_function_text.device) + + context.frame_idx = end_idx + + if step_result.get("dynamic_cache") is not None: + context.dynamic_cache = step_result["dynamic_cache"] + if "audio_toks_buffer" in step_result: + context.audio_toks_buffer = step_result["audio_toks_buffer"] + if "input_embeds_history" in step_result: + context.input_embeds_history = step_result["input_embeds_history"] + if "past_key_values" in step_result: + context.past_key_values = step_result["past_key_values"] + if "code" in step_result: + context.code = step_result["code"] + if context.subword_mask is not None: + context.subword_mask[:, start_idx:end_idx] = True + if "perception_cache" in step_result and step_result["perception_cache"] is not None: + context.perception_cache = step_result["perception_cache"] + if "codec_cache" in step_result and step_result["codec_cache"] is not None: + context.codec_cache = step_result["codec_cache"] + if "cache_position_offset" in step_result: + context.cache_position_offset = step_result["cache_position_offset"] + + def reset_slots(self, stream_ids: List[int], eos_flags: List[bool]) -> None: + """Release contexts for streams that signalled end-of-stream.""" + if len(stream_ids) != len(eos_flags): + raise ValueError("stream_ids and eos_flags must have the same length") + for stream_id, eos_flag in zip(stream_ids, eos_flags): + if eos_flag and stream_id in self.streamidx2slotidx: + self.reset_slot(self.streamidx2slotidx[stream_id]) + + def get_context(self, stream_ids: List[int]) -> Tuple[StreamingDecodeState, Dict[int, int]]: + """Return the cached context associated with the provided stream ids.""" + if len(stream_ids) == 0: + return self._create_context(), {} + if len(stream_ids) != 1: + raise NotImplementedError("get_context currently supports batch_size == 1") + + stream_id = stream_ids[0] + slot_idx = self._ensure_slot(stream_id) + + if self.slot_contexts[slot_idx] is None: + self.slot_contexts[slot_idx] = self._create_context() + + return self.slot_contexts[slot_idx], {slot_idx: 0} diff --git a/nemo/collections/speechlm2/inference/streaming/state/s2s_state.py b/nemo/collections/speechlm2/inference/streaming/state/s2s_state.py index 83ec8f09b439..c6ab5ae0a5e9 100644 --- a/nemo/collections/speechlm2/inference/streaming/state/s2s_state.py +++ b/nemo/collections/speechlm2/inference/streaming/state/s2s_state.py @@ -21,126 +21,126 @@ @dataclass class S2SStreamingState: - """ - State for streaming speech generation. - - This dataclass stores streaming tensors and counters used during - incremental generation. It keeps initialization metadata so it can be - reset to a clean state on demand. - """ - # Initialization metadata (required) - device: torch.device - dtype: torch.dtype - max_len: int - num_audio_codebooks: int - output_sample_rate: int - - # Runtime tensors (initialized in __post_init__) - audio_buffer: torch.Tensor = field(init=False) - - # Accumulated text output - output_text_str: str = "" - output_text_tokens: List[str] = field(default_factory=list) - # Accumulated ASR text output - output_asr_text_str: str = "" - output_asr_text_tokens: List[str] = field(default_factory=list) - # Accumulated words with timings - output_words: List[Word] = field(default_factory=list) - # Final token tensors saved from the context before it is destroyed. - # Used for post-hoc tokens_to_str / tokens_to_str_raw conversion. - final_gen_text: Optional[torch.Tensor] = None - final_gen_asr_text: Optional[torch.Tensor] = None - final_total_frames: int = 0 - - def __post_init__(self) -> None: - """Allocate tensors lazily based on provided metadata.""" - with torch.no_grad(): - # Empty 2D buffer: shape (1, 0). Will be appended over time. - self.audio_buffer = torch.empty((1, 0), device=self.device, dtype=self.dtype) - - def reset(self) -> None: - """Reset all tensors and counters to their initial state.""" - with torch.no_grad(): - self.audio_buffer = torch.empty((1, 0), device=self.device, dtype=self.dtype) - self.output_text_str = "" - self.output_text_tokens.clear() - self.output_asr_text_str = "" - self.output_asr_text_tokens.clear() - self.output_words.clear() - self.final_gen_text = None - self.final_gen_asr_text = None - self.final_total_frames = 0 - - def update_state(self, processed_frames: torch.Tensor, output_text_tokens: Any = None, output_text: str | None = None, output_asr_text: str | None = None) -> None: - """Append new audio to the right of the buffer; token/text args are accepted for API compatibility.""" - if processed_frames is None: - return - if not isinstance(processed_frames, torch.Tensor): - raise TypeError("processed_frames must be a torch.Tensor") - with torch.no_grad(): - # Ensure 2D [1, T] layout by flattening extra dims - append_tensor = processed_frames - if append_tensor.dim() > 1: - append_tensor = append_tensor.reshape(1, -1) - elif append_tensor.dim() == 1: - append_tensor = append_tensor.unsqueeze(0) - prior_samples = int(self.audio_buffer.shape[-1]) - appended_samples = int(append_tensor.shape[-1]) - self.audio_buffer = torch.cat([self.audio_buffer, append_tensor.to(self.device, dtype=self.dtype)], dim=-1) - - # Accumulate text output if provided and create a Word with naive timing - if isinstance(output_text, str) and output_text: - self.output_text_tokens.append(output_text) - # Directly concatenate - spacing is already handled by tokenizer (Ġ → space) - self.output_text_str += output_text - try: - if appended_samples > 0 and self.output_sample_rate > 0: - start_t = float(prior_samples) / float(self.output_sample_rate) - end_t = float(prior_samples + appended_samples) / float(self.output_sample_rate) - self.output_words.append(Word(text=output_text, start=start_t, end=end_t, conf=1.0)) - except Exception: - pass - - if isinstance(output_asr_text, str) and output_asr_text: - self.output_asr_text_tokens.append(output_asr_text) - self.output_asr_text_str += output_asr_text - - @property - def speech_frames(self) -> List[torch.Tensor]: - """Backward-compatible view for code expecting a list of chunks.""" - return [self.audio_buffer] - - def get_output_text(self) -> str: - """Return accumulated text as a single string.""" - return self.output_text_str - - def get_output_asr_text(self) -> str: - """Return accumulated ASR text as a single string.""" - return self.output_asr_text_str - - def get_output_words(self) -> List[Word]: - """Return accumulated words with timings.""" - return list(self.output_words) - - def save_token_tensors(self, gen_text: torch.Tensor, gen_asr_text: torch.Tensor, total_frames: int, - gen_function_text: torch.Tensor = None) -> None: - """Snapshot the full token-ID tensors from the context before it is destroyed.""" - with torch.no_grad(): - self.final_gen_text = gen_text[:, :total_frames].clone().cpu() - self.final_gen_asr_text = gen_asr_text[:, :total_frames].clone().cpu() - self.final_total_frames = total_frames - self.final_gen_function_text = ( - gen_function_text[:, :total_frames].clone().cpu() - if gen_function_text is not None else None - ) - - def get_token_tensors(self) -> Optional[tuple]: - """Return (gen_text, gen_asr_text, total_frames[, gen_function_text]) or None if not saved.""" - if self.final_gen_text is None: - return None - return self.final_gen_text, self.final_gen_asr_text, self.final_total_frames, getattr(self, 'final_gen_function_text', None) - - def cleanup_after_response(self) -> None: - """Clear transient audio; keep token workspaces allocated.""" - with torch.no_grad(): - self.audio_buffer = torch.empty((1, 0), device=self.device, dtype=self.dtype) + """ + State for streaming speech generation. + + This dataclass stores streaming tensors and counters used during + incremental generation. It keeps initialization metadata so it can be + reset to a clean state on demand. + """ + # Initialization metadata (required) + device: torch.device + dtype: torch.dtype + max_len: int + num_audio_codebooks: int + output_sample_rate: int + + # Runtime tensors (initialized in __post_init__) + audio_buffer: torch.Tensor = field(init=False) + + # Accumulated text output + output_text_str: str = "" + output_text_tokens: List[str] = field(default_factory=list) + # Accumulated ASR text output + output_asr_text_str: str = "" + output_asr_text_tokens: List[str] = field(default_factory=list) + # Accumulated words with timings + output_words: List[Word] = field(default_factory=list) + # Final token tensors saved from the context before it is destroyed. + # Used for post-hoc tokens_to_str / tokens_to_str_raw conversion. + final_gen_text: Optional[torch.Tensor] = None + final_gen_asr_text: Optional[torch.Tensor] = None + final_total_frames: int = 0 + + def __post_init__(self) -> None: + """Allocate tensors lazily based on provided metadata.""" + with torch.no_grad(): + # Empty 2D buffer: shape (1, 0). Will be appended over time. + self.audio_buffer = torch.empty((1, 0), device=self.device, dtype=self.dtype) + + def reset(self) -> None: + """Reset all tensors and counters to their initial state.""" + with torch.no_grad(): + self.audio_buffer = torch.empty((1, 0), device=self.device, dtype=self.dtype) + self.output_text_str = "" + self.output_text_tokens.clear() + self.output_asr_text_str = "" + self.output_asr_text_tokens.clear() + self.output_words.clear() + self.final_gen_text = None + self.final_gen_asr_text = None + self.final_total_frames = 0 + + def update_state(self, processed_frames: torch.Tensor, output_text_tokens: Any = None, output_text: str | None = None, output_asr_text: str | None = None) -> None: + """Append new audio to the right of the buffer; token/text args are accepted for API compatibility.""" + if processed_frames is None: + return + if not isinstance(processed_frames, torch.Tensor): + raise TypeError("processed_frames must be a torch.Tensor") + with torch.no_grad(): + # Ensure 2D [1, T] layout by flattening extra dims + append_tensor = processed_frames + if append_tensor.dim() > 1: + append_tensor = append_tensor.reshape(1, -1) + elif append_tensor.dim() == 1: + append_tensor = append_tensor.unsqueeze(0) + prior_samples = int(self.audio_buffer.shape[-1]) + appended_samples = int(append_tensor.shape[-1]) + self.audio_buffer = torch.cat([self.audio_buffer, append_tensor.to(self.device, dtype=self.dtype)], dim=-1) + + # Accumulate text output if provided and create a Word with naive timing + if isinstance(output_text, str) and output_text: + self.output_text_tokens.append(output_text) + # Directly concatenate - spacing is already handled by tokenizer (Ġ → space) + self.output_text_str += output_text + try: + if appended_samples > 0 and self.output_sample_rate > 0: + start_t = float(prior_samples) / float(self.output_sample_rate) + end_t = float(prior_samples + appended_samples) / float(self.output_sample_rate) + self.output_words.append(Word(text=output_text, start=start_t, end=end_t, conf=1.0)) + except Exception: + pass + + if isinstance(output_asr_text, str) and output_asr_text: + self.output_asr_text_tokens.append(output_asr_text) + self.output_asr_text_str += output_asr_text + + @property + def speech_frames(self) -> List[torch.Tensor]: + """Backward-compatible view for code expecting a list of chunks.""" + return [self.audio_buffer] + + def get_output_text(self) -> str: + """Return accumulated text as a single string.""" + return self.output_text_str + + def get_output_asr_text(self) -> str: + """Return accumulated ASR text as a single string.""" + return self.output_asr_text_str + + def get_output_words(self) -> List[Word]: + """Return accumulated words with timings.""" + return list(self.output_words) + + def save_token_tensors(self, gen_text: torch.Tensor, gen_asr_text: torch.Tensor, total_frames: int, + gen_function_text: torch.Tensor = None) -> None: + """Snapshot the full token-ID tensors from the context before it is destroyed.""" + with torch.no_grad(): + self.final_gen_text = gen_text[:, :total_frames].clone().cpu() + self.final_gen_asr_text = gen_asr_text[:, :total_frames].clone().cpu() + self.final_total_frames = total_frames + self.final_gen_function_text = ( + gen_function_text[:, :total_frames].clone().cpu() + if gen_function_text is not None else None + ) + + def get_token_tensors(self) -> Optional[tuple]: + """Return (gen_text, gen_asr_text, total_frames[, gen_function_text]) or None if not saved.""" + if self.final_gen_text is None: + return None + return self.final_gen_text, self.final_gen_asr_text, self.final_total_frames, getattr(self, 'final_gen_function_text', None) + + def cleanup_after_response(self) -> None: + """Clear transient audio; keep token workspaces allocated.""" + with torch.no_grad(): + self.audio_buffer = torch.empty((1, 0), device=self.device, dtype=self.dtype) From a485dccca5a8be579a5ccc02b55750a8b60b361b Mon Sep 17 00:00:00 2001 From: Elena Rastorgueva Date: Fri, 20 Mar 2026 18:08:39 +0000 Subject: [PATCH 12/40] remove hardcoded env var, simple tidy: remove dead code atc Signed-off-by: Elena Rastorgueva --- .../inference/model_wrappers/model_factory.py | 12 +- .../nemotron_voicechat_inference_wrapper.py | 130 +++--------------- .../pipelines/streaming_s2s_pipeline.py | 4 +- .../inference/utils/pipeline_utils.py | 25 ++++ 4 files changed, 52 insertions(+), 119 deletions(-) diff --git a/nemo/collections/speechlm2/inference/model_wrappers/model_factory.py b/nemo/collections/speechlm2/inference/model_wrappers/model_factory.py index cdf9d2ab613a..770005bbf2eb 100644 --- a/nemo/collections/speechlm2/inference/model_wrappers/model_factory.py +++ b/nemo/collections/speechlm2/inference/model_wrappers/model_factory.py @@ -38,7 +38,7 @@ import os import torch from transformers import DynamicCache -from dataclasses import dataclass +from dataclasses import dataclass, fields from nemo.utils import logging @@ -530,7 +530,7 @@ async def _process_inputs_to_outputs( return ans - def to(self, device_or_dtype: Union[torch.device, torch.dtype]) -> 'VLLMModel': + def to(self, device_or_dtype: Union[torch.device, torch.dtype]) -> 'VllmLLMModel': """ Move model to specified device or convert to specified dtype. @@ -543,7 +543,7 @@ def to(self, device_or_dtype: Union[torch.device, torch.dtype]) -> 'VLLMModel': pass return self - def eval(self) -> 'VLLMModel': + def eval(self) -> 'VllmLLMModel': """Set model to evaluation mode (vLLM is always in eval mode).""" return self @@ -961,12 +961,12 @@ def _extract_special_token_ids_from_nemo(model) -> Set[int]: return special_ids - def to(self, device_or_dtype: Union[torch.device, torch.dtype]) -> 'NativeModelInterface': + def to(self, device_or_dtype: Union[torch.device, torch.dtype]) -> 'NativeModel': """Move underlying model to device or convert dtype.""" self.model = self.model.to(device_or_dtype) return self - def eval(self) -> 'NativeModelInterface': + def eval(self) -> 'NativeModel': """Set underlying model to eval mode.""" self.model.eval() return self @@ -986,7 +986,7 @@ def __getattr__(self, name: str): Delegate attribute access to the underlying model. This allows transparent access to model attributes like - perception, tokenizer, etc.ß + perception, tokenizer, etc. """ # Avoid infinite recursion for special attributes if name in ('model', '__dict__', '__class__'): diff --git a/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py b/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py index 4ad77f8be83c..02e8c15cc6a7 100755 --- a/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py +++ b/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py @@ -12,34 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -import torch -from omegaconf import OmegaConf, DictConfig -import time -import re +import gc import os -import sys -import torchaudio -import functools -from dataclasses import dataclass +import time +import types from typing import Optional, Tuple -from nemo.utils import logging -import gc -import types +import torch +import torchaudio +from omegaconf import OmegaConf, DictConfig +from nemo.utils import logging from transformers import DynamicCache - -# Set environment variables (use existing env vars if set, otherwise use defaults) -_default_cache = "/tmp/cache" -os.environ.setdefault("HF_HOME", _default_cache) -os.environ.setdefault("TORCH_HOME", _default_cache) -os.environ.setdefault("NEMO_CACHE_DIR", _default_cache) -os.environ.setdefault("NEMO_NLP_TMP", os.path.join(_default_cache, "nemo_nlp_tmp")) - from nemo.collections.speechlm2.models.nemotron_voicechat import NemotronVoiceChat - -from nemo.collections.speechlm2.parts.text_utils import tokens_to_str from nemo.collections.speechlm2.parts.precision import fp32_precision from nemo.collections.audio.parts.utils.transforms import resample from nemo.collections.speechlm2.modules.ear_tts_vae_codec import CausalConv1dCache @@ -48,40 +34,6 @@ PerceptionCacheState, PerceptionCacheManager, ) -from nemo.collections.speechlm2.inference.utils.pipeline_utils import clean_pred_text - - -def tokens_to_str_raw(tokens: torch.Tensor, lengths: torch.Tensor, tokenizer, pad_id: int) -> list: - """ - Convert token IDs to text strings, preserving ALL special tokens including (pad token). - - Unlike tokens_to_str, this function uses ids_to_tokens which preserves special tokens, - and does NOT filter out any tokens (including pad tokens like ). - - Args: - tokens: Token IDs tensor (B, T) - lengths: Length of each sequence (B,) - tokenizer: Tokenizer for decoding - pad_id: Pad token ID (not used for filtering in raw mode, kept for API compatibility) - - Returns: - List of decoded text strings with ALL special tokens preserved (including ) - """ - ans = [] - for hyp_ids, hyp_len in zip(tokens.cpu(), lengths.cpu()): - hyp_ids = hyp_ids[:hyp_len] - # Do NOT filter out any tokens - keep everything including pad tokens () - hyp_ids_list = hyp_ids.tolist() - - # Use ids_to_tokens which preserves special tokens like - toks = tokenizer.ids_to_tokens(hyp_ids_list) - - # Only replace 'Ġ' with space for proper word boundaries, keep all special tokens - toks = [tok.replace('Ġ', ' ') for tok in toks] - - ans.append("".join(toks)) - return ans - # --- Configuration --- @@ -121,14 +73,6 @@ def __init__(self, model_cfg: DictConfig): if not isinstance(model_cfg, DictConfig): model_cfg = OmegaConf.create(model_cfg) - - logging.info(f"pythonpath: {sys.path}") - - - logging.info(f"before setting - torch.backends.cudnn.allow_tf32: {torch.backends.cudnn.allow_tf32}") - logging.info(f"before setting - torch.backends.cuda.matmul.allow_tf32: {torch.backends.cuda.matmul.allow_tf32}") - logging.info(f"before setting - torch.get_float32_matmul_precision(): {torch.get_float32_matmul_precision()}") - torch.backends.cudnn.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True torch.set_float32_matmul_precision("medium") @@ -162,10 +106,6 @@ def __init__(self, model_cfg: DictConfig): "Inference will also be slower." ) - logging.info(f"torch.backends.cudnn.allow_tf32: {torch.backends.cudnn.allow_tf32}") - logging.info(f"torch.backends.cuda.matmul.allow_tf32: {torch.backends.cuda.matmul.allow_tf32}") - logging.info(f"torch.get_float32_matmul_precision(): {torch.get_float32_matmul_precision()}") - self.model_cfg = model_cfg self.model_path = model_cfg.get("model_path") @@ -206,6 +146,7 @@ def __init__(self, model_cfg: DictConfig): logging.info(f"Decode audio: {self.decode_audio}") logging.info(f"Engine type: {model_cfg.get('engine_type', 'native')}") logging.info(f"Sampling - top_p: {model_cfg.get('top_p', 1.0)}, repetition_penalty: {model_cfg.get('repetition_penalty', 1.0)}, temperature: {model_cfg.get('temperature', 1.0)}") + logging.info(f"Float32 matmul precision: {torch.get_float32_matmul_precision()}") logging.info("=" * 70) # Cached TTS helpers populated during initialization/warmup @@ -214,7 +155,6 @@ def __init__(self, model_cfg: DictConfig): self.first_tts_code_input = None self.first_tts_past_key_values_input = None - self.model = None self.model_llm_interface = None self.tokenizer = None @@ -267,9 +207,6 @@ def __init__(self, model_cfg: DictConfig): logging.info("NemotronVoicechatInferenceWrapper initialized successfully.") - logging.info(f"{self.model.stt_model.perception.encoder._cfg = }") - logging.info(f"{self.model.stt_model.perception.encoder.streaming_cfg = }") - @staticmethod def _resolve_dtype(compute_dtype): if isinstance(compute_dtype, torch.dtype): @@ -330,14 +267,13 @@ def _samples_per_audio_output_frame(self): def _initialize_model(self): """Initialize the NemotronVoiceChat model from an HF checkpoint.""" logging.info("Initializing model structure...") - start_DuplexS2S_init = time.time() + start_model_init = time.time() self.model = NemotronVoiceChat.from_pretrained( self.model_path, local_files_only=True, ) - logging.info(f"Time taken to initialize NemotronVoiceChat: {time.time() - start_DuplexS2S_init} seconds") - logging.info(" Model structure initialized") + logging.info(f"NemotronVoiceChat initialized in {time.time() - start_model_init:.1f}s") if self.use_vllm_eartts: # Use object.__setattr__ to bypass PyTorch's module registration @@ -374,10 +310,9 @@ def _initialize_model(self): self.model.to(self.device) self.model.eval() - # Convert only the S2S components to the configured dtype, not the TTS model - logging.info(f"Converting S2S components to {self.dtype} (keeping TTS in float32)...") - if self.model.stt_model.llm is not None: - self.model.stt_model.llm = self.model.stt_model.llm.to(self.dtype) + # Convert some S2S components to the configured dtype + logging.info(f"Converting some S2S components to {self.dtype} (keeping perception & TTS in float32)...") + self.model.stt_model.llm = self.model.stt_model.llm.to(self.dtype) self.model.stt_model.lm_head = self.model.stt_model.lm_head.to(self.dtype) self.model.stt_model.embed_tokens = self.model.stt_model.embed_tokens.to(self.dtype) self.model.stt_model.asr_head = self.model.stt_model.asr_head.to(self.dtype) @@ -385,14 +320,6 @@ def _initialize_model(self): if self.model.stt_model.function_head is not None: self.model.stt_model.function_head = self.model.stt_model.function_head.to(self.dtype) logging.info("function_head converted to %s", self.dtype) - #self.model.stt_model.perception = self.model.stt_model.perception.to(self.dtype) - logging.info("S2S components converted, TTS kept in float32") - logging.info("new update, perception also is kept in float32") - - # commenting this out to avoid error when try vllm tts - # and anyway - when sticking to "native", saw no difference in output - # with and without this call - #self.model.on_train_epoch_start() # torch.compile for native TTS backbone use_tts_torch_compile = bool(self.model_cfg.get("use_tts_torch_compile", False)) @@ -414,29 +341,21 @@ def _initialize_model(self): self.tokenizer = self.model.stt_model.tokenizer - - # allow overrides/additions from the self.model_cfg of nemotron_voicechat_inference_wrapper, - # into the model cfg that is read from config.json of the model. - # Specifically, this is so that we can specify inference_pad_boost, ... etc. - for key in ( + # Allow overrides from wrapper config into the model config (e.g. logit boosts). + _BOOST_KEYS = ( "inference_pad_boost", "inference_bos_boost", "inference_eos_boost", "inference_user_pad_boost", "inference_user_bos_boost", "inference_user_eos_boost", - ): + ) + for key in _BOOST_KEYS: val = self.model_cfg.get(key, None) if val is not None: OmegaConf.update(self.model.stt_model.cfg, key, val) - - # Print inference boost values - logging.info(f"inference_eos_boost: {self.model.stt_model.cfg.get('inference_eos_boost', None)}") - logging.info(f"inference_bos_boost: {self.model.stt_model.cfg.get('inference_bos_boost', None)}") - logging.info(f"inference_pad_boost: {self.model.stt_model.cfg.get('inference_pad_boost', None)}") - logging.info(f"inference_user_pad_boost: {self.model.stt_model.cfg.get('inference_user_pad_boost', None)}") - logging.info(f"inference_user_bos_boost: {self.model.stt_model.cfg.get('inference_user_bos_boost', None)}") - logging.info(f"inference_user_eos_boost: {self.model.stt_model.cfg.get('inference_user_eos_boost', None)}") + boost_values = {k: self.model.stt_model.cfg.get(k, None) for k in _BOOST_KEYS} + logging.info(f"Inference logit boosts: {boost_values}") # Wrap model with appropriate interface (Native or vLLM) if self.use_vllm_llm: @@ -670,7 +589,6 @@ def _prepare_tts_initial_state(self): self.first_tts_code_input = code.detach().clone() self.first_tts_past_key_values_input = self._clone_cache(outputs.past_key_values) - logging.info("TTS warmup state prepared") def create_decode_state(self, max_len: int): @@ -1137,13 +1055,3 @@ def _maybe_apply_forced_turn_taking(self, t, gen_text, gen_asr): gen_text[batch_idx, t] = self.model.stt_model.text_eos_id logging.info(f"Forced turn-taking at frame {t}: inserted agent EOS (reason: user started speaking)") -def main(): - raise RuntimeError( - "This module cannot be called directly. " - "Use examples/speechlm2/nemo_inference_pipelines/s2s_streaming_infer.py instead." - ) - - -if __name__ == "__main__": - sys.exit(main()) - diff --git a/nemo/collections/speechlm2/inference/pipelines/streaming_s2s_pipeline.py b/nemo/collections/speechlm2/inference/pipelines/streaming_s2s_pipeline.py index b7cbb06a34cb..fec34de1c5e8 100644 --- a/nemo/collections/speechlm2/inference/pipelines/streaming_s2s_pipeline.py +++ b/nemo/collections/speechlm2/inference/pipelines/streaming_s2s_pipeline.py @@ -31,11 +31,11 @@ from nemo.collections.asr.inference.utils.progressbar import ProgressBar from nemo.collections.speechlm2.inference.pipelines.s2s_pipeline_interface import S2SPipelineInterface from nemo.collections.speechlm2.inference.streaming.state.s2s_state import S2SStreamingState -from nemo.collections.speechlm2.inference.model_wrappers.nemotron_voicechat_inference_wrapper import NemotronVoicechatInferenceWrapper, tokens_to_str_raw +from nemo.collections.speechlm2.inference.model_wrappers.nemotron_voicechat_inference_wrapper import NemotronVoicechatInferenceWrapper from nemo.collections.speechlm2.parts.text_utils import tokens_to_str from nemo.collections.speechlm2.inference.streaming.state.s2s_context_manager import S2SContextManager from nemo.collections.speechlm2.inference.streaming.framing.s2s_request_options import S2SRequestOptions -from nemo.collections.speechlm2.inference.utils.pipeline_utils import PipelineOutput +from nemo.collections.speechlm2.inference.utils.pipeline_utils import PipelineOutput, tokens_to_str_raw from nemo.utils import logging diff --git a/nemo/collections/speechlm2/inference/utils/pipeline_utils.py b/nemo/collections/speechlm2/inference/utils/pipeline_utils.py index 4ad5634172f2..d61e2998d5ee 100644 --- a/nemo/collections/speechlm2/inference/utils/pipeline_utils.py +++ b/nemo/collections/speechlm2/inference/utils/pipeline_utils.py @@ -20,6 +20,31 @@ from nemo.collections.asr.inference.utils.text_segment import Word +def tokens_to_str_raw(tokens: torch.Tensor, lengths: torch.Tensor, tokenizer, pad_id: int) -> list: + """Convert token IDs to text strings, preserving ALL special tokens including pad tokens. + + Unlike ``tokens_to_str``, this function uses ``ids_to_tokens`` which preserves + special tokens and does NOT filter out any tokens (including pad tokens like + ````). + + Args: + tokens: Token IDs tensor (B, T). + lengths: Length of each sequence (B,). + tokenizer: Tokenizer for decoding. + pad_id: Pad token ID (kept for API compatibility with ``tokens_to_str``). + + Returns: + List of decoded text strings with ALL special tokens preserved. + """ + ans = [] + for hyp_ids, hyp_len in zip(tokens.cpu(), lengths.cpu()): + hyp_ids = hyp_ids[:hyp_len] + toks = tokenizer.ids_to_tokens(hyp_ids.tolist()) + toks = [tok.replace('Ġ', ' ') for tok in toks] + ans.append("".join(toks)) + return ans + + def clean_pred_text(text: str) -> str: """Clean prediction text by removing special markers, timestamps, punctuation, and lowercasing. From dd889877305a631f642daca1ae1ae93465cc2141 Mon Sep 17 00:00:00 2001 From: Elena Rastorgueva Date: Fri, 20 Mar 2026 18:44:40 +0000 Subject: [PATCH 13/40] always use codec cache => remove use_codec_cache flag and codec_token_history_size parameter Signed-off-by: Elena Rastorgueva --- .../conf/s2s_streaming.yaml | 3 - .../voicechat/1/infer_streaming.py | 2 - .../nemotron_voicechat_parity_harness.py | 4 - .../nemotron_voicechat_inference_wrapper.py | 99 ++++--------------- .../pipelines/streaming_s2s_pipeline.py | 1 - .../streaming/state/s2s_context_manager.py | 4 - 6 files changed, 18 insertions(+), 95 deletions(-) diff --git a/examples/speechlm2/nemo_inference_pipelines/conf/s2s_streaming.yaml b/examples/speechlm2/nemo_inference_pipelines/conf/s2s_streaming.yaml index 9f31b11bcd25..e7b0ed572320 100644 --- a/examples/speechlm2/nemo_inference_pipelines/conf/s2s_streaming.yaml +++ b/examples/speechlm2/nemo_inference_pipelines/conf/s2s_streaming.yaml @@ -58,11 +58,8 @@ s2s: # ======================== # Inference settings # ======================== - codec_token_history_size: 60 # Sliding-window buffer size; ignored when use_codec_cache is true use_perception_cache: true # Enable cache-aware streaming for perception encoder use_perception_cudagraph: true # Enable CUDA graph-accelerated perception encoder - use_codec_cache: true # Incremental codec decode to remove clicking sounds and wasted computation - # (when true, codec_token_history_size is unused) use_llm_cache: true # Use KV cache for the STT LLM (DynamicCache or HybridMambaAttentionDynamicCache) # TTS speedup flags (default to false; enable to speed up native inference) diff --git a/examples/speechlm2/nemo_inference_pipelines/triton/model_repo_s2s/voicechat/1/infer_streaming.py b/examples/speechlm2/nemo_inference_pipelines/triton/model_repo_s2s/voicechat/1/infer_streaming.py index 7fdb449f0bec..be9c2e44c568 100644 --- a/examples/speechlm2/nemo_inference_pipelines/triton/model_repo_s2s/voicechat/1/infer_streaming.py +++ b/examples/speechlm2/nemo_inference_pipelines/triton/model_repo_s2s/voicechat/1/infer_streaming.py @@ -49,7 +49,6 @@ def _resolve_env_overrides(self, cfg): s2s.engine_type -> S2S_ENGINE_TYPE (default: native) s2s.system_prompt -> S2S_SYSTEM_PROMPT (default: none) s2s.tts_system_prompt -> S2S_TTS_SYSTEM_PROMPT (default: none) - s2s.use_codec_cache -> S2S_USE_CODEC_CACHE (default: true) streaming.chunk_size_in_secs -> S2S_CHUNK_SIZE_IN_SECS (default: 0.08) streaming.buffer_size_in_secs -> S2S_BUFFER_SIZE_IN_SECS (default: 5.6) """ @@ -62,7 +61,6 @@ def _resolve_env_overrides(self, cfg): "s2s.engine_type": ("S2S_ENGINE_TYPE", "native"), "s2s.system_prompt": ("S2S_SYSTEM_PROMPT", None), "s2s.tts_system_prompt": ("S2S_TTS_SYSTEM_PROMPT", None), - "s2s.use_codec_cache": ("S2S_USE_CODEC_CACHE", True), "streaming.chunk_size_in_secs": ("S2S_CHUNK_SIZE_IN_SECS", 0.08), "streaming.buffer_size_in_secs": ("S2S_BUFFER_SIZE_IN_SECS", 5.6), } diff --git a/examples/speechlm2/nemotron_voicechat_parity_harness.py b/examples/speechlm2/nemotron_voicechat_parity_harness.py index db7083b3fb53..0673cd7a881d 100644 --- a/examples/speechlm2/nemotron_voicechat_parity_harness.py +++ b/examples/speechlm2/nemotron_voicechat_parity_harness.py @@ -526,7 +526,6 @@ def run_parity_harness(args) -> dict[str, Any]: "s2s.use_perception_cache": args.use_perception_cache, "s2s.use_perception_cudagraph": args.use_perception_cudagraph, "s2s.use_llm_cache": args.use_llm_cache, - "s2s.use_codec_cache": args.use_codec_cache, "s2s.deterministic": args.deterministic, "s2s.top_p": args.top_p, "s2s.repetition_penalty": args.repetition_penalty, @@ -540,7 +539,6 @@ def run_parity_harness(args) -> dict[str, Any]: "s2s.use_perception_cache": False, "s2s.use_perception_cudagraph": False, "s2s.use_llm_cache": False, - "s2s.use_codec_cache": False, "s2s.deterministic": True, "s2s.top_p": 1.0, "s2s.repetition_penalty": 1.0, @@ -644,7 +642,6 @@ def run_parity_harness(args) -> dict[str, Any]: "engine_type": inference_cfg.s2s.get("engine_type"), "use_perception_cache": bool(inference_cfg.s2s.get("use_perception_cache", False)), "use_llm_cache": bool(inference_cfg.s2s.get("use_llm_cache", False)), - "use_codec_cache": bool(inference_cfg.s2s.get("use_codec_cache", False)), "deterministic": bool(inference_cfg.s2s.get("deterministic", False)), "model_dtypes": _collect_model_dtypes(wrapper), "comparison": _compare_outputs(offline, incremental), @@ -715,7 +712,6 @@ def build_argparser() -> argparse.ArgumentParser: _bool_arg(parser, "--use_perception_cache", "Override perception cache usage.") _bool_arg(parser, "--use_perception_cudagraph", "Override perception CUDA-graph usage.") _bool_arg(parser, "--use_llm_cache", "Override LLM cache usage.") - _bool_arg(parser, "--use_codec_cache", "Override codec cache usage.") _bool_arg(parser, "--deterministic", "Override deterministic mode.") parser.add_argument("--top_p", type=float, default=None, help="Override top-p.") parser.add_argument("--repetition_penalty", type=float, default=None, help="Override repetition penalty.") diff --git a/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py b/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py index 02e8c15cc6a7..338c44e8cd1f 100755 --- a/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py +++ b/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py @@ -47,14 +47,6 @@ TTS_SAMPLE_RATE = 22050 -# Default hyper-parameters that can be overridden via `model_cfg` -DEFAULT_BUFFER_SIZE_FRAMES = 71 -DEFAULT_NUM_FRAMES_PER_CHUNK = 1 -# Only used when use_codec_cache=False (sliding-window fallback). -# Ignored when the codec streaming cache is enabled. -DEFAULT_CODEC_TOKEN_HISTORY_SIZE = 600 - - class NemotronVoicechatInferenceWrapper: """ Inference wrapper for NemotronVoiceChat models. @@ -113,13 +105,6 @@ def __init__(self, model_cfg: DictConfig): raise ValueError("`model_cfg.model_path` must be provided.") self.decode_audio = bool(model_cfg.get("decode_audio", True)) - # Number of past codec tokens kept in the sliding-window decode buffer. - # Only used when use_codec_cache=False (the fallback path). When the - # codec cache is enabled, context is maintained incrementally inside - # CausalConv1dCache and this value is ignored. - self.codec_token_history_size = int( - model_cfg.get("codec_token_history_size", DEFAULT_CODEC_TOKEN_HISTORY_SIZE) - ) self.speaker_reference = model_cfg.get("speaker_reference") self.speaker_name = model_cfg.get("speaker_name", None) @@ -172,21 +157,6 @@ def __init__(self, model_cfg: DictConfig): self.repetition_penalty = float(model_cfg.get("repetition_penalty", 1.0)) self.temperature = float(model_cfg.get("temperature", 1.0)) - # Codec streaming cache: decode only new tokens each step using the - # codec's CausalConv1dCache, which maintains ConvNeXt and ISTFT state - # across calls for sample-continuous audio. When enabled, the - # codec_token_history_size parameter and audio_toks_buffer are unused. - # When disabled, falls back to the sliding-window decode that re-decodes - # codec_token_history_size tokens each step and extracts the tail. - self.use_codec_cache = bool(model_cfg.get("use_codec_cache", True)) - if self.use_codec_cache and self.decode_audio: - configured_history = model_cfg.get("codec_token_history_size", None) - if configured_history is not None: - logging.info( - f"use_codec_cache is enabled — codec_token_history_size ({configured_history}) " - f"will be ignored (context is maintained incrementally by the codec cache)." - ) - # LLM KV cache: when enabled, uses DynamicCache (standard) or # HybridMambaAttentionDynamicCache (Nemotron) for incremental decoding. # When disabled, falls back to full-history reprocessing each step. @@ -524,20 +494,11 @@ def _create_llm_cache(self): def _create_codec_state(self, max_len: int): if not self.decode_audio or not hasattr(self.model, "tts_model"): - return None, None, None - - audio_toks_buffer = None - codec_cache = None - if self.use_codec_cache: - codec_cache = CausalConv1dCache() - elif self.codec_token_history_size > 0: - silence_tokens = self.model.tts_model.codec_silence_tokens.detach().clone() - audio_toks_buffer = silence_tokens.view(1, 1, -1).expand( - 1, self.codec_token_history_size, -1 - ).contiguous().to(self.device) + return None, None + codec_cache = CausalConv1dCache() subword_mask = torch.ones((1, max_len), device=self.device, dtype=torch.bool) - return audio_toks_buffer, subword_mask, codec_cache + return subword_mask, codec_cache def _prepare_tts_initial_state(self): if not self.decode_audio: @@ -594,7 +555,7 @@ def _prepare_tts_initial_state(self): def create_decode_state(self, max_len: int): gen_text, gen_asr_text, gen_function_text = self._create_generation_workspace(max_len) llm_cache = self._create_llm_cache() - audio_toks_buffer, subword_mask, codec_cache = self._create_codec_state(max_len) + subword_mask, codec_cache = self._create_codec_state(max_len) perception_cache = None if self.use_perception_cache and self.perception_cache_mgr is not None: perception_cache = self.perception_cache_mgr.get_initial_state(batch_size=1) @@ -610,7 +571,6 @@ def create_decode_state(self, max_len: int): "gen_text": gen_text, "gen_asr_text": gen_asr_text, "gen_function_text": gen_function_text, - "audio_toks_buffer": audio_toks_buffer, "input_embeds_history": [], "dynamic_cache": llm_cache, "past_key_values": past_key_values, @@ -626,7 +586,6 @@ def infer_one_step(self, num_frames_per_chunk, frame_idx, gen_text, - audio_toks_buffer, input_embeds_history, dynamic_cache, past_key_values=None, @@ -856,12 +815,8 @@ def infer_one_step(self, logging.info(f"Time taken for tts_model: {time_tts_model:.3f}s") new_codes_for_decode.append(code.clone()) - # Update sliding-window buffer (only needed for fallback decode when codec_cache is off) - if audio_toks_buffer is not None: - audio_toks_buffer = torch.cat([audio_toks_buffer[:, 1:], code], dim=1) - # now that we've saved audio_toks_buffer for audio decoding purposes, - # we can potentially overwrite the audio token with silence tokens (for feeding to the audio token predictor) + # Potentially overwrite the audio token with silence tokens (for feeding to the audio token predictor) if self.model.cfg.get('inference_force_speech_silence_on_eos', None): silence_codes = self.model.tts_model.codec_silence_tokens.view(1, 1, -1).expand(code.shape) code = torch.where( @@ -877,43 +832,26 @@ def infer_one_step(self, start_time_decode = time.time() with fp32_precision(), torch.no_grad(): - if codec_cache is not None and new_codes_for_decode: - # Incremental decode: feed only the num_frames_per_chunk new tokens - # to the codec. CausalConv1dCache maintains all necessary ConvNeXt - # and ISTFT overlap state from prior calls, so no history buffer - # is needed — this replaces the sliding-window approach entirely. - new_codes_tensor = torch.cat(new_codes_for_decode, dim=1) - if hasattr(self.model.tts_model, '_control_codes'): - from nemo.collections.speechlm2.models.duplex_ear_tts import replace_control_speech_codes - new_codes_tensor = replace_control_speech_codes( - new_codes_tensor, - self.model.tts_model._control_codes, - getattr(self.model.tts_model, 'codec_silence_tokens', None), - ) - new_code_len = torch.tensor( - [new_codes_tensor.shape[1]], dtype=torch.long, device=self.device - ) - decoded_audio_new, _ = self.model.tts_model.audio_codec.decode( - new_codes_tensor, new_code_len, cache=codec_cache, + new_codes_tensor = torch.cat(new_codes_for_decode, dim=1) + if hasattr(self.model.tts_model, '_control_codes'): + from nemo.collections.speechlm2.models.duplex_ear_tts import replace_control_speech_codes + new_codes_tensor = replace_control_speech_codes( + new_codes_tensor, + self.model.tts_model._control_codes, + getattr(self.model.tts_model, 'codec_silence_tokens', None), ) - logging.debug(f" Incremental decode: {new_codes_tensor.shape[1]} new tokens -> {decoded_audio_new.shape}") - else: - # Fallback: full-buffer sliding-window decode (original behavior) - len_audio_toks_buffer = torch.tensor( - [self.codec_token_history_size], dtype=torch.long, device=self.device - ) - decoded_audio, decoded_audio_len = self.model.tts_model.audio_codec.decode( - audio_toks_buffer, len_audio_toks_buffer, - ) - decoded_audio_new = decoded_audio[:, :, -samples_per_audio_output_frame * num_frames_per_chunk:] - logging.debug(f" Sliding-window decode: extracted {decoded_audio_new.shape} from {decoded_audio.shape}") + new_code_len = torch.tensor( + [new_codes_tensor.shape[1]], dtype=torch.long, device=self.device + ) + decoded_audio_new, _ = self.model.tts_model.audio_codec.decode( + new_codes_tensor, new_code_len, cache=codec_cache, + ) torch.cuda.synchronize() time_audio_codec = time.time() - start_time_decode logging.info(f"Time taken for audio_codec: {time_audio_codec:.3f}s") else: - audio_toks_buffer = None decoded_audio_new = None time_tts_model = 0 time_audio_codec = 0 @@ -952,7 +890,6 @@ def infer_one_step(self, result = { 'predicted_text_tokens': predicted_tokens, 'asr_predicted_text_tokens': asr_predicted_tokens, - 'audio_toks_buffer': audio_toks_buffer, 'decoded_audio_new': decoded_audio_new, 'predicted_text_strs': predicted_text_strs, 'asr_predicted_text_strs': asr_predicted_text_strs, diff --git a/nemo/collections/speechlm2/inference/pipelines/streaming_s2s_pipeline.py b/nemo/collections/speechlm2/inference/pipelines/streaming_s2s_pipeline.py index fec34de1c5e8..7d5ea6fce7b3 100644 --- a/nemo/collections/speechlm2/inference/pipelines/streaming_s2s_pipeline.py +++ b/nemo/collections/speechlm2/inference/pipelines/streaming_s2s_pipeline.py @@ -232,7 +232,6 @@ def inner_generate_step(self, frames: List[Frame], buffers: List[Tensor], left_p num_frames_per_chunk=self.num_frames_per_chunk, frame_idx=context.frame_idx, gen_text=context.gen_text, - audio_toks_buffer=context.audio_toks_buffer, input_embeds_history=context.input_embeds_history, dynamic_cache=context.dynamic_cache, past_key_values=context.past_key_values, diff --git a/nemo/collections/speechlm2/inference/streaming/state/s2s_context_manager.py b/nemo/collections/speechlm2/inference/streaming/state/s2s_context_manager.py index 0cf1ce8633a7..ab4a49e0ce92 100644 --- a/nemo/collections/speechlm2/inference/streaming/state/s2s_context_manager.py +++ b/nemo/collections/speechlm2/inference/streaming/state/s2s_context_manager.py @@ -29,7 +29,6 @@ class StreamingDecodeState: gen_text: torch.Tensor gen_asr_text: torch.Tensor gen_function_text: Optional[torch.Tensor] - audio_toks_buffer: Optional[torch.Tensor] input_embeds_history: List[torch.Tensor] dynamic_cache: Any # DynamicCache or HybridMambaAttentionDynamicCache past_key_values: Any @@ -76,7 +75,6 @@ def _create_context(self) -> StreamingDecodeState: gen_text=decode_state["gen_text"], gen_asr_text=decode_state["gen_asr_text"], gen_function_text=decode_state["gen_function_text"], - audio_toks_buffer=decode_state["audio_toks_buffer"], input_embeds_history=decode_state["input_embeds_history"], dynamic_cache=decode_state["dynamic_cache"], past_key_values=decode_state["past_key_values"], @@ -177,8 +175,6 @@ def update_context( if step_result.get("dynamic_cache") is not None: context.dynamic_cache = step_result["dynamic_cache"] - if "audio_toks_buffer" in step_result: - context.audio_toks_buffer = step_result["audio_toks_buffer"] if "input_embeds_history" in step_result: context.input_embeds_history = step_result["input_embeds_history"] if "past_key_values" in step_result: From adc42f715b2011b6aa636472d866f0a90f4a4b46 Mon Sep 17 00:00:00 2001 From: Elena Rastorgueva Date: Fri, 20 Mar 2026 18:52:24 +0000 Subject: [PATCH 14/40] remove newlines in logs Signed-off-by: Elena Rastorgueva --- .../nemotron_voicechat_inference_wrapper.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py b/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py index 338c44e8cd1f..f4e31ba61fa1 100755 --- a/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py +++ b/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py @@ -265,7 +265,7 @@ def _initialize_model(self): # If using vLLM for LLM, delete native LLM BEFORE moving to device to save memory if self.use_vllm_llm: - logging.info("\nDeleting native LLM before GPU transfer (will use vLLM instead)...") + logging.info("Deleting native LLM before GPU transfer (will use vLLM instead)...") if hasattr(self.model.stt_model, 'llm') and self.model.stt_model.llm is not None: # Delete all submodules of LLM to free memory for name, child in list(self.model.stt_model.llm.named_children()): @@ -329,7 +329,7 @@ def _initialize_model(self): # Wrap model with appropriate interface (Native or vLLM) if self.use_vllm_llm: - logging.info("\nWrapping model with VllmLLMModel interface...") + logging.info("Wrapping model with VllmLLMModel interface...") if self.vllm_llm_config is None: raise ValueError("vllm_llm_config must be provided when engine_type contains 'vllm_llm'") @@ -378,7 +378,7 @@ def _initialize_model(self): logging.info("VllmLLMModel interface created") else: - logging.info("\nWrapping model with NativeModel interface...") + logging.info("Wrapping model with NativeModel interface...") self.model_llm_interface = create_model( model=self.model, engine_type="native", @@ -392,7 +392,7 @@ def _initialize_model(self): if hasattr(self.model, 'tts_model'): self.target_fps = self.model.tts_model.target_fps self.target_sample_rate = self.model.tts_model.target_sample_rate - logging.info(f"\nTTS model initialized: target_fps={self.target_fps}, sample_rate={self.target_sample_rate}") + logging.info(f"TTS model initialized: target_fps={self.target_fps}, sample_rate={self.target_sample_rate}") if self.decode_audio: self._prepare_tts_initial_state() else: From 8babc04870d3bcfd015cf377e2cf07c9a7e4f90a Mon Sep 17 00:00:00 2001 From: Elena Rastorgueva Date: Mon, 23 Mar 2026 22:46:50 +0000 Subject: [PATCH 15/40] further tidying: pass StreamingDecodeState directly, return InferenceStepResult etc Signed-off-by: Elena Rastorgueva --- .../conf/s2s_streaming.yaml | 24 +- .../s2s_streaming_infer.py | 195 +-------- .../voicechat/1/infer_streaming.py | 6 +- .../speechlm2/inference/__init__.py | 9 + .../inference/model_wrappers/decode_state.py | 80 ++++ .../nemotron_voicechat_inference_wrapper.py | 369 +++++++++--------- .../pipelines/streaming_s2s_pipeline.py | 57 +-- .../streaming/state/s2s_context_manager.py | 90 +---- .../inference/streaming/state/s2s_state.py | 184 ++++----- .../speechlm2/inference/utils/audio_data.py | 183 +++++++++ 10 files changed, 595 insertions(+), 602 deletions(-) create mode 100644 nemo/collections/speechlm2/inference/model_wrappers/decode_state.py create mode 100644 nemo/collections/speechlm2/inference/utils/audio_data.py diff --git a/examples/speechlm2/nemo_inference_pipelines/conf/s2s_streaming.yaml b/examples/speechlm2/nemo_inference_pipelines/conf/s2s_streaming.yaml index e7b0ed572320..ac50445ef97c 100644 --- a/examples/speechlm2/nemo_inference_pipelines/conf/s2s_streaming.yaml +++ b/examples/speechlm2/nemo_inference_pipelines/conf/s2s_streaming.yaml @@ -66,12 +66,6 @@ s2s: use_tts_torch_compile: false # Compile TTS backbone with torch.compile (mode='default') use_tts_subword_cache: false # Cache CharAwareSubwordEncoder embeddings (skip backbone for repeated tokens) - # Deterministic inference (native engine only). Ensures identical results across - # runs by disabling FlashAttention and forcing deterministic CUDA algorithms. - # Trade-offs: slower inference, might be worse results than non-deterministic mode, since - # non-deterministic mode was used in training. - deterministic: false - # sampling parameters. if set all to 1.0, it will be greedy decoding. top_p: 0.5 repetition_penalty: 1.1 @@ -89,12 +83,18 @@ s2s: tts_system_prompt: null # TTS system prompt - conditions TTS generation style # Requires a checkpoint trained with individual TTS prompts - - -# ======================== -# Pipeline settings -# ======================== -matmul_precision: medium # Matrix multiplication precision: highest, high, medium + # ======================== + # Precision & determinism + # ======================== + # These defaults match what was used during model training. + matmul_precision: medium # Matrix multiplication precision: highest, high, medium + allow_tf32: true # Allow TF32 for cuDNN and CUDA matmul (Ampere+ GPUs). + # Set to false for stricter float32 precision. + # Deterministic inference (native engine only). Ensures identical results across + # runs by disabling FlashAttention and forcing deterministic CUDA algorithms. + # Trade-offs: slower inference, might produce worse results than non-deterministic mode, + # since non-deterministic mode was used in training. + deterministic: false streaming: input_sample_rate: 16000 # Audio sample rate in Hz diff --git a/examples/speechlm2/nemo_inference_pipelines/s2s_streaming_infer.py b/examples/speechlm2/nemo_inference_pipelines/s2s_streaming_infer.py index 693e3ac6c5ad..b521a5fc318f 100644 --- a/examples/speechlm2/nemo_inference_pipelines/s2s_streaming_infer.py +++ b/examples/speechlm2/nemo_inference_pipelines/s2s_streaming_infer.py @@ -19,193 +19,27 @@ python s2s_streaming_infer.py \ audio_file=/path/to/audio_or_directory \ s2s.model_path=/path/to/eartts_ckpt \ - s2s.llm_checkpoint_path=/path/to/llm_ckpt \ s2s.speaker_reference=/path/to/speaker.wav \ streaming.chunk_size_in_secs=0.08 \ streaming.buffer_size_in_secs=5.6 """ -import json -import os -import re from time import time -from typing import List, Optional import hydra -import soundfile as sf +import torch from jiwer import wer as compute_wer -from nemo.collections.common.parts.preprocessing.manifest import get_full_path -from nemo.collections.speechlm2.inference.factory.s2s_pipeline_builder import S2SPipelineBuilder -from nemo.collections.speechlm2.inference.streaming.framing.s2s_request_options import S2SRequestOptions -from nemo.collections.speechlm2.inference.utils.pipeline_utils import PipelineOutput, clean_pred_text -from nemo.utils import logging from omegaconf import DictConfig -import torch - - -def prepare_audio_data( - audio_file: str, - default_system_prompt: str | None = None, - sort_by_duration: bool = True, -) -> tuple[List[str], List[S2SRequestOptions], List[str | None]]: - """ - Get audio filepaths and per-stream options from a folder, single file, or manifest. - - When the input is a JSON manifest, each line may contain: - {"audio_filepath": "clip.wav", "text": "...", "system_prompt": "..."} - If ``system_prompt`` is absent on a line, *default_system_prompt* is used. - - Returns: - (filepaths, options, ground_truths) -- parallel lists of audio paths, - per-stream options, and ground-truth texts (None when unavailable). - """ - audio_file = audio_file.strip() - if not os.path.isabs(audio_file): - audio_file = os.path.abspath(audio_file) - - options: List[S2SRequestOptions] = [] - ground_truths: List[str | None] = [] - - if os.path.isdir(audio_file): - filepaths = [os.path.join(audio_file, x) for x in os.listdir(audio_file) if x.endswith(".wav")] - options = [S2SRequestOptions(system_prompt=default_system_prompt) for _ in filepaths] - ground_truths = [None] * len(filepaths) - elif audio_file.endswith(".wav"): - filepaths = [audio_file] - options = [S2SRequestOptions(system_prompt=default_system_prompt)] - ground_truths = [None] - elif audio_file.endswith((".json", ".jsonl")): - samples = [] - with open(audio_file, 'r') as f: - for line in f.readlines(): - if line.strip(): - samples.append(json.loads(line)) - filepaths = [get_full_path(entry["audio_filepath"], audio_file) for entry in samples] - options = [ - S2SRequestOptions( - system_prompt=entry.get("system_prompt", default_system_prompt), - ) - for entry in samples - ] - ground_truths = [entry.get("text", None) for entry in samples] - else: - raise ValueError(f"audio_file `{audio_file}` needs to be a folder, audio file, or manifest file") - - if sort_by_duration: - durations = [sf.SoundFile(fp).frames for fp in filepaths] - order = sorted(range(len(filepaths)), key=lambda i: durations[i]) - filepaths = [filepaths[i] for i in order] - options = [options[i] for i in order] - ground_truths = [ground_truths[i] for i in order] - - return filepaths, options, ground_truths - - -def calculate_duration(audio_filepaths: List[str]) -> float: - """Calculate the total duration of the audio files in seconds.""" - total_dur = 0 - for audio_filepath in audio_filepaths: - sound = sf.SoundFile(audio_filepath) - total_dur += sound.frames / sound.samplerate - return total_dur - -def calculate_padded_duration( - audio_filepaths: List[str], - pad_audio_to_sec: float | None = None, - pad_silence_ratio: float | None = None, - pad_audio_by_sec: float | None = None, -) -> float: - """Calculate total duration including padding for RTFX reporting.""" - total = 0.0 - for fp in audio_filepaths: - sound = sf.SoundFile(fp) - orig = sound.frames / sound.samplerate - if pad_audio_to_sec is not None: - total += max(orig, pad_audio_to_sec) - elif pad_silence_ratio is not None: - total += orig * (1 + pad_silence_ratio) - elif pad_audio_by_sec is not None: - total += orig + pad_audio_by_sec - else: - total += orig - return total - - -def dump_output( - audio_filepaths: List[str], - output: PipelineOutput, - output_dir: str, - options: List[S2SRequestOptions], - ground_truths: List[str | None], -) -> None: - """ - Dump inference results to output_processed.json and output_raw.json. - - output_processed.json uses the canonical S2S processed-output schema - (timestamps in pred_text via <|t|> / <$t$>). - - output_raw.json preserves all tokens including (pad tokens). - - CTM files are still written for per-word audio-sample-based timing. - - Args: - audio_filepaths: List of audio file paths - output: Pipeline output - output_dir: Directory for all output files - options: Per-stream request options (carries the system prompt) - ground_truths: Ground-truth texts (None when unavailable) - """ - output_processed_path = os.path.join(output_dir, "output_processed.json") - output_raw_path = os.path.join(output_dir, "output_raw.json") - output_ctm_dir = os.path.join(output_dir, "ctm") - - os.makedirs(output_ctm_dir, exist_ok=True) - - asr_texts_ts = output.asr_texts_with_timestamps or [None] * len(audio_filepaths) - texts_ts = output.texts_with_timestamps or [""] * len(audio_filepaths) - raw_texts = output.raw_texts or [""] * len(audio_filepaths) - raw_asr_texts = output.raw_asr_texts or [""] * len(audio_filepaths) - - with open(output_processed_path, 'w') as f_proc, open(output_raw_path, 'w') as f_raw: - for audio_filepath, words, opts, gt, pred_text_ts, pred_src_text_ts, pred_text_raw, pred_src_text_raw in zip( - audio_filepaths, output.words, options, ground_truths, - texts_ts, asr_texts_ts, raw_texts, raw_asr_texts, - ): - stem = os.path.splitext(os.path.basename(audio_filepath))[0] - ctm_filepath = os.path.abspath(os.path.join(output_ctm_dir, f"{stem}.ctm")) - with open(ctm_filepath, 'w') as ctm_fout: - for word in words: - ctm_line = f"A {round(word.start, 2)} {round(word.duration, 2)} {word.text} {word.conf}" - ctm_fout.write(f"{stem} {ctm_line}\n") - - pred_audio_path = os.path.join(output_dir, "wav", f"{stem}.wav") - - record_processed = { - "id": stem, - "target_text": "", - "pred_audio": pred_audio_path, - "src_text": gt or "", - "pred_src_text": pred_src_text_ts or "", - "pred_text": pred_text_ts or "", - "system_prompt": opts.system_prompt or "", - } - json.dump(record_processed, f_proc, ensure_ascii=False) - f_proc.write('\n') - f_proc.flush() - - record_raw = { - "id": stem, - "target_text": "", - "pred_audio": pred_audio_path, - "src_text": gt or "", - "pred_src_text": pred_src_text_raw or "", - "pred_text": pred_text_raw or "", - "system_prompt": opts.system_prompt or "", - } - json.dump(record_raw, f_raw, ensure_ascii=False) - f_raw.write('\n') - f_raw.flush() +from nemo.collections.speechlm2.inference.factory.s2s_pipeline_builder import S2SPipelineBuilder +from nemo.collections.speechlm2.inference.utils.audio_data import ( + calculate_duration, + calculate_padded_duration, + dump_output, + prepare_audio_data, +) +from nemo.collections.speechlm2.inference.utils.pipeline_utils import clean_pred_text +from nemo.utils import logging @hydra.main(config_path="./conf", config_name="s2s_streaming", version_base=None) @@ -216,11 +50,6 @@ def main(cfg: DictConfig): ) logging.info(f"Found {len(audio_filepaths)} audio files to generate") - # Set matmul precision - matmul_precision = cfg.get("matmul_precision", "high") - torch.set_float32_matmul_precision(matmul_precision) - logging.info(f"Using matmul precision: {matmul_precision}") - pipeline = S2SPipelineBuilder.build_pipeline(cfg) start = time() @@ -239,8 +68,7 @@ def main(cfg: DictConfig): rtfx = data_dur / exec_dur if exec_dur > 0 else float('inf') logging.info(f"RTFX: {rtfx:.2f} ({data_dur:.2f}s / {exec_dur:.2f}s)") - # Compute WER when ground-truth texts are available. - # Use asr_texts_with_timestamps (from tokens_to_str with full post-processing) + # Compute WER when ground-truth texts are available asr_texts = output.asr_texts_with_timestamps or [None] * len(audio_filepaths) wer_scores = [] for gt, asr_text in zip(ground_truths, asr_texts): @@ -257,7 +85,6 @@ def main(cfg: DictConfig): f"min={min(wer_scores):.4f}, max={max(wer_scores):.4f}" ) - # Dump the transcriptions and CTMs output_dir = cfg.get("output_dir", "./generated") dump_output(audio_filepaths, output, output_dir, options, ground_truths) logging.info(f"Transcriptions written to {output_dir}/output_processed.json and {output_dir}/output_raw.json") diff --git a/examples/speechlm2/nemo_inference_pipelines/triton/model_repo_s2s/voicechat/1/infer_streaming.py b/examples/speechlm2/nemo_inference_pipelines/triton/model_repo_s2s/voicechat/1/infer_streaming.py index be9c2e44c568..7f9ee3aba227 100644 --- a/examples/speechlm2/nemo_inference_pipelines/triton/model_repo_s2s/voicechat/1/infer_streaming.py +++ b/examples/speechlm2/nemo_inference_pipelines/triton/model_repo_s2s/voicechat/1/infer_streaming.py @@ -262,8 +262,8 @@ def get_generations(self, frames: List[Frame]) -> List[Tuple]: state = self.pipeline.get_or_create_state(stream_id) audio = state.audio_buffer - full_text = state.get_output_text() - full_asr_text = state.get_output_asr_text() + full_text = state.output_text_str + full_asr_text = state.output_asr_text_str if stream_id not in self.text_positions: self.text_positions[stream_id] = 0 @@ -279,7 +279,7 @@ def get_generations(self, frames: List[Frame]) -> List[Tuple]: generations.append((audio, incremental_text, incremental_asr_text)) - state.cleanup_after_response() + state.clear_audio_buffer() if frame.is_last: self.pipeline.delete_state(stream_id) diff --git a/nemo/collections/speechlm2/inference/__init__.py b/nemo/collections/speechlm2/inference/__init__.py index 9e3fb699d9f6..03abea1da74d 100644 --- a/nemo/collections/speechlm2/inference/__init__.py +++ b/nemo/collections/speechlm2/inference/__init__.py @@ -11,3 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from nemo.collections.speechlm2.inference.factory.s2s_pipeline_builder import S2SPipelineBuilder +from nemo.collections.speechlm2.inference.model_wrappers.decode_state import ( + InferenceStepResult, + StreamingDecodeState, +) +from nemo.collections.speechlm2.inference.pipelines.streaming_s2s_pipeline import StreamingS2SPipeline +from nemo.collections.speechlm2.inference.streaming.framing.s2s_request_options import S2SRequestOptions +from nemo.collections.speechlm2.inference.utils.pipeline_utils import PipelineOutput diff --git a/nemo/collections/speechlm2/inference/model_wrappers/decode_state.py b/nemo/collections/speechlm2/inference/model_wrappers/decode_state.py new file mode 100644 index 000000000000..9f34a35a55b2 --- /dev/null +++ b/nemo/collections/speechlm2/inference/model_wrappers/decode_state.py @@ -0,0 +1,80 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""State and result types for streaming S2S inference. + +These dataclasses define the *model wrapper's* interface contract: + +* :class:`StreamingDecodeState` — mutable per-stream decode state + (KV caches, token workspaces, perception/codec caches). Created by + the wrapper, mutated in-place by ``infer_one_step``, and held between + steps by the pipeline's context manager. + +* :class:`InferenceStepResult` — immutable per-step outputs returned + by ``infer_one_step`` (predicted tokens, text strings, audio). + +Defined here (in ``model_wrappers/``) because the wrapper is the +component that knows what state it needs. The context manager and +pipeline import from here. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, List, Optional, TYPE_CHECKING + +import torch + +if TYPE_CHECKING: + from nemo.collections.speechlm2.inference.model_wrappers.perception_cache import PerceptionCacheState + + +@dataclass +class StreamingDecodeState: + """Per-stream model-level decode state for streaming S2S inference. + + Holds KV caches, token workspaces, perception cache, and codec state + that persist across inference steps within a single stream. + """ + + frame_idx: int + gen_text: torch.Tensor + gen_asr_text: torch.Tensor + gen_function_text: Optional[torch.Tensor] + input_embeds_history: List[torch.Tensor] + llm_cache: Any # DynamicCache or HybridMambaAttentionDynamicCache + tts_past_key_values: Any + tts_code: Optional[torch.Tensor] + subword_mask: Optional[torch.Tensor] + perception_cache: Optional["PerceptionCacheState"] = None + tts_codec_cache: Any = None + llm_cache_position_offset: int = 0 + + +@dataclass +class InferenceStepResult: + """Output from a single ``infer_one_step`` call. + + State mutations (caches, token workspaces, frame_idx) are applied + in-place on :class:`StreamingDecodeState`. This dataclass carries + only the per-step *outputs* needed by the pipeline. + """ + + predicted_text_tokens: torch.Tensor + asr_predicted_text_tokens: torch.Tensor + predicted_text_strs: list[str] + asr_predicted_text_strs: list[str] + decoded_audio: Optional[torch.Tensor] = None + function_predicted_text_tokens: Optional[torch.Tensor] = None + debug: Optional[dict] = None diff --git a/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py b/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py index f4e31ba61fa1..f04e438bb375 100755 --- a/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py +++ b/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import gc import os import time @@ -34,6 +35,10 @@ PerceptionCacheState, PerceptionCacheManager, ) +from nemo.collections.speechlm2.inference.model_wrappers.decode_state import ( + InferenceStepResult, + StreamingDecodeState, +) # --- Configuration --- @@ -65,9 +70,12 @@ def __init__(self, model_cfg: DictConfig): if not isinstance(model_cfg, DictConfig): model_cfg = OmegaConf.create(model_cfg) - torch.backends.cudnn.allow_tf32 = True - torch.backends.cuda.matmul.allow_tf32 = True - torch.set_float32_matmul_precision("medium") + # Precision settings (applied here so they take effect before model loading) + allow_tf32 = bool(model_cfg.get("allow_tf32", True)) + torch.backends.cudnn.allow_tf32 = allow_tf32 + torch.backends.cuda.matmul.allow_tf32 = allow_tf32 + matmul_precision = str(model_cfg.get("matmul_precision", "medium")) + torch.set_float32_matmul_precision(matmul_precision) self._deterministic = bool(model_cfg.get("deterministic", False)) if self._deterministic: @@ -78,26 +86,13 @@ def __init__(self, model_cfg: DictConfig): "CUDA kernels (PagedAttention, FlashAttention) that do not support deterministic mode. " f"Got engine_type='{engine_type}'. Use engine_type='native' for deterministic inference." ) - - # Required by torch.use_deterministic_algorithms for cuBLAS reproducibility os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" - torch.manual_seed(0) torch.cuda.manual_seed_all(0) torch.backends.cuda.enable_flash_sdp(False) torch.backends.cuda.enable_mem_efficient_sdp(False) torch.use_deterministic_algorithms(True, warn_only=False) - logging.info("Deterministic mode ENABLED") - logging.info(f" CUBLAS_WORKSPACE_CONFIG={os.environ.get('CUBLAS_WORKSPACE_CONFIG')}") - logging.info(f" flash_sdp enabled: {torch.backends.cuda.flash_sdp_enabled()}") - logging.info(f" mem_efficient_sdp enabled: {torch.backends.cuda.mem_efficient_sdp_enabled()}") - logging.info( - " NOTE: deterministic mode uses different CUDA kernels (e.g. math SDPA instead of " - "FlashAttention), so results may differ slightly from non-deterministic mode. " - "Inference will also be slower." - ) - self.model_cfg = model_cfg self.model_path = model_cfg.get("model_path") @@ -131,7 +126,8 @@ def __init__(self, model_cfg: DictConfig): logging.info(f"Decode audio: {self.decode_audio}") logging.info(f"Engine type: {model_cfg.get('engine_type', 'native')}") logging.info(f"Sampling - top_p: {model_cfg.get('top_p', 1.0)}, repetition_penalty: {model_cfg.get('repetition_penalty', 1.0)}, temperature: {model_cfg.get('temperature', 1.0)}") - logging.info(f"Float32 matmul precision: {torch.get_float32_matmul_precision()}") + logging.info(f"Precision (configured): matmul_precision={matmul_precision}, allow_tf32={allow_tf32}, deterministic={self._deterministic}") + logging.info(f"Precision (effective): float32_matmul_precision={torch.get_float32_matmul_precision()}, cudnn.allow_tf32={torch.backends.cudnn.allow_tf32}, cuda.matmul.allow_tf32={torch.backends.cuda.matmul.allow_tf32}") logging.info("=" * 70) # Cached TTS helpers populated during initialization/warmup @@ -466,7 +462,6 @@ def _clone_cache(self, cache): if isinstance(cache, dict): return {k: self._clone_cache(v) for k, v in cache.items()} if hasattr(cache, '__dict__'): - import copy return copy.deepcopy(cache) return cache @@ -475,7 +470,7 @@ def _build_prompt_token_ids(self, system_prompt: str | None) -> list[int]: return [] return [self.tokenizer.bos_id] + self.tokenizer.text_to_ids(system_prompt) + [self.tokenizer.eos_id] - def _create_generation_workspace(self, max_len: int): + def _init_token_buffers(self, max_len: int): stt_model = self.model.stt_model gen_text = torch.full((1, max_len), stt_model.text_pad_id, device=self.device, dtype=torch.long) gen_asr_text = torch.full((1, max_len), stt_model.text_pad_id, device=self.device, dtype=torch.long) @@ -552,128 +547,99 @@ def _prepare_tts_initial_state(self): logging.info("TTS warmup state prepared") - def create_decode_state(self, max_len: int): - gen_text, gen_asr_text, gen_function_text = self._create_generation_workspace(max_len) + def create_decode_state(self, max_len: int) -> StreamingDecodeState: + gen_text, gen_asr_text, gen_function_text = self._init_token_buffers(max_len) llm_cache = self._create_llm_cache() - subword_mask, codec_cache = self._create_codec_state(max_len) + subword_mask, tts_codec_cache = self._create_codec_state(max_len) perception_cache = None if self.use_perception_cache and self.perception_cache_mgr is not None: perception_cache = self.perception_cache_mgr.get_initial_state(batch_size=1) - past_key_values = None - code = None + tts_past_key_values = None + tts_code = None if self.decode_audio and self.first_tts_code_input is not None: - past_key_values = self._clone_cache(self.first_tts_past_key_values_input) - code = self.first_tts_code_input.detach().clone() - - return { - "frame_idx": 0, - "gen_text": gen_text, - "gen_asr_text": gen_asr_text, - "gen_function_text": gen_function_text, - "input_embeds_history": [], - "dynamic_cache": llm_cache, - "past_key_values": past_key_values, - "code": code, - "subword_mask": subword_mask, - "perception_cache": perception_cache, - "codec_cache": codec_cache, - "cache_position_offset": 0, - } - - def infer_one_step(self, - audio_input, - num_frames_per_chunk, - frame_idx, - gen_text, - input_embeds_history, - dynamic_cache, - past_key_values=None, - code=None, - subword_mask=None, - gen_asr_text=None, - gen_function_text=None, - request_id: Optional[str] = None, - perception_cache: Optional[PerceptionCacheState] = None, - has_prompt: bool = False, - codec_cache=None, - cache_position_offset: int = 0, - return_debug: bool = False): - - # Set up effective request ID for vLLM streaming + tts_past_key_values = self._clone_cache(self.first_tts_past_key_values_input) + tts_code = self.first_tts_code_input.detach().clone() + + return StreamingDecodeState( + frame_idx=0, + gen_text=gen_text, + gen_asr_text=gen_asr_text, + gen_function_text=gen_function_text, + input_embeds_history=[], + llm_cache=llm_cache, + tts_past_key_values=tts_past_key_values, + tts_code=tts_code, + subword_mask=subword_mask, + perception_cache=perception_cache, + tts_codec_cache=tts_codec_cache, + llm_cache_position_offset=0, + ) + + def infer_one_step( + self, + audio_input: torch.Tensor, + num_frames_per_chunk: int, + state: StreamingDecodeState, + *, + request_id: Optional[str] = None, + has_prompt: bool = False, + return_debug: bool = False, + ) -> InferenceStepResult: + """Run one streaming inference step: perception -> LLM -> TTS -> audio decode. + + All mutable decode state (caches, gen_text, gen_asr_text, code, etc.) is + updated **in-place** on *state*. The returned :class:`InferenceStepResult` + carries only per-step outputs needed by the pipeline. + """ effective_request_id = request_id or self.request_id + frame_idx = state.frame_idx start_time_one_step = time.time() - use_cache = dynamic_cache is not None - batch_size = gen_text.shape[0] + use_cache = state.llm_cache is not None + batch_size = state.gen_text.shape[0] - predicted_tokens = torch.empty((batch_size, num_frames_per_chunk), dtype=gen_text.dtype, device=gen_text.device) - asr_predicted_tokens = torch.empty((batch_size, num_frames_per_chunk), dtype=gen_text.dtype, device=gen_text.device) - function_predicted_tokens = torch.empty((batch_size, num_frames_per_chunk), dtype=gen_text.dtype, device=gen_text.device) + predicted_tokens = torch.empty((batch_size, num_frames_per_chunk), dtype=state.gen_text.dtype, device=state.gen_text.device) + asr_predicted_tokens = torch.empty((batch_size, num_frames_per_chunk), dtype=state.gen_text.dtype, device=state.gen_text.device) + function_predicted_tokens = torch.empty((batch_size, num_frames_per_chunk), dtype=state.gen_text.dtype, device=state.gen_text.device) debug_text_logits = [] debug_asr_logits = [] debug_input_embeds = [] selected_frame_indices = [] - # Do "perception" step outside the for-loop - start_perception = time.time() - - if self.use_perception_cache and perception_cache is not None and perception_cache.is_initialized(): - # Cache-aware perception - source_encoded, perception_cache = self.perception_cache_mgr.step( - audio_input=audio_input, - frame_idx=frame_idx, - num_frames_per_chunk=num_frames_per_chunk, - perception_cache=perception_cache, - ) - else: - # Standard perception (full buffer processing) - buffer_len = torch.tensor([audio_input.shape[1]], dtype=torch.long, device=self.device) - source_encoded, _, _ = self.model.stt_model.perception( - input_signal=audio_input, - input_signal_length=buffer_len, - return_encoder_emb=True, - ) - - torch.cuda.synchronize() - time_perception = time.time() - start_perception - logging.info(f"Time taken for perception: {time_perception:.3f}s") - source_encoded = source_encoded.to(self.dtype) + # --- Stage 1: Perception --- + source_encoded, state.perception_cache = self._run_perception( + audio_input, frame_idx, num_frames_per_chunk, state.perception_cache, + ) total_encoded_frames = source_encoded.shape[1] - # Determine embedding position based on whether we're using cache - if self.use_perception_cache and perception_cache is not None and perception_cache.is_initialized(): - # With cache: we get exactly num_frames_per_chunk output frames - # Use all of them directly - embedding_position = 0 - newest_frame_index = total_encoded_frames - 1 + if self.use_perception_cache and state.perception_cache is not None and state.perception_cache.is_initialized(): + # With cache: we get exactly num_frames_per_chunk output frames — use all directly base_frame_index = 0 else: # Without cache: Use the second-to-last encoded frame (-2) as the "newest" frame embedding. - # This is because the model's expects the chunk sizes to be size 10ms, 80ms, 80ms, 80ms, ...., + # This is because the model expects the chunk sizes to be size 10ms, 80ms, 80ms, 80ms, ...., # but we pass in always 80ms, 80ms, 80ms.... # e.g. # (1) if we pass in just one 80ms chunk -> the model treats it as 10ms, then 70ms with 10ms silence padding at the end. # (2) if we pass 80ms, 80ms -> the model treats it as 10ms, 80ms, 70ms with 10ms silence padding at the end. # => we do not want to use the final embedding due to containing silence padding. We want to use the second-to-last embedding. - embedding_position = -2 - newest_frame_index = total_encoded_frames + embedding_position - base_frame_index = newest_frame_index - (num_frames_per_chunk - 1) - base_frame_index = max(base_frame_index, 0) + newest_frame_index = total_encoded_frames - 2 + base_frame_index = max(newest_frame_index - (num_frames_per_chunk - 1), 0) + # --- Stage 2: Per-frame generation loop --- new_input_embeds = [] new_codes_for_decode = [] for frame_offset in range(num_frames_per_chunk): current_frame_idx = frame_idx + frame_offset - current_frame_index = base_frame_index + frame_offset - current_frame_index = min(current_frame_index, total_encoded_frames - 1) + current_frame_index = min(base_frame_index + frame_offset, total_encoded_frames - 1) selected_frame_indices.append(current_frame_index) current_frame_embedding = source_encoded[:, current_frame_index:current_frame_index + 1, :] current_input_emb = current_frame_embedding.clone() current_input_emb *= self.model.stt_model.cfg.get("duplex_user_channel_weight", 1.0) - has_fc = gen_function_text is not None + has_fc = state.gen_function_text is not None if current_frame_idx == 0 and not has_prompt: # Only add BOS if there's no prompt (BOS is already in prompt's position 0) @@ -689,7 +655,6 @@ def infer_one_step(self, current_input_emb += self.model.stt_model.embed_tokens(fc_pad_token).to(dtype=self.dtype) elif current_frame_idx == 0 and has_prompt: # With prompt: first audio frame uses pad embedding (like offline_inference) - # gen_text[:, -1] from prompt positions is pad_id pad_id = self.model.stt_model.text_pad_id pad_token = torch.full((1,), fill_value=pad_id, device=self.device, dtype=torch.long) pad_emb = self.model.stt_model.embed_tokens(pad_token).to(dtype=self.dtype) @@ -701,14 +666,14 @@ def infer_one_step(self, else: # t > 0: add embeddings from model's own predictions at t-1 last_token_emb = self.model.stt_model.embed_tokens( - gen_text[:, current_frame_idx - 1] + state.gen_text[:, current_frame_idx - 1] ) * self.model.stt_model.cfg.get("duplex_text_channel_weight", 1.0) last_asr_token_emb = self.model.stt_model.embed_asr_tokens( - gen_asr_text[:, current_frame_idx - 1] + state.gen_asr_text[:, current_frame_idx - 1] ) * self.model.stt_model.cfg.get("duplex_asr_text_weight", 1.0) current_input_emb += last_token_emb + last_asr_token_emb if has_fc: - last_fc_token_emb = self.model.stt_model.embed_tokens(gen_function_text[:, current_frame_idx - 1]) + last_fc_token_emb = self.model.stt_model.embed_tokens(state.gen_function_text[:, current_frame_idx - 1]) current_input_emb += last_fc_token_emb.to(dtype=self.dtype) if return_debug: debug_input_embeds.append(current_input_emb.detach().cpu()) @@ -717,33 +682,32 @@ def infer_one_step(self, if use_cache or self.use_vllm_llm: if self.use_vllm_llm: - # vLLM requires request_id ans = self.model_llm_interface( current_input_emb, request_id=effective_request_id, - generated_tokens=gen_text, + generated_tokens=state.gen_text, current_step=current_frame_idx ) else: cache_pos = torch.tensor( - [cache_position_offset + frame_offset], device=self.device + [state.llm_cache_position_offset + frame_offset], device=self.device ) ans = self.model_llm_interface( current_input_emb, - cache=dynamic_cache, + cache=state.llm_cache, cache_position=cache_pos, - generated_tokens=gen_text, + generated_tokens=state.gen_text, current_step=current_frame_idx, return_logits=return_debug, ) - dynamic_cache = ans["cache"] + state.llm_cache = ans["cache"] else: new_input_embeds.append(current_input_emb) - full_input_embeds = torch.cat(input_embeds_history + new_input_embeds, dim=1) + full_input_embeds = torch.cat(state.input_embeds_history + new_input_embeds, dim=1) ans = self.model_llm_interface( full_input_embeds, cache=None, - generated_tokens=gen_text, + generated_tokens=state.gen_text, current_step=current_frame_idx, return_logits=return_debug, ) @@ -759,35 +723,33 @@ def infer_one_step(self, if return_debug and "asr_logits" in ans and ans["asr_logits"] is not None: debug_asr_logits.append(ans["asr_logits"][:, -1].detach().cpu()) - gen_text[:, current_frame_idx] = predicted_token + state.gen_text[:, current_frame_idx] = predicted_token predicted_tokens[:, frame_offset] = predicted_token - gen_asr_text[:, current_frame_idx] = asr_predicted_token + state.gen_asr_text[:, current_frame_idx] = asr_predicted_token asr_predicted_tokens[:, frame_offset] = asr_predicted_token if "function_predicted_token" in ans: function_predicted_tokens[:, frame_offset] = ans["function_predicted_token"] - if gen_function_text is not None: - gen_function_text[:, current_frame_idx] = ans["function_predicted_token"] + if state.gen_function_text is not None: + state.gen_function_text[:, current_frame_idx] = ans["function_predicted_token"] # Apply forced turn taking based on ASR results - self._maybe_apply_forced_turn_taking(current_frame_idx, gen_text, gen_asr_text) + self._maybe_apply_forced_turn_taking(current_frame_idx, state.gen_text, state.gen_asr_text) # Update predicted_tokens with any changes made by forced turn taking - predicted_tokens[:, frame_offset] = gen_text[:, current_frame_idx] + predicted_tokens[:, frame_offset] = state.gen_text[:, current_frame_idx] if self.decode_audio: - current_subword_id = gen_text[:, current_frame_idx].unsqueeze(-1) + current_subword_id = state.gen_text[:, current_frame_idx].unsqueeze(-1) - # do one step inference on Duplex TTS model if current_frame_idx == 0: if self.first_context_subword_id is None: raise RuntimeError("first_context_subword_id is not initialized. Ensure TTS warmup ran successfully.") prev_subword_id = self.first_context_subword_id else: - prev_subword_id = gen_text[:, current_frame_idx-1].unsqueeze(-1) + prev_subword_id = state.gen_text[:, current_frame_idx-1].unsqueeze(-1) - # create subword_mask - current_subword_mask = subword_mask[:, current_frame_idx].unsqueeze(-1) + current_subword_mask = state.subword_mask[:, current_frame_idx].unsqueeze(-1) if self.generation_config is None: raise RuntimeError("generation_config is not initialized. Ensure TTS warmup ran successfully.") @@ -797,8 +759,8 @@ def infer_one_step(self, "current_subword_id": current_subword_id, "prev_subword_id": prev_subword_id, "current_subword_mask": current_subword_mask, - "prev_audio_tokens": code, - "past_key_values": past_key_values, + "prev_audio_tokens": state.tts_code, + "past_key_values": state.tts_past_key_values, "guidance_enabled": True, "generation_config": self.generation_config, "ignore_eos_flag_stop": True, @@ -806,29 +768,27 @@ def infer_one_step(self, if self.use_vllm_eartts: inputs["request_id"] = effective_request_id - code, past_key_values = self.model.tts_model.infer_codes_one_step( - **inputs - ) + state.tts_code, state.tts_past_key_values = self.model.tts_model.infer_codes_one_step(**inputs) torch.cuda.synchronize() time_tts_model = time.time() - start_tts_model logging.info(f"Time taken for tts_model: {time_tts_model:.3f}s") - new_codes_for_decode.append(code.clone()) + new_codes_for_decode.append(state.tts_code.clone()) # Potentially overwrite the audio token with silence tokens (for feeding to the audio token predictor) if self.model.cfg.get('inference_force_speech_silence_on_eos', None): - silence_codes = self.model.tts_model.codec_silence_tokens.view(1, 1, -1).expand(code.shape) - code = torch.where( + silence_codes = self.model.tts_model.codec_silence_tokens.view(1, 1, -1).expand(state.tts_code.shape) + state.tts_code = torch.where( current_subword_id.unsqueeze(-1) == self.model.tts_model.text_eos_id, silence_codes, - code, + state.tts_code, ) - # exit for-loop & do audio decoding non-autoregressively (if decode_audio is True) + # --- Stage 3: Audio decode --- + decoded_audio_new = None if self.decode_audio: - samples_per_audio_output_frame = self._samples_per_audio_output_frame() - logging.debug(f"\nDecoding audio for {frame_idx}-th frame ({num_frames_per_chunk=})") + logging.info(f"\nDecoding audio for {frame_idx}-th frame ({num_frames_per_chunk=})") start_time_decode = time.time() with fp32_precision(), torch.no_grad(): @@ -844,75 +804,102 @@ def infer_one_step(self, [new_codes_tensor.shape[1]], dtype=torch.long, device=self.device ) decoded_audio_new, _ = self.model.tts_model.audio_codec.decode( - new_codes_tensor, new_code_len, cache=codec_cache, + new_codes_tensor, new_code_len, cache=state.tts_codec_cache, ) torch.cuda.synchronize() time_audio_codec = time.time() - start_time_decode logging.info(f"Time taken for audio_codec: {time_audio_codec:.3f}s") - else: - decoded_audio_new = None - time_tts_model = 0 - time_audio_codec = 0 - - # Convert new text tokens to string via tokens_to_text (convert_tokens_to_string) - # so byte-level BPE is decoded properly (e.g. "é" → "é") and leading spaces - # from Ġ-prefixed tokens are preserved for correct concatenation of incremental - # chunks: " Musée" + " National" → " Musée National". - # NOTE: multi-byte UTF-8 characters whose BPE tokens span two frames will show - # as replacement chars (�) because each frame is decoded independently. A proper - # fix would require an incremental UTF-8 decoder that buffers incomplete trailing - # bytes across frames. - predicted_text_strs = [] - for predicted_tok_ids_b in predicted_tokens: - predicted_tok_ids_b = predicted_tok_ids_b.tolist() - predicted_toks_b = self.tokenizer.ids_to_tokens(predicted_tok_ids_b) - predicted_toks_b = [tok for tok in predicted_toks_b if tok != ''] - predicted_text_strs.append(self.tokenizer.tokens_to_text(predicted_toks_b)) - - # convert new ASR tokens to string - asr_predicted_text_strs = [] - for asr_predicted_tok_ids_b in asr_predicted_tokens: - asr_predicted_tok_ids_b = asr_predicted_tok_ids_b.tolist() - asr_predicted_toks_b = self.tokenizer.ids_to_tokens(asr_predicted_tok_ids_b) - asr_predicted_toks_b = [tok for tok in asr_predicted_toks_b if tok != ''] - asr_predicted_text_strs.append(self.tokenizer.tokens_to_text(asr_predicted_toks_b)) - - logging.info(f'frame {frame_idx}: USER\'s asr_predicted_text_strs: {asr_predicted_text_strs}') - logging.info(f'frame {frame_idx}: --------------------------------AGENT\'s predicted_text_strs: {predicted_text_strs}') + # --- Stage 4: Token -> string conversion --- + predicted_text_strs = self._tokens_to_strings(predicted_tokens) + asr_predicted_text_strs = self._tokens_to_strings(asr_predicted_tokens) - torch.cuda.synchronize() + logging.info(f'frame {frame_idx}: USER asr: {asr_predicted_text_strs}') + logging.info(f'frame {frame_idx}: AGENT txt: {predicted_text_strs}') + + # --- Update remaining state fields --- + if not use_cache: + state.input_embeds_history = state.input_embeds_history + new_input_embeds + if use_cache: + state.llm_cache_position_offset += num_frames_per_chunk + torch.cuda.synchronize() time_for_one_step = time.time() - start_time_one_step logging.info(f'frame {frame_idx}: Time taken for one step: {time_for_one_step:.3f}s') - result = { - 'predicted_text_tokens': predicted_tokens, - 'asr_predicted_text_tokens': asr_predicted_tokens, - 'decoded_audio_new': decoded_audio_new, - 'predicted_text_strs': predicted_text_strs, - 'asr_predicted_text_strs': asr_predicted_text_strs, - 'input_embeds_history': input_embeds_history + new_input_embeds if not use_cache else input_embeds_history, - 'dynamic_cache': dynamic_cache if use_cache else None, - 'past_key_values': past_key_values, - 'code': code, - 'perception_cache': perception_cache, - 'codec_cache': codec_cache, - 'cache_position_offset': cache_position_offset + num_frames_per_chunk if use_cache else cache_position_offset, - } - if self.model.stt_model.function_head is not None: - result['function_predicted_text_tokens'] = function_predicted_tokens + debug = None if return_debug: - result["debug"] = { + debug = { "source_encoded": source_encoded.detach().cpu(), "selected_frame_indices": selected_frame_indices, "input_embeds": torch.cat(debug_input_embeds, dim=1) if debug_input_embeds else None, - "gen_text": gen_text.detach().cpu(), - "gen_asr": gen_asr_text.detach().cpu() if gen_asr_text is not None else None, + "gen_text": state.gen_text.detach().cpu(), + "gen_asr": state.gen_asr_text.detach().cpu() if state.gen_asr_text is not None else None, "text_logits": torch.stack(debug_text_logits, dim=1) if debug_text_logits else None, "asr_logits": torch.stack(debug_asr_logits, dim=1) if debug_asr_logits else None, } + + func_tokens = function_predicted_tokens if self.model.stt_model.function_head is not None else None + return InferenceStepResult( + predicted_text_tokens=predicted_tokens, + asr_predicted_text_tokens=asr_predicted_tokens, + predicted_text_strs=predicted_text_strs, + asr_predicted_text_strs=asr_predicted_text_strs, + decoded_audio=decoded_audio_new, + function_predicted_text_tokens=func_tokens, + debug=debug, + ) + + def _run_perception( + self, + audio_input: torch.Tensor, + frame_idx: int, + num_frames_per_chunk: int, + perception_cache: Optional[PerceptionCacheState], + ) -> Tuple[torch.Tensor, Optional[PerceptionCacheState]]: + """Run the perception encoder and return (source_encoded, updated_cache).""" + start_perception = time.time() + + if self.use_perception_cache and perception_cache is not None and perception_cache.is_initialized(): + source_encoded, perception_cache = self.perception_cache_mgr.step( + audio_input=audio_input, + frame_idx=frame_idx, + num_frames_per_chunk=num_frames_per_chunk, + perception_cache=perception_cache, + ) + else: + buffer_len = torch.tensor([audio_input.shape[1]], dtype=torch.long, device=self.device) + source_encoded, _, _ = self.model.stt_model.perception( + input_signal=audio_input, + input_signal_length=buffer_len, + return_encoder_emb=True, + ) + + torch.cuda.synchronize() + time_perception = time.time() - start_perception + logging.info(f"Time taken for perception: {time_perception:.3f}s") + source_encoded = source_encoded.to(self.dtype) + return source_encoded, perception_cache + + def _tokens_to_strings(self, token_ids: torch.Tensor) -> list[str]: + """Convert a [B, T] tensor of token IDs to a list of strings. + + Uses tokens_to_text (convert_tokens_to_string) so byte-level BPE is + decoded properly (e.g. "é" -> "é") and leading spaces from + Ġ-prefixed tokens are preserved for correct concatenation of + incremental chunks: " Musée" + " National" -> " Musée National". + + NOTE: multi-byte UTF-8 characters whose BPE tokens span two frames + will show as replacement chars (U+FFFD) because each frame is decoded + independently. + """ + result = [] + for tok_ids_b in token_ids: + tok_ids_b = tok_ids_b.tolist() + toks = self.tokenizer.ids_to_tokens(tok_ids_b) + toks = [t for t in toks if t != ''] + result.append(self.tokenizer.tokens_to_text(toks)) return result def abort_request(self, request_id: Optional[str]) -> bool: diff --git a/nemo/collections/speechlm2/inference/pipelines/streaming_s2s_pipeline.py b/nemo/collections/speechlm2/inference/pipelines/streaming_s2s_pipeline.py index 7d5ea6fce7b3..d5a1aa6b6a3e 100644 --- a/nemo/collections/speechlm2/inference/pipelines/streaming_s2s_pipeline.py +++ b/nemo/collections/speechlm2/inference/pipelines/streaming_s2s_pipeline.py @@ -141,16 +141,12 @@ def __init__(self, cfg: DictConfig, s2s_model: NemotronVoicechatInferenceWrapper # ------------------------------------------------------------------ def create_state(self) -> S2SStreamingState: """Create new empty state.""" - num_audio_codebooks = getattr(self.s2s_model.model, "_num_codebooks", 1) - dtype = getattr(self.s2s_model, "compute_dtype", torch.float32) - state = S2SStreamingState( + dtype = getattr(self.s2s_model, "dtype", torch.float32) + return S2SStreamingState( device=self.device, dtype=dtype, - max_len=self.max_len, - num_audio_codebooks=num_audio_codebooks, output_sample_rate=self.output_sample_rate, ) - return state # ------------------------------------------------------------------ @@ -178,7 +174,7 @@ def log_output(self, frames: List[Frame], audio_wave: Tensor, ready_feats: List[ if isinstance(candidate, str) and candidate: asr_piece = candidate - state.update_state(sample_audio, output_text=piece, output_asr_text=asr_piece) + state.append_step_output(sample_audio, text=piece, asr_text=asr_piece) def inner_generate_step(self, frames: List[Frame], buffers: List[Tensor], left_paddings: List[int], ready_feats: List[bool]): @@ -230,28 +226,17 @@ def inner_generate_step(self, frames: List[Frame], buffers: List[Tensor], left_p result = self.s2s_model.infer_one_step( audio_input=audio_buffer, num_frames_per_chunk=self.num_frames_per_chunk, - frame_idx=context.frame_idx, - gen_text=context.gen_text, - input_embeds_history=context.input_embeds_history, - dynamic_cache=context.dynamic_cache, - past_key_values=context.past_key_values, - code=context.code, - subword_mask=context.subword_mask, - gen_asr_text=context.gen_asr_text, - gen_function_text=context.gen_function_text, + state=context, request_id=request_id, - perception_cache=context.perception_cache, has_prompt=has_prompt, - codec_cache=context.codec_cache, - cache_position_offset=context.cache_position_offset, return_debug=self.collect_debug, ) - if self.collect_debug and "debug" in result: + if self.collect_debug and result.debug is not None: state = self.get_or_create_state(stream_ids[0]) if not hasattr(state, "debug_steps"): state.debug_steps = [] - state.debug_steps.append(result["debug"]) + state.debug_steps.append(result.debug) # Persist updated cache & clean finished streams self.context_manager.update_context(stream_ids, result, self.num_frames_per_chunk) @@ -280,7 +265,7 @@ def inner_generate_step(self, frames: List[Frame], buffers: List[Tensor], left_p # It will be cleaned up in close_session() # Log audio and attach text to state - self.log_output(frames, result["decoded_audio_new"], ready_feats, result["predicted_text_strs"], result.get("asr_predicted_text_strs")) + self.log_output(frames, result.decoded_audio, ready_feats, result.predicted_text_strs, result.asr_predicted_text_strs) def prefill_for_new_stream(self, stream_id: int, system_prompt: str | None = None) -> bool: """Prepare the pipeline for a new stream by resetting context and prefilling the system prompt. @@ -402,7 +387,7 @@ def _finalize_and_save_finished_streams( if hasattr(state, "finalize"): state.finalize() # Concatenate emitted chunks and squeeze (B=1,C=1) to mono waveform - generated_audio = torch.cat(state.speech_frames, dim=-1) + generated_audio = state.audio_buffer # Ensure 1D mono waveform and float32 dtype for soundfile if generated_audio.dim() == 3 and generated_audio.size(0) == 1 and generated_audio.size(1) == 1: generated_audio = generated_audio.squeeze(0).squeeze(0) @@ -452,7 +437,7 @@ def _finalize_and_save_finished_streams( sf.write(stereo_path, stereo.detach().cpu().numpy(), self.output_sample_rate) # Save accumulated text - text_out = state.get_output_text() if hasattr(state, "get_output_text") else "" + text_out = state.output_text_str if isinstance(text_out, str): try: with open(os.path.join(txt_dir, f"{base}.txt"), "w", encoding="utf-8") as f: @@ -461,7 +446,7 @@ def _finalize_and_save_finished_streams( pass # Save accumulated ASR text - asr_text_out = state.get_output_asr_text() if hasattr(state, "get_output_asr_text") else "" + asr_text_out = state.output_asr_text_str if isinstance(asr_text_out, str) and asr_text_out: try: with open(os.path.join(txt_dir, f"{base}_asr.txt"), "w", encoding="utf-8") as f: @@ -607,15 +592,11 @@ def run( for idx in range(len(audio_filepaths)): state = self.get_or_create_state(idx) - text_value = state.get_output_text() if hasattr(state, "get_output_text") else "" - if not text_value: - text_value = saved_paths_by_stream.get(idx, "") + text_value = state.output_text_str or saved_paths_by_stream.get(idx, "") texts.append(text_value) audio_paths.append(saved_paths_by_stream.get(idx)) - per_stream_words = state.get_output_words() if hasattr(state, "get_output_words") else [] - words.append(per_stream_words) - asr_text_value = state.get_output_asr_text() if hasattr(state, "get_output_asr_text") else "" - asr_texts.append(asr_text_value) + words.append(list(state.output_words)) + asr_texts.append(state.output_asr_text_str) token_data = state.get_token_tensors() if token_data is not None: @@ -690,14 +671,14 @@ def _prefill_system_prompt(self, stream_id: int, system_prompt: str | None = Non Note on TTS prefill codes: The TTS prefill generates output codes, but these should NOT be used - to initialize context.code for inference. The batch approach uses + to initialize context.tts_code for inference. The batch approach uses first_tts_code_input (INPUT codes from speaker reference) instead. Using prefill OUTPUT codes causes audio quality issues (mumbling). Returns: Optional[torch.Tensor]: The TTS prefill output codes if vLLM EarTTS prefill happened, None otherwise. These are returned for logging/debugging but - should NOT be used to update context.code. + should NOT be used to update context.tts_code. """ request_id = self._request_id_for_stream(stream_id) engine_type = getattr(self.s2s_model, "engine_type", "native") @@ -758,11 +739,11 @@ def _prefill_system_prompt(self, stream_id: int, system_prompt: str | None = Non else: context, _ = self.context_manager.get_context([stream_id]) - if context.dynamic_cache is not None: + if context.llm_cache is not None: # Native cache mode: process prompt through LLM to update KV cache with torch.no_grad(): cache_pos = torch.arange(prompt_len, device=self.s2s_model.device) - llm_cache = context.dynamic_cache + llm_cache = context.llm_cache ans = self.s2s_model.model_llm_interface( prompt_embedded, cache=llm_cache, @@ -770,8 +751,8 @@ def _prefill_system_prompt(self, stream_id: int, system_prompt: str | None = Non generated_tokens=None, current_step=0 ) - context.dynamic_cache = ans.get("cache", llm_cache) - context.cache_position_offset = prompt_len + context.llm_cache = ans.get("cache", llm_cache) + context.llm_cache_position_offset = prompt_len logging.info(f"System prompt processed, cache updated ({prompt_len} tokens, offset={prompt_len})") else: for t in range(prompt_len): diff --git a/nemo/collections/speechlm2/inference/streaming/state/s2s_context_manager.py b/nemo/collections/speechlm2/inference/streaming/state/s2s_context_manager.py index ab4a49e0ce92..9ea360f8ff1f 100644 --- a/nemo/collections/speechlm2/inference/streaming/state/s2s_context_manager.py +++ b/nemo/collections/speechlm2/inference/streaming/state/s2s_context_manager.py @@ -12,31 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass from queue import Queue -from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING +from typing import Any, Dict, List, Optional, Tuple import torch from nemo.utils import logging -if TYPE_CHECKING: - from nemo.collections.speechlm2.inference.model_wrappers.perception_cache import PerceptionCacheState - - -@dataclass -class StreamingDecodeState: - frame_idx: int - gen_text: torch.Tensor - gen_asr_text: torch.Tensor - gen_function_text: Optional[torch.Tensor] - input_embeds_history: List[torch.Tensor] - dynamic_cache: Any # DynamicCache or HybridMambaAttentionDynamicCache - past_key_values: Any - code: Optional[torch.Tensor] - subword_mask: Optional[torch.Tensor] - perception_cache: Optional["PerceptionCacheState"] = None - codec_cache: Any = None - cache_position_offset: int = 0 +from nemo.collections.speechlm2.inference.model_wrappers.decode_state import ( + InferenceStepResult, + StreamingDecodeState, +) class S2SContextManager: @@ -69,21 +54,7 @@ def _create_context(self) -> StreamingDecodeState: """Allocate a fresh context backed by the realtime inference model.""" if not hasattr(self.s2s_model, "create_decode_state"): raise RuntimeError("s2s_model must provide create_decode_state(max_len)") - decode_state = self.s2s_model.create_decode_state(self.max_len) - return StreamingDecodeState( - frame_idx=decode_state["frame_idx"], - gen_text=decode_state["gen_text"], - gen_asr_text=decode_state["gen_asr_text"], - gen_function_text=decode_state["gen_function_text"], - input_embeds_history=decode_state["input_embeds_history"], - dynamic_cache=decode_state["dynamic_cache"], - past_key_values=decode_state["past_key_values"], - code=decode_state["code"], - subword_mask=decode_state["subword_mask"], - perception_cache=decode_state["perception_cache"], - codec_cache=decode_state["codec_cache"], - cache_position_offset=decode_state["cache_position_offset"], - ) + return self.s2s_model.create_decode_state(self.max_len) def _ensure_slot(self, stream_id: int) -> int: if stream_id not in self.streamidx2slotidx: @@ -121,10 +92,17 @@ def reset_slot(self, slot_idx: int) -> None: def update_context( self, stream_ids: List[int], - step_result: Dict[str, Any], + step_result: InferenceStepResult, num_frames: int, ) -> None: - """Persist model outputs back into the cached context.""" + """Advance frame counter and set subword mask after an inference step. + + All cache and tensor mutations (dynamic_cache, past_key_values, code, + perception_cache, codec_cache, gen_text, gen_asr_text, etc.) are + already applied in-place on the ``StreamingDecodeState`` by + ``infer_one_step``. This method only bumps ``frame_idx`` and marks + the subword mask for the newly generated frames. + """ if len(stream_ids) == 0: return if len(stream_ids) != 1: @@ -147,48 +125,10 @@ def update_context( "Context maximum length exceeded. Consider increasing `streaming.max_len` in the configuration." ) - predicted_tokens = step_result.get("predicted_text_tokens") - if predicted_tokens is not None: - if predicted_tokens.dim() == 1: - token_slice = predicted_tokens.unsqueeze(0) - else: - token_slice = predicted_tokens[0:1] - context.gen_text[:, start_idx:end_idx] = token_slice.to(context.gen_text.device) - - asr_predicted_tokens = step_result.get("asr_predicted_text_tokens") - if asr_predicted_tokens is not None: - if asr_predicted_tokens.dim() == 1: - asr_token_slice = asr_predicted_tokens.unsqueeze(0) - else: - asr_token_slice = asr_predicted_tokens[0:1] - context.gen_asr_text[:, start_idx:end_idx] = asr_token_slice.to(context.gen_asr_text.device) - - func_predicted_tokens = step_result.get("function_predicted_text_tokens") - if func_predicted_tokens is not None and context.gen_function_text is not None: - if func_predicted_tokens.dim() == 1: - func_token_slice = func_predicted_tokens.unsqueeze(0) - else: - func_token_slice = func_predicted_tokens[0:1] - context.gen_function_text[:, start_idx:end_idx] = func_token_slice.to(context.gen_function_text.device) - context.frame_idx = end_idx - if step_result.get("dynamic_cache") is not None: - context.dynamic_cache = step_result["dynamic_cache"] - if "input_embeds_history" in step_result: - context.input_embeds_history = step_result["input_embeds_history"] - if "past_key_values" in step_result: - context.past_key_values = step_result["past_key_values"] - if "code" in step_result: - context.code = step_result["code"] if context.subword_mask is not None: context.subword_mask[:, start_idx:end_idx] = True - if "perception_cache" in step_result and step_result["perception_cache"] is not None: - context.perception_cache = step_result["perception_cache"] - if "codec_cache" in step_result and step_result["codec_cache"] is not None: - context.codec_cache = step_result["codec_cache"] - if "cache_position_offset" in step_result: - context.cache_position_offset = step_result["cache_position_offset"] def reset_slots(self, stream_ids: List[int], eos_flags: List[bool]) -> None: """Release contexts for streams that signalled end-of-stream.""" diff --git a/nemo/collections/speechlm2/inference/streaming/state/s2s_state.py b/nemo/collections/speechlm2/inference/streaming/state/s2s_state.py index c6ab5ae0a5e9..7ddb6c1f0f0b 100644 --- a/nemo/collections/speechlm2/inference/streaming/state/s2s_state.py +++ b/nemo/collections/speechlm2/inference/streaming/state/s2s_state.py @@ -13,134 +13,120 @@ # limitations under the License. from dataclasses import dataclass, field -from typing import List, Any, Optional +from typing import List, Optional import torch from nemo.collections.asr.inference.utils.text_segment import Word +from nemo.utils import logging @dataclass class S2SStreamingState: - """ - State for streaming speech generation. + """Pipeline-level output accumulator for a single S2S stream. + + Collects generated audio samples, text strings, and word timings + produced by each inference step. Kept alive by the pipeline's + ``_state_pool`` until ``close_session()`` so the final + :class:`PipelineOutput` can be assembled. - This dataclass stores streaming tensors and counters used during - incremental generation. It keeps initialization metadata so it can be - reset to a clean state on demand. + This is *not* the model-level decode state (KV caches, token + workspaces) -- that is :class:`StreamingDecodeState` in + ``model_wrappers/decode_state.py``. """ - # Initialization metadata (required) + + # Required init metadata device: torch.device dtype: torch.dtype - max_len: int - num_audio_codebooks: int output_sample_rate: int - # Runtime tensors (initialized in __post_init__) + # Growing audio buffer — shape (1, T), appended each step audio_buffer: torch.Tensor = field(init=False) - # Accumulated text output + # Accumulated agent response text (built incrementally per step) output_text_str: str = "" - output_text_tokens: List[str] = field(default_factory=list) - # Accumulated ASR text output + # Accumulated ASR (user) text output_asr_text_str: str = "" - output_asr_text_tokens: List[str] = field(default_factory=list) - # Accumulated words with timings + # Word-level timings for the agent response output_words: List[Word] = field(default_factory=list) - # Final token tensors saved from the context before it is destroyed. + + # Snapshots of full token-ID tensors, saved from StreamingDecodeState + # before the decode context is destroyed at end-of-stream. # Used for post-hoc tokens_to_str / tokens_to_str_raw conversion. final_gen_text: Optional[torch.Tensor] = None final_gen_asr_text: Optional[torch.Tensor] = None + final_gen_function_text: Optional[torch.Tensor] = None final_total_frames: int = 0 def __post_init__(self) -> None: - """Allocate tensors lazily based on provided metadata.""" - with torch.no_grad(): - # Empty 2D buffer: shape (1, 0). Will be appended over time. - self.audio_buffer = torch.empty((1, 0), device=self.device, dtype=self.dtype) + # Depends on self.device and self.dtype, so can't be a field default. + self.audio_buffer = torch.empty((1, 0), device=self.device, dtype=self.dtype) def reset(self) -> None: - """Reset all tensors and counters to their initial state.""" - with torch.no_grad(): - self.audio_buffer = torch.empty((1, 0), device=self.device, dtype=self.dtype) - self.output_text_str = "" - self.output_text_tokens.clear() - self.output_asr_text_str = "" - self.output_asr_text_tokens.clear() - self.output_words.clear() - self.final_gen_text = None - self.final_gen_asr_text = None - self.final_total_frames = 0 - - def update_state(self, processed_frames: torch.Tensor, output_text_tokens: Any = None, output_text: str | None = None, output_asr_text: str | None = None) -> None: - """Append new audio to the right of the buffer; token/text args are accepted for API compatibility.""" - if processed_frames is None: + """Reset all accumulated outputs to initial state.""" + self.audio_buffer = torch.empty((1, 0), device=self.device, dtype=self.dtype) + self.output_text_str = "" + self.output_asr_text_str = "" + self.output_words.clear() + self.final_gen_text = None + self.final_gen_asr_text = None + self.final_gen_function_text = None + self.final_total_frames = 0 + + def append_step_output( + self, + audio: torch.Tensor, + text: str | None = None, + asr_text: str | None = None, + ) -> None: + """Append generated audio and optional text from one inference step.""" + if audio is None: return - if not isinstance(processed_frames, torch.Tensor): - raise TypeError("processed_frames must be a torch.Tensor") - with torch.no_grad(): - # Ensure 2D [1, T] layout by flattening extra dims - append_tensor = processed_frames - if append_tensor.dim() > 1: - append_tensor = append_tensor.reshape(1, -1) - elif append_tensor.dim() == 1: - append_tensor = append_tensor.unsqueeze(0) - prior_samples = int(self.audio_buffer.shape[-1]) - appended_samples = int(append_tensor.shape[-1]) - self.audio_buffer = torch.cat([self.audio_buffer, append_tensor.to(self.device, dtype=self.dtype)], dim=-1) - - # Accumulate text output if provided and create a Word with naive timing - if isinstance(output_text, str) and output_text: - self.output_text_tokens.append(output_text) - # Directly concatenate - spacing is already handled by tokenizer (Ġ → space) - self.output_text_str += output_text - try: - if appended_samples > 0 and self.output_sample_rate > 0: - start_t = float(prior_samples) / float(self.output_sample_rate) - end_t = float(prior_samples + appended_samples) / float(self.output_sample_rate) - self.output_words.append(Word(text=output_text, start=start_t, end=end_t, conf=1.0)) - except Exception: - pass - - if isinstance(output_asr_text, str) and output_asr_text: - self.output_asr_text_tokens.append(output_asr_text) - self.output_asr_text_str += output_asr_text - - @property - def speech_frames(self) -> List[torch.Tensor]: - """Backward-compatible view for code expecting a list of chunks.""" - return [self.audio_buffer] - - def get_output_text(self) -> str: - """Return accumulated text as a single string.""" - return self.output_text_str - - def get_output_asr_text(self) -> str: - """Return accumulated ASR text as a single string.""" - return self.output_asr_text_str - - def get_output_words(self) -> List[Word]: - """Return accumulated words with timings.""" - return list(self.output_words) - - def save_token_tensors(self, gen_text: torch.Tensor, gen_asr_text: torch.Tensor, total_frames: int, - gen_function_text: torch.Tensor = None) -> None: - """Snapshot the full token-ID tensors from the context before it is destroyed.""" - with torch.no_grad(): - self.final_gen_text = gen_text[:, :total_frames].clone().cpu() - self.final_gen_asr_text = gen_asr_text[:, :total_frames].clone().cpu() - self.final_total_frames = total_frames - self.final_gen_function_text = ( - gen_function_text[:, :total_frames].clone().cpu() - if gen_function_text is not None else None - ) + if not isinstance(audio, torch.Tensor): + raise TypeError("audio must be a torch.Tensor") + + append_tensor = audio + if append_tensor.dim() > 1: + append_tensor = append_tensor.reshape(1, -1) + elif append_tensor.dim() == 1: + append_tensor = append_tensor.unsqueeze(0) + prior_samples = int(self.audio_buffer.shape[-1]) + appended_samples = int(append_tensor.shape[-1]) + self.audio_buffer = torch.cat( + [self.audio_buffer, append_tensor.to(self.device, dtype=self.dtype)], dim=-1 + ) + + if isinstance(text, str) and text: + self.output_text_str += text + if appended_samples > 0 and self.output_sample_rate > 0: + start_t = float(prior_samples) / float(self.output_sample_rate) + end_t = float(prior_samples + appended_samples) / float(self.output_sample_rate) + self.output_words.append(Word(text=text, start=start_t, end=end_t, conf=1.0)) + + if isinstance(asr_text, str) and asr_text: + self.output_asr_text_str += asr_text + + def save_token_tensors( + self, + gen_text: torch.Tensor, + gen_asr_text: torch.Tensor, + total_frames: int, + gen_function_text: Optional[torch.Tensor] = None, + ) -> None: + """Snapshot the full token-ID tensors from the decode context before it is destroyed.""" + self.final_gen_text = gen_text[:, :total_frames].clone().cpu() + self.final_gen_asr_text = gen_asr_text[:, :total_frames].clone().cpu() + self.final_total_frames = total_frames + self.final_gen_function_text = ( + gen_function_text[:, :total_frames].clone().cpu() + if gen_function_text is not None else None + ) def get_token_tensors(self) -> Optional[tuple]: - """Return (gen_text, gen_asr_text, total_frames[, gen_function_text]) or None if not saved.""" + """Return (gen_text, gen_asr_text, total_frames, gen_function_text) or None if not saved.""" if self.final_gen_text is None: return None - return self.final_gen_text, self.final_gen_asr_text, self.final_total_frames, getattr(self, 'final_gen_function_text', None) + return self.final_gen_text, self.final_gen_asr_text, self.final_total_frames, self.final_gen_function_text - def cleanup_after_response(self) -> None: - """Clear transient audio; keep token workspaces allocated.""" - with torch.no_grad(): - self.audio_buffer = torch.empty((1, 0), device=self.device, dtype=self.dtype) + def clear_audio_buffer(self) -> None: + """Clear the audio buffer (e.g. after sending audio to a client).""" + self.audio_buffer = torch.empty((1, 0), device=self.device, dtype=self.dtype) diff --git a/nemo/collections/speechlm2/inference/utils/audio_data.py b/nemo/collections/speechlm2/inference/utils/audio_data.py new file mode 100644 index 000000000000..c73f6bdd6266 --- /dev/null +++ b/nemo/collections/speechlm2/inference/utils/audio_data.py @@ -0,0 +1,183 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Audio data loading and output serialization for S2S inference scripts.""" + +from __future__ import annotations + +import json +import os +from typing import List + +import soundfile as sf + +from nemo.collections.common.parts.preprocessing.manifest import get_full_path +from nemo.collections.speechlm2.inference.streaming.framing.s2s_request_options import S2SRequestOptions +from nemo.collections.speechlm2.inference.utils.pipeline_utils import PipelineOutput + + +def prepare_audio_data( + audio_file: str, + default_system_prompt: str | None = None, + sort_by_duration: bool = True, +) -> tuple[List[str], List[S2SRequestOptions], List[str | None]]: + """Load audio filepaths and per-stream options from a folder, single file, or manifest. + + When the input is a JSON manifest, each line may contain:: + + {"audio_filepath": "clip.wav", "text": "...", "system_prompt": "..."} + + If ``system_prompt`` is absent on a line, *default_system_prompt* is used. + + Returns: + ``(filepaths, options, ground_truths)`` -- parallel lists of audio paths, + per-stream request options, and ground-truth texts (``None`` when unavailable). + """ + audio_file = audio_file.strip() + if not os.path.isabs(audio_file): + audio_file = os.path.abspath(audio_file) + + options: List[S2SRequestOptions] = [] + ground_truths: List[str | None] = [] + + if os.path.isdir(audio_file): + filepaths = [os.path.join(audio_file, x) for x in os.listdir(audio_file) if x.endswith(".wav")] + options = [S2SRequestOptions(system_prompt=default_system_prompt) for _ in filepaths] + ground_truths = [None] * len(filepaths) + elif audio_file.endswith(".wav"): + filepaths = [audio_file] + options = [S2SRequestOptions(system_prompt=default_system_prompt)] + ground_truths = [None] + elif audio_file.endswith((".json", ".jsonl")): + samples = [] + with open(audio_file, 'r') as f: + for line in f.readlines(): + if line.strip(): + samples.append(json.loads(line)) + filepaths = [get_full_path(entry["audio_filepath"], audio_file) for entry in samples] + options = [ + S2SRequestOptions( + system_prompt=entry.get("system_prompt", default_system_prompt), + ) + for entry in samples + ] + ground_truths = [entry.get("text", None) for entry in samples] + else: + raise ValueError(f"audio_file `{audio_file}` needs to be a folder, audio file, or manifest file") + + if sort_by_duration: + durations = [sf.SoundFile(fp).frames for fp in filepaths] + order = sorted(range(len(filepaths)), key=lambda i: durations[i]) + filepaths = [filepaths[i] for i in order] + options = [options[i] for i in order] + ground_truths = [ground_truths[i] for i in order] + + return filepaths, options, ground_truths + + +def calculate_duration(audio_filepaths: List[str]) -> float: + """Calculate total duration of the given audio files in seconds.""" + total_dur = 0 + for audio_filepath in audio_filepaths: + sound = sf.SoundFile(audio_filepath) + total_dur += sound.frames / sound.samplerate + return total_dur + + +def calculate_padded_duration( + audio_filepaths: List[str], + pad_audio_to_sec: float | None = None, + pad_silence_ratio: float | None = None, + pad_audio_by_sec: float | None = None, +) -> float: + """Calculate total duration including silence padding for RTFX reporting.""" + total = 0.0 + for fp in audio_filepaths: + sound = sf.SoundFile(fp) + orig = sound.frames / sound.samplerate + if pad_audio_to_sec is not None: + total += max(orig, pad_audio_to_sec) + elif pad_silence_ratio is not None: + total += orig * (1 + pad_silence_ratio) + elif pad_audio_by_sec is not None: + total += orig + pad_audio_by_sec + else: + total += orig + return total + + +def dump_output( + audio_filepaths: List[str], + output: PipelineOutput, + output_dir: str, + options: List[S2SRequestOptions], + ground_truths: List[str | None], +) -> None: + """Dump inference results to output_processed.json, output_raw.json, and per-file CTM. + + ``output_processed.json`` uses the canonical S2S processed-output schema + (timestamps in pred_text via ``<|t|>`` / ``<$t$>``). + + ``output_raw.json`` preserves all tokens including ```` (pad tokens). + """ + output_processed_path = os.path.join(output_dir, "output_processed.json") + output_raw_path = os.path.join(output_dir, "output_raw.json") + output_ctm_dir = os.path.join(output_dir, "ctm") + + os.makedirs(output_ctm_dir, exist_ok=True) + + asr_texts_ts = output.asr_texts_with_timestamps or [None] * len(audio_filepaths) + texts_ts = output.texts_with_timestamps or [""] * len(audio_filepaths) + raw_texts = output.raw_texts or [""] * len(audio_filepaths) + raw_asr_texts = output.raw_asr_texts or [""] * len(audio_filepaths) + + with open(output_processed_path, 'w') as f_proc, open(output_raw_path, 'w') as f_raw: + for audio_filepath, words, opts, gt, pred_text_ts, pred_src_text_ts, pred_text_raw, pred_src_text_raw in zip( + audio_filepaths, output.words, options, ground_truths, + texts_ts, asr_texts_ts, raw_texts, raw_asr_texts, + ): + stem = os.path.splitext(os.path.basename(audio_filepath))[0] + ctm_filepath = os.path.abspath(os.path.join(output_ctm_dir, f"{stem}.ctm")) + with open(ctm_filepath, 'w') as ctm_fout: + for word in words: + ctm_line = f"A {round(word.start, 2)} {round(word.duration, 2)} {word.text} {word.conf}" + ctm_fout.write(f"{stem} {ctm_line}\n") + + pred_audio_path = os.path.join(output_dir, "wav", f"{stem}.wav") + + record_processed = { + "id": stem, + "target_text": "", + "pred_audio": pred_audio_path, + "src_text": gt or "", + "pred_src_text": pred_src_text_ts or "", + "pred_text": pred_text_ts or "", + "system_prompt": opts.system_prompt or "", + } + json.dump(record_processed, f_proc, ensure_ascii=False) + f_proc.write('\n') + f_proc.flush() + + record_raw = { + "id": stem, + "target_text": "", + "pred_audio": pred_audio_path, + "src_text": gt or "", + "pred_src_text": pred_src_text_raw or "", + "pred_text": pred_text_raw or "", + "system_prompt": opts.system_prompt or "", + } + json.dump(record_raw, f_raw, ensure_ascii=False) + f_raw.write('\n') + f_raw.flush() From 02855897c27d9b0cf6b9257dde98a5501eab279a Mon Sep 17 00:00:00 2001 From: Elena Rastorgueva Date: Tue, 24 Mar 2026 22:40:58 +0000 Subject: [PATCH 16/40] Add pytest-based offline vs. incremental inference parity test with logit comparison Signed-off-by: Elena Rastorgueva --- .../nemotron_voicechat_parity_harness.py | 747 ------------------ .../pipelines/streaming_s2s_pipeline.py | 4 +- .../speechlm2/models/nemotron_voicechat.py | 62 +- .../test_offline_incremental_parity.py | 596 ++++++++++++++ 4 files changed, 652 insertions(+), 757 deletions(-) delete mode 100644 examples/speechlm2/nemotron_voicechat_parity_harness.py create mode 100644 tests/collections/speechlm2/test_offline_incremental_parity.py diff --git a/examples/speechlm2/nemotron_voicechat_parity_harness.py b/examples/speechlm2/nemotron_voicechat_parity_harness.py deleted file mode 100644 index 0673cd7a881d..000000000000 --- a/examples/speechlm2/nemotron_voicechat_parity_harness.py +++ /dev/null @@ -1,747 +0,0 @@ -from __future__ import annotations - -import argparse -import json -import math -import os -from pathlib import Path -import tempfile -from typing import Any - -import librosa -import soundfile as sf -import torch -from omegaconf import MISSING, OmegaConf - -from nemo.collections.speechlm2.inference.factory.s2s_pipeline_builder import S2SPipelineBuilder -from nemo.collections.speechlm2.inference.model_wrappers.nemotron_voicechat_inference_wrapper import ( - FRAME_SIZE_SAMPLES, - NemotronVoicechatInferenceWrapper, -) -from nemo.collections.speechlm2.inference.streaming.framing.s2s_request_options import S2SRequestOptions -def _bool_arg(parser: argparse.ArgumentParser, name: str, help_text: str) -> None: - parser.add_argument(name, action=argparse.BooleanOptionalAction, default=None, help=help_text) - - -def _default_s2s_streaming_config_path() -> str: - repo_root = Path(__file__).resolve().parents[2] - return str(repo_root / "examples" / "speechlm2" / "nemo_inference_pipelines" / "conf" / "s2s_streaming.yaml") - - -def _load_s2s_inference_config(config_path: str | None = None): - path = config_path or _default_s2s_streaming_config_path() - cfg = OmegaConf.load(path) - for key, value in { - "audio_file": "", - "output_dir": "./generated", - "s2s.model_path": None, - "s2s.llm_checkpoint_path": None, - "s2s.decode_audio": True, - "s2s.engine_type": "native", - "s2s.system_prompt": None, - "streaming.chunk_size_in_secs": FRAME_SIZE_SAMPLES / 16000.0, - "streaming.buffer_size_in_secs": 71 * (FRAME_SIZE_SAMPLES / 16000.0), - }.items(): - if OmegaConf.select(cfg, key, default=MISSING) is MISSING: - OmegaConf.update(cfg, key, value, force_add=True) - return cfg - - -def _apply_inference_overrides(cfg, overrides: dict[str, Any]): - for key, value in overrides.items(): - if value is not None: - OmegaConf.update(cfg, key, value, force_add=True) - return cfg - - -def _load_audio_tensor(audio_path: str, sample_rate: int, device: torch.device, dtype: torch.dtype): - audio_np, _ = librosa.load(audio_path, sr=sample_rate) - audio = torch.tensor(audio_np, device=device, dtype=dtype).unsqueeze(0) - audio_lens = torch.tensor([audio.shape[1]], device=device, dtype=torch.long) - return audio, audio_lens - - -def _build_prompt_token_ids(tokenizer, system_prompt: str | None) -> list[int]: - if not system_prompt or not system_prompt.strip(): - return [] - return [tokenizer.bos_id] + tokenizer.text_to_ids(system_prompt) + [tokenizer.eos_id] - - -def _resolve_num_frames_per_chunk(args, total_frames: int) -> int: - if args.num_frames_per_chunk is not None: - value = int(args.num_frames_per_chunk) - elif args.chunk_size_in_secs is not None: - value = int(round(float(args.chunk_size_in_secs) / (FRAME_SIZE_SAMPLES / 16000.0))) - else: - value = total_frames - - if value < 1: - raise ValueError(f"num_frames_per_chunk must be >= 1, got {value}") - return value - - -def _apply_deterministic_runtime_settings(enabled: bool) -> None: - if not enabled: - return - os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" - torch.manual_seed(0) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(0) - torch.backends.cuda.enable_flash_sdp(False) - torch.backends.cuda.enable_mem_efficient_sdp(False) - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - torch.set_float32_matmul_precision("medium") - torch.use_deterministic_algorithms(True, warn_only=False) - - - -def _compute_min_buffer_frames(wrapper, num_frames_per_chunk: int) -> int: - att_context_size = wrapper.model.stt_model.perception.encoder._cfg.att_context_size - if wrapper.use_perception_cache: - return num_frames_per_chunk * (att_context_size[1] + 1) + 2 - return att_context_size[0] + att_context_size[1] + 1 - - -def _compute_min_buffer_frames_from_cfg(cfg, num_frames_per_chunk: int) -> int: - att_context_size = cfg.streaming.get("att_context_size", [70, 0]) - if cfg.s2s.get("use_perception_cache", False): - return num_frames_per_chunk * (att_context_size[1] + 1) + 2 - return att_context_size[0] + att_context_size[1] + 1 - - -def _first_diff(a: torch.Tensor, b: torch.Tensor) -> int | None: - a = a.detach().cpu() - b = b.detach().cpu() - if a.shape != b.shape: - return 0 - diff = (a != b).flatten() - if not diff.any(): - return None - return int(diff.nonzero(as_tuple=False)[0].item()) - - -def _prefix_compare(a: torch.Tensor, b: torch.Tensor) -> tuple[int | None, bool | None, int | None]: - if not isinstance(a, torch.Tensor) or not isinstance(b, torch.Tensor): - return None, None, None - a = a.detach().cpu() - b = b.detach().cpu() - if a.dim() != b.dim(): - return None, None, None - - prefix_len = min(a.shape[-1], b.shape[-1]) - if prefix_len == 0: - return 0, True, None - - a_prefix = a[..., :prefix_len] - b_prefix = b[..., :prefix_len] - if torch.equal(a_prefix, b_prefix): - return prefix_len, True, None - - diff = (a_prefix != b_prefix).flatten() - first_diff = int(diff.nonzero(as_tuple=False)[0].item()) - return prefix_len, False, first_diff - - -def _prefix_tensor_diff(a: torch.Tensor | None, b: torch.Tensor | None) -> dict[str, Any] | None: - if not isinstance(a, torch.Tensor) or not isinstance(b, torch.Tensor): - return None - a = a.detach().cpu() - b = b.detach().cpu() - if a.dim() != b.dim(): - return None - prefix_len = min(a.shape[1], b.shape[1]) if a.dim() >= 2 else min(a.shape[0], b.shape[0]) - if prefix_len <= 0: - return {"prefix_len": 0, "match": True, "max_abs_diff": 0.0, "mean_abs_diff": 0.0} - if a.dim() == 2: - a_prefix = a[:, :prefix_len] - b_prefix = b[:, :prefix_len] - else: - a_prefix = a[:, :prefix_len, ...] - b_prefix = b[:, :prefix_len, ...] - diff = (a_prefix - b_prefix).abs() - reduce_dims = tuple(i for i in range(diff.dim()) if i != 1) - if reduce_dims: - per_step_max = diff.amax(dim=reduce_dims) - else: - per_step_max = diff - first_step_diff_index = None - differing_steps = (per_step_max > 0).nonzero(as_tuple=False) - if differing_steps.numel() > 0: - first_step_diff_index = int(differing_steps[0].item()) - return { - "prefix_len": prefix_len, - "match": bool(torch.equal(a_prefix, b_prefix)), - "max_abs_diff": float(diff.max().item()), - "mean_abs_diff": float(diff.mean().item()), - "first_step_diff_index": first_step_diff_index, - } - - -def _dtype_name(value) -> str | None: - if value is None: - return None - if isinstance(value, torch.Tensor): - return str(value.dtype) - if isinstance(value, torch.dtype): - return str(value) - return str(value) - - -def _module_param_dtype(module) -> str | None: - if module is None: - return None - try: - return str(next(module.parameters()).dtype) - except StopIteration: - return None - except Exception: - return None - - -def _collect_model_dtypes(wrapper: NemotronVoicechatInferenceWrapper) -> dict[str, Any]: - stt_model = wrapper.model.stt_model - return { - "wrapper_dtype": _dtype_name(wrapper.dtype), - "llm_dtype": _module_param_dtype(getattr(stt_model, "llm", None)), - "lm_head_dtype": _module_param_dtype(getattr(stt_model, "lm_head", None)), - "asr_head_dtype": _module_param_dtype(getattr(stt_model, "asr_head", None)), - "embed_tokens_dtype": _module_param_dtype(getattr(stt_model, "embed_tokens", None)), - "embed_asr_tokens_dtype": _module_param_dtype(getattr(stt_model, "embed_asr_tokens", None)), - "perception_dtype": _module_param_dtype(getattr(stt_model, "perception", None)), - "tts_dtype": _module_param_dtype(getattr(wrapper.model, "tts_model", None)), - } - - -def _tensor_summary_diff(a: torch.Tensor | None, b: torch.Tensor | None) -> dict[str, Any] | None: - if not isinstance(a, torch.Tensor) or not isinstance(b, torch.Tensor): - return None - if a.shape != b.shape: - return {"shape_a": list(a.shape), "shape_b": list(b.shape), "match": False} - diff = (a - b).abs() - return { - "shape": list(a.shape), - "match": bool(torch.equal(a, b)), - "max_abs_diff": float(diff.max().item()), - "mean_abs_diff": float(diff.mean().item()), - } - - -def _step_component_diagnostics( - wrapper: NemotronVoicechatInferenceWrapper, - offline_debug: dict[str, Any], - incremental_debug: dict[str, Any], -) -> dict[str, Any] | None: - input_embed_diff = _prefix_tensor_diff(offline_debug.get("input_embeds"), incremental_debug.get("input_embeds")) - if input_embed_diff is None: - return {"status": "missing_input_embed_diff"} - step_idx = input_embed_diff.get("first_step_diff_index") - if step_idx is None: - return {"status": "no_input_embed_drift"} - if step_idx == 0: - return {"first_step_diff_index": 0, "note": "Drift starts at step 0; component breakdown not specialized."} - - offline_source = offline_debug.get("source_encoded") - incremental_source = incremental_debug.get("source_encoded") - selected_indices = incremental_debug.get("selected_frame_indices") or [] - offline_tokens = offline_debug.get("gen_text") - offline_asr = offline_debug.get("gen_asr") - incremental_tokens = incremental_debug.get("gen_text") - incremental_asr = incremental_debug.get("gen_asr") - required = { - "offline_source_encoded": offline_source, - "incremental_source_encoded": incremental_source, - "offline_gen_text": offline_tokens, - "offline_gen_asr": offline_asr, - "incremental_gen_text": incremental_tokens, - "incremental_gen_asr": incremental_asr, - } - missing = [name for name, value in required.items() if not isinstance(value, torch.Tensor)] - if missing: - return { - "status": "missing_tensors", - "first_step_diff_index": step_idx, - "missing": missing, - } - if step_idx >= len(selected_indices): - return { - "status": "selected_index_out_of_range", - "first_step_diff_index": step_idx, - "selected_frame_indices_len": len(selected_indices), - } - - stt_model = wrapper.model.stt_model - source_frame_offline = offline_source[:, step_idx : step_idx + 1, :] - source_frame_incremental = incremental_source[:, selected_indices[step_idx] : selected_indices[step_idx] + 1, :] - prev_offline_text = offline_tokens[:, step_idx - 1] - prev_incremental_text = incremental_tokens[:, step_idx - 1] - prev_offline_asr = offline_asr[:, step_idx - 1] - prev_incremental_asr = incremental_asr[:, step_idx - 1] - - offline_text_emb = stt_model.embed_tokens(prev_offline_text.to(wrapper.device)).detach().cpu() - incremental_text_emb = stt_model.embed_tokens(prev_incremental_text.to(wrapper.device)).detach().cpu() - offline_asr_emb = stt_model.embed_asr_tokens(prev_offline_asr.to(wrapper.device)).detach().cpu() - incremental_asr_emb = stt_model.embed_asr_tokens(prev_incremental_asr.to(wrapper.device)).detach().cpu() - - text_weight = stt_model.cfg.get("duplex_text_channel_weight", 1.0) - asr_weight = stt_model.cfg.get("duplex_asr_text_weight", 1.0) - offline_last_emb = offline_text_emb * text_weight + offline_asr_emb * asr_weight - incremental_last_emb = incremental_text_emb * text_weight + incremental_asr_emb * asr_weight - - offline_input = offline_debug["input_embeds"][:, step_idx : step_idx + 1, :] - incremental_input = incremental_debug["input_embeds"][:, step_idx : step_idx + 1, :] - - offline_style = source_frame_incremental.detach().cpu().clone() - offline_style += incremental_last_emb.unsqueeze(1) - incremental_style = source_frame_incremental.detach().cpu().clone() - incremental_style += (incremental_text_emb * text_weight).unsqueeze(1) - incremental_style += (incremental_asr_emb * asr_weight).unsqueeze(1) - - return { - "status": "ok", - "first_step_diff_index": step_idx, - "selected_frame_index": selected_indices[step_idx], - "prev_text_token_equal": bool(torch.equal(prev_offline_text.cpu(), prev_incremental_text.cpu())), - "prev_asr_token_equal": bool(torch.equal(prev_offline_asr.cpu(), prev_incremental_asr.cpu())), - "source_frame_diff": _tensor_summary_diff(source_frame_offline.cpu(), source_frame_incremental.cpu()), - "text_embedding_diff": _tensor_summary_diff(offline_text_emb, incremental_text_emb), - "asr_embedding_diff": _tensor_summary_diff(offline_asr_emb, incremental_asr_emb), - "last_emb_diff": _tensor_summary_diff(offline_last_emb, incremental_last_emb), - "offline_input_vs_incremental_input": _tensor_summary_diff(offline_input.cpu(), incremental_input.cpu()), - "offline_input_vs_offline_style_rebuild": _tensor_summary_diff(offline_input.cpu(), offline_style), - "incremental_input_vs_offline_style_rebuild": _tensor_summary_diff(incremental_input.cpu(), offline_style), - "offline_input_vs_incremental_style_rebuild": _tensor_summary_diff(offline_input.cpu(), incremental_style), - "incremental_input_vs_incremental_style_rebuild": _tensor_summary_diff(incremental_input.cpu(), incremental_style), - } - - -def _compare_debug_outputs(offline_debug: dict[str, Any] | None, incremental_debug: dict[str, Any] | None) -> dict[str, Any] | None: - if offline_debug is None or incremental_debug is None: - return None - - offline_encoder = offline_debug.get("source_encoded") - incremental_encoder = incremental_debug.get("source_encoded") - selected_indices = incremental_debug.get("selected_frame_indices") or [] - selected_incremental = None - selected_prefix = None - if isinstance(incremental_encoder, torch.Tensor) and selected_indices: - selected_incremental = incremental_encoder[:, selected_indices, :] - if isinstance(offline_encoder, torch.Tensor) and selected_incremental is not None: - prefix_len = min(offline_encoder.shape[1], selected_incremental.shape[1]) - selected_prefix = _prefix_tensor_diff( - offline_encoder[:, :prefix_len, :], - selected_incremental[:, :prefix_len, :], - ) - - report = { - "offline_tensor_dtypes": { - "source_encoded": _dtype_name(offline_debug.get("source_encoded")), - "input_embeds": _dtype_name(offline_debug.get("input_embeds")), - "text_logits": _dtype_name(offline_debug.get("text_logits")), - "asr_logits": _dtype_name(offline_debug.get("asr_logits")), - }, - "incremental_tensor_dtypes": { - "source_encoded": _dtype_name(incremental_debug.get("source_encoded")), - "input_embeds": _dtype_name(incremental_debug.get("input_embeds")), - "text_logits": _dtype_name(incremental_debug.get("text_logits")), - "asr_logits": _dtype_name(incremental_debug.get("asr_logits")), - }, - "offline_source_encoded_shape": list(offline_encoder.shape) if isinstance(offline_encoder, torch.Tensor) else None, - "incremental_source_encoded_shape": list(incremental_encoder.shape) if isinstance(incremental_encoder, torch.Tensor) else None, - "incremental_selected_frame_indices": selected_indices, - "selected_encoder_prefix": selected_prefix, - "offline_input_embeds_shape": list(offline_debug["input_embeds"].shape) - if isinstance(offline_debug.get("input_embeds"), torch.Tensor) - else None, - "incremental_input_embeds_shape": list(incremental_debug["input_embeds"].shape) - if isinstance(incremental_debug.get("input_embeds"), torch.Tensor) - else None, - "input_embeds_prefix": _prefix_tensor_diff(offline_debug.get("input_embeds"), incremental_debug.get("input_embeds")), - "offline_text_logits_shape": list(offline_debug["text_logits"].shape) - if isinstance(offline_debug.get("text_logits"), torch.Tensor) - else None, - "incremental_text_logits_shape": list(incremental_debug["text_logits"].shape) - if isinstance(incremental_debug.get("text_logits"), torch.Tensor) - else None, - "text_logits_prefix": _prefix_tensor_diff(offline_debug.get("text_logits"), incremental_debug.get("text_logits")), - "offline_asr_logits_shape": list(offline_debug["asr_logits"].shape) - if isinstance(offline_debug.get("asr_logits"), torch.Tensor) - else None, - "incremental_asr_logits_shape": list(incremental_debug["asr_logits"].shape) - if isinstance(incremental_debug.get("asr_logits"), torch.Tensor) - else None, - "asr_logits_prefix": _prefix_tensor_diff(offline_debug.get("asr_logits"), incremental_debug.get("asr_logits")), - } - report["step_component_diagnostics"] = None - return report - - -def _compare_outputs(offline: dict[str, Any], incremental: dict[str, Any]) -> dict[str, Any]: - offline_tokens = offline.get("tokens_text") - incremental_tokens = incremental.get("tokens_text") - offline_asr = offline.get("tokens_text_src") - incremental_asr = incremental.get("asr_tokens") - token_prefix_len, token_prefix_match, token_prefix_first_diff = _prefix_compare(offline_tokens, incremental_tokens) - asr_prefix_len, asr_prefix_match, asr_prefix_first_diff = _prefix_compare(offline_asr, incremental_asr) - - token_match = ( - isinstance(offline_tokens, torch.Tensor) - and isinstance(incremental_tokens, torch.Tensor) - and offline_tokens.shape == incremental_tokens.shape - and torch.equal(offline_tokens.detach().cpu(), incremental_tokens.detach().cpu()) - ) - asr_token_match = None - if isinstance(offline_asr, torch.Tensor) and isinstance(incremental_asr, torch.Tensor): - asr_token_match = offline_asr.shape == incremental_asr.shape and torch.equal( - offline_asr.detach().cpu(), incremental_asr.detach().cpu() - ) - - offline_audio_len = offline.get("audio_len") - incremental_audio = incremental.get("audio") - audio_sample_count_equal = None - if offline_audio_len is not None and incremental_audio is not None: - expected = int(offline_audio_len[0].item()) - got = int(incremental_audio.shape[-1]) - audio_sample_count_equal = expected == got - - report = { - "offline_text": offline.get("text", [""])[0], - "incremental_text": incremental.get("text", [""])[0], - "text_equal": offline.get("text", [""])[0] == incremental.get("text", [""])[0], - "offline_asr_text": (offline.get("src_text") or [""])[0] if offline.get("src_text") is not None else None, - "incremental_asr_text": (incremental.get("asr_text") or [""])[0] if incremental.get("asr_text") is not None else None, - "asr_text_equal": ( - offline.get("src_text") is not None - and incremental.get("asr_text") is not None - and offline["src_text"][0] == incremental["asr_text"][0] - ), - "offline_token_shape": list(offline_tokens.shape) if isinstance(offline_tokens, torch.Tensor) else None, - "incremental_token_shape": list(incremental_tokens.shape) if isinstance(incremental_tokens, torch.Tensor) else None, - "token_match": token_match, - "token_first_diff_index": _first_diff(offline_tokens, incremental_tokens) - if isinstance(offline_tokens, torch.Tensor) and isinstance(incremental_tokens, torch.Tensor) - else None, - "token_prefix_len": token_prefix_len, - "token_prefix_match": token_prefix_match, - "token_prefix_first_diff_index": token_prefix_first_diff, - "offline_asr_token_shape": list(offline_asr.shape) if isinstance(offline_asr, torch.Tensor) else None, - "incremental_asr_token_shape": list(incremental_asr.shape) if isinstance(incremental_asr, torch.Tensor) else None, - "asr_token_match": asr_token_match, - "asr_token_first_diff_index": _first_diff(offline_asr, incremental_asr) - if isinstance(offline_asr, torch.Tensor) and isinstance(incremental_asr, torch.Tensor) - else None, - "asr_token_prefix_len": asr_prefix_len, - "asr_token_prefix_match": asr_prefix_match, - "asr_token_prefix_first_diff_index": asr_prefix_first_diff, - "audio_sample_count_equal": audio_sample_count_equal, - } - return report - - -def _merge_incremental_debug_steps(steps: list[dict[str, Any]]) -> dict[str, Any]: - """Merge per-step debug dicts from the pipeline into a single dict matching offline debug format.""" - if not steps: - return {} - all_source_encoded = [s["source_encoded"] for s in steps if s.get("source_encoded") is not None] - all_input_embeds = [s["input_embeds"] for s in steps if s.get("input_embeds") is not None] - all_text_logits = [s["text_logits"] for s in steps if s.get("text_logits") is not None] - all_asr_logits = [s["asr_logits"] for s in steps if s.get("asr_logits") is not None] - all_gen_text = [s["gen_text"] for s in steps if s.get("gen_text") is not None] - all_gen_asr = [s["gen_asr"] for s in steps if s.get("gen_asr") is not None] - selected_frame_indices = [] - for s in steps: - selected_frame_indices.extend(s.get("selected_frame_indices", [])) - return { - "source_encoded": all_source_encoded[-1] if all_source_encoded else None, - "input_embeds": torch.cat(all_input_embeds, dim=1) if all_input_embeds else None, - "gen_text": all_gen_text[-1] if all_gen_text else None, - "gen_asr": all_gen_asr[-1] if all_gen_asr else None, - "text_logits": torch.cat(all_text_logits, dim=1) if all_text_logits else None, - "asr_logits": torch.cat(all_asr_logits, dim=1) if all_asr_logits else None, - "selected_frame_indices": selected_frame_indices, - } - - -def _collect_offline_debug( - wrapper: NemotronVoicechatInferenceWrapper, - audio: torch.Tensor, - audio_lens: torch.Tensor, - prompt_tokens: torch.Tensor | None, - prompt_token_lens: torch.Tensor | None, -) -> dict[str, Any]: - buffer_len = audio_lens.to(device=wrapper.device, dtype=torch.long) - source_encoded, _, _ = wrapper.model.stt_model.perception( - input_signal=audio, - input_signal_length=buffer_len, - return_encoder_emb=True, - ) - source_encoded = source_encoded.to(wrapper.dtype) - - inference_state = wrapper.model.stt_model.streaming_inference._init_inference( - audio, - audio_lens, - 0, - prompt_tokens, - prompt_token_lens, - ) - ans, inference_state = wrapper.model.stt_model.streaming_inference._step_zero(inference_state) - text_logits = [ans["text_logits"][:, -1].detach().cpu()] - asr_logits = [ans["asr_logits"][:, -1].detach().cpu()] if "asr_logits" in ans else [] - T = inference_state["T"] - for t in range(1, T): - ans = wrapper.model.stt_model.streaming_inference._step_inference(t, inference_state, ans) - text_logits.append(ans["text_logits"][:, -1].detach().cpu()) - if "asr_logits" in ans: - asr_logits.append(ans["asr_logits"][:, -1].detach().cpu()) - - return { - "source_encoded": source_encoded.detach().cpu(), - "input_embeds": inference_state["input_embeds"].detach().cpu(), - "gen_text": inference_state["gen_text"].detach().cpu(), - "gen_asr": inference_state["gen_asr"].detach().cpu() if inference_state.get("gen_asr") is not None else None, - "text_logits": torch.stack(text_logits, dim=1), - "asr_logits": torch.stack(asr_logits, dim=1) if asr_logits else None, - } - - -def run_parity_harness(args) -> dict[str, Any]: - inference_cfg = _load_s2s_inference_config(args.config_path) - - if args.strict_runtime_parity and args.tts_system_prompt: - raise ValueError( - "Strict offline/incremental parity does not currently support `tts_system_prompt`, " - "because offline_inference has no equivalent string prompt API for TTS conditioning." - ) - - overrides = { - "s2s.model_path": args.model_path, - "s2s.llm_checkpoint_path": args.llm_checkpoint_path, - "s2s.speaker_reference": args.speaker_reference, - "s2s.speaker_name": args.speaker_name, - "s2s.compute_dtype": args.compute_dtype, - "s2s.decode_audio": args.decode_audio, - "s2s.system_prompt": args.system_prompt, - "s2s.tts_system_prompt": args.tts_system_prompt, - "s2s.engine_type": args.engine_type, - "s2s.use_perception_cache": args.use_perception_cache, - "s2s.use_perception_cudagraph": args.use_perception_cudagraph, - "s2s.use_llm_cache": args.use_llm_cache, - "s2s.deterministic": args.deterministic, - "s2s.top_p": args.top_p, - "s2s.repetition_penalty": args.repetition_penalty, - "s2s.temperature": args.temperature, - } - - if args.strict_runtime_parity: - strict_defaults = { - "s2s.engine_type": "native", - "s2s.compute_dtype": "float32", - "s2s.use_perception_cache": False, - "s2s.use_perception_cudagraph": False, - "s2s.use_llm_cache": False, - "s2s.deterministic": True, - "s2s.top_p": 1.0, - "s2s.repetition_penalty": 1.0, - "s2s.temperature": 0.0, - } - for key, value in strict_defaults.items(): - overrides[key] = value if overrides.get(key) is None else overrides[key] - - inference_cfg = _apply_inference_overrides(inference_cfg, overrides) - _apply_deterministic_runtime_settings(bool(inference_cfg.s2s.get("deterministic", False))) - - input_sample_rate = int(inference_cfg.streaming.get("input_sample_rate", 16000)) - audio_np, _ = librosa.load(args.audio_path, sr=input_sample_rate) - total_samples = len(audio_np) - total_frames = int(math.ceil(total_samples / FRAME_SIZE_SAMPLES)) - num_frames_per_chunk = _resolve_num_frames_per_chunk(args, total_frames) - chunk_size_in_secs = num_frames_per_chunk * (FRAME_SIZE_SAMPLES / float(input_sample_rate)) - buffer_size_frames = max(num_frames_per_chunk, _compute_min_buffer_frames_from_cfg(inference_cfg, num_frames_per_chunk)) - - with tempfile.TemporaryDirectory(prefix="voicechat-parity-") as tmpdir: - inference_cfg = _apply_inference_overrides( - inference_cfg, - { - "output_dir": tmpdir, - "streaming.chunk_size_in_secs": chunk_size_in_secs, - "streaming.buffer_size_in_secs": buffer_size_frames * (FRAME_SIZE_SAMPLES / float(input_sample_rate)), - }, - ) - pipeline = S2SPipelineBuilder.build_pipeline(inference_cfg) - do_collect_debug = args.collect_debug if args.collect_debug is not None else bool(args.strict_runtime_parity) - pipeline.collect_debug = do_collect_debug - wrapper = pipeline.s2s_model - - audio, audio_lens = _load_audio_tensor( - args.audio_path, - sample_rate=wrapper.model.source_sample_rate, - device=wrapper.device, - dtype=wrapper.dtype, - ) - - prompt_tokens = None - prompt_token_lens = None - if inference_cfg.s2s.get("system_prompt"): - prompt_token_ids = _build_prompt_token_ids(wrapper.tokenizer, inference_cfg.s2s.system_prompt) - prompt_tokens = torch.tensor(prompt_token_ids, device=wrapper.device, dtype=torch.long).unsqueeze(0) - prompt_token_lens = torch.tensor([len(prompt_token_ids)], device=wrapper.device, dtype=torch.long) - - if wrapper.speaker_name is not None: - OmegaConf.update(wrapper.model.cfg, "inference_speaker_name", wrapper.speaker_name, force_add=True) - elif wrapper.speaker_reference: - OmegaConf.update(wrapper.model.cfg, "inference_speaker_reference", wrapper.speaker_reference, force_add=True) - - offline = wrapper.model.offline_inference( - input_signal=audio, - input_signal_lens=audio_lens, - prompt_tokens=prompt_tokens, - prompt_token_lens=prompt_token_lens, - decode_audio=bool(inference_cfg.s2s.get("decode_audio", True)), - ) - offline_debug = _collect_offline_debug( - wrapper, - audio=audio, - audio_lens=audio_lens, - prompt_tokens=prompt_tokens, - prompt_token_lens=prompt_token_lens, - ) - pipeline_output = pipeline.run( - [args.audio_path], - options=[S2SRequestOptions(system_prompt=inference_cfg.s2s.get("system_prompt"))], - ) - - incremental_audio = None - incremental_audio_path = None - audio_sample_count_equal = None - if getattr(pipeline_output, "audio_filepaths", None): - incremental_audio_path = pipeline_output.audio_filepaths[0] - if incremental_audio_path: - incremental_audio, _ = sf.read(incremental_audio_path) - incremental_audio = torch.tensor(incremental_audio).reshape(1, -1) - incremental = { - "text": [pipeline_output.texts_with_timestamps[0] if pipeline_output.texts_with_timestamps else pipeline_output.texts[0]], - "asr_text": [pipeline_output.asr_texts_with_timestamps[0] if pipeline_output.asr_texts_with_timestamps else pipeline_output.asr_texts[0]], - "tokens_text": pipeline_output.token_texts[0] if pipeline_output.token_texts else None, - "asr_tokens": pipeline_output.token_asr_texts[0] if pipeline_output.token_asr_texts else None, - "audio": incremental_audio, - } - - incremental_debug = None - if pipeline_output.debug_data and pipeline_output.debug_data[0]: - incremental_debug = _merge_incremental_debug_steps(pipeline_output.debug_data[0]) - - debug_comparison = _compare_debug_outputs(offline_debug, incremental_debug) if incremental_debug else None - - report = { - "audio_path": args.audio_path, - "total_samples": int(total_samples), - "total_frames": total_frames, - "num_frames_per_chunk": num_frames_per_chunk, - "buffer_size_frames": buffer_size_frames, - "strict_runtime_parity": bool(args.strict_runtime_parity), - "engine_type": inference_cfg.s2s.get("engine_type"), - "use_perception_cache": bool(inference_cfg.s2s.get("use_perception_cache", False)), - "use_llm_cache": bool(inference_cfg.s2s.get("use_llm_cache", False)), - "deterministic": bool(inference_cfg.s2s.get("deterministic", False)), - "model_dtypes": _collect_model_dtypes(wrapper), - "comparison": _compare_outputs(offline, incremental), - "debug_comparison": debug_comparison, - "debug": { - "incremental_mode": "pipeline", - "incremental_audio_filepath": incremental_audio_path, - "offline_debug": offline_debug, - "incremental_debug": incremental_debug, - }, - } - - if args.output_json: - output_path = Path(args.output_json) - output_path.parent.mkdir(parents=True, exist_ok=True) - serializable_report = {k: v for k, v in report.items() if k != "debug"} - with output_path.open("w", encoding="utf-8") as f: - json.dump(serializable_report, f, indent=2, ensure_ascii=False) - - if args.strict_runtime_parity: - comparison = report["comparison"] - failed = [] - if not comparison["text_equal"]: - failed.append("text") - if comparison["token_match"] is False: - failed.append("tokens") - if comparison["asr_token_match"] is False: - failed.append("asr_tokens") - if comparison["audio_sample_count_equal"] is False: - failed.append("audio_length") - if failed: - raise AssertionError( - "Offline/incremental parity failed for: " - + ", ".join(failed) - + f". Report: {json.dumps({k: v for k, v in report.items() if k != 'debug'}, ensure_ascii=False)}" - ) - - return report - - -def build_argparser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser( - description="Compare Nemotron VoiceChat offline inference against incremental decoding with one full-audio chunk." - ) - parser.add_argument("--model_path", type=str, required=True, help="Path to S2S/TTS checkpoint directory.") - parser.add_argument("--llm_checkpoint_path", type=str, required=True, help="Path to LLM/perception checkpoint directory.") - parser.add_argument("--audio_path", type=str, required=True, help="Audio file to compare across both paths.") - parser.add_argument("--speaker_reference", type=str, default=None, help="Speaker reference audio path.") - parser.add_argument("--speaker_name", type=str, default=None, help="Registered speaker name.") - parser.add_argument("--config_path", type=str, default=None, help="Optional path to s2s_streaming.yaml.") - parser.add_argument("--system_prompt", type=str, default=None, help="Optional system prompt.") - parser.add_argument("--tts_system_prompt", type=str, default=None, help="Optional TTS system prompt.") - parser.add_argument( - "--num_frames_per_chunk", - type=int, - default=None, - help="Override incremental chunk size in 80ms frames. If unset, defaults to full audio length.", - ) - parser.add_argument( - "--chunk_size_in_secs", - type=float, - default=None, - help="Override incremental chunk size in seconds. If set, converted to 80ms frames. If unset, defaults to full audio length.", - ) - parser.add_argument("--engine_type", type=str, default=None, help="Override engine type.") - parser.add_argument("--compute_dtype", type=str, default=None, help="Override compute dtype (for example: float32, bfloat16).") - _bool_arg(parser, "--decode_audio", "Whether to decode waveform outputs.") - _bool_arg(parser, "--use_perception_cache", "Override perception cache usage.") - _bool_arg(parser, "--use_perception_cudagraph", "Override perception CUDA-graph usage.") - _bool_arg(parser, "--use_llm_cache", "Override LLM cache usage.") - _bool_arg(parser, "--deterministic", "Override deterministic mode.") - parser.add_argument("--top_p", type=float, default=None, help="Override top-p.") - parser.add_argument("--repetition_penalty", type=float, default=None, help="Override repetition penalty.") - parser.add_argument("--temperature", type=float, default=None, help="Override temperature.") - parser.add_argument("--output_json", type=str, default=None, help="Optional JSON report path.") - parser.add_argument( - "--strict_runtime_parity", - action=argparse.BooleanOptionalAction, - default=True, - help=( - "When enabled, force a strict native/deterministic parity profile and raise if text/token/audio-length " - "comparisons differ." - ), - ) - _bool_arg( - parser, - "--collect_debug", - "Collect per-step encoder outputs and logits for comparison. " - "Stores tensors on CPU each step; disable for long audio to avoid OOM.", - ) - return parser - - -def main() -> int: - args = build_argparser().parse_args() - report = run_parity_harness(args) - printable = {k: v for k, v in report.items() if k != "debug"} - print(json.dumps(printable, indent=2, ensure_ascii=False)) - return 0 - - -if __name__ == "__main__": - raise SystemExit(main()) diff --git a/nemo/collections/speechlm2/inference/pipelines/streaming_s2s_pipeline.py b/nemo/collections/speechlm2/inference/pipelines/streaming_s2s_pipeline.py index d5a1aa6b6a3e..6f931643b1a6 100644 --- a/nemo/collections/speechlm2/inference/pipelines/streaming_s2s_pipeline.py +++ b/nemo/collections/speechlm2/inference/pipelines/streaming_s2s_pipeline.py @@ -158,8 +158,8 @@ def log_output(self, frames: List[Frame], audio_wave: Tensor, ready_feats: List[ if not ready_feats[idx]: continue state = self.get_or_create_state(frame.stream_id) - # audio_wave is [B, S]; take sample idx - sample_audio = audio_wave[idx:idx+1, ...] + # audio_wave is [B, S]; take sample idx (None when decode_audio=False) + sample_audio = audio_wave[idx:idx+1, ...] if audio_wave is not None else None # Determine text piece for this index piece = None if text_pieces and idx < len(text_pieces): diff --git a/nemo/collections/speechlm2/models/nemotron_voicechat.py b/nemo/collections/speechlm2/models/nemotron_voicechat.py index 84a3af2e96df..0e9e7d2d8d5b 100644 --- a/nemo/collections/speechlm2/models/nemotron_voicechat.py +++ b/nemo/collections/speechlm2/models/nemotron_voicechat.py @@ -361,16 +361,35 @@ def init_from_safetensors_ckpt(self, ckpt_path, prefix=""): if missing_keys: logging.warning(f"{len(missing_keys)} keys in checkpoint not found in model") - # Fail if any parameters/buffers are still on meta device — this means - # the checkpoint is missing weights the model requires. - meta_remaining = [n for n, p in self.named_parameters() if p.is_meta] - meta_remaining += [n for n, b in self.named_buffers() if b.is_meta] - if meta_remaining: + # Fail if any *parameters* are still on meta device — those genuinely + # need weights from the checkpoint and their absence is an error. + meta_params = [n for n, p in self.named_parameters() if p.is_meta] + if meta_params: raise RuntimeError( - f"{len(meta_remaining)} tensors still on meta device after checkpoint load " - f"(missing from checkpoint): {meta_remaining[:20]}" + f"{len(meta_params)} parameters still on meta device after checkpoint load " + f"(missing from checkpoint): {meta_params[:20]}" ) + # Buffers on meta device are typically non-persistent computed values + # (e.g. rotary_emb.inv_freq registered with persistent=False) that + # save_pretrained / safetensors intentionally omit. Reinitialize + # them by moving the owning module to CPU then back, which triggers + # the buffer's factory function. + meta_buffers = [(n, b) for n, b in self.named_buffers() if b.is_meta] + if meta_buffers: + logging.info( + f"Reinitializing {len(meta_buffers)} non-persistent meta buffer(s): " + f"{[n for n, _ in meta_buffers[:10]]}" + ) + for buf_name, buf in meta_buffers: + parts = buf_name.split(".") + module = self + for part in parts[:-1]: + module = getattr(module, part) + module.to_empty(device="cpu") + module.to(device="cpu") + logging.info("Meta buffers reinitialised on CPU") + gc.collect() def training_step(self, batch: dict, batch_idx: int): @@ -527,6 +546,7 @@ def offline_inference( incremental_audio_decoding: bool = False, generation_config: dict = None, guidance_enabled: bool = True, + return_logits: bool = False, ) -> dict[str, torch.Tensor]: """ Runs full offline duplex speech-to-speech inference. @@ -575,6 +595,12 @@ def offline_inference( guidance_enabled (bool, optional): Enables classifier-free guidance. + return_logits (bool, optional): + When True, collect per-step text and ASR logits and + include them in the returned dict as ``"text_logits"`` + (B, T, V_text) and ``"asr_logits"`` (B, T, V_asr). + Useful for parity testing against incremental inference. + Returns: dict[str, torch.Tensor]: @@ -598,6 +624,12 @@ def offline_inference( Tensor (B,) — waveform lengths in samples (if decode_audio=True). + • "text_logits" (only when return_logits=True): + Tensor (B, T, V_text) — per-step text head logits. + + • "asr_logits" (only when return_logits=True): + Tensor (B, T, V_asr) — per-step ASR head logits. + Notes: • Uses streaming inference backend of DuplexSTTModel. • Uses autoregressive codec generation from DuplexEARTTS. @@ -615,6 +647,10 @@ def offline_inference( B = inference_state["B"] T = inference_state["T"] + if return_logits: + _text_logits = [ans["text_logits"][:, -1].detach().cpu()] + _asr_logits = [ans["asr_logits"][:, -1].detach().cpu()] if "asr_logits" in ans else [] + # if speaker_name is provided uses it, if not uses the speaker_audio provided, if speaker_audio is None load it from inference_speaker_reference if speaker_audio is None: speaker_name = self.cfg.get("inference_speaker_name", None) @@ -662,7 +698,12 @@ def offline_inference( # Autoregressive loop for t in range(1, T): # do one step inference on Duplex STT model - _ = self.stt_model.streaming_inference._step_inference(t, inference_state, ans) + ans = self.stt_model.streaming_inference._step_inference(t, inference_state, ans) + + if return_logits: + _text_logits.append(ans["text_logits"][:, -1].detach().cpu()) + if "asr_logits" in ans: + _asr_logits.append(ans["asr_logits"][:, -1].detach().cpu()) # do one step inference on Duplex TTS model # current subword id is always seem @@ -719,6 +760,11 @@ def offline_inference( ans["audio"] = audio_pred.squeeze(1) ans["audio_len"] = audio_pred_len + if return_logits: + ans["text_logits"] = torch.stack(_text_logits, dim=1) + if _asr_logits: + ans["asr_logits"] = torch.stack(_asr_logits, dim=1) + return ans def load_state_dict(self, state_dict, strict: bool = True): diff --git a/tests/collections/speechlm2/test_offline_incremental_parity.py b/tests/collections/speechlm2/test_offline_incremental_parity.py new file mode 100644 index 000000000000..5c6cd66e148e --- /dev/null +++ b/tests/collections/speechlm2/test_offline_incremental_parity.py @@ -0,0 +1,596 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Offline vs. incremental inference parity tests for NemotronVoiceChat. + +``test_parity_tiny_model`` checks offline-vs-incremental parity using a +tiny model with random weights (no checkpoint needed, requires only a GPU). + +``test_parity_real_checkpoint`` does the same on a real exported checkpoint +and is skipped unless +the following environment variables point to a real exported checkpoint:: + + PARITY_CHECKPOINT_PATH=/path/to/exported/checkpoint + PARITY_AUDIO_PATH=/path/to/test.wav + PARITY_SPEAKER_NAME= # optional + +Run from the NeMo repo root (use ``-s`` to see live progress):: + + # unit tests only + CUDA_VISIBLE_DEVICES=0 pytest tests/collections/speechlm2/test_offline_incremental_parity.py -v -s + + # include integration test + PARITY_CHECKPOINT_PATH=... PARITY_AUDIO_PATH=... \\ + CUDA_VISIBLE_DEVICES=0 pytest tests/collections/speechlm2/test_offline_incremental_parity.py -v -s +""" + +from __future__ import annotations + +import json +import math +import os +import time +from typing import Any + +import numpy as np +import pytest +import soundfile as sf +import torch +from omegaconf import OmegaConf + +from nemo.utils import logging + +from nemo.collections.speechlm2.inference.factory.s2s_pipeline_builder import S2SPipelineBuilder +from nemo.collections.speechlm2.inference.model_wrappers.nemotron_voicechat_inference_wrapper import ( + FRAME_SIZE_SAMPLES, + SAMPLE_RATE, +) +from nemo.collections.speechlm2.inference.pipelines.streaming_s2s_pipeline import StreamingS2SPipeline +from nemo.collections.speechlm2.inference.streaming.framing.s2s_request_options import S2SRequestOptions +from nemo.collections.speechlm2.models import NemotronVoiceChat + +# --------------------------------------------------------------------------- +# Comparison helpers +# --------------------------------------------------------------------------- + + +def compare_logits( + a: torch.Tensor, + b: torch.Tensor, +) -> dict[str, Any]: + """Prefix-aware comparison of two ``(B, T, V)`` logit tensors.""" + a = a.detach().cpu().float() + b = b.detach().cpu().float() + prefix_len = min(a.shape[1], b.shape[1]) + if prefix_len == 0: + return {"prefix_len": 0, "match": True, "max_abs_diff": 0.0, "mean_abs_diff": 0.0} + ap, bp = a[:, :prefix_len], b[:, :prefix_len] + diff = (ap - bp).abs() + reduce_dims = tuple(i for i in range(diff.dim()) if i != 1) + per_step_max = diff.amax(dim=reduce_dims) if reduce_dims else diff + first_diff_step = None + nonzero = (per_step_max > 0).nonzero(as_tuple=False) + if nonzero.numel() > 0: + first_diff_step = int(nonzero[0].item()) + return { + "prefix_len": prefix_len, + "match": bool(torch.equal(ap, bp)), + "max_abs_diff": float(diff.max()), + "mean_abs_diff": float(diff.mean()), + "first_diff_step": first_diff_step, + } + + +def compare_tokens( + a: torch.Tensor | None, + b: torch.Tensor | None, +) -> dict[str, Any]: + """Compare two ``(B, T)`` token tensors with prefix-aware logic.""" + if a is None or b is None: + return {"match": None, "note": "one or both tensors missing"} + a, b = a.detach().cpu(), b.detach().cpu() + prefix_len = min(a.shape[-1], b.shape[-1]) + if prefix_len == 0: + return {"prefix_len": 0, "match": True} + ap, bp = a[..., :prefix_len], b[..., :prefix_len] + match = bool(torch.equal(ap, bp)) + first_diff_index = None + if not match: + diffs = (ap != bp).flatten().nonzero(as_tuple=False) + if diffs.numel() > 0: + first_diff_index = int(diffs[0].item()) + return {"prefix_len": prefix_len, "match": match, "first_diff_index": first_diff_index} + + +def _merge_incremental_debug_steps(steps: list[dict[str, Any]]) -> dict[str, Any]: + """Merge per-step debug dicts from the pipeline into a single dict.""" + if not steps: + return {} + all_text_logits = [s["text_logits"] for s in steps if s.get("text_logits") is not None] + all_asr_logits = [s["asr_logits"] for s in steps if s.get("asr_logits") is not None] + return { + "text_logits": torch.cat(all_text_logits, dim=1) if all_text_logits else None, + "asr_logits": torch.cat(all_asr_logits, dim=1) if all_asr_logits else None, + } + + +def run_parity_check( + pipeline: StreamingS2SPipeline, + audio_path: str, + *, + system_prompt: str | None = None, +) -> dict[str, Any]: + """Run offline and pipeline-based incremental inference, return comparison. + + The offline path calls :meth:`NemotronVoiceChat.offline_inference` + (the same code path used by ``nemotron_voicechat_eval.py``). + The incremental path runs through :meth:`StreamingS2SPipeline.run`, + which is the real production code path (buffering, framing, prefill). + + Only STT-level tokens and logits are compared; TTS is irrelevant for + the core parity invariant. + """ + import librosa + + wrapper = pipeline.s2s_model + t0 = time.time() + logging.info("=" * 60) + logging.info("PARITY CHECK -- offline vs. incremental (pipeline)") + logging.info("=" * 60) + logging.info(f"Audio : {audio_path}") + logging.info(f"Prompt: {system_prompt or '(none)'}") + + logging.info("Loading audio ...") + audio_np, _ = librosa.load(audio_path, sr=SAMPLE_RATE) + total_frames = math.ceil(len(audio_np) / FRAME_SIZE_SAMPLES) + padded_len = total_frames * FRAME_SIZE_SAMPLES + audio = torch.tensor(audio_np, device=wrapper.device, dtype=wrapper.dtype).unsqueeze(0) + if audio.shape[1] < padded_len: + audio = torch.nn.functional.pad(audio, (0, padded_len - audio.shape[1])) + audio_lens = torch.tensor([audio.shape[1]], device=wrapper.device, dtype=torch.long) + logging.info(f" {len(audio_np)} samples (padded to {audio.shape[1]}), {total_frames} frames, {len(audio_np)/SAMPLE_RATE:.2f}s") + + # -- prompt tokens for offline -- + prompt_tokens = prompt_token_lens = None + if system_prompt: + tok = wrapper.tokenizer + ids = [tok.bos_id] + tok.text_to_ids(system_prompt) + [tok.eos_id] + prompt_tokens = torch.tensor(ids, device=wrapper.device, dtype=torch.long).unsqueeze(0) + prompt_token_lens = torch.tensor([len(ids)], device=wrapper.device, dtype=torch.long) + + # -- ensure speaker info for offline_inference's mandatory TTS init -- + if wrapper.speaker_name is not None: + OmegaConf.update(wrapper.model.cfg, "inference_speaker_name", wrapper.speaker_name, force_add=True) + elif wrapper.speaker_reference: + OmegaConf.update(wrapper.model.cfg, "inference_speaker_reference", wrapper.speaker_reference, force_add=True) + speaker_kw: dict[str, Any] = {} + cfg = wrapper.model.cfg + if not cfg.get("inference_speaker_name", None) and not cfg.get("inference_speaker_reference", None): + speaker_kw["speaker_audio"] = torch.randn(1, 22050, device=wrapper.device) + speaker_kw["speaker_audio_lens"] = torch.tensor([22050], device=wrapper.device, dtype=torch.long) + + # ---- Offline path (eval.py code path) ---- + logging.info("Running offline_inference (with return_logits=True) ...") + t_offline = time.time() + offline = wrapper.model.offline_inference( + input_signal=audio, + input_signal_lens=audio_lens, + prompt_tokens=prompt_tokens, + prompt_token_lens=prompt_token_lens, + decode_audio=False, + return_logits=True, + **speaker_kw, + ) + logging.info(f" offline_inference done in {time.time() - t_offline:.2f}s") + + # ---- Incremental path (pipeline.run) ---- + logging.info("Running incremental inference (pipeline.run) ...") + t_incr = time.time() + pipeline.collect_debug = True + pipeline_output = pipeline.run( + [audio_path], + options=[S2SRequestOptions(system_prompt=system_prompt)], + ) + logging.info(f" pipeline.run done in {time.time() - t_incr:.2f}s") + + # Extract tokens and logits from pipeline output + inc_tokens = pipeline_output.token_texts[0] if pipeline_output.token_texts else None + inc_asr_tokens = pipeline_output.token_asr_texts[0] if pipeline_output.token_asr_texts else None + + incremental_debug = {} + if pipeline_output.debug_data and pipeline_output.debug_data[0]: + incremental_debug = _merge_incremental_debug_steps(pipeline_output.debug_data[0]) + + # ---- Build comparison report ---- + logging.info("Comparing outputs ...") + report: dict[str, Any] = { + "audio_path": audio_path, + "total_frames": total_frames, + "system_prompt": system_prompt, + "offline_text": offline.get("text", [""])[0], + "token_comparison": compare_tokens(offline.get("tokens_text"), inc_tokens), + "asr_token_comparison": compare_tokens(offline.get("tokens_text_src"), inc_asr_tokens), + } + + off_tl = offline.get("text_logits") + inc_tl = incremental_debug.get("text_logits") + if off_tl is not None and inc_tl is not None: + report["text_logit_comparison"] = compare_logits(off_tl, inc_tl) + + off_al = offline.get("asr_logits") + inc_al = incremental_debug.get("asr_logits") + if off_al is not None and inc_al is not None: + report["asr_logit_comparison"] = compare_logits(off_al, inc_al) + + # ---- Summary ---- + elapsed = time.time() - t0 + logging.info("-" * 60) + logging.info("PARITY CHECK SUMMARY") + logging.info("-" * 60) + tc = report.get("token_comparison", {}) + ac = report.get("asr_token_comparison", {}) + logging.info(f" Text tokens match : {tc.get('match')} (prefix_len={tc.get('prefix_len')})") + logging.info(f" ASR tokens match : {ac.get('match')} (prefix_len={ac.get('prefix_len')})") + for tag, key in [("Text logits", "text_logit_comparison"), ("ASR logits", "asr_logit_comparison")]: + lc = report.get(key, {}) + if lc: + logging.info( + f" {tag:16s}: match={lc.get('match')}, " + f"max_abs_diff={lc.get('max_abs_diff', 0):.6e}, " + f"mean_abs_diff={lc.get('mean_abs_diff', 0):.6e}" + ) + logging.info(f" Total time: {elapsed:.2f}s") + logging.info("=" * 60) + + return report + + +def assert_parity( + report: dict[str, Any], + *, + strict: bool = True, + atol: float = 0.0, +) -> None: + """Raise :class:`AssertionError` if parity checks in *report* fail. + + Args: + report: Dict returned by :func:`run_parity_check`. + strict: When ``True``, also assert logit-level match. + atol: Absolute tolerance for logit comparison (only used when *strict*). + """ + failures: list[str] = [] + + tc = report.get("token_comparison", {}) + if tc.get("match") is False: + failures.append(f"text tokens diverge at index {tc.get('first_diff_index')}") + + ac = report.get("asr_token_comparison", {}) + if ac.get("match") is False: + failures.append(f"ASR tokens diverge at index {ac.get('first_diff_index')}") + + if strict: + for key in ("text_logit_comparison", "asr_logit_comparison"): + lc = report.get(key, {}) + if lc and lc.get("match") is False: + max_diff = lc.get("max_abs_diff", float("inf")) + if max_diff > atol: + failures.append( + f"{key}: max_abs_diff={max_diff:.6e} > atol={atol:.6e}, " + f"first diff at step {lc.get('first_diff_step')}" + ) + + if failures: + detail = json.dumps( + {k: v for k, v in report.items() if k != "debug"}, + indent=2, + default=str, + ) + raise AssertionError( + "Offline/incremental parity failed:\n - " + + "\n - ".join(failures) + + f"\n\nFull report:\n{detail}" + ) + + +# --------------------------------------------------------------------------- +# Tiny-model configuration (derived from test_nemotron_voicechat.py) +# --------------------------------------------------------------------------- + +_pretrained_llm = "TinyLlama/TinyLlama_v1.1" +if os.path.exists("/home/TestData/speechlm/pretrained_models"): + _pretrained_llm = "/home/TestData/speechlm/pretrained_models/TinyLlama--TinyLlama_v1.1" + + +def _tiny_voicechat_config() -> dict: + """Return a minimal NemotronVoiceChat config with random weights.""" + return { + "model": { + "scoring_asr": "stt_en_fastconformer_transducer_large", + "stt": { + "model": { + "pretrained_llm": _pretrained_llm, + "pretrained_weights": False, + "predict_user_text": True, + "audio_loss_weight": 1, + "text_loss_weight": 3, + "source_sample_rate": 16000, + "validation_save_path": "/tmp/test_parity_stt_logs", + "perception": { + "_target_": "nemo.collections.speechlm2.modules.perception.AudioPerceptionModule", + "preprocessor": { + "_target_": "nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor", + "features": 80, + }, + "encoder": { + "_target_": "nemo.collections.asr.modules.ConformerEncoder", + "feat_in": 80, + "d_model": 512, + "n_heads": 8, + "n_layers": 1, + "subsampling_factor": 8, + }, + "modality_adapter": { + "_target_": "nemo.collections.speechlm2.modules.perception.IdentityConnector", + "d_model": 512, + }, + "output_dim": 2048, + }, + "optimizer": {"_target_": "torch.optim.AdamW"}, + }, + "data": {"source_sample_rate": 16000}, + "exp_manager": {"explicit_log_dir": "/tmp/test_parity_stt_logs"}, + }, + "speech_generation": { + "model": { + "pretrained_lm_name": _pretrained_llm, + "pretrained_ae_dir": None, + "pretrained_tts_model": None, + "scoring_asr": "stt_en_fastconformer_transducer_large", + "freeze_params": [r"^audio_codec\..+$", r"^embed_tokens\..+$"], + "bos_token": "", + "eos_token": "", + "pad_token": "", + "audio_codec_run_dtype": "float32", + "prevent_freeze_params": [], + "audio_save_path": "", + "inference_guidance_scale": 0.5, + "inference_noise_scale": 0.8, + "inference_top_p_or_k": 0.8, + "inference_guidance_enabled": False, + "subword_mask_exactly_as_eartts": False, + "context_hidden_mask_exactly_as_eartts": False, + "optimizer": { + "_target_": "torch.optim.AdamW", + "lr": 4e-5, + "betas": [0.9, 0.98], + "weight_decay": 0, + "foreach": True, + }, + "lr_scheduler": { + "_target_": "nemo.core.optim.lr_scheduler.InverseSquareRootAnnealing", + "warmup_steps": 2500, + "min_lr": 1e-6, + "max_steps": 100_000_000, + }, + "codec_config": { + "latent_size": 512, + "n_fft": 16, + "hop_length": 4, + "base_hidden_size": 384, + "channel_mult": [1, 2, 4], + "rates": [7, 7, 9], + "num_blocks": 3, + "kernel_size": 7, + "groups": 1, + "codebook_size": 1024, + "num_quantizers": 31, + "wav_to_token_ratio": 1764, + }, + "tts_config": { + "use_gated_fusion_for_text_audio": True, + "disable_eos_prediction": True, + "use_bos_eos_emb": True, + "use_subword_flag_emb": True, + "num_delay_speech_tokens": 2, + "backbone_type": "gemma3_text", + "backbone_model_class": None, + "backbone_config_class": None, + "backbone_config": { + "hidden_size": 1152, + "intermediate_size": 4608, + "num_hidden_layers": 1, + "num_attention_heads": 16, + "num_key_value_heads": 16, + "head_dim": 72, + "attention_dropout": 0.1, + "use_cache": False, + }, + "latent_size": 512, + "codebook_size": 1024, + "num_quantizers": 31, + "context_hidden_size": None, + "cas_config": { + "backbone_type": "t5gemma", + "backbone_model_class": None, + "backbone_config_class": None, + "backbone_config": { + "is_encoder_decoder": False, + "encoder": { + "hidden_size": 1152, + "intermediate_size": 4608, + "num_hidden_layers": 1, + "num_attention_heads": 16, + "num_key_value_heads": 16, + "head_dim": 72, + "use_cache": False, + "attention_dropout": 0.1, + }, + }, + }, + "mog_head_config": { + "intermediate_size": 4608, + "num_layers": 3, + "low_rank": 64, + "num_predictions": 1024, + "min_log_std": -4.0, + "eps": 1e-6, + }, + "p_uncond": 0.1, + "label_smoothing": 0.01, + "max_training_rate": 0.8, + "quantizer_dropout": 0.5, + "random_target_masking": False, + "exponent": 3.0, + }, + }, + "data": { + "add_text_bos_and_eos_in_each_turn": True, + "add_audio_prompt": True, + "audio_prompt_duration": 3.0, + "frame_length": 0.08, + "source_sample_rate": 16000, + "target_sample_rate": 22050, + }, + "exp_manager": {"explicit_log_dir": "/tmp/test_parity_tts_logs"}, + }, + }, + "data": { + "frame_length": 0.08, + "source_sample_rate": 16000, + "target_sample_rate": 22050, + "input_roles": ["user", "User"], + "output_roles": ["agent", "Assistant", "assistant", "Agent"], + }, + "exp_manager": {"explicit_log_dir": "/tmp/test_parity_logs"}, + } + + +def _build_parity_pipeline( + model_path: str, + audio_path: str, + output_dir: str, + *, + speaker_name: str | None = None, +) -> StreamingS2SPipeline: + """Build a :class:`StreamingS2SPipeline` configured for strict parity testing.""" + import librosa + + audio_np, _ = librosa.load(audio_path, sr=SAMPLE_RATE) + total_frames = math.ceil(len(audio_np) / FRAME_SIZE_SAMPLES) + chunk_secs = total_frames * FRAME_SIZE_SAMPLES / SAMPLE_RATE + + speaker_kw = {} + if speaker_name: + speaker_kw["speaker_name"] = speaker_name + + pipeline_cfg = OmegaConf.create( + { + "output_dir": output_dir, + "s2s": { + "model_path": model_path, + **speaker_kw, + "compute_dtype": "float32", + "engine_type": "native", + "deterministic": True, + "use_perception_cache": False, + "use_perception_cudagraph": False, + "use_llm_cache": False, + "top_p": 1.0, + "repetition_penalty": 1.0, + "temperature": 0.0, + "decode_audio": False, + }, + "streaming": { + "input_sample_rate": SAMPLE_RATE, + "output_sample_rate": 22050, + "batch_size": 1, + "att_context_size": [70, 0], + "chunk_size_in_secs": chunk_secs, + "buffer_size_in_secs": max(71 * 0.08, chunk_secs), + "request_type": "frame", + "max_len": 8192, + }, + } + ) + return S2SPipelineBuilder.build_pipeline(pipeline_cfg) + + +# --------------------------------------------------------------------------- +# Parity test -- tiny model (no real checkpoint needed) +# --------------------------------------------------------------------------- + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU") +def test_parity_tiny_model(tmp_path): + """Offline/incremental parity with a tiny random-weight model. + + Saves the model as an HF checkpoint, then loads it through the real + ``S2SPipelineBuilder`` so the test exercises the same code path as + ``test_parity_real_checkpoint``. + """ + import json as _json + + audio_path = str(tmp_path / "test_audio.wav") + sf.write(audio_path, np.random.RandomState(42).randn(16000).astype(np.float32), 16000) + + cfg = _tiny_voicechat_config() + model = NemotronVoiceChat(cfg) + model.to("cuda") + model.eval() + + model_dir = str(tmp_path / "model") + model.save_pretrained(model_dir) + with open(os.path.join(model_dir, "config.json"), "w") as f: + _json.dump(cfg, f) + del model + torch.cuda.empty_cache() + + pipeline = _build_parity_pipeline(model_dir, audio_path, str(tmp_path / "output")) + report = run_parity_check(pipeline, audio_path) + assert_parity(report, strict=True, atol=0.0) + + +# --------------------------------------------------------------------------- +# Integration test -- real checkpoint (skipped when env vars are not set) +# --------------------------------------------------------------------------- + + +def _real_checkpoint_available() -> bool: + path = os.environ.get("PARITY_CHECKPOINT_PATH", "") + audio = os.environ.get("PARITY_AUDIO_PATH", "") + return bool(path) and os.path.isdir(path) and bool(audio) and os.path.isfile(audio) + + +@pytest.mark.skipif(not _real_checkpoint_available(), reason="set PARITY_CHECKPOINT_PATH and PARITY_AUDIO_PATH") +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU") +def test_parity_real_checkpoint(): + """Parity check using a real exported checkpoint. + + Configure via environment variables:: + + PARITY_CHECKPOINT_PATH=/path/to/exported/checkpoint + PARITY_AUDIO_PATH=/path/to/test.wav + PARITY_SPEAKER_NAME= # optional + """ + import tempfile + + ckpt = os.environ["PARITY_CHECKPOINT_PATH"] + audio = os.environ["PARITY_AUDIO_PATH"] + speaker = os.environ.get("PARITY_SPEAKER_NAME") + + pipeline = _build_parity_pipeline( + ckpt, audio, tempfile.mkdtemp(prefix="parity-"), speaker_name=speaker, + ) + report = run_parity_check(pipeline, audio) + assert_parity(report, strict=True, atol=0.0) From a918f7fc268ccaabea87ff1d021f66cd41ea6ef0 Mon Sep 17 00:00:00 2001 From: Elena Rastorgueva Date: Wed, 25 Mar 2026 00:26:50 +0000 Subject: [PATCH 17/40] refactor streaming S2S pipeline: extract helpers, factor infer_one_step, add docs Signed-off-by: Elena Rastorgueva --- docs/source/speechlm2/intro.rst | 31 +- docs/source/speechlm2/streaming_inference.rst | 304 ++++++++++ .../nemotron_voicechat_inference_wrapper.py | 530 +++++++++--------- .../pipelines/streaming_s2s_pipeline.py | 281 ++++++---- .../streaming/framing/s2s_request_options.py | 2 +- .../test_offline_incremental_parity.py | 290 ++++------ 6 files changed, 873 insertions(+), 565 deletions(-) create mode 100644 docs/source/speechlm2/streaming_inference.rst diff --git a/docs/source/speechlm2/intro.rst b/docs/source/speechlm2/intro.rst index 94f0426d575b..ff4158002963 100644 --- a/docs/source/speechlm2/intro.rst +++ b/docs/source/speechlm2/intro.rst @@ -246,7 +246,35 @@ You can evaluate and run full-duplex inference using the `NemotronVoiceChat` pip print(f"Agent response: {generated_text}") # generated_speech can now be saved or played (sampled at model.target_sample_rate) - + +NemotronVoiceChat Streaming Inference +************************************* + +For real-time, chunk-by-chunk inference (as opposed to the offline mode shown +above), use the Streaming S2S Pipeline: + +.. code-block:: python + + from nemo.collections.speechlm2.inference import S2SPipelineBuilder + + pipeline = S2SPipelineBuilder.build_pipeline(cfg) + output = pipeline.run(audio_filepaths, options=options) + +Or from the command line: + +.. code-block:: bash + + python examples/speechlm2/nemo_inference_pipelines/s2s_streaming_infer.py \ + audio_file=/path/to/audio \ + s2s.model_path=/path/to/checkpoint \ + s2s.speaker_name="" \ + s2s.engine_type=native \ + s2s.system_prompt="You are a helpful assistant." \ + streaming.chunk_size_in_secs=0.24 \ + streaming.buffer_size_in_secs=1.68 + +See :doc:`streaming_inference` for full details on configuration, architecture, +and server integration. Training a Model ---------------- @@ -341,3 +369,4 @@ For more information, see additional sections in the SpeechLM2 docs: datasets configs training_and_scaling + streaming_inference diff --git a/docs/source/speechlm2/streaming_inference.rst b/docs/source/speechlm2/streaming_inference.rst new file mode 100644 index 000000000000..2313cd553315 --- /dev/null +++ b/docs/source/speechlm2/streaming_inference.rst @@ -0,0 +1,304 @@ +Streaming Inference +=================== + +The speechlm2 collection provides a streaming inference pipeline for +NemotronVoiceChat that processes audio in real time, chunk by chunk, and +produces both text and speech output incrementally. The pipeline follows the +same methodology as the NeMo ASR Inference Pipelines (see +``nemo.collections.asr.inference``). + +Overview +-------- + +The streaming inference stack has four layers: + +.. code-block:: text + + Entry Script s2s_streaming_infer.py (Hydra) + │ + ▼ + Pipeline StreamingS2SPipeline + │ - audio buffering + │ - state management + │ - file I/O + ▼ + Model Wrapper NemotronVoicechatInferenceWrapper + │ - infer_one_step() + │ - perception / LLM / TTS / codec + ▼ + Model NemotronVoiceChat + - DuplexSTTModel + DuplexEARTTS + +Quick Start +----------- + +Batch Inference from a Script +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The simplest way to run streaming inference is with the provided Hydra script: + +.. code-block:: bash + + python examples/speechlm2/nemo_inference_pipelines/s2s_streaming_infer.py \ + --config-path=examples/speechlm2/nemo_inference_pipelines/conf \ + --config-name=s2s_streaming \ + audio_file=/path/to/audio_or_directory_or_manifest.json \ + output_dir=./generated \ + s2s.model_path=/path/to/checkpoint \ + s2s.speaker_name="" \ + s2s.engine_type=native \ + s2s.system_prompt="You are a helpful assistant." \ + streaming.chunk_size_in_secs=0.24 \ + streaming.buffer_size_in_secs=1.68 + +This will: + +1. Load the NemotronVoiceChat checkpoint. +2. Stream each audio file through the pipeline in chunks. +3. Save generated ``.wav``, stereo (input+output), and ``.txt`` files under + ``output_dir``. + +Programmatic Usage +^^^^^^^^^^^^^^^^^^ + +.. code-block:: python + + from nemo.collections.speechlm2.inference import S2SPipelineBuilder + + pipeline = S2SPipelineBuilder.build_pipeline(cfg) + output = pipeline.run(audio_filepaths, options=options) + + # output.texts -- generated agent text per file + # output.asr_texts -- recognized user text per file + # output.audio_filepaths -- paths to generated .wav files + + +Architecture +------------ + +The Core Loop +^^^^^^^^^^^^^ + +Like the ASR pipeline's ``BasePipeline.run()``, the S2S pipeline iterates +over chunks and calls a single step method: + +.. code-block:: python + + pipeline.open_session() + for frames in streamer: + pipeline.generate_step(frames) + pipeline.close_session() + return PipelineOutput(...) + +``generate_step()`` is the unified entry point used by **both** the batch +``run()`` method and server deployments. + + +What Happens Inside One Step +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Each call to ``generate_step(frames)`` performs: + +1. **Prefill detection** -- A zero-length first frame with a system prompt + triggers ``prefill_for_new_stream()``, which initializes the LLM KV cache + with the system prompt and the TTS speaker embedding. + +2. **Audio buffering** -- ``BatchedAudioBufferer`` (reused from ASR + infrastructure) maintains a sliding window of ``buffer_size_in_secs``. + +3. **Model inference** via ``infer_one_step(audio_buffer, state)``: + + a. **Perception** -- The audio buffer is encoded by the streaming + FastConformer encoder into frame embeddings. + b. **Per-frame LLM loop** -- For each of the ``num_frames_per_chunk`` + frames, the pipeline builds an input embedding (user audio + + previous-step text/ASR tokens), runs it through the LLM, and obtains + predicted text and ASR tokens. + c. **Per-frame TTS** -- Each predicted text token is fed into the EarTTS + model to produce audio codec codes. + d. **Codec decode** -- The accumulated codes are decoded into a waveform. + +4. **State updates** -- The context manager advances ``frame_idx`` and + updates the subword mask. + +5. **Output accumulation** -- Decoded audio and text are appended to the + per-stream ``S2SStreamingState``. + + +Two Kinds of State +^^^^^^^^^^^^^^^^^^ + +The pipeline maintains two separate state objects per stream: + +**StreamingDecodeState** (model level) + Lives in ``S2SContextManager`` slots. Contains LLM KV cache, TTS KV + cache, perception cache, codec cache, token workspaces (``gen_text``, + ``gen_asr_text``), and ``frame_idx``. Created by the wrapper, mutated + in-place by ``infer_one_step()``, destroyed at end-of-stream. + +**S2SStreamingState** (pipeline level) + Lives in the pipeline's ``_state_pool``. Accumulates generated audio + chunks, text strings, and word timings across steps. Kept alive until + ``close_session()`` so the final ``PipelineOutput`` can be assembled. + + +Configuration +------------- + +The streaming inference configuration is defined in +``examples/speechlm2/nemo_inference_pipelines/conf/s2s_streaming.yaml``. + +Key configuration groups: + +S2S Model Settings (``s2s``) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. list-table:: + :header-rows: 1 + :widths: 30 15 55 + + * - Parameter + - Default + - Description + * - ``model_path`` + - (required) + - Path to the NemotronVoiceChat HuggingFace checkpoint. + * - ``engine_type`` + - (required) + - ``native``, ``vllm_llm``, ``vllm_eartts``, or + ``vllm_llm_vllm_eartts``. + * - ``speaker_name`` + - ``null`` + - Registered speaker name (must match a speaker in the checkpoint). + * - ``system_prompt`` + - (required) + - Text injected into the LLM KV cache before audio streaming begins. + * - ``compute_dtype`` + - ``bfloat16`` + - Precision for LLM/embedding layers. + * - ``use_perception_cache`` + - ``true`` + - Cache-aware streaming for the perception encoder. + * - ``use_llm_cache`` + - ``true`` + - Use KV cache for incremental LLM decoding. + * - ``top_p`` + - ``0.5`` + - Top-p sampling threshold. + * - ``temperature`` + - ``0.3`` + - Sampling temperature. + * - ``deterministic`` + - ``false`` + - Force deterministic mode (native engine only). + * - ``profile_timing`` + - ``false`` + - Insert ``torch.cuda.synchronize()`` around each stage for accurate + per-stage timing. Disabled by default to avoid GPU stalls. + +Streaming Settings (``streaming``) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. list-table:: + :header-rows: 1 + :widths: 30 15 55 + + * - Parameter + - Default + - Description + * - ``chunk_size_in_secs`` + - (required) + - Audio processed per inference step. Must be a multiple of 0.08 s. + * - ``buffer_size_in_secs`` + - (required) + - Sliding-window size passed to the perception encoder. + * - ``batch_size`` + - ``1`` + - Number of concurrent streams (currently only 1 is supported). + * - ``max_len`` + - ``8192`` + - Maximum number of frames per stream. + +Padding Settings (top-level) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +At most one of these may be set: + +.. list-table:: + :header-rows: 1 + :widths: 30 15 55 + + * - Parameter + - Default + - Description + * - ``pad_audio_to_sec`` + - ``null`` + - Pad each input to a fixed duration. + * - ``pad_silence_ratio`` + - ``null`` + - Append silence equal to this fraction of the original duration. + * - ``pad_audio_by_sec`` + - ``null`` + - Append a fixed number of extra seconds of silence. + + +Server Integration +------------------ + +The same ``generate_step()`` method used by ``run()`` can be called directly +from a custom server. The zero-length Frame protocol handles prefill: + +.. code-block:: python + + # 1. Prefill system prompt (zero-length frame) + prefill_frame = Frame( + samples=torch.empty(0), + stream_id=stream_id, + is_first=True, is_last=False, + options=S2SRequestOptions(system_prompt=prompt), + ) + pipeline.generate_step([prefill_frame]) + + # 2. Stream audio chunks + for chunk in audio_source: + frame = Frame( + samples=chunk, + stream_id=stream_id, + is_first=(i == 0), is_last=(i == last), + ) + pipeline.generate_step([frame]) + + +Batch Size +---------- + +The pipeline currently supports ``batch_size=1`` (one stream at a time). + + +File Layout +----------- + +.. code-block:: text + + nemo/collections/speechlm2/inference/ + ├── __init__.py # Public exports + ├── factory/ + │ └── s2s_pipeline_builder.py # S2SPipelineBuilder + ├── pipelines/ + │ ├── s2s_pipeline_interface.py # Base: _state_pool, sessions + │ └── streaming_s2s_pipeline.py # StreamingS2SPipeline + ├── model_wrappers/ + │ ├── decode_state.py # StreamingDecodeState, InferenceStepResult + │ ├── nemotron_voicechat_inference_wrapper.py + │ ├── model_factory.py # Native / vLLM model interfaces + │ └── perception_cache.py # Perception cache + CUDA graphs + ├── streaming/ + │ ├── framing/ + │ │ └── s2s_request_options.py # S2SRequestOptions + │ └── state/ + │ ├── s2s_state.py # S2SStreamingState + │ └── s2s_context_manager.py # Slot-based decode-state lifecycle + ├── utils/ + │ ├── pipeline_utils.py # PipelineOutput, text helpers + │ └── audio_data.py # Manifest / folder loading + └── vllm/ # Optional vLLM engine backend diff --git a/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py b/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py index f04e438bb375..38fd7688ae40 100755 --- a/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py +++ b/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py @@ -23,7 +23,7 @@ import torchaudio from omegaconf import OmegaConf, DictConfig -from nemo.utils import logging +from nemo.utils import logging, str_to_dtype from transformers import DynamicCache from nemo.collections.speechlm2.models.nemotron_voicechat import NemotronVoiceChat @@ -109,13 +109,17 @@ def __init__(self, model_cfg: DictConfig): self.tts_system_prompt = model_cfg.get("tts_system_prompt", None) logging.info(f"TTS system prompt: {self.tts_system_prompt}") - compute_dtype = model_cfg.get("compute_dtype", "bfloat16") - self.dtype = self._resolve_dtype(compute_dtype) + self.dtype = str_to_dtype(model_cfg.get("compute_dtype", "bfloat16")) - self.device = self._resolve_device( - device=model_cfg.get("device"), - device_id=model_cfg.get("device_id"), - ) + device = model_cfg.get("device") + device_id = model_cfg.get("device_id") + if device is None: + self.device = DEFAULT_DEVICE + else: + device_str = str(device) + if device_id is not None and device_str.startswith("cuda") and ":" not in device_str: + device_str = f"{device_str}:{device_id}" + self.device = torch.device(device_str) logging.info("=" * 70) logging.info("INITIALIZING REALTIME STREAMING INFERENCE") @@ -130,6 +134,11 @@ def __init__(self, model_cfg: DictConfig): logging.info(f"Precision (effective): float32_matmul_precision={torch.get_float32_matmul_precision()}, cudnn.allow_tf32={torch.backends.cudnn.allow_tf32}, cuda.matmul.allow_tf32={torch.backends.cuda.matmul.allow_tf32}") logging.info("=" * 70) + # Profiling: when True, insert torch.cuda.synchronize() around each + # stage for accurate per-stage wall-clock timing. Disabled by default + # to avoid unnecessary GPU stalls in production. + self._profile_timing = bool(model_cfg.get("profile_timing", False)) + # Cached TTS helpers populated during initialization/warmup self.first_context_subword_id = None self.generation_config = None @@ -173,63 +182,6 @@ def __init__(self, model_cfg: DictConfig): logging.info("NemotronVoicechatInferenceWrapper initialized successfully.") - @staticmethod - def _resolve_dtype(compute_dtype): - if isinstance(compute_dtype, torch.dtype): - return compute_dtype - if compute_dtype is None: - return torch.bfloat16 - if isinstance(compute_dtype, str): - key = compute_dtype.lower() - mapping = { - "bfloat16": torch.bfloat16, - "bf16": torch.bfloat16, - "float16": torch.float16, - "fp16": torch.float16, - "half": torch.float16, - "float32": torch.float32, - "fp32": torch.float32, - "full": torch.float32, - } - if key in mapping: - return mapping[key] - raise ValueError(f"Unsupported compute_dtype: {compute_dtype}") - - @staticmethod - def _resolve_device(device=None, device_id=None): - if isinstance(device, torch.device): - resolved_device = device - else: - if device is None: - resolved_device = DEFAULT_DEVICE - else: - device_str = str(device) - base = device_str - if device_id is not None and device_str.startswith("cuda") and ":" not in device_str: - base = f"{device_str}:{device_id}" - resolved_device = torch.device(base) - return resolved_device - - def _samples_per_audio_output_frame(self): - rate = getattr(self, "target_sample_rate", None) - if rate is None: - cfg_rate = None - try: - cfg_rate = self.model_cfg.get("tts_sample_rate", None) - except Exception: - cfg_rate = None - if cfg_rate is None: - try: - cfg_rate = self.model_cfg.get("output_sample_rate", None) - except Exception: - cfg_rate = None - if cfg_rate is not None: - rate = float(cfg_rate) - if rate is None: - rate = TTS_SAMPLE_RATE - samples = int(float(rate) * FRAME_SIZE_SEC) - return samples - def _initialize_model(self): """Initialize the NemotronVoiceChat model from an HF checkpoint.""" logging.info("Initializing model structure...") @@ -596,220 +548,80 @@ def infer_one_step( frame_idx = state.frame_idx start_time_one_step = time.time() - use_cache = state.llm_cache is not None - batch_size = state.gen_text.shape[0] + use_llm_cache = state.llm_cache is not None + B = state.gen_text.shape[0] + + predicted_tokens = torch.empty((B, num_frames_per_chunk), dtype=state.gen_text.dtype, device=state.gen_text.device) + asr_predicted_tokens = torch.empty((B, num_frames_per_chunk), dtype=state.gen_text.dtype, device=state.gen_text.device) + function_predicted_tokens = None + if self.model.stt_model.function_head is not None: + function_predicted_tokens = torch.empty((B, num_frames_per_chunk), dtype=state.gen_text.dtype, device=state.gen_text.device) - predicted_tokens = torch.empty((batch_size, num_frames_per_chunk), dtype=state.gen_text.dtype, device=state.gen_text.device) - asr_predicted_tokens = torch.empty((batch_size, num_frames_per_chunk), dtype=state.gen_text.dtype, device=state.gen_text.device) - function_predicted_tokens = torch.empty((batch_size, num_frames_per_chunk), dtype=state.gen_text.dtype, device=state.gen_text.device) - debug_text_logits = [] - debug_asr_logits = [] - debug_input_embeds = [] - selected_frame_indices = [] + debug_text_logits, debug_asr_logits, debug_input_embeds, selected_frame_indices = [], [], [], [] # --- Stage 1: Perception --- source_encoded, state.perception_cache = self._run_perception( audio_input, frame_idx, num_frames_per_chunk, state.perception_cache, ) total_encoded_frames = source_encoded.shape[1] - if self.use_perception_cache and state.perception_cache is not None and state.perception_cache.is_initialized(): - # With cache: we get exactly num_frames_per_chunk output frames — use all directly + # With cache: we get exactly num_frames_per_chunk output frames base_frame_index = 0 else: - # Without cache: Use the second-to-last encoded frame (-2) as the "newest" frame embedding. - # This is because the model expects the chunk sizes to be size 10ms, 80ms, 80ms, 80ms, ...., - # but we pass in always 80ms, 80ms, 80ms.... - # e.g. - # (1) if we pass in just one 80ms chunk -> the model treats it as 10ms, then 70ms with 10ms silence padding at the end. - # (2) if we pass 80ms, 80ms -> the model treats it as 10ms, 80ms, 70ms with 10ms silence padding at the end. - # => we do not want to use the final embedding due to containing silence padding. We want to use the second-to-last embedding. - newest_frame_index = total_encoded_frames - 2 - base_frame_index = max(newest_frame_index - (num_frames_per_chunk - 1), 0) + # Without cache: use the second-to-last encoded frame as the + # "newest" because the model expects 10ms / 80ms / 80ms ... framing + # but we always feed 80ms chunks, so the final frame contains + # silence padding. + newest = total_encoded_frames - 2 + base_frame_index = max(newest - (num_frames_per_chunk - 1), 0) # --- Stage 2: Per-frame generation loop --- new_input_embeds = [] new_codes_for_decode = [] for frame_offset in range(num_frames_per_chunk): current_frame_idx = frame_idx + frame_offset - current_frame_index = min(base_frame_index + frame_offset, total_encoded_frames - 1) + current_frame_index = min(base_frame_index + frame_offset, source_encoded.shape[1] - 1) selected_frame_indices.append(current_frame_index) - current_frame_embedding = source_encoded[:, current_frame_index:current_frame_index + 1, :] + frame_embedding = source_encoded[:, current_frame_index:current_frame_index + 1, :] - current_input_emb = current_frame_embedding.clone() - current_input_emb *= self.model.stt_model.cfg.get("duplex_user_channel_weight", 1.0) - - has_fc = state.gen_function_text is not None - - if current_frame_idx == 0 and not has_prompt: - # Only add BOS if there's no prompt (BOS is already in prompt's position 0) - current_input_emb += self._get_bos_embedding() * self.model.stt_model.cfg.get( - "duplex_text_channel_weight", 1.0 - ) - current_input_emb += self._get_asr_bos_embedding() * self.model.stt_model.cfg.get( - "duplex_asr_text_weight", 1.0 - ) - if has_fc: - pad_id = self.model.stt_model.text_pad_id - fc_pad_token = torch.full((1,), fill_value=pad_id, device=self.device, dtype=torch.long) - current_input_emb += self.model.stt_model.embed_tokens(fc_pad_token).to(dtype=self.dtype) - elif current_frame_idx == 0 and has_prompt: - # With prompt: first audio frame uses pad embedding (like offline_inference) - pad_id = self.model.stt_model.text_pad_id - pad_token = torch.full((1,), fill_value=pad_id, device=self.device, dtype=torch.long) - pad_emb = self.model.stt_model.embed_tokens(pad_token).to(dtype=self.dtype) - pad_asr_emb = self.model.stt_model.embed_asr_tokens(pad_token).to(dtype=self.dtype) - current_input_emb += pad_emb - current_input_emb += pad_asr_emb - if has_fc: - current_input_emb += self.model.stt_model.embed_tokens(pad_token).to(dtype=self.dtype) - else: - # t > 0: add embeddings from model's own predictions at t-1 - last_token_emb = self.model.stt_model.embed_tokens( - state.gen_text[:, current_frame_idx - 1] - ) * self.model.stt_model.cfg.get("duplex_text_channel_weight", 1.0) - last_asr_token_emb = self.model.stt_model.embed_asr_tokens( - state.gen_asr_text[:, current_frame_idx - 1] - ) * self.model.stt_model.cfg.get("duplex_asr_text_weight", 1.0) - current_input_emb += last_token_emb + last_asr_token_emb - if has_fc: - last_fc_token_emb = self.model.stt_model.embed_tokens(state.gen_function_text[:, current_frame_idx - 1]) - current_input_emb += last_fc_token_emb.to(dtype=self.dtype) + input_emb = self._build_input_embedding( + frame_embedding, current_frame_idx, state, has_prompt, + ) if return_debug: - debug_input_embeds.append(current_input_emb.detach().cpu()) - - start_stt_model = time.time() - - if use_cache or self.use_vllm_llm: - if self.use_vllm_llm: - ans = self.model_llm_interface( - current_input_emb, - request_id=effective_request_id, - generated_tokens=state.gen_text, - current_step=current_frame_idx - ) - else: - cache_pos = torch.tensor( - [state.llm_cache_position_offset + frame_offset], device=self.device - ) - ans = self.model_llm_interface( - current_input_emb, - cache=state.llm_cache, - cache_position=cache_pos, - generated_tokens=state.gen_text, - current_step=current_frame_idx, - return_logits=return_debug, - ) - state.llm_cache = ans["cache"] - else: - new_input_embeds.append(current_input_emb) - full_input_embeds = torch.cat(state.input_embeds_history + new_input_embeds, dim=1) - ans = self.model_llm_interface( - full_input_embeds, - cache=None, - generated_tokens=state.gen_text, - current_step=current_frame_idx, - return_logits=return_debug, - ) + debug_input_embeds.append(input_emb.detach().cpu()) - torch.cuda.synchronize() - time_stt_model = time.time() - start_stt_model - logging.info(f"Time taken for stt_model: {time_stt_model:.3f}s") + ans = self._run_llm_step( + input_emb, state, frame_offset, effective_request_id, + current_frame_idx, use_llm_cache, return_debug, new_input_embeds, + ) - predicted_token = ans["predicted_token"] - asr_predicted_token = ans["asr_predicted_token"] if return_debug and "text_logits" in ans: debug_text_logits.append(ans["text_logits"][:, -1].detach().cpu()) if return_debug and "asr_logits" in ans and ans["asr_logits"] is not None: debug_asr_logits.append(ans["asr_logits"][:, -1].detach().cpu()) - state.gen_text[:, current_frame_idx] = predicted_token - predicted_tokens[:, frame_offset] = predicted_token - - state.gen_asr_text[:, current_frame_idx] = asr_predicted_token - asr_predicted_tokens[:, frame_offset] = asr_predicted_token + state.gen_text[:, current_frame_idx] = ans["predicted_token"] + predicted_tokens[:, frame_offset] = ans["predicted_token"] + state.gen_asr_text[:, current_frame_idx] = ans["asr_predicted_token"] + asr_predicted_tokens[:, frame_offset] = ans["asr_predicted_token"] if "function_predicted_token" in ans: - function_predicted_tokens[:, frame_offset] = ans["function_predicted_token"] + if function_predicted_tokens is not None: + function_predicted_tokens[:, frame_offset] = ans["function_predicted_token"] if state.gen_function_text is not None: state.gen_function_text[:, current_frame_idx] = ans["function_predicted_token"] - # Apply forced turn taking based on ASR results self._maybe_apply_forced_turn_taking(current_frame_idx, state.gen_text, state.gen_asr_text) - # Update predicted_tokens with any changes made by forced turn taking predicted_tokens[:, frame_offset] = state.gen_text[:, current_frame_idx] if self.decode_audio: - current_subword_id = state.gen_text[:, current_frame_idx].unsqueeze(-1) - - if current_frame_idx == 0: - if self.first_context_subword_id is None: - raise RuntimeError("first_context_subword_id is not initialized. Ensure TTS warmup ran successfully.") - prev_subword_id = self.first_context_subword_id - else: - prev_subword_id = state.gen_text[:, current_frame_idx-1].unsqueeze(-1) - - current_subword_mask = state.subword_mask[:, current_frame_idx].unsqueeze(-1) - - if self.generation_config is None: - raise RuntimeError("generation_config is not initialized. Ensure TTS warmup ran successfully.") - - start_tts_model = time.time() - inputs = { - "current_subword_id": current_subword_id, - "prev_subword_id": prev_subword_id, - "current_subword_mask": current_subword_mask, - "prev_audio_tokens": state.tts_code, - "past_key_values": state.tts_past_key_values, - "guidance_enabled": True, - "generation_config": self.generation_config, - "ignore_eos_flag_stop": True, - } - if self.use_vllm_eartts: - inputs["request_id"] = effective_request_id - - state.tts_code, state.tts_past_key_values = self.model.tts_model.infer_codes_one_step(**inputs) - - torch.cuda.synchronize() - time_tts_model = time.time() - start_tts_model - logging.info(f"Time taken for tts_model: {time_tts_model:.3f}s") - - new_codes_for_decode.append(state.tts_code.clone()) - - # Potentially overwrite the audio token with silence tokens (for feeding to the audio token predictor) - if self.model.cfg.get('inference_force_speech_silence_on_eos', None): - silence_codes = self.model.tts_model.codec_silence_tokens.view(1, 1, -1).expand(state.tts_code.shape) - state.tts_code = torch.where( - current_subword_id.unsqueeze(-1) == self.model.tts_model.text_eos_id, - silence_codes, - state.tts_code, - ) - - # --- Stage 3: Audio decode --- - decoded_audio_new = None - if self.decode_audio: - logging.info(f"\nDecoding audio for {frame_idx}-th frame ({num_frames_per_chunk=})") - - start_time_decode = time.time() - with fp32_precision(), torch.no_grad(): - new_codes_tensor = torch.cat(new_codes_for_decode, dim=1) - if hasattr(self.model.tts_model, '_control_codes'): - from nemo.collections.speechlm2.models.duplex_ear_tts import replace_control_speech_codes - new_codes_tensor = replace_control_speech_codes( - new_codes_tensor, - self.model.tts_model._control_codes, - getattr(self.model.tts_model, 'codec_silence_tokens', None), - ) - new_code_len = torch.tensor( - [new_codes_tensor.shape[1]], dtype=torch.long, device=self.device - ) - decoded_audio_new, _ = self.model.tts_model.audio_codec.decode( - new_codes_tensor, new_code_len, cache=state.tts_codec_cache, + new_code = self._run_tts_step( + state, current_frame_idx, effective_request_id, ) + new_codes_for_decode.append(new_code) - torch.cuda.synchronize() - time_audio_codec = time.time() - start_time_decode - logging.info(f"Time taken for audio_codec: {time_audio_codec:.3f}s") + # --- Stage 3: Audio decode --- + decoded_audio_new = self._decode_audio(new_codes_for_decode, state, frame_idx, num_frames_per_chunk) # --- Stage 4: Token -> string conversion --- predicted_text_strs = self._tokens_to_strings(predicted_tokens) @@ -819,14 +631,15 @@ def infer_one_step( logging.info(f'frame {frame_idx}: AGENT txt: {predicted_text_strs}') # --- Update remaining state fields --- - if not use_cache: + if not use_llm_cache: state.input_embeds_history = state.input_embeds_history + new_input_embeds - if use_cache: + if use_llm_cache: state.llm_cache_position_offset += num_frames_per_chunk - torch.cuda.synchronize() - time_for_one_step = time.time() - start_time_one_step - logging.info(f'frame {frame_idx}: Time taken for one step: {time_for_one_step:.3f}s') + if self._profile_timing: + torch.cuda.synchronize() + time_for_one_step = time.time() - start_time_one_step + logging.info(f'frame {frame_idx}: Time taken for one step: {time_for_one_step:.3f}s') debug = None if return_debug: @@ -840,17 +653,230 @@ def infer_one_step( "asr_logits": torch.stack(debug_asr_logits, dim=1) if debug_asr_logits else None, } - func_tokens = function_predicted_tokens if self.model.stt_model.function_head is not None else None return InferenceStepResult( predicted_text_tokens=predicted_tokens, asr_predicted_text_tokens=asr_predicted_tokens, predicted_text_strs=predicted_text_strs, asr_predicted_text_strs=asr_predicted_text_strs, decoded_audio=decoded_audio_new, - function_predicted_text_tokens=func_tokens, + function_predicted_text_tokens=function_predicted_tokens, debug=debug, ) + # ------------------------------------------------------------------ + # infer_one_step sub-stages + # ------------------------------------------------------------------ + + def _build_input_embedding( + self, + frame_embedding: torch.Tensor, + current_frame_idx: int, + state: StreamingDecodeState, + has_prompt: bool, + ) -> torch.Tensor: + """Compose the LLM input embedding for a single frame. + + Combines the perception embedding (user channel) with the text / + ASR / function-call channel embeddings from the previous step. + At frame 0 this is either BOS (no prompt) or pad (after prompt). + + IMPORTANT: The arithmetic order here must match offline_inference + exactly (floating-point addition is not associative). For t > 0 + the text and ASR embeddings are summed first, then added to the + perception embedding in a single ``+=``. For t == 0 the sequential + ``+=`` pattern matches the offline path. + """ + stt = self.model.stt_model + emb = frame_embedding.clone() + emb *= stt.cfg.get("duplex_user_channel_weight", 1.0) + + has_fc = state.gen_function_text is not None + + if current_frame_idx == 0 and not has_prompt: + emb += self._get_bos_embedding() * stt.cfg.get("duplex_text_channel_weight", 1.0) + emb += self._get_asr_bos_embedding() * stt.cfg.get("duplex_asr_text_weight", 1.0) + if has_fc: + pad_token = torch.full((1,), fill_value=stt.text_pad_id, device=self.device, dtype=torch.long) + emb += stt.embed_tokens(pad_token).to(dtype=self.dtype) + + elif current_frame_idx == 0 and has_prompt: + pad_token = torch.full((1,), fill_value=stt.text_pad_id, device=self.device, dtype=torch.long) + emb += stt.embed_tokens(pad_token).to(dtype=self.dtype) + emb += stt.embed_asr_tokens(pad_token).to(dtype=self.dtype) + if has_fc: + emb += stt.embed_tokens(pad_token).to(dtype=self.dtype) + + else: + # Sum text + ASR first, then add once (must match offline operation order) + prev = current_frame_idx - 1 + last_token_emb = stt.embed_tokens( + state.gen_text[:, prev] + ) * stt.cfg.get("duplex_text_channel_weight", 1.0) + last_asr_token_emb = stt.embed_asr_tokens( + state.gen_asr_text[:, prev] + ) * stt.cfg.get("duplex_asr_text_weight", 1.0) + emb += last_token_emb + last_asr_token_emb + if has_fc: + emb += stt.embed_tokens(state.gen_function_text[:, prev]).to(dtype=self.dtype) + + return emb + + def _run_llm_step( + self, + input_emb: torch.Tensor, + state: StreamingDecodeState, + frame_offset: int, + request_id: str, + current_frame_idx: int, + use_llm_cache: bool, + return_debug: bool, + new_input_embeds: list, + ) -> dict: + """Run one LLM forward pass (native cache, vLLM, or full-history). + + Updates ``state.llm_cache`` in-place for cached paths. For the + no-cache fallback, appends to *new_input_embeds* (list, mutated). + """ + start_stt_model = time.time() + + if use_llm_cache or self.use_vllm_llm: + if self.use_vllm_llm: + ans = self.model_llm_interface( + input_emb, + request_id=request_id, + generated_tokens=state.gen_text, + current_step=current_frame_idx, + ) + else: + cache_pos = torch.tensor( + [state.llm_cache_position_offset + frame_offset], device=self.device, + ) + ans = self.model_llm_interface( + input_emb, + cache=state.llm_cache, + cache_position=cache_pos, + generated_tokens=state.gen_text, + current_step=current_frame_idx, + return_logits=return_debug, + ) + state.llm_cache = ans["cache"] + else: + new_input_embeds.append(input_emb) + full_input_embeds = torch.cat(state.input_embeds_history + new_input_embeds, dim=1) + ans = self.model_llm_interface( + full_input_embeds, + cache=None, + generated_tokens=state.gen_text, + current_step=current_frame_idx, + return_logits=return_debug, + ) + + if self._profile_timing: + torch.cuda.synchronize() + time_stt_model = time.time() - start_stt_model + logging.info(f"Time taken for stt_model: {time_stt_model:.3f}s") + + return ans + + def _run_tts_step( + self, + state: StreamingDecodeState, + current_frame_idx: int, + request_id: str, + ) -> torch.Tensor: + """Run one TTS code-generation step. + + Mutates ``state.tts_code`` and ``state.tts_past_key_values`` in-place. + Returns the new code tensor (cloned) for later batch decoding. + """ + current_subword_id = state.gen_text[:, current_frame_idx].unsqueeze(-1) + + if current_frame_idx == 0: + if self.first_context_subword_id is None: + raise RuntimeError("first_context_subword_id is not initialized. Ensure TTS warmup ran successfully.") + prev_subword_id = self.first_context_subword_id + else: + prev_subword_id = state.gen_text[:, current_frame_idx - 1].unsqueeze(-1) + + current_subword_mask = state.subword_mask[:, current_frame_idx].unsqueeze(-1) + + if self.generation_config is None: + raise RuntimeError("generation_config is not initialized. Ensure TTS warmup ran successfully.") + + start_tts_model = time.time() + inputs = { + "current_subword_id": current_subword_id, + "prev_subword_id": prev_subword_id, + "current_subword_mask": current_subword_mask, + "prev_audio_tokens": state.tts_code, + "past_key_values": state.tts_past_key_values, + "guidance_enabled": True, + "generation_config": self.generation_config, + "ignore_eos_flag_stop": True, + } + if self.use_vllm_eartts: + inputs["request_id"] = request_id + + state.tts_code, state.tts_past_key_values = self.model.tts_model.infer_codes_one_step(**inputs) + + if self._profile_timing: + torch.cuda.synchronize() + time_tts_model = time.time() - start_tts_model + logging.info(f"Time taken for tts_model: {time_tts_model:.3f}s") + + new_code = state.tts_code.clone() + + if self.model.cfg.get('inference_force_speech_silence_on_eos', None): + silence_codes = self.model.tts_model.codec_silence_tokens.view(1, 1, -1).expand(state.tts_code.shape) + state.tts_code = torch.where( + current_subword_id.unsqueeze(-1) == self.model.tts_model.text_eos_id, + silence_codes, + state.tts_code, + ) + + return new_code + + def _decode_audio( + self, + new_codes_for_decode: list[torch.Tensor], + state: StreamingDecodeState, + frame_idx: int, + num_frames_per_chunk: int, + ) -> Optional[torch.Tensor]: + """Decode accumulated TTS codes into a waveform. + + Returns the decoded audio tensor or *None* when ``decode_audio`` + is disabled or no codes were produced. + """ + if not self.decode_audio or not new_codes_for_decode: + return None + + logging.info(f"Decoding audio for {frame_idx}-th frame ({num_frames_per_chunk=})") + + start_time_decode = time.time() + with fp32_precision(), torch.no_grad(): + new_codes_tensor = torch.cat(new_codes_for_decode, dim=1) + if hasattr(self.model.tts_model, '_control_codes'): + from nemo.collections.speechlm2.models.duplex_ear_tts import replace_control_speech_codes + new_codes_tensor = replace_control_speech_codes( + new_codes_tensor, + self.model.tts_model._control_codes, + getattr(self.model.tts_model, 'codec_silence_tokens', None), + ) + new_code_len = torch.tensor( + [new_codes_tensor.shape[1]], dtype=torch.long, device=self.device, + ) + decoded_audio, _ = self.model.tts_model.audio_codec.decode( + new_codes_tensor, new_code_len, cache=state.tts_codec_cache, + ) + + if self._profile_timing: + torch.cuda.synchronize() + time_audio_codec = time.time() - start_time_decode + logging.info(f"Time taken for audio_codec: {time_audio_codec:.3f}s") + + return decoded_audio + def _run_perception( self, audio_input: torch.Tensor, @@ -876,9 +902,11 @@ def _run_perception( return_encoder_emb=True, ) - torch.cuda.synchronize() - time_perception = time.time() - start_perception - logging.info(f"Time taken for perception: {time_perception:.3f}s") + if self._profile_timing: + torch.cuda.synchronize() + time_perception = time.time() - start_perception + logging.info(f"Time taken for perception: {time_perception:.3f}s") + source_encoded = source_encoded.to(self.dtype) return source_encoded, perception_cache diff --git a/nemo/collections/speechlm2/inference/pipelines/streaming_s2s_pipeline.py b/nemo/collections/speechlm2/inference/pipelines/streaming_s2s_pipeline.py index 6f931643b1a6..11b6dfd1f835 100644 --- a/nemo/collections/speechlm2/inference/pipelines/streaming_s2s_pipeline.py +++ b/nemo/collections/speechlm2/inference/pipelines/streaming_s2s_pipeline.py @@ -50,6 +50,7 @@ def __init__(self, cfg: DictConfig, s2s_model: NemotronVoicechatInferenceWrapper # ------------------------------------------------------------------ self.s2s_model = s2s_model self.device = self.s2s_model.device + self.decode_audio = self.s2s_model.decode_audio self.collect_debug = False # ------------------------------------------------------------------ @@ -65,7 +66,6 @@ def __init__(self, cfg: DictConfig, s2s_model: NemotronVoicechatInferenceWrapper "StreamingS2SPipeline currently supports only single-stream inference " "(streaming.batch_size must be 1)." ) - # ------------------------------------------------------------------ # Chunk & buffer sizes @@ -136,7 +136,7 @@ def __init__(self, cfg: DictConfig, s2s_model: NemotronVoicechatInferenceWrapper super().__init__() - # -------------------------------- ---------------------------------- + # ------------------------------------------------------------------ # State helpers # ------------------------------------------------------------------ def create_state(self) -> S2SStreamingState: @@ -177,7 +177,7 @@ def log_output(self, frames: List[Frame], audio_wave: Tensor, ready_feats: List[ state.append_step_output(sample_audio, text=piece, asr_text=asr_piece) - def inner_generate_step(self, frames: List[Frame], buffers: List[Tensor], left_paddings: List[int], ready_feats: List[bool]): + def inner_generate_step(self, frames: List[Frame], buffers: List[Tensor], ready_feats: List[bool]): """Generate speech for chunks in *batch* using a shared ContextManager.""" if len(frames) == 0: return @@ -217,11 +217,6 @@ def inner_generate_step(self, frames: List[Frame], buffers: List[Tensor], left_p if audio_buffer.dim() == 1: audio_buffer = audio_buffer.unsqueeze(0) audio_buffer = audio_buffer.to(self.s2s_model.device, dtype=self.s2s_model.dtype) - - # Trim the buffer to exclude left padding (zeros at the beginning before buffer is filled) - left_pad = left_paddings[0] - if left_pad > 0: - audio_buffer = audio_buffer[:, left_pad:] result = self.s2s_model.infer_one_step( audio_input=audio_buffer, @@ -270,7 +265,7 @@ def inner_generate_step(self, frames: List[Frame], buffers: List[Tensor], left_p def prefill_for_new_stream(self, stream_id: int, system_prompt: str | None = None) -> bool: """Prepare the pipeline for a new stream by resetting context and prefilling the system prompt. - This is the public API for prefill-only calls (e.g. from the Triton backend) + This is the public API for prefill-only calls (e.g. from a server backend) that need to initialize TTS speaker embeddings and/or inject a system prompt into the LLM KV cache *without* processing any audio. @@ -349,7 +344,7 @@ def generate_step(self, frames: List[Frame]): prompt in ``options``, this is treated as a **prefill-only** request: the context manager and system prompt are initialized but no audio inference runs. This is the unified protocol used by both the CLI - (``run()``) and the Triton backend. + (``run()``) and server backends. """ # Detect prefill-only frame: is_first + zero-length audio if (len(frames) == 1 @@ -363,10 +358,13 @@ def generate_step(self, frames: List[Frame]): return buffers, left_paddings = self.bufferer.update(frames) + # This is a workaround for the fact that the audio buffer does left + # padding, but the rest of the code requires no padding at all. + buffers = [b[lp:] for b, lp in zip(buffers, left_paddings)] ready_feats = [True] * len(frames) with torch.no_grad(), torch.inference_mode(): - self.inner_generate_step(frames, buffers, left_paddings, ready_feats) + self.inner_generate_step(frames, buffers, ready_feats) # ------------------------------------------------------------------ # Finalization helpers @@ -377,66 +375,63 @@ def _finalize_and_save_finished_streams( audio_filepaths: List[str], saved_paths_by_stream: dict[int, str], ) -> None: - """Finalize any streams that ended in this batch and save their audio.""" + """Finalize any streams that ended in this batch and save their outputs.""" for frame in frames: if frame.is_last: stream_id = frame.stream_id state = self.get_or_create_state(stream_id) - # Flush remaining buffered samples and assemble waveform if hasattr(state, "finalize"): state.finalize() - # Concatenate emitted chunks and squeeze (B=1,C=1) to mono waveform - generated_audio = state.audio_buffer - # Ensure 1D mono waveform and float32 dtype for soundfile - if generated_audio.dim() == 3 and generated_audio.size(0) == 1 and generated_audio.size(1) == 1: - generated_audio = generated_audio.squeeze(0).squeeze(0) - elif generated_audio.dim() == 2 and generated_audio.size(0) == 1: - generated_audio = generated_audio.squeeze(0) - generated_audio = generated_audio.to(torch.float32) - - # Build output paths in subdirectories under output_dir + in_path = audio_filepaths[stream_id] base = os.path.splitext(os.path.basename(in_path))[0] - - wav_dir = os.path.join(self.output_dir, "wav") - stereo_dir = os.path.join(self.output_dir, "stereo") txt_dir = os.path.join(self.output_dir, "txt") - os.makedirs(wav_dir, exist_ok=True) - os.makedirs(stereo_dir, exist_ok=True) os.makedirs(txt_dir, exist_ok=True) - out_path = os.path.join(wav_dir, f"{base}.wav") - - # Write audio to disk - if generated_audio.numel() > 0: - sf.write(out_path, generated_audio.detach().cpu().numpy(), self.output_sample_rate) - - # Also save a stereo file with input (ch0) and output (ch1) - # Load input with librosa (handles mono conversion and resampling) - input_np, _ = librosa.load(in_path, sr=self.output_sample_rate, mono=True) - input_audio = torch.from_numpy(input_np).to(torch.float32) - gen_cpu = generated_audio.detach().cpu().to(input_audio.dtype) - - # Prepend silence to output channel to account for - # the one-chunk processing delay: the server can't - # produce output until it has received a full input chunk. - delay_samples = int(self.chunk_size_in_secs * self.output_sample_rate) - silence = torch.zeros(delay_samples, dtype=gen_cpu.dtype) - gen_cpu = torch.cat([silence, gen_cpu], dim=-1) - - gen_len = int(gen_cpu.shape[-1]) - in_len = int(input_audio.shape[-1]) - max_len = max(gen_len, in_len) - if in_len < max_len: - input_audio = torch.cat([input_audio, torch.zeros(max_len - in_len, dtype=input_audio.dtype)], dim=-1) - if gen_len < max_len: - gen_cpu = torch.cat([gen_cpu, torch.zeros(max_len - gen_len, dtype=gen_cpu.dtype)], dim=-1) - stereo = torch.stack([input_audio, gen_cpu], dim=0).transpose(0, 1) - stereo_path = os.path.join(stereo_dir, f"{base}_input_output.wav") - sf.write(stereo_path, stereo.detach().cpu().numpy(), self.output_sample_rate) - - # Save accumulated text + out_path = None + if self.decode_audio: + # Squeeze (B=1,C=1) to 1D mono waveform for soundfile + generated_audio = state.audio_buffer + if generated_audio.dim() == 3 and generated_audio.size(0) == 1 and generated_audio.size(1) == 1: + generated_audio = generated_audio.squeeze(0).squeeze(0) + elif generated_audio.dim() == 2 and generated_audio.size(0) == 1: + generated_audio = generated_audio.squeeze(0) + generated_audio = generated_audio.to(torch.float32) + + wav_dir = os.path.join(self.output_dir, "wav") + stereo_dir = os.path.join(self.output_dir, "stereo") + os.makedirs(wav_dir, exist_ok=True) + os.makedirs(stereo_dir, exist_ok=True) + + out_path = os.path.join(wav_dir, f"{base}.wav") + if generated_audio.numel() > 0: + sf.write(out_path, generated_audio.detach().cpu().numpy(), self.output_sample_rate) + + # Save a stereo file with input (ch0) and output (ch1) + input_np, _ = librosa.load(in_path, sr=self.output_sample_rate, mono=True) + input_audio = torch.from_numpy(input_np).to(torch.float32) + gen_cpu = generated_audio.detach().cpu().to(input_audio.dtype) + + # Prepend silence to output channel to account for the one-chunk + # processing delay: the server can't produce output until it has + # received a full input chunk. + delay_samples = int(self.chunk_size_in_secs * self.output_sample_rate) + silence = torch.zeros(delay_samples, dtype=gen_cpu.dtype) + gen_cpu = torch.cat([silence, gen_cpu], dim=-1) + + # Pad the shorter channel so both have equal length + gen_len = int(gen_cpu.shape[-1]) + in_len = int(input_audio.shape[-1]) + max_len = max(gen_len, in_len) + if in_len < max_len: + input_audio = torch.cat([input_audio, torch.zeros(max_len - in_len, dtype=input_audio.dtype)], dim=-1) + if gen_len < max_len: + gen_cpu = torch.cat([gen_cpu, torch.zeros(max_len - gen_len, dtype=gen_cpu.dtype)], dim=-1) + stereo = torch.stack([input_audio, gen_cpu], dim=0).transpose(0, 1) + stereo_path = os.path.join(stereo_dir, f"{base}_input_output.wav") + sf.write(stereo_path, stereo.detach().cpu().numpy(), self.output_sample_rate) + text_out = state.output_text_str if isinstance(text_out, str): try: @@ -445,7 +440,6 @@ def _finalize_and_save_finished_streams( except Exception: pass - # Save accumulated ASR text asr_text_out = state.output_asr_text_str if isinstance(asr_text_out, str) and asr_text_out: try: @@ -455,8 +449,8 @@ def _finalize_and_save_finished_streams( pass saved_paths_by_stream[stream_id] = out_path - - # Keep state until outputs are assembled; will be cleared on close_session + # Keep state in _state_pool until _build_pipeline_output; + # it will be cleared on close_session(). # ------------------------------------------------------------------ @@ -499,7 +493,6 @@ def run( batch_size=self.batch_size, pad_last_frame=True, ) - streamer.set_audio_filepaths(audio_filepaths, options) streamer.set_progress_bar(progress_bar) @@ -512,68 +505,112 @@ def run( self.open_session() for frames in streamer: - # Unified prefill protocol: if the first frame of a new stream - # carries a system prompt, emit a zero-length prefill frame first. - if (len(frames) == 1 - and frames[0].is_first - and frames[0].options is not None - and hasattr(frames[0].options, "system_prompt") - and frames[0].options.system_prompt): - prefill_frame = Frame( - samples=torch.empty(0), - stream_id=frames[0].stream_id, - is_first=True, - is_last=False, - options=frames[0].options, - ) - self.generate_step([prefill_frame]) - - # If padding is configured, intercept last frames so the - # bufferer/context stay alive for the silence-padding phase. - # Padding is generated immediately (same iteration) to avoid - # the next stream's setup destroying this stream's context. - pad_targets: dict[int, float] = {} - if self.pad_audio_to_sec or self.pad_silence_ratio or self.pad_audio_by_sec: - processed_frames = [] - for frame in frames: - if frame.is_last: - elapsed = streamer.elapsed_durations[frame.stream_id] - remaining = self._padding_remaining_secs(elapsed) - if remaining > 0: - processed_frames.append(Frame( - samples=frame.samples, - stream_id=frame.stream_id, - is_first=frame.is_first, - is_last=False, - length=frame.length, - options=frame.options, - )) - pad_targets[frame.stream_id] = remaining - continue - processed_frames.append(frame) - frames = processed_frames - + self._maybe_prefill(frames) + frames, pad_targets = self._apply_padding(frames, streamer) self.generate_step(frames) self._finalize_and_save_finished_streams(frames, audio_filepaths, saved_paths_by_stream) + self._generate_silence_padding(pad_targets, chunk_samples, audio_filepaths, saved_paths_by_stream) + + output = self._build_pipeline_output(audio_filepaths, saved_paths_by_stream) + self.close_session() + return output + + # ------------------------------------------------------------------ + # run() helpers + # ------------------------------------------------------------------ + + def _maybe_prefill(self, frames: List[Frame]) -> None: + """If the first frame of a new stream carries a system prompt, emit a + zero-length prefill frame through ``generate_step`` before inference + begins. This is the unified prefill protocol used by both ``run()`` + and server backends. + """ + if (len(frames) == 1 + and frames[0].is_first + and frames[0].options is not None + and hasattr(frames[0].options, "system_prompt") + and frames[0].options.system_prompt): + prefill_frame = Frame( + samples=torch.empty(0), + stream_id=frames[0].stream_id, + is_first=True, + is_last=False, + options=frames[0].options, + ) + self.generate_step([prefill_frame]) + + def _apply_padding( + self, + frames: List[Frame], + streamer: ContinuousBatchedFrameStreamer, + ) -> tuple[List[Frame], dict[int, float]]: + """If padding is configured, intercept last frames so the bufferer and + context stay alive for the silence-padding phase. Returns the + (possibly modified) frames and a dict mapping ``stream_id`` to the + remaining seconds of silence to append. + """ + pad_targets: dict[int, float] = {} + if not (self.pad_audio_to_sec or self.pad_silence_ratio or self.pad_audio_by_sec): + return frames, pad_targets - # Generate silence padding before the next iteration adds a new stream - for stream_id, remaining_secs in pad_targets.items(): - num_pad_frames = max(1, round(remaining_secs / self.chunk_size_in_secs)) - for i in range(num_pad_frames): - is_last = (i == num_pad_frames - 1) - silence_frame = Frame( - samples=torch.zeros(chunk_samples), - stream_id=stream_id, - is_first=False, - is_last=is_last, - length=chunk_samples, + processed_frames = [] + for frame in frames: + if frame.is_last: + elapsed = streamer.elapsed_durations[frame.stream_id] + remaining = self._padding_remaining_secs(elapsed) + if remaining > 0: + processed_frames.append(Frame( + samples=frame.samples, + stream_id=frame.stream_id, + is_first=frame.is_first, + is_last=False, + length=frame.length, + options=frame.options, + )) + pad_targets[frame.stream_id] = remaining + continue + processed_frames.append(frame) + return processed_frames, pad_targets + + def _generate_silence_padding( + self, + pad_targets: dict[int, float], + chunk_samples: int, + audio_filepaths: List[str], + saved_paths_by_stream: dict[int, str], + ) -> None: + """Generate silence-padding frames for streams that need them. + + Must run in the same iteration as the real last frame to avoid the next + stream's setup destroying this stream's context. + """ + for stream_id, remaining_secs in pad_targets.items(): + num_pad_frames = max(1, round(remaining_secs / self.chunk_size_in_secs)) + for i in range(num_pad_frames): + is_last = (i == num_pad_frames - 1) + silence_frame = Frame( + samples=torch.zeros(chunk_samples), + stream_id=stream_id, + is_first=False, + is_last=is_last, + length=chunk_samples, + ) + self.generate_step([silence_frame]) + if is_last: + self._finalize_and_save_finished_streams( + [silence_frame], audio_filepaths, saved_paths_by_stream ) - self.generate_step([silence_frame]) - if is_last: - self._finalize_and_save_finished_streams( - [silence_frame], audio_filepaths, saved_paths_by_stream - ) - # Build outputs before closing the session + + def _build_pipeline_output( + self, + audio_filepaths: List[str], + saved_paths_by_stream: dict[int, str], + ) -> PipelineOutput: + """Assemble final ``PipelineOutput`` from accumulated per-stream state. + + Must be called *before* ``close_session()`` since it reads from the + state pool. + """ texts = [] words = [] asr_texts = [] @@ -638,8 +675,6 @@ def run( state = self.get_or_create_state(idx) debug_data.append(getattr(state, "debug_steps", [])) - self.close_session() - return PipelineOutput( texts=texts, words=words, diff --git a/nemo/collections/speechlm2/inference/streaming/framing/s2s_request_options.py b/nemo/collections/speechlm2/inference/streaming/framing/s2s_request_options.py index 4bbb222b1149..88012f41c941 100644 --- a/nemo/collections/speechlm2/inference/streaming/framing/s2s_request_options.py +++ b/nemo/collections/speechlm2/inference/streaming/framing/s2s_request_options.py @@ -21,7 +21,7 @@ class S2SRequestOptions: Attached to the first ``Frame`` of each stream via the ``options`` field so that the pipeline can read per-stream configuration at the - start of every new audio file / Triton sequence. + start of every new audio stream. """ system_prompt: str | None = None diff --git a/tests/collections/speechlm2/test_offline_incremental_parity.py b/tests/collections/speechlm2/test_offline_incremental_parity.py index 5c6cd66e148e..dd80975c85a6 100644 --- a/tests/collections/speechlm2/test_offline_incremental_parity.py +++ b/tests/collections/speechlm2/test_offline_incremental_parity.py @@ -37,7 +37,6 @@ from __future__ import annotations -import json import math import os import time @@ -60,57 +59,38 @@ from nemo.collections.speechlm2.inference.streaming.framing.s2s_request_options import S2SRequestOptions from nemo.collections.speechlm2.models import NemotronVoiceChat +_CONF_YAML = os.path.join( + os.path.dirname(__file__), + "../../../examples/speechlm2/nemo_inference_pipelines/conf/s2s_streaming.yaml", +) + # --------------------------------------------------------------------------- -# Comparison helpers +# Helpers # --------------------------------------------------------------------------- -def compare_logits( - a: torch.Tensor, - b: torch.Tensor, -) -> dict[str, Any]: - """Prefix-aware comparison of two ``(B, T, V)`` logit tensors.""" - a = a.detach().cpu().float() - b = b.detach().cpu().float() - prefix_len = min(a.shape[1], b.shape[1]) - if prefix_len == 0: - return {"prefix_len": 0, "match": True, "max_abs_diff": 0.0, "mean_abs_diff": 0.0} - ap, bp = a[:, :prefix_len], b[:, :prefix_len] - diff = (ap - bp).abs() - reduce_dims = tuple(i for i in range(diff.dim()) if i != 1) - per_step_max = diff.amax(dim=reduce_dims) if reduce_dims else diff - first_diff_step = None - nonzero = (per_step_max > 0).nonzero(as_tuple=False) - if nonzero.numel() > 0: - first_diff_step = int(nonzero[0].item()) - return { - "prefix_len": prefix_len, - "match": bool(torch.equal(ap, bp)), - "max_abs_diff": float(diff.max()), - "mean_abs_diff": float(diff.mean()), - "first_diff_step": first_diff_step, - } - - -def compare_tokens( +def _compare_tensors( a: torch.Tensor | None, b: torch.Tensor | None, ) -> dict[str, Any]: - """Compare two ``(B, T)`` token tensors with prefix-aware logic.""" + """Prefix-aware comparison of two tensors (tokens or logits, any shape).""" if a is None or b is None: return {"match": None, "note": "one or both tensors missing"} - a, b = a.detach().cpu(), b.detach().cpu() - prefix_len = min(a.shape[-1], b.shape[-1]) - if prefix_len == 0: + a, b = a.detach().cpu().float(), b.detach().cpu().float() + T = min(a.shape[1], b.shape[1]) + if T == 0: return {"prefix_len": 0, "match": True} - ap, bp = a[..., :prefix_len], b[..., :prefix_len] - match = bool(torch.equal(ap, bp)) - first_diff_index = None + ap, bp = a[:, :T], b[:, :T] + diff = (ap - bp).abs() + match = bool(diff.max() == 0) + result: dict[str, Any] = {"prefix_len": T, "match": match, "max_abs_diff": float(diff.max())} if not match: - diffs = (ap != bp).flatten().nonzero(as_tuple=False) - if diffs.numel() > 0: - first_diff_index = int(diffs[0].item()) - return {"prefix_len": prefix_len, "match": match, "first_diff_index": first_diff_index} + reduce = tuple(i for i in range(diff.dim()) if i != 1) + per_step = diff.amax(dim=reduce) if reduce else diff.squeeze() + nonzero = (per_step > 0).nonzero(as_tuple=False) + if nonzero.numel(): + result["first_diff_step"] = int(nonzero[0].item()) + return result def _merge_incremental_debug_steps(steps: list[dict[str, Any]]) -> dict[str, Any]: @@ -125,43 +105,36 @@ def _merge_incremental_debug_steps(steps: list[dict[str, Any]]) -> dict[str, Any } +def _load_and_pad_audio( + audio_path: str, device: torch.device, dtype: torch.dtype, +) -> tuple[torch.Tensor, torch.Tensor]: + """Load audio, zero-pad to a whole number of 80 ms frames, return ``(audio, lens)``.""" + import librosa + + audio_np, _ = librosa.load(audio_path, sr=SAMPLE_RATE) + padded_len = math.ceil(len(audio_np) / FRAME_SIZE_SAMPLES) * FRAME_SIZE_SAMPLES + audio = torch.nn.functional.pad( + torch.tensor(audio_np, device=device, dtype=dtype).unsqueeze(0), + (0, max(0, padded_len - len(audio_np))), + ) + return audio, torch.tensor([audio.shape[1]], device=device, dtype=torch.long) + + def run_parity_check( pipeline: StreamingS2SPipeline, audio_path: str, *, system_prompt: str | None = None, ) -> dict[str, Any]: - """Run offline and pipeline-based incremental inference, return comparison. - - The offline path calls :meth:`NemotronVoiceChat.offline_inference` - (the same code path used by ``nemotron_voicechat_eval.py``). - The incremental path runs through :meth:`StreamingS2SPipeline.run`, - which is the real production code path (buffering, framing, prefill). + """Run offline and incremental inference on the same audio, return comparison. Only STT-level tokens and logits are compared; TTS is irrelevant for the core parity invariant. """ - import librosa - wrapper = pipeline.s2s_model - t0 = time.time() - logging.info("=" * 60) - logging.info("PARITY CHECK -- offline vs. incremental (pipeline)") - logging.info("=" * 60) - logging.info(f"Audio : {audio_path}") - logging.info(f"Prompt: {system_prompt or '(none)'}") + audio, audio_lens = _load_and_pad_audio(audio_path, wrapper.device, wrapper.dtype) - logging.info("Loading audio ...") - audio_np, _ = librosa.load(audio_path, sr=SAMPLE_RATE) - total_frames = math.ceil(len(audio_np) / FRAME_SIZE_SAMPLES) - padded_len = total_frames * FRAME_SIZE_SAMPLES - audio = torch.tensor(audio_np, device=wrapper.device, dtype=wrapper.dtype).unsqueeze(0) - if audio.shape[1] < padded_len: - audio = torch.nn.functional.pad(audio, (0, padded_len - audio.shape[1])) - audio_lens = torch.tensor([audio.shape[1]], device=wrapper.device, dtype=torch.long) - logging.info(f" {len(audio_np)} samples (padded to {audio.shape[1]}), {total_frames} frames, {len(audio_np)/SAMPLE_RATE:.2f}s") - - # -- prompt tokens for offline -- + # Prompt tokens for the offline path prompt_tokens = prompt_token_lens = None if system_prompt: tok = wrapper.tokenizer @@ -169,20 +142,19 @@ def run_parity_check( prompt_tokens = torch.tensor(ids, device=wrapper.device, dtype=torch.long).unsqueeze(0) prompt_token_lens = torch.tensor([len(ids)], device=wrapper.device, dtype=torch.long) - # -- ensure speaker info for offline_inference's mandatory TTS init -- + # offline_inference requires speaker info for TTS init if wrapper.speaker_name is not None: OmegaConf.update(wrapper.model.cfg, "inference_speaker_name", wrapper.speaker_name, force_add=True) elif wrapper.speaker_reference: OmegaConf.update(wrapper.model.cfg, "inference_speaker_reference", wrapper.speaker_reference, force_add=True) speaker_kw: dict[str, Any] = {} - cfg = wrapper.model.cfg - if not cfg.get("inference_speaker_name", None) and not cfg.get("inference_speaker_reference", None): + if not wrapper.model.cfg.get("inference_speaker_name") and not wrapper.model.cfg.get("inference_speaker_reference"): speaker_kw["speaker_audio"] = torch.randn(1, 22050, device=wrapper.device) speaker_kw["speaker_audio_lens"] = torch.tensor([22050], device=wrapper.device, dtype=torch.long) - # ---- Offline path (eval.py code path) ---- - logging.info("Running offline_inference (with return_logits=True) ...") - t_offline = time.time() + # -- Offline -- + logging.info("Running offline_inference ...") + t0 = time.time() offline = wrapper.model.offline_inference( input_signal=audio, input_signal_lens=audio_lens, @@ -192,66 +164,36 @@ def run_parity_check( return_logits=True, **speaker_kw, ) - logging.info(f" offline_inference done in {time.time() - t_offline:.2f}s") + logging.info(f" offline done in {time.time() - t0:.2f}s") - # ---- Incremental path (pipeline.run) ---- + # -- Incremental -- logging.info("Running incremental inference (pipeline.run) ...") - t_incr = time.time() + t0 = time.time() pipeline.collect_debug = True pipeline_output = pipeline.run( [audio_path], options=[S2SRequestOptions(system_prompt=system_prompt)], ) - logging.info(f" pipeline.run done in {time.time() - t_incr:.2f}s") + logging.info(f" incremental done in {time.time() - t0:.2f}s") - # Extract tokens and logits from pipeline output inc_tokens = pipeline_output.token_texts[0] if pipeline_output.token_texts else None inc_asr_tokens = pipeline_output.token_asr_texts[0] if pipeline_output.token_asr_texts else None + inc_debug = _merge_incremental_debug_steps( + pipeline_output.debug_data[0] if pipeline_output.debug_data and pipeline_output.debug_data[0] else [] + ) - incremental_debug = {} - if pipeline_output.debug_data and pipeline_output.debug_data[0]: - incremental_debug = _merge_incremental_debug_steps(pipeline_output.debug_data[0]) - - # ---- Build comparison report ---- - logging.info("Comparing outputs ...") + # -- Compare -- report: dict[str, Any] = { - "audio_path": audio_path, - "total_frames": total_frames, - "system_prompt": system_prompt, - "offline_text": offline.get("text", [""])[0], - "token_comparison": compare_tokens(offline.get("tokens_text"), inc_tokens), - "asr_token_comparison": compare_tokens(offline.get("tokens_text_src"), inc_asr_tokens), + "token_comparison": _compare_tensors(offline.get("tokens_text"), inc_tokens), + "asr_token_comparison": _compare_tensors(offline.get("tokens_text_src"), inc_asr_tokens), } - - off_tl = offline.get("text_logits") - inc_tl = incremental_debug.get("text_logits") - if off_tl is not None and inc_tl is not None: - report["text_logit_comparison"] = compare_logits(off_tl, inc_tl) - - off_al = offline.get("asr_logits") - inc_al = incremental_debug.get("asr_logits") - if off_al is not None and inc_al is not None: - report["asr_logit_comparison"] = compare_logits(off_al, inc_al) - - # ---- Summary ---- - elapsed = time.time() - t0 - logging.info("-" * 60) - logging.info("PARITY CHECK SUMMARY") - logging.info("-" * 60) - tc = report.get("token_comparison", {}) - ac = report.get("asr_token_comparison", {}) - logging.info(f" Text tokens match : {tc.get('match')} (prefix_len={tc.get('prefix_len')})") - logging.info(f" ASR tokens match : {ac.get('match')} (prefix_len={ac.get('prefix_len')})") - for tag, key in [("Text logits", "text_logit_comparison"), ("ASR logits", "asr_logit_comparison")]: - lc = report.get(key, {}) - if lc: - logging.info( - f" {tag:16s}: match={lc.get('match')}, " - f"max_abs_diff={lc.get('max_abs_diff', 0):.6e}, " - f"mean_abs_diff={lc.get('mean_abs_diff', 0):.6e}" - ) - logging.info(f" Total time: {elapsed:.2f}s") - logging.info("=" * 60) + for key, off_key, inc_key in [ + ("text_logit_comparison", "text_logits", "text_logits"), + ("asr_logit_comparison", "asr_logits", "asr_logits"), + ]: + off_t, inc_t = offline.get(off_key), inc_debug.get(inc_key) + if off_t is not None and inc_t is not None: + report[key] = _compare_tensors(off_t, inc_t) return report @@ -262,45 +204,18 @@ def assert_parity( strict: bool = True, atol: float = 0.0, ) -> None: - """Raise :class:`AssertionError` if parity checks in *report* fail. - - Args: - report: Dict returned by :func:`run_parity_check`. - strict: When ``True``, also assert logit-level match. - atol: Absolute tolerance for logit comparison (only used when *strict*). - """ + """Raise ``AssertionError`` if parity checks in *report* fail.""" failures: list[str] = [] - - tc = report.get("token_comparison", {}) - if tc.get("match") is False: - failures.append(f"text tokens diverge at index {tc.get('first_diff_index')}") - - ac = report.get("asr_token_comparison", {}) - if ac.get("match") is False: - failures.append(f"ASR tokens diverge at index {ac.get('first_diff_index')}") - + for key in ("token_comparison", "asr_token_comparison"): + c = report.get(key, {}) + if c.get("match") is False: + failures.append(f"{key}: diverge at step {c.get('first_diff_step')}") if strict: for key in ("text_logit_comparison", "asr_logit_comparison"): - lc = report.get(key, {}) - if lc and lc.get("match") is False: - max_diff = lc.get("max_abs_diff", float("inf")) - if max_diff > atol: - failures.append( - f"{key}: max_abs_diff={max_diff:.6e} > atol={atol:.6e}, " - f"first diff at step {lc.get('first_diff_step')}" - ) - - if failures: - detail = json.dumps( - {k: v for k, v in report.items() if k != "debug"}, - indent=2, - default=str, - ) - raise AssertionError( - "Offline/incremental parity failed:\n - " - + "\n - ".join(failures) - + f"\n\nFull report:\n{detail}" - ) + c = report.get(key, {}) + if c.get("match") is False and c.get("max_abs_diff", 0) > atol: + failures.append(f"{key}: max_abs_diff={c['max_abs_diff']:.2e} > atol={atol:.2e}") + assert not failures, "Parity failed:\n " + "\n ".join(failures) # --------------------------------------------------------------------------- @@ -483,47 +398,44 @@ def _build_parity_pipeline( *, speaker_name: str | None = None, ) -> StreamingS2SPipeline: - """Build a :class:`StreamingS2SPipeline` configured for strict parity testing.""" + """Build a :class:`StreamingS2SPipeline` configured for strict parity testing. + + Loads ``s2s_streaming.yaml`` as the base config and applies + parity-specific overrides (deterministic, float32, no caches, greedy). + """ import librosa audio_np, _ = librosa.load(audio_path, sr=SAMPLE_RATE) total_frames = math.ceil(len(audio_np) / FRAME_SIZE_SAMPLES) chunk_secs = total_frames * FRAME_SIZE_SAMPLES / SAMPLE_RATE - speaker_kw = {} + cfg = OmegaConf.load(_CONF_YAML) + overrides = { + "audio_file": audio_path, + "output_dir": output_dir, + "s2s": { + "model_path": model_path, + "engine_type": "native", + "compute_dtype": "float32", + "deterministic": True, + "decode_audio": False, + "use_perception_cache": False, + "use_perception_cudagraph": False, + "use_llm_cache": False, + "system_prompt": None, + "top_p": 1.0, + "repetition_penalty": 1.0, + "temperature": 0.0, + }, + "streaming": { + "chunk_size_in_secs": chunk_secs, + "buffer_size_in_secs": max(71 * 0.08, chunk_secs), + }, + } if speaker_name: - speaker_kw["speaker_name"] = speaker_name - - pipeline_cfg = OmegaConf.create( - { - "output_dir": output_dir, - "s2s": { - "model_path": model_path, - **speaker_kw, - "compute_dtype": "float32", - "engine_type": "native", - "deterministic": True, - "use_perception_cache": False, - "use_perception_cudagraph": False, - "use_llm_cache": False, - "top_p": 1.0, - "repetition_penalty": 1.0, - "temperature": 0.0, - "decode_audio": False, - }, - "streaming": { - "input_sample_rate": SAMPLE_RATE, - "output_sample_rate": 22050, - "batch_size": 1, - "att_context_size": [70, 0], - "chunk_size_in_secs": chunk_secs, - "buffer_size_in_secs": max(71 * 0.08, chunk_secs), - "request_type": "frame", - "max_len": 8192, - }, - } - ) - return S2SPipelineBuilder.build_pipeline(pipeline_cfg) + overrides["s2s"]["speaker_name"] = speaker_name + cfg = OmegaConf.merge(cfg, OmegaConf.create(overrides)) + return S2SPipelineBuilder.build_pipeline(cfg) # --------------------------------------------------------------------------- From 277511b2215e4b4a8a585bbab4ee9808e6c950bc Mon Sep 17 00:00:00 2001 From: Elena Rastorgueva Date: Wed, 25 Mar 2026 04:35:13 +0000 Subject: [PATCH 18/40] in test: use existing audio file, allow system prompt, specify params for parity Signed-off-by: Elena Rastorgueva --- .../test_offline_incremental_parity.py | 64 +++++++++++-------- 1 file changed, 39 insertions(+), 25 deletions(-) diff --git a/tests/collections/speechlm2/test_offline_incremental_parity.py b/tests/collections/speechlm2/test_offline_incremental_parity.py index dd80975c85a6..612bbf98769b 100644 --- a/tests/collections/speechlm2/test_offline_incremental_parity.py +++ b/tests/collections/speechlm2/test_offline_incremental_parity.py @@ -18,12 +18,11 @@ tiny model with random weights (no checkpoint needed, requires only a GPU). ``test_parity_real_checkpoint`` does the same on a real exported checkpoint -and is skipped unless -the following environment variables point to a real exported checkpoint:: +and is skipped unless ``PARITY_CHECKPOINT_PATH`` is set:: PARITY_CHECKPOINT_PATH=/path/to/exported/checkpoint - PARITY_AUDIO_PATH=/path/to/test.wav - PARITY_SPEAKER_NAME= # optional + PARITY_AUDIO_PATH=/path/to/test.wav # optional, defaults to force_align_test.mp3 + PARITY_SPEAKER_NAME= # optional Run from the NeMo repo root (use ``-s`` to see live progress):: @@ -31,7 +30,7 @@ CUDA_VISIBLE_DEVICES=0 pytest tests/collections/speechlm2/test_offline_incremental_parity.py -v -s # include integration test - PARITY_CHECKPOINT_PATH=... PARITY_AUDIO_PATH=... \\ + PARITY_CHECKPOINT_PATH=... \\ CUDA_VISIBLE_DEVICES=0 pytest tests/collections/speechlm2/test_offline_incremental_parity.py -v -s """ @@ -63,6 +62,11 @@ os.path.dirname(__file__), "../../../examples/speechlm2/nemo_inference_pipelines/conf/s2s_streaming.yaml", ) +_FORCE_ALIGN_AUDIO = os.path.join( + os.path.dirname(__file__), + "test_data", + "force_align_test.mp3", +) # --------------------------------------------------------------------------- # Helpers @@ -183,6 +187,11 @@ def run_parity_check( ) # -- Compare -- + # offline_inference returns logits for ALL positions (including prompt), + # while the incremental path only produces logits for audio positions. + # Trim the prompt prefix from offline logits so the two are aligned. + prompt_len = prompt_tokens.shape[1] if prompt_tokens is not None else 0 + report: dict[str, Any] = { "token_comparison": _compare_tensors(offline.get("tokens_text"), inc_tokens), "asr_token_comparison": _compare_tensors(offline.get("tokens_text_src"), inc_asr_tokens), @@ -192,6 +201,8 @@ def run_parity_check( ("asr_logit_comparison", "asr_logits", "asr_logits"), ]: off_t, inc_t = offline.get(off_key), inc_debug.get(inc_key) + if off_t is not None and prompt_len > 0: + off_t = off_t[:, prompt_len:] if off_t is not None and inc_t is not None: report[key] = _compare_tensors(off_t, inc_t) @@ -391,12 +402,16 @@ def _tiny_voicechat_config() -> dict: } +_MOCK_SYSTEM_PROMPT = "This is a mock prompt for the test" + + def _build_parity_pipeline( model_path: str, audio_path: str, output_dir: str, *, speaker_name: str | None = None, + system_prompt: str | None = _MOCK_SYSTEM_PROMPT, ) -> StreamingS2SPipeline: """Build a :class:`StreamingS2SPipeline` configured for strict parity testing. @@ -415,21 +430,21 @@ def _build_parity_pipeline( "output_dir": output_dir, "s2s": { "model_path": model_path, - "engine_type": "native", - "compute_dtype": "float32", - "deterministic": True, - "decode_audio": False, - "use_perception_cache": False, - "use_perception_cudagraph": False, - "use_llm_cache": False, - "system_prompt": None, - "top_p": 1.0, - "repetition_penalty": 1.0, - "temperature": 0.0, + "engine_type": "native", # offline model can only be run with "native" - no vllm support + "compute_dtype": "float32", # online code would only cast some layers to "compute_dtype" => let's keep everything in float32 for parity + "deterministic": False, # "deterministic" doesn't seem to be necessary for results to match, so let's go without it + "decode_audio": False, # parity test is only for comparing text outputs, not audio + "use_perception_cache": False, # results are slightly different with & without cache. offline does not use perception cache + "use_perception_cudagraph": False, # because not using perception cache + "use_llm_cache": False, # llm cache on/off will affect results. Offline code does not currently support llm cache. + "system_prompt": system_prompt, # use a system prompt to make test more "difficult" + "top_p": 1.0, # greedy decoding because offline decoding does not support sampling parameters + "repetition_penalty": 1.0, # greedy decoding because offline decoding does not support sampling parameters + "temperature": 1.0, # greedy decoding because offline decoding does not support sampling parameters }, "streaming": { "chunk_size_in_secs": chunk_secs, - "buffer_size_in_secs": max(71 * 0.08, chunk_secs), + "buffer_size_in_secs": max(71 * 0.08, chunk_secs), # buffer size needs to be equal or longer than the audio input to guarantee parity }, } if speaker_name: @@ -469,7 +484,7 @@ def test_parity_tiny_model(tmp_path): torch.cuda.empty_cache() pipeline = _build_parity_pipeline(model_dir, audio_path, str(tmp_path / "output")) - report = run_parity_check(pipeline, audio_path) + report = run_parity_check(pipeline, audio_path, system_prompt=_MOCK_SYSTEM_PROMPT) assert_parity(report, strict=True, atol=0.0) @@ -480,11 +495,10 @@ def test_parity_tiny_model(tmp_path): def _real_checkpoint_available() -> bool: path = os.environ.get("PARITY_CHECKPOINT_PATH", "") - audio = os.environ.get("PARITY_AUDIO_PATH", "") - return bool(path) and os.path.isdir(path) and bool(audio) and os.path.isfile(audio) + return bool(path) and os.path.isdir(path) -@pytest.mark.skipif(not _real_checkpoint_available(), reason="set PARITY_CHECKPOINT_PATH and PARITY_AUDIO_PATH") +@pytest.mark.skipif(not _real_checkpoint_available(), reason="set PARITY_CHECKPOINT_PATH") @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU") def test_parity_real_checkpoint(): """Parity check using a real exported checkpoint. @@ -492,17 +506,17 @@ def test_parity_real_checkpoint(): Configure via environment variables:: PARITY_CHECKPOINT_PATH=/path/to/exported/checkpoint - PARITY_AUDIO_PATH=/path/to/test.wav - PARITY_SPEAKER_NAME= # optional + PARITY_AUDIO_PATH=/path/to/test.wav # optional, defaults to force_align_test.mp3 + PARITY_SPEAKER_NAME= # optional """ import tempfile ckpt = os.environ["PARITY_CHECKPOINT_PATH"] - audio = os.environ["PARITY_AUDIO_PATH"] + audio = os.environ.get("PARITY_AUDIO_PATH") or _FORCE_ALIGN_AUDIO speaker = os.environ.get("PARITY_SPEAKER_NAME") pipeline = _build_parity_pipeline( ckpt, audio, tempfile.mkdtemp(prefix="parity-"), speaker_name=speaker, ) - report = run_parity_check(pipeline, audio) + report = run_parity_check(pipeline, audio, system_prompt=_MOCK_SYSTEM_PROMPT) assert_parity(report, strict=True, atol=0.0) From dc6a7590d49b816941b38be7c8117e8727b9edd3 Mon Sep 17 00:00:00 2001 From: Elena Rastorgueva Date: Mon, 30 Mar 2026 22:06:17 +0000 Subject: [PATCH 19/40] Refactor voicechat tests: shared fixtures, no-crash sweep, deterministic parity Signed-off-by: Elena Rastorgueva --- .../inference/model_wrappers/model_factory.py | 12 +- .../nemotron_voicechat_inference_wrapper.py | 4 + tests/collections/speechlm2/conftest.py | 280 ++++++++++++++++++ .../speechlm2/test_nemotron_voicechat.py | 215 +------------- ...est_nemotron_voicechat_pipeline_nocrash.py | 170 +++++++++++ ...est_nemotron_voicechat_pipeline_parity.py} | 247 ++------------- 6 files changed, 494 insertions(+), 434 deletions(-) create mode 100644 tests/collections/speechlm2/conftest.py create mode 100644 tests/collections/speechlm2/test_nemotron_voicechat_pipeline_nocrash.py rename tests/collections/speechlm2/{test_offline_incremental_parity.py => test_nemotron_voicechat_pipeline_parity.py} (52%) diff --git a/nemo/collections/speechlm2/inference/model_wrappers/model_factory.py b/nemo/collections/speechlm2/inference/model_wrappers/model_factory.py index 770005bbf2eb..9f8616fdcbac 100644 --- a/nemo/collections/speechlm2/inference/model_wrappers/model_factory.py +++ b/nemo/collections/speechlm2/inference/model_wrappers/model_factory.py @@ -162,7 +162,17 @@ def _sample_text_token( indices_to_remove = sorted_indices[sorted_indices_to_remove] batch_logits[indices_to_remove] = float('-inf') - # Sample from the filtered distribution + # Fall back to greedy if logits contain NaN or inf + if batch_logits.isnan().any() or batch_logits.isinf().any(): + logging.warning( + f"_sample_text_token: logits contain NaN or inf at step {current_step}, batch {b}: " + f"nan={batch_logits.isnan().sum().item()}, " + f"inf={batch_logits.isinf().sum().item()}, " + f"min={batch_logits[~batch_logits.isnan()].min().item() if not batch_logits.isnan().all() else 'all_nan'}, " + f"max={batch_logits[~batch_logits.isnan()].max().item() if not batch_logits.isnan().all() else 'all_nan'}" + ) + sampled_tokens[b] = greedy_tokens[b] + continue probs = torch.softmax(batch_logits, dim=-1) sampled_tokens[b] = torch.multinomial(probs, num_samples=1).item() diff --git a/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py b/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py index 38fd7688ae40..1a3a6e2cdfe3 100755 --- a/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py +++ b/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py @@ -92,6 +92,10 @@ def __init__(self, model_cfg: DictConfig): torch.backends.cuda.enable_flash_sdp(False) torch.backends.cuda.enable_mem_efficient_sdp(False) torch.use_deterministic_algorithms(True, warn_only=False) + else: + torch.backends.cuda.enable_flash_sdp(True) + torch.backends.cuda.enable_mem_efficient_sdp(True) + torch.use_deterministic_algorithms(False) self.model_cfg = model_cfg diff --git a/tests/collections/speechlm2/conftest.py b/tests/collections/speechlm2/conftest.py new file mode 100644 index 000000000000..d37130392370 --- /dev/null +++ b/tests/collections/speechlm2/conftest.py @@ -0,0 +1,280 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared fixtures for speechlm2 tests.""" + +from __future__ import annotations + +import json +import os + +# nemotron_voicechat_pipeline_{parity,nocrash} tests set +# torch.use_deterministic_algorithms(True), which requires CuBLAS to have a +# deterministic workspace. CuBLAS reads this env var only once — at +# initialization (first CUDA matmul in the process) — so it must be set here, +# before any fixture or test triggers CUDA work. The setting is harmless for +# non-deterministic tests: it only reserves 32 KB of extra GPU workspace and +# has no effect unless deterministic mode is active. +os.environ.setdefault("CUBLAS_WORKSPACE_CONFIG", ":4096:8") + +import numpy as np +import pytest +import soundfile as sf +import torch + +from nemo.collections.speechlm2.models import NemotronVoiceChat + +_pretrained_llm = "TinyLlama/TinyLlama_v1.1" +if os.path.exists("/home/TestData/speechlm/pretrained_models"): + _pretrained_llm = "/home/TestData/speechlm/pretrained_models/TinyLlama--TinyLlama_v1.1" + + +def _tiny_voicechat_config(*, predict_user_text: bool = True, streaming_encoder: bool = False) -> dict: + """Return a minimal NemotronVoiceChat config with random weights. + + Args: + predict_user_text: Enable ASR head for user text prediction. + streaming_encoder: When True, configure the conformer encoder for + cache-aware streaming (causal convolutions, chunked_limited + attention) matching the real checkpoint. When False, use + default (non-causal) settings suitable for offline tests. + """ + encoder_cfg: dict = { + "_target_": "nemo.collections.asr.modules.ConformerEncoder", + "feat_in": 80, + "d_model": 512, + "n_heads": 8, + "n_layers": 1, + "subsampling_factor": 8, + } + if streaming_encoder: + encoder_cfg.update({ + "subsampling": "dw_striding", + "causal_downsampling": True, + "att_context_size": [70, 0], + "att_context_style": "chunked_limited", + "conv_kernel_size": 9, + "conv_context_size": "causal", + }) + + return { + "model": { + "scoring_asr": "stt_en_fastconformer_transducer_large", + "stt": { + "model": { + "pretrained_llm": _pretrained_llm, + "pretrained_weights": False, + "predict_user_text": predict_user_text, + "audio_loss_weight": 1, + "text_loss_weight": 3, + "source_sample_rate": 16000, + "validation_save_path": "/tmp/test_duplex_stt_logs", + "perception": { + "_target_": "nemo.collections.speechlm2.modules.perception.AudioPerceptionModule", + "preprocessor": { + "_target_": "nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor", + "features": 80, + }, + "encoder": encoder_cfg, + "modality_adapter": { + "_target_": "nemo.collections.speechlm2.modules.perception.IdentityConnector", + "d_model": 512, + }, + "output_dim": 2048, + }, + "optimizer": {"_target_": "torch.optim.AdamW"}, + }, + "data": {"source_sample_rate": 16000}, + "exp_manager": {"explicit_log_dir": "/tmp/test_duplex_stt_logs"}, + }, + "speech_generation": { + "model": { + "pretrained_lm_name": _pretrained_llm, + "pretrained_ae_dir": None, + "pretrained_tts_model": None, + "scoring_asr": "stt_en_fastconformer_transducer_large", + "freeze_params": [r"^audio_codec\..+$", r"^embed_tokens\..+$"], + "bos_token": "", + "eos_token": "", + "pad_token": "", + "audio_codec_run_dtype": "float32", + "prevent_freeze_params": [], + "audio_save_path": "", + "inference_guidance_scale": 0.5, + "inference_noise_scale": 0.8, + "inference_top_p_or_k": 0.8, + "inference_guidance_enabled": False, + "subword_mask_exactly_as_eartts": False, + "context_hidden_mask_exactly_as_eartts": False, + "optimizer": { + "_target_": "torch.optim.AdamW", + "lr": 4e-5, + "betas": [0.9, 0.98], + "weight_decay": 0, + "foreach": True, + }, + "lr_scheduler": { + "_target_": "nemo.core.optim.lr_scheduler.InverseSquareRootAnnealing", + "warmup_steps": 2500, + "min_lr": 1e-6, + "max_steps": 100_000_000, + }, + "codec_config": { + "latent_size": 512, + "n_fft": 16, + "hop_length": 4, + "base_hidden_size": 384, + "channel_mult": [1, 2, 4], + "rates": [7, 7, 9], + "num_blocks": 3, + "kernel_size": 7, + "groups": 1, + "codebook_size": 1024, + "num_quantizers": 31, + "wav_to_token_ratio": 1764, + }, + "tts_config": { + "use_gated_fusion_for_text_audio": True, + "disable_eos_prediction": True, + "use_bos_eos_emb": True, + "use_subword_flag_emb": True, + "num_delay_speech_tokens": 2, + "backbone_type": "gemma3_text", + "backbone_model_class": None, + "backbone_config_class": None, + "backbone_config": { + "hidden_size": 1152, + "intermediate_size": 4608, + "num_hidden_layers": 1, + "num_attention_heads": 16, + "num_key_value_heads": 16, + "head_dim": 72, + "attention_dropout": 0.1, + "use_cache": False, + }, + "latent_size": 512, + "codebook_size": 1024, + "num_quantizers": 31, + "context_hidden_size": None, + "cas_config": { + "backbone_type": "t5gemma", + "backbone_model_class": None, + "backbone_config_class": None, + "backbone_config": { + "is_encoder_decoder": False, + "encoder": { + "hidden_size": 1152, + "intermediate_size": 4608, + "num_hidden_layers": 1, + "num_attention_heads": 16, + "num_key_value_heads": 16, + "head_dim": 72, + "use_cache": False, + "attention_dropout": 0.1, + }, + }, + }, + "mog_head_config": { + "intermediate_size": 4608, + "num_layers": 3, + "low_rank": 64, + "num_predictions": 1024, + "min_log_std": -4.0, + "eps": 1e-6, + }, + "p_uncond": 0.1, + "label_smoothing": 0.01, + "max_training_rate": 0.8, + "quantizer_dropout": 0.5, + "random_target_masking": False, + "exponent": 3.0, + }, + }, + "data": { + "add_text_bos_and_eos_in_each_turn": True, + "add_audio_prompt": True, + "audio_prompt_duration": 3.0, + "frame_length": 0.08, + "source_sample_rate": 16000, + "target_sample_rate": 22050, + }, + "exp_manager": {"explicit_log_dir": "/tmp/test_duplex_stt_logs"}, + }, + }, + "data": { + "frame_length": 0.08, + "source_sample_rate": 16000, + "target_sample_rate": 22050, + "input_roles": ["user", "User"], + "output_roles": ["agent", "Assistant", "assistant", "Agent"], + }, + "exp_manager": {"explicit_log_dir": "/tmp/test_parity_logs"}, + } + + +@pytest.fixture(scope="session") +def tiny_model_artifacts(tmp_path_factory): + """Build a tiny NemotronVoiceChat with random weights, write test audio files. + + Session-scoped so the model is built only once across all test modules. + The fixture returns only file paths (immutable), so sharing is safe. + + Returns ``(model_dir, audio_path, speaker_ref_path)``. + """ + base = tmp_path_factory.mktemp("tiny_model") + + audio_path = str(base / "test_audio.wav") + sf.write(audio_path, np.random.RandomState(42).randn(3 * 16000).astype(np.float32), 16000) + + speaker_ref_path = str(base / "speaker_ref.wav") + sf.write(speaker_ref_path, np.random.RandomState(99).randn(22050).astype(np.float32), 22050) + + cfg = _tiny_voicechat_config(streaming_encoder=True) + model = NemotronVoiceChat(cfg) + model.to("cuda") + model.eval() + + model_dir = str(base / "model") + model.save_pretrained(model_dir) + + # save_pretrained writes the tokenizer to llm_artifacts/, but config.json + # still references the HF hub name (e.g. "TinyLlama/TinyLlama_v1.1"). + # Save the LLM model config alongside the tokenizer so llm_artifacts/ + # is a complete local model reference, then rewrite config.json to point + # at it. This avoids HuggingFace network requests on every from_pretrained. + llm_artifacts = os.path.join(model_dir, "llm_artifacts") + model.stt_model.llm.config.save_pretrained(llm_artifacts) + cfg["model"]["stt"]["model"]["pretrained_llm"] = llm_artifacts + cfg["model"]["speech_generation"]["model"]["pretrained_lm_name"] = llm_artifacts + with open(os.path.join(model_dir, "config.json"), "w") as f: + json.dump(cfg, f) + + del model + torch.cuda.empty_cache() + + return model_dir, audio_path, speaker_ref_path + + +@pytest.fixture(scope="session") +def tiny_voicechat_model(): + """Build a tiny NemotronVoiceChat model (predict_user_text=False). + + Used by ``test_nemotron_voicechat.py`` for validation and offline + generation tests. + """ + cfg = _tiny_voicechat_config(predict_user_text=False) + model = NemotronVoiceChat(cfg) + if torch.cuda.is_available(): + model.to("cuda") + return model diff --git a/tests/collections/speechlm2/test_nemotron_voicechat.py b/tests/collections/speechlm2/test_nemotron_voicechat.py index 4f2b431e4f25..5afd176f08fc 100644 --- a/tests/collections/speechlm2/test_nemotron_voicechat.py +++ b/tests/collections/speechlm2/test_nemotron_voicechat.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os - import pytest import torch from lhotse import CutSet, SupervisionSegment @@ -21,218 +19,11 @@ from nemo.collections.common.data.utils import move_data_to_device from nemo.collections.speechlm2 import DuplexSTTDataset -from nemo.collections.speechlm2.models import NemotronVoiceChat - -if torch.cuda.is_available(): - torch.set_default_device('cuda') - - -pretrained_llm = "TinyLlama/TinyLlama_v1.1" -if os.path.exists("/home/TestData/speechlm/pretrained_models"): - pretrained_llm = "/home/TestData/speechlm/pretrained_models/TinyLlama--TinyLlama_v1.1" - -# STT sampling rate -source_sample_rate = 16000 -# TTS sampling rate -target_sample_rate = 22050 - - -def create_model( - predict_user_text=False, - force_use_noise_augmentation=False, - old_noise_prob=0.0, - old_noise_min_snr=0.0, - old_noise_max_snr=0.0, -): - """Helper function to create a model with configurable settings.""" - test_stt_cfg = { - "model": { - "pretrained_llm": pretrained_llm, - "pretrained_weights": False, - "audio_loss_weight": 1, - "text_loss_weight": 3, - "source_sample_rate": source_sample_rate, - "validation_save_path": "/tmp/test_duplex_stt_logs", - "perception": { - "_target_": "nemo.collections.speechlm2.modules.perception.AudioPerceptionModule", - "preprocessor": { - "_target_": "nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor", - "features": 80, - }, - "encoder": { - "_target_": "nemo.collections.asr.modules.ConformerEncoder", - "feat_in": 80, - "d_model": 512, - "n_heads": 8, - "n_layers": 1, - "subsampling_factor": 8, - }, - "modality_adapter": { - "_target_": "nemo.collections.speechlm2.modules.perception.IdentityConnector", - "d_model": 512, - }, - "output_dim": 2048, - }, - "predict_user_text": predict_user_text, - "force_use_noise_augmentation": force_use_noise_augmentation, - "old_noise_prob": old_noise_prob, - "old_noise_min_snr": old_noise_min_snr, - "old_noise_max_snr": old_noise_max_snr, - "optimizer": {"_target_": "torch.optim.AdamW"}, - }, - "data": { - "source_sample_rate": 16000, - }, - "exp_manager": { - "explicit_log_dir": "/tmp/test_duplex_stt_logs", - }, - } - - test_tts_config = { - "model": { - "pretrained_lm_name": pretrained_llm, - "pretrained_ae_dir": None, - "pretrained_tts_model": None, - "scoring_asr": "stt_en_fastconformer_transducer_large", - "freeze_params": [ - r"^audio_codec\..+$", # Keep audio codec frozen as it only provides supervision for training. - r"^embed_tokens\..+$", # Keep embed_tokens frozen as done in eartts - ], - "bos_token": "", - "eos_token": "", - "pad_token": "", - "audio_codec_run_dtype": "float32", - "prevent_freeze_params": [], - "audio_save_path": "", - "inference_guidance_scale": 0.5, - "inference_noise_scale": 0.8, - "inference_top_p_or_k": 0.8, - "inference_guidance_enabled": False, - "subword_mask_exactly_as_eartts": False, - "context_hidden_mask_exactly_as_eartts": False, - "optimizer": { - "_target_": "torch.optim.AdamW", - "lr": 4e-5, - "betas": [0.9, 0.98], - "weight_decay": 0, - "foreach": True, - }, - "lr_scheduler": { - "_target_": "nemo.core.optim.lr_scheduler.InverseSquareRootAnnealing", - "warmup_steps": 2500, - "min_lr": 1e-6, - "max_steps": 100_000_000, - }, - "codec_config": { - "latent_size": 512, - "n_fft": 16, - "hop_length": 4, - "base_hidden_size": 384, - "channel_mult": [1, 2, 4], - "rates": [7, 7, 9], - "num_blocks": 3, - "kernel_size": 7, - "groups": 1, - "codebook_size": 1024, - "num_quantizers": 31, - "wav_to_token_ratio": 1764, - }, - "tts_config": { - "use_gated_fusion_for_text_audio": True, - "disable_eos_prediction": True, - "use_bos_eos_emb": True, - "use_subword_flag_emb": True, - "num_delay_speech_tokens": 2, - "backbone_type": "gemma3_text", - "backbone_model_class": None, - "backbone_config_class": None, - "backbone_config": { - "hidden_size": 1152, - "intermediate_size": 4608, - "num_hidden_layers": 1, - "num_attention_heads": 16, - "num_key_value_heads": 16, - "head_dim": 72, - "attention_dropout": 0.1, - "use_cache": False, - }, - "latent_size": 512, - "codebook_size": 1024, - "num_quantizers": 31, - "context_hidden_size": None, - "cas_config": { - "backbone_type": "t5gemma", - "backbone_model_class": None, - "backbone_config_class": None, - "backbone_config": { - "is_encoder_decoder": False, - "encoder": { - "hidden_size": 1152, - "intermediate_size": 4608, - "num_hidden_layers": 1, - "num_attention_heads": 16, - "num_key_value_heads": 16, - "head_dim": 72, - "use_cache": False, - "attention_dropout": 0.1, - }, - }, - }, - "mog_head_config": { - "intermediate_size": 4608, - "num_layers": 3, - "low_rank": 64, - "num_predictions": 1024, - "min_log_std": -4.0, - "eps": 1e-6, - }, - "p_uncond": 0.1, - "label_smoothing": 0.01, - "max_training_rate": 0.8, - "quantizer_dropout": 0.5, - "random_target_masking": False, - "exponent": 3.0, - }, - }, - "data": { - "add_text_bos_and_eos_in_each_turn": True, - "add_audio_prompt": True, - "audio_prompt_duration": 3.0, - "frame_length": 0.08, - "source_sample_rate": source_sample_rate, - "target_sample_rate": target_sample_rate, - }, - "exp_manager": { - "explicit_log_dir": "/tmp/test_duplex_stt_logs", - }, - } - - test_config = { - "model": { - "scoring_asr": "stt_en_fastconformer_transducer_large", - "stt": test_stt_cfg, - "speech_generation": test_tts_config, - }, - "data": { - "frame_length": 0.08, - "source_sample_rate": source_sample_rate, - "target_sample_rate": target_sample_rate, - "input_roles": ["user", "User"], - "output_roles": ["agent", "Assistant", "assistant", "Agent"], - }, - "exp_manager": { - "explicit_log_dir": "/tmp/test_nemotron_voicechat_logs", - }, - } - model = NemotronVoiceChat(test_config) - if torch.cuda.is_available(): - model.to("cuda") - return model @pytest.fixture(scope="session") -def model(): - return create_model(predict_user_text=False) +def model(tiny_voicechat_model): + return tiny_voicechat_model @pytest.fixture(scope="session") @@ -240,7 +31,7 @@ def dataset(model): return DuplexSTTDataset( model.stt_model.tokenizer, frame_length=0.08, - source_sample_rate=source_sample_rate, + source_sample_rate=16000, input_roles=["user"], output_roles=["assistant"], ) diff --git a/tests/collections/speechlm2/test_nemotron_voicechat_pipeline_nocrash.py b/tests/collections/speechlm2/test_nemotron_voicechat_pipeline_nocrash.py new file mode 100644 index 000000000000..18ab6b4861b0 --- /dev/null +++ b/tests/collections/speechlm2/test_nemotron_voicechat_pipeline_nocrash.py @@ -0,0 +1,170 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""No-crash pipeline tests for NemotronVoiceChat streaming inference. + +Exercises ``StreamingS2SPipeline.run()`` with a tiny random-weight model +under various config combinations. Each test verifies only that the +pipeline completes without raising — no output quality checks. + +Run from the NeMo repo root:: + + CUDA_VISIBLE_DEVICES=0 pytest tests/collections/speechlm2/test_nemotron_voicechat_pipeline_nocrash.py -v -s +""" + +from __future__ import annotations + +import os +import tempfile +from typing import Any + +import pytest +import torch +from omegaconf import OmegaConf + +from nemo.collections.speechlm2.inference.factory.s2s_pipeline_builder import S2SPipelineBuilder +from nemo.collections.speechlm2.inference.pipelines.streaming_s2s_pipeline import StreamingS2SPipeline + +_CONF_YAML = os.path.join( + os.path.dirname(__file__), + "../../../examples/speechlm2/nemo_inference_pipelines/conf/s2s_streaming.yaml", +) +_MOCK_SYSTEM_PROMPT = "This is a mock prompt for the test" + +# --------------------------------------------------------------------------- +# Helper +# --------------------------------------------------------------------------- + + +def _build_no_crash_pipeline( + model_path: str, + audio_path: str, + output_dir: str, + *, + s2s_overrides: dict[str, Any] | None = None, + streaming_overrides: dict[str, Any] | None = None, +) -> StreamingS2SPipeline: + """Build a :class:`StreamingS2SPipeline` with custom overrides for no-crash testing.""" + cfg = OmegaConf.load(_CONF_YAML) + + s2s_cfg: dict[str, Any] = { + "model_path": model_path, + "engine_type": "native", + "compute_dtype": "float32", + "deterministic": False, + "decode_audio": False, + "use_perception_cache": False, + "use_perception_cudagraph": False, + "use_llm_cache": False, + "system_prompt": None, + "top_p": 1.0, + "repetition_penalty": 1.0, + "temperature": 1.0, + } + streaming_cfg: dict[str, Any] = { + "chunk_size_in_secs": 0.08, + "buffer_size_in_secs": 71 * 0.08, + } + + if s2s_overrides: + s2s_cfg.update(s2s_overrides) + if streaming_overrides: + streaming_cfg.update(streaming_overrides) + + overrides = { + "audio_file": audio_path, + "output_dir": output_dir, + "s2s": s2s_cfg, + "streaming": streaming_cfg, + } + cfg = OmegaConf.merge(cfg, OmegaConf.create(overrides)) + return S2SPipelineBuilder.build_pipeline(cfg) + + +# --------------------------------------------------------------------------- +# Parametrized configs +# --------------------------------------------------------------------------- + +# Text-only configs (decode_audio=False): minimal STT-path smoke checks. +# Most config variations are folded into the audio tests below. +_TEXT_CONFIGS = [ + pytest.param({}, {}, id="baseline"), + pytest.param( + {"use_llm_cache": True, "use_perception_cache": True}, + {}, + id="both_caches", + ), +] + +# Audio configs (decode_audio=True): exercises the full STT + TTS pipeline. +_AUDIO_CONFIGS = [ + pytest.param({}, {}, id="baseline"), + pytest.param( + {"use_llm_cache": True, "use_perception_cache": True, "system_prompt": _MOCK_SYSTEM_PROMPT}, + {"chunk_size_in_secs": 0.24}, + id="both_caches_prompt_multiframe", + ), + pytest.param( + {"use_llm_cache": True, "top_p": 0.9, "temperature": 0.7, "repetition_penalty": 1.1}, + {}, + id="sampling", + ), + pytest.param( + {"use_tts_subword_cache": True, "use_tts_torch_compile": True}, + {}, + id="tts_optimizations", + ), + pytest.param( + {"deterministic": True, "temperature": 0.0}, + {}, + id="deterministic", + ), +] + +# --------------------------------------------------------------------------- +# Tests (tiny_model_artifacts fixture is provided by conftest.py) +# --------------------------------------------------------------------------- + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU") +@pytest.mark.parametrize("s2s_overrides,streaming_overrides", _TEXT_CONFIGS) +def test_pipeline_no_crash_tiny_model(tiny_model_artifacts, s2s_overrides, streaming_overrides): + """Run the streaming pipeline with various configs and verify it doesn't crash.""" + model_dir, audio_path, _ = tiny_model_artifacts + output_dir = tempfile.mkdtemp(prefix="no-crash-text-") + + pipeline = _build_no_crash_pipeline( + model_dir, audio_path, output_dir, + s2s_overrides=s2s_overrides, streaming_overrides=streaming_overrides, + ) + result = pipeline.run([audio_path]) + assert result is not None + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU") +@pytest.mark.parametrize("s2s_overrides,streaming_overrides", _AUDIO_CONFIGS) +def test_pipeline_no_crash_tiny_model_decode_audio(tiny_model_artifacts, s2s_overrides, streaming_overrides): + """Run the streaming pipeline with decode_audio=True and verify it doesn't crash.""" + model_dir, audio_path, speaker_ref_path = tiny_model_artifacts + output_dir = tempfile.mkdtemp(prefix="no-crash-audio-") + + audio_overrides = {"decode_audio": True, "speaker_reference": speaker_ref_path} + audio_overrides.update(s2s_overrides) + + pipeline = _build_no_crash_pipeline( + model_dir, audio_path, output_dir, + s2s_overrides=audio_overrides, streaming_overrides=streaming_overrides, + ) + result = pipeline.run([audio_path]) + assert result is not None diff --git a/tests/collections/speechlm2/test_offline_incremental_parity.py b/tests/collections/speechlm2/test_nemotron_voicechat_pipeline_parity.py similarity index 52% rename from tests/collections/speechlm2/test_offline_incremental_parity.py rename to tests/collections/speechlm2/test_nemotron_voicechat_pipeline_parity.py index 612bbf98769b..b4f61941c4d7 100644 --- a/tests/collections/speechlm2/test_offline_incremental_parity.py +++ b/tests/collections/speechlm2/test_nemotron_voicechat_pipeline_parity.py @@ -27,23 +27,22 @@ Run from the NeMo repo root (use ``-s`` to see live progress):: # unit tests only - CUDA_VISIBLE_DEVICES=0 pytest tests/collections/speechlm2/test_offline_incremental_parity.py -v -s + CUDA_VISIBLE_DEVICES=0 pytest tests/collections/speechlm2/test_nemotron_voicechat_pipeline_parity.py -v -s # include integration test PARITY_CHECKPOINT_PATH=... \\ - CUDA_VISIBLE_DEVICES=0 pytest tests/collections/speechlm2/test_offline_incremental_parity.py -v -s + CUDA_VISIBLE_DEVICES=0 pytest tests/collections/speechlm2/test_nemotron_voicechat_pipeline_parity.py -v -s """ from __future__ import annotations import math import os +import tempfile import time from typing import Any -import numpy as np import pytest -import soundfile as sf import torch from omegaconf import OmegaConf @@ -56,7 +55,6 @@ ) from nemo.collections.speechlm2.inference.pipelines.streaming_s2s_pipeline import StreamingS2SPipeline from nemo.collections.speechlm2.inference.streaming.framing.s2s_request_options import S2SRequestOptions -from nemo.collections.speechlm2.models import NemotronVoiceChat _CONF_YAML = os.path.join( os.path.dirname(__file__), @@ -67,6 +65,7 @@ "test_data", "force_align_test.mp3", ) +_MOCK_SYSTEM_PROMPT = "This is a mock prompt for the test" # --------------------------------------------------------------------------- # Helpers @@ -138,7 +137,6 @@ def run_parity_check( wrapper = pipeline.s2s_model audio, audio_lens = _load_and_pad_audio(audio_path, wrapper.device, wrapper.dtype) - # Prompt tokens for the offline path prompt_tokens = prompt_token_lens = None if system_prompt: tok = wrapper.tokenizer @@ -146,7 +144,6 @@ def run_parity_check( prompt_tokens = torch.tensor(ids, device=wrapper.device, dtype=torch.long).unsqueeze(0) prompt_token_lens = torch.tensor([len(ids)], device=wrapper.device, dtype=torch.long) - # offline_inference requires speaker info for TTS init if wrapper.speaker_name is not None: OmegaConf.update(wrapper.model.cfg, "inference_speaker_name", wrapper.speaker_name, force_add=True) elif wrapper.speaker_reference: @@ -229,189 +226,13 @@ def assert_parity( assert not failures, "Parity failed:\n " + "\n ".join(failures) -# --------------------------------------------------------------------------- -# Tiny-model configuration (derived from test_nemotron_voicechat.py) -# --------------------------------------------------------------------------- - -_pretrained_llm = "TinyLlama/TinyLlama_v1.1" -if os.path.exists("/home/TestData/speechlm/pretrained_models"): - _pretrained_llm = "/home/TestData/speechlm/pretrained_models/TinyLlama--TinyLlama_v1.1" - - -def _tiny_voicechat_config() -> dict: - """Return a minimal NemotronVoiceChat config with random weights.""" - return { - "model": { - "scoring_asr": "stt_en_fastconformer_transducer_large", - "stt": { - "model": { - "pretrained_llm": _pretrained_llm, - "pretrained_weights": False, - "predict_user_text": True, - "audio_loss_weight": 1, - "text_loss_weight": 3, - "source_sample_rate": 16000, - "validation_save_path": "/tmp/test_parity_stt_logs", - "perception": { - "_target_": "nemo.collections.speechlm2.modules.perception.AudioPerceptionModule", - "preprocessor": { - "_target_": "nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor", - "features": 80, - }, - "encoder": { - "_target_": "nemo.collections.asr.modules.ConformerEncoder", - "feat_in": 80, - "d_model": 512, - "n_heads": 8, - "n_layers": 1, - "subsampling_factor": 8, - }, - "modality_adapter": { - "_target_": "nemo.collections.speechlm2.modules.perception.IdentityConnector", - "d_model": 512, - }, - "output_dim": 2048, - }, - "optimizer": {"_target_": "torch.optim.AdamW"}, - }, - "data": {"source_sample_rate": 16000}, - "exp_manager": {"explicit_log_dir": "/tmp/test_parity_stt_logs"}, - }, - "speech_generation": { - "model": { - "pretrained_lm_name": _pretrained_llm, - "pretrained_ae_dir": None, - "pretrained_tts_model": None, - "scoring_asr": "stt_en_fastconformer_transducer_large", - "freeze_params": [r"^audio_codec\..+$", r"^embed_tokens\..+$"], - "bos_token": "", - "eos_token": "", - "pad_token": "", - "audio_codec_run_dtype": "float32", - "prevent_freeze_params": [], - "audio_save_path": "", - "inference_guidance_scale": 0.5, - "inference_noise_scale": 0.8, - "inference_top_p_or_k": 0.8, - "inference_guidance_enabled": False, - "subword_mask_exactly_as_eartts": False, - "context_hidden_mask_exactly_as_eartts": False, - "optimizer": { - "_target_": "torch.optim.AdamW", - "lr": 4e-5, - "betas": [0.9, 0.98], - "weight_decay": 0, - "foreach": True, - }, - "lr_scheduler": { - "_target_": "nemo.core.optim.lr_scheduler.InverseSquareRootAnnealing", - "warmup_steps": 2500, - "min_lr": 1e-6, - "max_steps": 100_000_000, - }, - "codec_config": { - "latent_size": 512, - "n_fft": 16, - "hop_length": 4, - "base_hidden_size": 384, - "channel_mult": [1, 2, 4], - "rates": [7, 7, 9], - "num_blocks": 3, - "kernel_size": 7, - "groups": 1, - "codebook_size": 1024, - "num_quantizers": 31, - "wav_to_token_ratio": 1764, - }, - "tts_config": { - "use_gated_fusion_for_text_audio": True, - "disable_eos_prediction": True, - "use_bos_eos_emb": True, - "use_subword_flag_emb": True, - "num_delay_speech_tokens": 2, - "backbone_type": "gemma3_text", - "backbone_model_class": None, - "backbone_config_class": None, - "backbone_config": { - "hidden_size": 1152, - "intermediate_size": 4608, - "num_hidden_layers": 1, - "num_attention_heads": 16, - "num_key_value_heads": 16, - "head_dim": 72, - "attention_dropout": 0.1, - "use_cache": False, - }, - "latent_size": 512, - "codebook_size": 1024, - "num_quantizers": 31, - "context_hidden_size": None, - "cas_config": { - "backbone_type": "t5gemma", - "backbone_model_class": None, - "backbone_config_class": None, - "backbone_config": { - "is_encoder_decoder": False, - "encoder": { - "hidden_size": 1152, - "intermediate_size": 4608, - "num_hidden_layers": 1, - "num_attention_heads": 16, - "num_key_value_heads": 16, - "head_dim": 72, - "use_cache": False, - "attention_dropout": 0.1, - }, - }, - }, - "mog_head_config": { - "intermediate_size": 4608, - "num_layers": 3, - "low_rank": 64, - "num_predictions": 1024, - "min_log_std": -4.0, - "eps": 1e-6, - }, - "p_uncond": 0.1, - "label_smoothing": 0.01, - "max_training_rate": 0.8, - "quantizer_dropout": 0.5, - "random_target_masking": False, - "exponent": 3.0, - }, - }, - "data": { - "add_text_bos_and_eos_in_each_turn": True, - "add_audio_prompt": True, - "audio_prompt_duration": 3.0, - "frame_length": 0.08, - "source_sample_rate": 16000, - "target_sample_rate": 22050, - }, - "exp_manager": {"explicit_log_dir": "/tmp/test_parity_tts_logs"}, - }, - }, - "data": { - "frame_length": 0.08, - "source_sample_rate": 16000, - "target_sample_rate": 22050, - "input_roles": ["user", "User"], - "output_roles": ["agent", "Assistant", "assistant", "Agent"], - }, - "exp_manager": {"explicit_log_dir": "/tmp/test_parity_logs"}, - } - - -_MOCK_SYSTEM_PROMPT = "This is a mock prompt for the test" - - def _build_parity_pipeline( model_path: str, audio_path: str, output_dir: str, *, speaker_name: str | None = None, - system_prompt: str | None = _MOCK_SYSTEM_PROMPT, + system_prompt: str | None = None, ) -> StreamingS2SPipeline: """Build a :class:`StreamingS2SPipeline` configured for strict parity testing. @@ -430,21 +251,21 @@ def _build_parity_pipeline( "output_dir": output_dir, "s2s": { "model_path": model_path, - "engine_type": "native", # offline model can only be run with "native" - no vllm support - "compute_dtype": "float32", # online code would only cast some layers to "compute_dtype" => let's keep everything in float32 for parity - "deterministic": False, # "deterministic" doesn't seem to be necessary for results to match, so let's go without it - "decode_audio": False, # parity test is only for comparing text outputs, not audio - "use_perception_cache": False, # results are slightly different with & without cache. offline does not use perception cache - "use_perception_cudagraph": False, # because not using perception cache - "use_llm_cache": False, # llm cache on/off will affect results. Offline code does not currently support llm cache. - "system_prompt": system_prompt, # use a system prompt to make test more "difficult" - "top_p": 1.0, # greedy decoding because offline decoding does not support sampling parameters - "repetition_penalty": 1.0, # greedy decoding because offline decoding does not support sampling parameters - "temperature": 1.0, # greedy decoding because offline decoding does not support sampling parameters + "engine_type": "native", + "compute_dtype": "float32", + "deterministic": True, + "decode_audio": False, + "use_perception_cache": False, + "use_perception_cudagraph": False, + "use_llm_cache": False, + "system_prompt": system_prompt, + "top_p": 1.0, + "repetition_penalty": 1.0, + "temperature": 1.0, }, "streaming": { "chunk_size_in_secs": chunk_secs, - "buffer_size_in_secs": max(71 * 0.08, chunk_secs), # buffer size needs to be equal or longer than the audio input to guarantee parity + "buffer_size_in_secs": max(71 * 0.08, chunk_secs), }, } if speaker_name: @@ -454,36 +275,21 @@ def _build_parity_pipeline( # --------------------------------------------------------------------------- -# Parity test -- tiny model (no real checkpoint needed) +# Parity test -- tiny model (uses tiny_model_artifacts from conftest.py) # --------------------------------------------------------------------------- @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU") -def test_parity_tiny_model(tmp_path): +def test_parity_tiny_model(tiny_model_artifacts): """Offline/incremental parity with a tiny random-weight model. - Saves the model as an HF checkpoint, then loads it through the real - ``S2SPipelineBuilder`` so the test exercises the same code path as - ``test_parity_real_checkpoint``. + Loads the model through the real ``S2SPipelineBuilder`` so the test + exercises the same code path as ``test_parity_real_checkpoint``. """ - import json as _json - - audio_path = str(tmp_path / "test_audio.wav") - sf.write(audio_path, np.random.RandomState(42).randn(16000).astype(np.float32), 16000) - - cfg = _tiny_voicechat_config() - model = NemotronVoiceChat(cfg) - model.to("cuda") - model.eval() + model_dir, audio_path, _ = tiny_model_artifacts + output_dir = tempfile.mkdtemp(prefix="parity-tiny-") - model_dir = str(tmp_path / "model") - model.save_pretrained(model_dir) - with open(os.path.join(model_dir, "config.json"), "w") as f: - _json.dump(cfg, f) - del model - torch.cuda.empty_cache() - - pipeline = _build_parity_pipeline(model_dir, audio_path, str(tmp_path / "output")) + pipeline = _build_parity_pipeline(model_dir, audio_path, output_dir, system_prompt=_MOCK_SYSTEM_PROMPT) report = run_parity_check(pipeline, audio_path, system_prompt=_MOCK_SYSTEM_PROMPT) assert_parity(report, strict=True, atol=0.0) @@ -509,14 +315,13 @@ def test_parity_real_checkpoint(): PARITY_AUDIO_PATH=/path/to/test.wav # optional, defaults to force_align_test.mp3 PARITY_SPEAKER_NAME= # optional """ - import tempfile - ckpt = os.environ["PARITY_CHECKPOINT_PATH"] audio = os.environ.get("PARITY_AUDIO_PATH") or _FORCE_ALIGN_AUDIO speaker = os.environ.get("PARITY_SPEAKER_NAME") pipeline = _build_parity_pipeline( - ckpt, audio, tempfile.mkdtemp(prefix="parity-"), speaker_name=speaker, + ckpt, audio, tempfile.mkdtemp(prefix="parity-"), + speaker_name=speaker, system_prompt=_MOCK_SYSTEM_PROMPT, ) report = run_parity_check(pipeline, audio, system_prompt=_MOCK_SYSTEM_PROMPT) assert_parity(report, strict=True, atol=0.0) From 6e98c8597ae6a6f2d4de94c321f316e8a205e131 Mon Sep 17 00:00:00 2001 From: Elena Rastorgueva Date: Mon, 30 Mar 2026 23:37:51 +0000 Subject: [PATCH 20/40] Fix byte-level BPE decoding in raw output: unify tokens_to_str and tokens_to_str_raw Signed-off-by: Elena Rastorgueva --- .../nemotron_voicechat_inference_wrapper.py | 22 +++--- .../pipelines/streaming_s2s_pipeline.py | 10 +-- .../inference/streaming/state/s2s_state.py | 2 +- .../inference/utils/pipeline_utils.py | 26 ------- .../collections/speechlm2/parts/text_utils.py | 69 +++++++++++++++++-- 5 files changed, 83 insertions(+), 46 deletions(-) diff --git a/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py b/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py index 1a3a6e2cdfe3..3926d97c0ea3 100755 --- a/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py +++ b/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py @@ -39,6 +39,7 @@ InferenceStepResult, StreamingDecodeState, ) +from nemo.collections.speechlm2.parts.text_utils import _decode_tokens_with_specials # --- Configuration --- @@ -917,10 +918,17 @@ def _run_perception( def _tokens_to_strings(self, token_ids: torch.Tensor) -> list[str]: """Convert a [B, T] tensor of token IDs to a list of strings. - Uses tokens_to_text (convert_tokens_to_string) so byte-level BPE is - decoded properly (e.g. "é" -> "é") and leading spaces from - Ġ-prefixed tokens are preserved for correct concatenation of - incremental chunks: " Musée" + " National" -> " Musée National". + Uses ``_decode_tokens_with_specials`` so byte-level BPE is decoded + properly (e.g. ``âĢĻ`` -> ``'``) via HF ``convert_tokens_to_string``. + + Leading spaces are preserved in the output: in byte-level BPE, + word-initial tokens carry a space prefix that ``convert_tokens_to_string`` + keeps intact. So callers can concatenate successive chunk strings to + recover properly spaced text. A leading space means "new word"; no + leading space means the token continues the previous word. For + example, three chunks producing ``"Hi"``, ``" how can"``, + ``" I help"`` concatenate to ``"Hi how can I help"`` (not + ``"Hihow canI help"``). NOTE: multi-byte UTF-8 characters whose BPE tokens span two frames will show as replacement chars (U+FFFD) because each frame is decoded @@ -928,10 +936,8 @@ def _tokens_to_strings(self, token_ids: torch.Tensor) -> list[str]: """ result = [] for tok_ids_b in token_ids: - tok_ids_b = tok_ids_b.tolist() - toks = self.tokenizer.ids_to_tokens(tok_ids_b) - toks = [t for t in toks if t != ''] - result.append(self.tokenizer.tokens_to_text(toks)) + toks = self.tokenizer.ids_to_tokens(tok_ids_b.tolist()) + result.append(_decode_tokens_with_specials(toks, self.tokenizer, keep_pad=False)) return result def abort_request(self, request_id: Optional[str]) -> bool: diff --git a/nemo/collections/speechlm2/inference/pipelines/streaming_s2s_pipeline.py b/nemo/collections/speechlm2/inference/pipelines/streaming_s2s_pipeline.py index 11b6dfd1f835..5bb97dcbe739 100644 --- a/nemo/collections/speechlm2/inference/pipelines/streaming_s2s_pipeline.py +++ b/nemo/collections/speechlm2/inference/pipelines/streaming_s2s_pipeline.py @@ -35,7 +35,7 @@ from nemo.collections.speechlm2.parts.text_utils import tokens_to_str from nemo.collections.speechlm2.inference.streaming.state.s2s_context_manager import S2SContextManager from nemo.collections.speechlm2.inference.streaming.framing.s2s_request_options import S2SRequestOptions -from nemo.collections.speechlm2.inference.utils.pipeline_utils import PipelineOutput, tokens_to_str_raw +from nemo.collections.speechlm2.inference.utils.pipeline_utils import PipelineOutput from nemo.utils import logging @@ -237,7 +237,7 @@ def inner_generate_step(self, frames: List[Frame], buffers: List[Tensor], ready_ self.context_manager.update_context(stream_ids, result, self.num_frames_per_chunk) # Save full token tensors to state before the context is destroyed, - # so we can run tokens_to_str / tokens_to_str_raw post-hoc. + # so we can run tokens_to_str post-hoc. for stream_id, eos_flag in zip(stream_ids, eos_flags): if eos_flag: ctx = self.context_manager.slot_contexts[ @@ -650,14 +650,14 @@ def _build_pipeline_output( tokens_to_str(gen_asr_text, lengths, tokenizer=tokenizer, pad_id=pad_id, eval_text_turn_taking=True)[0] ) raw_texts.append( - tokens_to_str_raw(gen_text, lengths, tokenizer=tokenizer, pad_id=pad_id)[0] + tokens_to_str(gen_text, lengths, tokenizer=tokenizer, pad_id=pad_id, keep_pad=True)[0] ) raw_asr_texts.append( - tokens_to_str_raw(gen_asr_text, lengths, tokenizer=tokenizer, pad_id=pad_id)[0] + tokens_to_str(gen_asr_text, lengths, tokenizer=tokenizer, pad_id=pad_id, keep_pad=True)[0] ) if gen_function_text is not None: fc_text = tokens_to_str(gen_function_text, lengths, tokenizer=tokenizer, pad_id=pad_id, eval_text_turn_taking=False)[0] - fc_text_raw = tokens_to_str_raw(gen_function_text, lengths, tokenizer=tokenizer, pad_id=pad_id)[0] + fc_text_raw = tokens_to_str(gen_function_text, lengths, tokenizer=tokenizer, pad_id=pad_id, keep_pad=True)[0] logging.info(f"Function calling channel: {fc_text}") else: token_texts.append(None) diff --git a/nemo/collections/speechlm2/inference/streaming/state/s2s_state.py b/nemo/collections/speechlm2/inference/streaming/state/s2s_state.py index 7ddb6c1f0f0b..76fd9aa8682c 100644 --- a/nemo/collections/speechlm2/inference/streaming/state/s2s_state.py +++ b/nemo/collections/speechlm2/inference/streaming/state/s2s_state.py @@ -51,7 +51,7 @@ class S2SStreamingState: # Snapshots of full token-ID tensors, saved from StreamingDecodeState # before the decode context is destroyed at end-of-stream. - # Used for post-hoc tokens_to_str / tokens_to_str_raw conversion. + # Used for post-hoc tokens_to_str conversion. final_gen_text: Optional[torch.Tensor] = None final_gen_asr_text: Optional[torch.Tensor] = None final_gen_function_text: Optional[torch.Tensor] = None diff --git a/nemo/collections/speechlm2/inference/utils/pipeline_utils.py b/nemo/collections/speechlm2/inference/utils/pipeline_utils.py index d61e2998d5ee..0886431dce20 100644 --- a/nemo/collections/speechlm2/inference/utils/pipeline_utils.py +++ b/nemo/collections/speechlm2/inference/utils/pipeline_utils.py @@ -20,31 +20,6 @@ from nemo.collections.asr.inference.utils.text_segment import Word -def tokens_to_str_raw(tokens: torch.Tensor, lengths: torch.Tensor, tokenizer, pad_id: int) -> list: - """Convert token IDs to text strings, preserving ALL special tokens including pad tokens. - - Unlike ``tokens_to_str``, this function uses ``ids_to_tokens`` which preserves - special tokens and does NOT filter out any tokens (including pad tokens like - ````). - - Args: - tokens: Token IDs tensor (B, T). - lengths: Length of each sequence (B,). - tokenizer: Tokenizer for decoding. - pad_id: Pad token ID (kept for API compatibility with ``tokens_to_str``). - - Returns: - List of decoded text strings with ALL special tokens preserved. - """ - ans = [] - for hyp_ids, hyp_len in zip(tokens.cpu(), lengths.cpu()): - hyp_ids = hyp_ids[:hyp_len] - toks = tokenizer.ids_to_tokens(hyp_ids.tolist()) - toks = [tok.replace('Ġ', ' ') for tok in toks] - ans.append("".join(toks)) - return ans - - def clean_pred_text(text: str) -> str: """Clean prediction text by removing special markers, timestamps, punctuation, and lowercasing. @@ -57,7 +32,6 @@ def clean_pred_text(text: str) -> str: text = re.sub(r'<\$[\d.]+\$>', '', text) text = re.sub(r'<\|[\d.]+\|>', '', text) text = re.sub(r'', '', text) - text = text.replace('\u0120', ' ') text = text.lower() text = re.sub(r'[^\w\s]', '', text) return ' '.join(text.split()) diff --git a/nemo/collections/speechlm2/parts/text_utils.py b/nemo/collections/speechlm2/parts/text_utils.py index 8c2a36facf9b..8b983fab4e5f 100644 --- a/nemo/collections/speechlm2/parts/text_utils.py +++ b/nemo/collections/speechlm2/parts/text_utils.py @@ -16,6 +16,57 @@ from nemo.collections.common.tokenizers import AutoTokenizer +def _decode_tokens_with_specials( + token_strings: list[str], + tokenizer, + pad_token_str: str = '', + keep_pad: bool = False, +) -> str: + """Decode token strings with proper byte-level BPE handling. + + Groups consecutive non-special tokens and decodes each group via + ``tokenizer.tokens_to_text()`` (HF ``convert_tokens_to_string``), which + properly reverses byte-level BPE encoding (e.g. ``âĢĻ`` -> ``'``). + Special tokens are never passed to ``convert_tokens_to_string`` — they + are either inserted as literal strings or dropped entirely. + + Args: + token_strings: Raw token strings from ``tokenizer.ids_to_tokens()``. + tokenizer: Tokenizer with ``tokens_to_text``, ``bos_token``, and + ``eos_token`` attributes (NeMo ``AutoTokenizer`` or similar). + pad_token_str: String representation of the pad token. + keep_pad: If True, preserve all special tokens as literal strings + in the output. If False, strip them. + """ + # Build special-token set from explicit bos/eos/pad — same approach as + # filter_special_tokens() and model_factory._extract_special_token_ids_from_nemo(). + special_tokens = {pad_token_str} + bos = getattr(tokenizer, 'bos_token', None) + eos = getattr(tokenizer, 'eos_token', None) + if bos: + special_tokens.add(bos) + if eos: + special_tokens.add(eos) + + result_parts: list[str] = [] + segment: list[str] = [] + + for tok in token_strings: + if tok in special_tokens: + if segment: + result_parts.append(tokenizer.tokens_to_text(segment)) + segment = [] + if keep_pad: + result_parts.append(tok) + else: + segment.append(tok) + + if segment: + result_parts.append(tokenizer.tokens_to_text(segment)) + + return ''.join(result_parts) + + def tokens_to_str( tokens: torch.Tensor, lengths: torch.Tensor, @@ -23,9 +74,10 @@ def tokens_to_str( pad_id: int, eval_text_turn_taking: bool = False, show_eot_timestamps: bool = False, + keep_pad: bool = False, ) -> list[str]: """ - Convert token IDs to text strings, filtering out special tokens. + Convert token IDs to text strings with proper byte-level BPE decoding. Args: tokens: Token IDs tensor (B, T) @@ -34,14 +86,18 @@ def tokens_to_str( pad_id: Pad token ID to filter out eval_text_turn_taking: If True, insert timestamps at bos/eos positions show_eot_timestamps: If True, also insert timestamps at end-of-text (first pad after BOS) + keep_pad: If True, preserve all special tokens (including pad) as literal + strings in the output. Useful for "raw" output that shows the full + token stream. If False (default), special tokens are stripped. Returns: List of decoded text strings """ + pad_token_str = tokenizer.ids_to_tokens([pad_id])[0] ans = [] - # Helper function to filter special tokens from token IDs - # This filtering is applied regardless of eval_text_turn_taking mode + # Helper function to filter special tokens from token IDs. + # This filtering is applied regardless of eval_text_turn_taking mode. def filter_special_tokens(token_ids): # Filter out pad token_ids = token_ids[token_ids != pad_id] @@ -102,8 +158,9 @@ def filter_special_tokens(token_ids): out_str.append(tokenizer.ids_to_text(remaining_ids)) ans.append(" ".join(out_str)) else: - # For non-turn-taking mode: filter out ALL special tokens, return only pure text hyp_ids = hyp_ids[:hyp_len] - hyp_ids = filter_special_tokens(hyp_ids) - ans.append(tokenizer.ids_to_text(hyp_ids)) + toks = tokenizer.ids_to_tokens(hyp_ids.tolist()) + ans.append( + _decode_tokens_with_specials(toks, tokenizer, pad_token_str=pad_token_str, keep_pad=keep_pad) + ) return ans From 819e5f45e2af74c2ad36e75aef767dc1a5130d3b Mon Sep 17 00:00:00 2001 From: Elena Rastorgueva Date: Tue, 31 Mar 2026 01:48:42 +0000 Subject: [PATCH 21/40] use whisper normalizer for wer calculation Signed-off-by: Elena Rastorgueva --- .../speechlm2/inference/utils/pipeline_utils.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/nemo/collections/speechlm2/inference/utils/pipeline_utils.py b/nemo/collections/speechlm2/inference/utils/pipeline_utils.py index 0886431dce20..7dfee3b6ed0f 100644 --- a/nemo/collections/speechlm2/inference/utils/pipeline_utils.py +++ b/nemo/collections/speechlm2/inference/utils/pipeline_utils.py @@ -16,25 +16,31 @@ from typing import List, Optional import torch +from whisper_normalizer.english import EnglishTextNormalizer from nemo.collections.asr.inference.utils.text_segment import Word +_whisper_normalizer = EnglishTextNormalizer() + def clean_pred_text(text: str) -> str: - """Clean prediction text by removing special markers, timestamps, punctuation, and lowercasing. + """Clean prediction text for fair WER comparison. - Useful for fair WER comparison between predicted and ground-truth text. + First strips model-specific tokens (turn markers, timestamps, pad tokens) + that the Whisper normalizer doesn't know about, then applies + ``EnglishTextNormalizer`` — the same normalizer used by the offline eval + metrics in ``speechlm2.parts.metrics.wer``. """ if not text: return "" + # Strip model-specific tokens text = text.lstrip('^') text = re.sub(r'', '', text) text = re.sub(r'<\$[\d.]+\$>', '', text) text = re.sub(r'<\|[\d.]+\|>', '', text) text = re.sub(r'', '', text) - text = text.lower() - text = re.sub(r'[^\w\s]', '', text) - return ' '.join(text.split()) + # Normalize with Whisper's EnglishTextNormalizer (same as offline eval) + return _whisper_normalizer(text) class PipelineOutput: From e8e7151ef694aa0a22ccdc6a0c1f2a3230111e06 Mon Sep 17 00:00:00 2001 From: Elena Rastorgueva Date: Tue, 31 Mar 2026 20:46:25 +0000 Subject: [PATCH 22/40] remove unnecessary logging in perception cache step Signed-off-by: Elena Rastorgueva --- .../inference/model_wrappers/perception_cache.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/nemo/collections/speechlm2/inference/model_wrappers/perception_cache.py b/nemo/collections/speechlm2/inference/model_wrappers/perception_cache.py index ab7dbe494bf5..32fcc868e93a 100644 --- a/nemo/collections/speechlm2/inference/model_wrappers/perception_cache.py +++ b/nemo/collections/speechlm2/inference/model_wrappers/perception_cache.py @@ -368,12 +368,10 @@ def step( streaming_cfg = self.streaming_cfg audio_len = torch.tensor([audio_input.shape[1]], dtype=torch.long, device=self.device) - _t_start_preprocessor = time.time() processed_signal, _ = self.preprocessor( input_signal=audio_input, length=audio_len, ) - logging.info(f"preprocessor time: {time.time() - _t_start_preprocessor:.3f}s") if isinstance(streaming_cfg.chunk_size, list): chunk_size_first = streaming_cfg.chunk_size[0] @@ -408,13 +406,9 @@ def step( ) num_sub_steps = num_frames_per_chunk // base_step_size - start_time = time.time() - encoded_chunks = [] for sub_step in range(num_sub_steps): - sub_step_start_time = time.time() - sub_frame_idx = frame_idx + (sub_step * base_step_size) is_first_sub_step = (sub_frame_idx == 0) @@ -516,8 +510,6 @@ def step( encoded_chunk = perception.proj(encoded_adapted.transpose(1, 2)) - torch.cuda.synchronize() - logging.info(f" Sub-step {sub_step}/{num_sub_steps} (sub_frame_idx={sub_frame_idx}, first={is_first_sub_step}): {time.time() - sub_step_start_time:.4f}s") encoded_chunks.append(encoded_chunk) if len(encoded_chunks) > 1: @@ -525,9 +517,6 @@ def step( else: encoded_chunk = encoded_chunks[0] - torch.cuda.synchronize() - logging.info(f"Time taken for encoder ({num_sub_steps} sub-steps): {time.time() - start_time}") - new_perception_cache = PerceptionCacheState( cache_last_channel=cache_last_channel, cache_last_time=cache_last_time, From eebea30279b97f23de0f39606a61de045110121b Mon Sep 17 00:00:00 2001 From: Elena Rastorgueva Date: Tue, 31 Mar 2026 22:51:11 +0000 Subject: [PATCH 23/40] vectorize rep penalty; fix sampling - nan/inf check before top-p filtering Signed-off-by: Elena Rastorgueva --- .../inference/model_wrappers/model_factory.py | 70 +++++++++++-------- .../nemotron_voicechat_inference_wrapper.py | 6 ++ 2 files changed, 48 insertions(+), 28 deletions(-) diff --git a/nemo/collections/speechlm2/inference/model_wrappers/model_factory.py b/nemo/collections/speechlm2/inference/model_wrappers/model_factory.py index 9f8616fdcbac..e7fc7a26fe25 100644 --- a/nemo/collections/speechlm2/inference/model_wrappers/model_factory.py +++ b/nemo/collections/speechlm2/inference/model_wrappers/model_factory.py @@ -80,6 +80,13 @@ def __init__( self.repetition_penalty = repetition_penalty self.temperature = temperature + # Pre-built tensor for special-token filtering in repetition penalty. + # Lazily moved to the right device on first use (see _sample_text_token). + self._special_ids_tensor: Optional[torch.Tensor] = ( + torch.tensor(sorted(self.special_token_ids), dtype=torch.long) + if self.special_token_ids else None + ) + def _sample_text_token( self, logits: torch.Tensor, @@ -115,6 +122,10 @@ def _sample_text_token( # For each batch, if greedy is special token, use greedy; otherwise sample sampled_tokens = greedy_tokens.clone() + # Ensure cached special-token tensor is on the right device (once). + if self._special_ids_tensor is not None and self._special_ids_tensor.device != device: + self._special_ids_tensor = self._special_ids_tensor.to(device) + for b in range(B): # If greedy token is a special token, keep it (no sampling) if greedy_tokens[b].item() in self.special_token_ids: @@ -123,30 +134,44 @@ def _sample_text_token( # Not a special token - apply repetition penalty and sampling batch_logits = logits[b].clone() # (V,) - # Apply repetition penalty + # Apply repetition penalty (vectorized, no Python loop) if self.repetition_penalty != 1.0 and current_step > 0: - prev_tokens = generated_tokens[b, :current_step] - unique_prev = prev_tokens.unique() + unique_prev = generated_tokens[b, :current_step].unique() # Exclude special tokens from penalty - if self.special_token_ids: - # Use unique_prev.device to ensure tensors are on the same device - # (generated_tokens may be on a different device than logits, e.g., vLLM returns CPU logits) - special_tensor = torch.tensor(list(self.special_token_ids), device=unique_prev.device) - mask = ~torch.isin(unique_prev, special_tensor) - unique_prev = unique_prev[mask] - - for token_id in unique_prev: - token_id = token_id.item() - if batch_logits[token_id] > 0: - batch_logits[token_id] = batch_logits[token_id] / self.repetition_penalty - else: - batch_logits[token_id] = batch_logits[token_id] * self.repetition_penalty + if self._special_ids_tensor is not None: + ids_t = self._special_ids_tensor + if ids_t.device != unique_prev.device: + ids_t = ids_t.to(unique_prev.device) + unique_prev = unique_prev[~torch.isin(unique_prev, ids_t)] + + if unique_prev.numel() > 0: + prev_logits = batch_logits[unique_prev] + # Positive logits are divided, negative logits are multiplied + # (same as the standard repetition_penalty convention) + batch_logits[unique_prev] = torch.where( + prev_logits > 0, + prev_logits / self.repetition_penalty, + prev_logits * self.repetition_penalty, + ) # Apply temperature scaling if self.temperature != 1.0: batch_logits = batch_logits / self.temperature - # Apply top-p sampling + # Fall back to greedy if logits are non-finite before top-p + # (top-p intentionally introduces -inf, so check must happen first) + if not torch.isfinite(batch_logits).all(): + logging.warning( + f"_sample_text_token: logits contain NaN or inf at step {current_step}, batch {b}: " + f"nan={batch_logits.isnan().sum().item()}, " + f"inf={batch_logits.isinf().sum().item()}, " + f"min={batch_logits[~batch_logits.isnan()].min().item() if not batch_logits.isnan().all() else 'all_nan'}, " + f"max={batch_logits[~batch_logits.isnan()].max().item() if not batch_logits.isnan().all() else 'all_nan'}" + ) + sampled_tokens[b] = greedy_tokens[b] + continue + + # Apply top-p (nucleus) sampling if self.top_p < 1.0: sorted_logits, sorted_indices = torch.sort(batch_logits, descending=True) sorted_probs = torch.softmax(sorted_logits, dim=-1) @@ -162,17 +187,6 @@ def _sample_text_token( indices_to_remove = sorted_indices[sorted_indices_to_remove] batch_logits[indices_to_remove] = float('-inf') - # Fall back to greedy if logits contain NaN or inf - if batch_logits.isnan().any() or batch_logits.isinf().any(): - logging.warning( - f"_sample_text_token: logits contain NaN or inf at step {current_step}, batch {b}: " - f"nan={batch_logits.isnan().sum().item()}, " - f"inf={batch_logits.isinf().sum().item()}, " - f"min={batch_logits[~batch_logits.isnan()].min().item() if not batch_logits.isnan().all() else 'all_nan'}, " - f"max={batch_logits[~batch_logits.isnan()].max().item() if not batch_logits.isnan().all() else 'all_nan'}" - ) - sampled_tokens[b] = greedy_tokens[b] - continue probs = torch.softmax(batch_logits, dim=-1) sampled_tokens[b] = torch.multinomial(probs, num_samples=1).item() diff --git a/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py b/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py index 3926d97c0ea3..6e9ab711a594 100755 --- a/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py +++ b/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py @@ -78,6 +78,12 @@ def __init__(self, model_cfg: DictConfig): matmul_precision = str(model_cfg.get("matmul_precision", "medium")) torch.set_float32_matmul_precision(matmul_precision) + # Deterministic mode: guarantees identical *text* outputs (from the STT/LLM + # heads) across runs for the same inputs, even when sampling is enabled + # (top_p < 1, temperature != 1, repetition_penalty != 1). This works + # because we fix the global PyTorch RNG seeds and force all CUDA ops + # to use deterministic algorithm implementations. + # Not compatible with vLLM engines (raises an error below). self._deterministic = bool(model_cfg.get("deterministic", False)) if self._deterministic: engine_type = model_cfg.get("engine_type", "native") From de230d9068a85feaf2afae4218a031a3a1e57685 Mon Sep 17 00:00:00 2001 From: Elena Rastorgueva Date: Tue, 31 Mar 2026 23:42:32 +0000 Subject: [PATCH 24/40] Preserve BOS/EOS as literal strings in decoded text output Signed-off-by: Elena Rastorgueva --- nemo/collections/speechlm2/parts/text_utils.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/nemo/collections/speechlm2/parts/text_utils.py b/nemo/collections/speechlm2/parts/text_utils.py index 8b983fab4e5f..b5805ea84627 100644 --- a/nemo/collections/speechlm2/parts/text_utils.py +++ b/nemo/collections/speechlm2/parts/text_utils.py @@ -27,8 +27,10 @@ def _decode_tokens_with_specials( Groups consecutive non-special tokens and decodes each group via ``tokenizer.tokens_to_text()`` (HF ``convert_tokens_to_string``), which properly reverses byte-level BPE encoding (e.g. ``âĢĻ`` -> ``'``). - Special tokens are never passed to ``convert_tokens_to_string`` — they - are either inserted as literal strings or dropped entirely. + Special tokens (BOS, EOS, PAD) are never passed to + ``convert_tokens_to_string``. BOS/EOS are always kept as literal + strings so that turn boundaries are visible. PAD tokens are kept + only when *keep_pad* is True. Args: token_strings: Raw token strings from ``tokenizer.ids_to_tokens()``. @@ -38,11 +40,11 @@ def _decode_tokens_with_specials( keep_pad: If True, preserve all special tokens as literal strings in the output. If False, strip them. """ - # Build special-token set from explicit bos/eos/pad — same approach as - # filter_special_tokens() and model_factory._extract_special_token_ids_from_nemo(). - special_tokens = {pad_token_str} bos = getattr(tokenizer, 'bos_token', None) eos = getattr(tokenizer, 'eos_token', None) + + # All tokens that must not go through convert_tokens_to_string. + special_tokens = {pad_token_str} if bos: special_tokens.add(bos) if eos: @@ -56,7 +58,10 @@ def _decode_tokens_with_specials( if segment: result_parts.append(tokenizer.tokens_to_text(segment)) segment = [] - if keep_pad: + if tok == pad_token_str: + if keep_pad: + result_parts.append(tok) + else: result_parts.append(tok) else: segment.append(tok) From 8b849c1481e9b872d860280411dbb053fdbd79d6 Mon Sep 17 00:00:00 2001 From: Elena Rastorgueva Date: Wed, 1 Apr 2026 02:26:57 +0000 Subject: [PATCH 25/40] update triton code; bugfix for vllm dtype/device Signed-off-by: Elena Rastorgueva --- .../triton/client_streaming.py | 3 +- .../voicechat/1/infer_streaming.py | 48 +++++++++++-------- .../triton/start_triton.sh | 45 +++++++++++------ .../inference/model_wrappers/model_factory.py | 2 + .../nemotron_voicechat_inference_wrapper.py | 3 +- 5 files changed, 65 insertions(+), 36 deletions(-) diff --git a/examples/speechlm2/nemo_inference_pipelines/triton/client_streaming.py b/examples/speechlm2/nemo_inference_pipelines/triton/client_streaming.py index 78ca02373a77..d6dbd1ba878c 100644 --- a/examples/speechlm2/nemo_inference_pipelines/triton/client_streaming.py +++ b/examples/speechlm2/nemo_inference_pipelines/triton/client_streaming.py @@ -105,6 +105,7 @@ def send_sequence_end(client, sequence_id): outputs = [ grpcclient.InferRequestedOutput("output_text"), + grpcclient.InferRequestedOutput("output_asr_text"), grpcclient.InferRequestedOutput("output_audio"), ] @@ -115,7 +116,7 @@ def send_sequence_end(client, sequence_id): outputs=outputs, sequence_id=sequence_id, sequence_start=False, - sequence_end=True, # This is the key - properly end the sequence + sequence_end=True, ) logger.info("Sequence ended successfully") diff --git a/examples/speechlm2/nemo_inference_pipelines/triton/model_repo_s2s/voicechat/1/infer_streaming.py b/examples/speechlm2/nemo_inference_pipelines/triton/model_repo_s2s/voicechat/1/infer_streaming.py index 7f9ee3aba227..edca2a0d3b43 100644 --- a/examples/speechlm2/nemo_inference_pipelines/triton/model_repo_s2s/voicechat/1/infer_streaming.py +++ b/examples/speechlm2/nemo_inference_pipelines/triton/model_repo_s2s/voicechat/1/infer_streaming.py @@ -43,35 +43,45 @@ def _resolve_env_overrides(self, cfg): env vars, while sharing the same s2s_streaming.yaml used by the CLI. Env var mapping (cfg key -> env var, default): - s2s.model_path -> S2S_MODEL_PATH (required) - s2s.llm_checkpoint_path -> S2S_LLM_CHECKPOINT_PATH (required) - s2s.speaker_reference -> S2S_SPEAKER_REFERENCE (required) - s2s.engine_type -> S2S_ENGINE_TYPE (default: native) - s2s.system_prompt -> S2S_SYSTEM_PROMPT (default: none) - s2s.tts_system_prompt -> S2S_TTS_SYSTEM_PROMPT (default: none) + s2s.model_path -> S2S_MODEL_PATH (required) + s2s.speaker_reference -> S2S_SPEAKER_REFERENCE (optional) + s2s.speaker_name -> S2S_SPEAKER_NAME (optional) + s2s.engine_type -> S2S_ENGINE_TYPE (default: native) + s2s.deterministic -> S2S_DETERMINISTIC (default: false) + s2s.use_llm_cache -> S2S_USE_LLM_CACHE (default: true) + s2s.use_tts_subword_cache -> S2S_USE_TTS_SUBWORD_CACHE (default: false) + s2s.system_prompt -> S2S_SYSTEM_PROMPT (optional) + s2s.tts_system_prompt -> S2S_TTS_SYSTEM_PROMPT (optional) streaming.chunk_size_in_secs -> S2S_CHUNK_SIZE_IN_SECS (default: 0.08) streaming.buffer_size_in_secs -> S2S_BUFFER_SIZE_IN_SECS (default: 5.6) """ env_overrides = { # Required - "s2s.model_path": ("S2S_MODEL_PATH", None), - "s2s.llm_checkpoint_path": ("S2S_LLM_CHECKPOINT_PATH", None), - "s2s.speaker_reference": ("S2S_SPEAKER_REFERENCE", None), - # Optional (with defaults) - "s2s.engine_type": ("S2S_ENGINE_TYPE", "native"), - "s2s.system_prompt": ("S2S_SYSTEM_PROMPT", None), - "s2s.tts_system_prompt": ("S2S_TTS_SYSTEM_PROMPT", None), + "s2s.model_path": ("S2S_MODEL_PATH", None), + # Speaker identity (set one or both) + "s2s.speaker_reference": ("S2S_SPEAKER_REFERENCE", None), + "s2s.speaker_name": ("S2S_SPEAKER_NAME", None), + # Engine & precision + "s2s.engine_type": ("S2S_ENGINE_TYPE", "native"), + "s2s.deterministic": ("S2S_DETERMINISTIC", False), + # Cache / speedup flags + "s2s.use_llm_cache": ("S2S_USE_LLM_CACHE", True), + "s2s.use_tts_subword_cache": ("S2S_USE_TTS_SUBWORD_CACHE", False), + # Prompts + "s2s.system_prompt": ("S2S_SYSTEM_PROMPT", None), + "s2s.tts_system_prompt": ("S2S_TTS_SYSTEM_PROMPT", None), + # Streaming "streaming.chunk_size_in_secs": ("S2S_CHUNK_SIZE_IN_SECS", 0.08), - "streaming.buffer_size_in_secs": ("S2S_BUFFER_SIZE_IN_SECS", 5.6), + "streaming.buffer_size_in_secs":("S2S_BUFFER_SIZE_IN_SECS", 5.6), } for cfg_key, (env_var, default) in env_overrides.items(): - val = os.environ.get(env_var) - if val is not None: - if default is not None and isinstance(default, bool): + val = os.environ.get(env_var, "") + if val: + if isinstance(default, bool): val = val.lower() in ("true", "1", "yes") - elif default is not None and isinstance(default, float): + elif isinstance(default, float): val = float(val) - elif default is not None and isinstance(default, int): + elif isinstance(default, int): val = int(val) OmegaConf.update(cfg, cfg_key, val, force_add=True) elif default is not None: diff --git a/examples/speechlm2/nemo_inference_pipelines/triton/start_triton.sh b/examples/speechlm2/nemo_inference_pipelines/triton/start_triton.sh index 8f78f1e47b1f..d42b035fde48 100755 --- a/examples/speechlm2/nemo_inference_pipelines/triton/start_triton.sh +++ b/examples/speechlm2/nemo_inference_pipelines/triton/start_triton.sh @@ -19,23 +19,26 @@ # Fields marked ??? in the YAML are resolved from environment variables below. # # Usage: -# S2S_MODEL_PATH=/path/to/eartts_ckpt \ -# S2S_LLM_CHECKPOINT_PATH=/path/to/llm_ckpt \ -# S2S_SPEAKER_REFERENCE=/path/to/speaker.wav \ +# S2S_MODEL_PATH=/path/to/hf_checkpoint \ +# S2S_SPEAKER_NAME=MySpeaker \ # ./start_triton.sh # # Environment variables (required): -# S2S_MODEL_PATH - Path to the EarTTS / S2S checkpoint -# S2S_LLM_CHECKPOINT_PATH - Path to the LLM checkpoint +# S2S_MODEL_PATH - Path to the HF-format checkpoint directory +# +# Environment variables (speaker identity — set at least one): # S2S_SPEAKER_REFERENCE - Path to a speaker reference .wav file +# S2S_SPEAKER_NAME - Registered speaker name from the checkpoint # # Environment variables (optional): # S2S_ENGINE_TYPE - Engine type (default: native) +# S2S_DETERMINISTIC - "true"/"false": deterministic mode (default: false) +# S2S_USE_LLM_CACHE - "true"/"false": LLM KV cache (default: true) +# S2S_USE_TTS_SUBWORD_CACHE - "true"/"false": TTS subword cache (default: false) # S2S_SYSTEM_PROMPT - LLM system prompt text (default: none) -# S2S_TTS_SYSTEM_PROMPT - TTS system prompt, (default: none) +# S2S_TTS_SYSTEM_PROMPT - TTS system prompt (default: none) # S2S_CHUNK_SIZE_IN_SECS - Chunk size in seconds, multiple of 0.08 (default: 0.08) # S2S_BUFFER_SIZE_IN_SECS - Audio buffer size in seconds (default: 5.6) -# S2S_USE_CODEC_CACHE - "true"/"false": incremental codec decode (default: true) # S2S_TRITON_CONFIG_PATH - Override the YAML config file path # MODEL_REPO_DIR - Override the Triton model repository path @@ -45,33 +48,45 @@ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" # backend (infer_streaming.py reads them via os.environ). # ======================== -# Model paths (required) +# Model path (required) +# ======================== +export S2S_MODEL_PATH="${S2S_MODEL_PATH:?Please set S2S_MODEL_PATH to the HF-format checkpoint directory}" + +# ======================== +# Speaker identity (set at least one) # ======================== -export S2S_MODEL_PATH="${S2S_MODEL_PATH:?Please set S2S_MODEL_PATH to the EarTTS / S2S checkpoint path}" -export S2S_LLM_CHECKPOINT_PATH="${S2S_LLM_CHECKPOINT_PATH:?Please set S2S_LLM_CHECKPOINT_PATH to the LLM checkpoint path}" -export S2S_SPEAKER_REFERENCE="${S2S_SPEAKER_REFERENCE:?Please set S2S_SPEAKER_REFERENCE to a speaker reference .wav file}" +export S2S_SPEAKER_REFERENCE="${S2S_SPEAKER_REFERENCE:-}" +export S2S_SPEAKER_NAME="${S2S_SPEAKER_NAME:-}" +if [ -z "${S2S_SPEAKER_REFERENCE}" ] && [ -z "${S2S_SPEAKER_NAME}" ]; then + echo "ERROR: Set at least one of S2S_SPEAKER_REFERENCE or S2S_SPEAKER_NAME" + exit 1 +fi # ======================== # Optional overrides # ======================== export S2S_ENGINE_TYPE="${S2S_ENGINE_TYPE:-native}" +export S2S_DETERMINISTIC="${S2S_DETERMINISTIC:-}" +export S2S_USE_LLM_CACHE="${S2S_USE_LLM_CACHE:-}" +export S2S_USE_TTS_SUBWORD_CACHE="${S2S_USE_TTS_SUBWORD_CACHE:-}" export S2S_SYSTEM_PROMPT="${S2S_SYSTEM_PROMPT:-}" export S2S_TTS_SYSTEM_PROMPT="${S2S_TTS_SYSTEM_PROMPT:-}" export S2S_CHUNK_SIZE_IN_SECS="${S2S_CHUNK_SIZE_IN_SECS:-0.08}" export S2S_BUFFER_SIZE_IN_SECS="${S2S_BUFFER_SIZE_IN_SECS:-5.6}" -export S2S_USE_CODEC_CACHE="${S2S_USE_CODEC_CACHE:-true}" export S2S_TRITON_CONFIG_PATH="${S2S_TRITON_CONFIG_PATH:-${SCRIPT_DIR}/../conf/s2s_streaming.yaml}" export MODEL_REPO_DIR="${MODEL_REPO_DIR:-${SCRIPT_DIR}/model_repo_s2s}" echo "=== S2S Triton Server ===" echo " S2S_MODEL_PATH: ${S2S_MODEL_PATH}" -echo " S2S_LLM_CHECKPOINT_PATH: ${S2S_LLM_CHECKPOINT_PATH}" -echo " S2S_SPEAKER_REFERENCE: ${S2S_SPEAKER_REFERENCE}" +echo " S2S_SPEAKER_REFERENCE: ${S2S_SPEAKER_REFERENCE:-}" +echo " S2S_SPEAKER_NAME: ${S2S_SPEAKER_NAME:-}" echo " S2S_ENGINE_TYPE: ${S2S_ENGINE_TYPE}" +echo " S2S_DETERMINISTIC: ${S2S_DETERMINISTIC:-}" +echo " S2S_USE_LLM_CACHE: ${S2S_USE_LLM_CACHE:-}" +echo " S2S_USE_TTS_SUBWORD_CACHE: ${S2S_USE_TTS_SUBWORD_CACHE:-}" echo " S2S_CHUNK_SIZE_IN_SECS: ${S2S_CHUNK_SIZE_IN_SECS}" echo " S2S_BUFFER_SIZE_IN_SECS: ${S2S_BUFFER_SIZE_IN_SECS}" -echo " S2S_USE_CODEC_CACHE: ${S2S_USE_CODEC_CACHE}" echo " S2S_SYSTEM_PROMPT: ${S2S_SYSTEM_PROMPT:-}" echo " S2S_TTS_SYSTEM_PROMPT: ${S2S_TTS_SYSTEM_PROMPT:-}" echo " MODEL_REPO_DIR: ${MODEL_REPO_DIR}" diff --git a/nemo/collections/speechlm2/inference/model_wrappers/model_factory.py b/nemo/collections/speechlm2/inference/model_wrappers/model_factory.py index e7fc7a26fe25..4d808d3c16fe 100644 --- a/nemo/collections/speechlm2/inference/model_wrappers/model_factory.py +++ b/nemo/collections/speechlm2/inference/model_wrappers/model_factory.py @@ -145,6 +145,8 @@ def _sample_text_token( unique_prev = unique_prev[~torch.isin(unique_prev, ids_t)] if unique_prev.numel() > 0: + if unique_prev.device != batch_logits.device: + unique_prev = unique_prev.to(batch_logits.device) prev_logits = batch_logits[unique_prev] # Positive logits are divided, negative logits are multiplied # (same as the standard repetition_penalty convention) diff --git a/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py b/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py index 6e9ab711a594..1cec405f283a 100755 --- a/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py +++ b/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py @@ -241,7 +241,8 @@ def _initialize_model(self): # Convert some S2S components to the configured dtype logging.info(f"Converting some S2S components to {self.dtype} (keeping perception & TTS in float32)...") - self.model.stt_model.llm = self.model.stt_model.llm.to(self.dtype) + if self.model.stt_model.llm is not None: + self.model.stt_model.llm = self.model.stt_model.llm.to(self.dtype) self.model.stt_model.lm_head = self.model.stt_model.lm_head.to(self.dtype) self.model.stt_model.embed_tokens = self.model.stt_model.embed_tokens.to(self.dtype) self.model.stt_model.asr_head = self.model.stt_model.asr_head.to(self.dtype) From d3db7008b1f8a29700bd9fbf5c8c05eee9a2adb9 Mon Sep 17 00:00:00 2001 From: Elena Rastorgueva Date: Wed, 1 Apr 2026 04:55:03 +0000 Subject: [PATCH 26/40] Always send prefill before audio streaming; fix bfloat16 audio output Signed-off-by: Elena Rastorgueva --- .../triton/client_streaming.py | 61 ++++++++++--------- .../voicechat/1/infer_streaming.py | 10 +-- 2 files changed, 37 insertions(+), 34 deletions(-) diff --git a/examples/speechlm2/nemo_inference_pipelines/triton/client_streaming.py b/examples/speechlm2/nemo_inference_pipelines/triton/client_streaming.py index d6dbd1ba878c..6f95ea16c85d 100644 --- a/examples/speechlm2/nemo_inference_pipelines/triton/client_streaming.py +++ b/examples/speechlm2/nemo_inference_pipelines/triton/client_streaming.py @@ -134,42 +134,43 @@ def send_sequence_end(client, sequence_id): sequence_id = random.randint(1, 2**63 - 1) # Generate random uint64 value try: - # If a system prompt is provided, send a separate prefill request first: - # zero-length audio + system_prompt, with sequence_start=True. - prefill_sent = False - if args.system_prompt is not None: - logger.info(f"Sending prefill request with system_prompt ({len(args.system_prompt)} chars)") - empty_audio = np.zeros((1, 0), dtype=np.float32) - prefill_inputs = [ - grpcclient.InferInput( - "audio_signal", empty_audio.shape, np_to_triton_dtype(empty_audio.dtype) - ), - ] - prefill_inputs[0].set_data_from_numpy(empty_audio) + # Always send a prefill request first (zero-length audio, sequence_start=True). + # This initializes the TTS speaker embedding and system prompt for the session. + # If --system_prompt is provided, it is included; otherwise the server uses + # its configured default. + logger.info("Sending prefill request%s", + f" with system_prompt ({len(args.system_prompt)} chars)" if args.system_prompt else "") + empty_audio = np.zeros((1, 0), dtype=np.float32) + prefill_inputs = [ + grpcclient.InferInput( + "audio_signal", empty_audio.shape, np_to_triton_dtype(empty_audio.dtype) + ), + ] + prefill_inputs[0].set_data_from_numpy(empty_audio) + if args.system_prompt is not None: prompt_np = np.array([args.system_prompt.encode("utf-8")], dtype=object) prompt_input = grpcclient.InferInput("system_prompt", prompt_np.shape, "BYTES") prompt_input.set_data_from_numpy(prompt_np) prefill_inputs.append(prompt_input) - prefill_outputs = [ - grpcclient.InferRequestedOutput("output_text"), - grpcclient.InferRequestedOutput("output_asr_text"), - grpcclient.InferRequestedOutput("output_audio"), - ] + prefill_outputs = [ + grpcclient.InferRequestedOutput("output_text"), + grpcclient.InferRequestedOutput("output_asr_text"), + grpcclient.InferRequestedOutput("output_audio"), + ] - prefill_start = time.time() - client.infer( - model_name, - prefill_inputs, - request_id=str(uuid.uuid4()), - outputs=prefill_outputs, - sequence_id=sequence_id, - sequence_start=True, - sequence_end=False, - ) - logger.info(f"Prefill completed in {time.time() - prefill_start:.3f}s") - prefill_sent = True + prefill_start = time.time() + client.infer( + model_name, + prefill_inputs, + request_id=str(uuid.uuid4()), + outputs=prefill_outputs, + sequence_id=sequence_id, + sequence_start=True, + sequence_end=False, + ) + logger.info(f"Prefill completed in {time.time() - prefill_start:.3f}s") for idx, audio_chunk in tqdm(enumerate(audio_signal_chunks)): inputs = [ @@ -193,7 +194,7 @@ def send_sequence_end(client, sequence_id): request_id=str(uuid.uuid4()), outputs=outputs, sequence_id=sequence_id, - sequence_start=(idx == 0 and not prefill_sent), + sequence_start=False, sequence_end=idx == len(audio_signal_chunks) - 1, ) end_time = time.time() diff --git a/examples/speechlm2/nemo_inference_pipelines/triton/model_repo_s2s/voicechat/1/infer_streaming.py b/examples/speechlm2/nemo_inference_pipelines/triton/model_repo_s2s/voicechat/1/infer_streaming.py index edca2a0d3b43..b0beb18207d4 100644 --- a/examples/speechlm2/nemo_inference_pipelines/triton/model_repo_s2s/voicechat/1/infer_streaming.py +++ b/examples/speechlm2/nemo_inference_pipelines/triton/model_repo_s2s/voicechat/1/infer_streaming.py @@ -307,9 +307,11 @@ def get_generations(self, frames: List[Frame]) -> List[Tuple]: def execute(self, requests: Iterable) -> List[pb_utils.InferenceResponse]: """Execute the model and return the responses. - Zero-length audio with ``sequence_start=True`` and a ``system_prompt`` - is treated as a prefill-only request by the pipeline (no fake audio - needed). All other requests are normal audio generation. + Clients MUST send a prefill request (zero-length audio with + ``sequence_start=True``) before streaming audio. The prefill + initializes the TTS speaker embedding and system prompt for the + session. Sending audio on the first request without a prefill + will produce degraded speaker voice quality. Returns: - output_audio: float32 array of generated audio samples @@ -329,7 +331,7 @@ def execute(self, requests: Iterable) -> List[pb_utils.InferenceResponse]: responses = [] for audio, text, asr_text in generations: if isinstance(audio, torch.Tensor): - audio_np = audio.detach().cpu().numpy().astype(np.float32) + audio_np = audio.detach().cpu().float().numpy() if audio_np.ndim == 1: audio_np = audio_np.reshape(1, -1) else: From 81a752e80a1ba63aa45fc0d68b8d9a284df42279 Mon Sep 17 00:00:00 2001 From: Elena Rastorgueva Date: Wed, 1 Apr 2026 05:38:50 +0000 Subject: [PATCH 27/40] remove triton code to keep PR simple Signed-off-by: Elena Rastorgueva --- .../triton/client_streaming.py | 249 ------------ .../voicechat/1/infer_streaming.py | 354 ------------------ .../model_repo_s2s/voicechat/config.pbtxt | 90 ----- .../triton/start_triton.sh | 103 ----- 4 files changed, 796 deletions(-) delete mode 100644 examples/speechlm2/nemo_inference_pipelines/triton/client_streaming.py delete mode 100644 examples/speechlm2/nemo_inference_pipelines/triton/model_repo_s2s/voicechat/1/infer_streaming.py delete mode 100644 examples/speechlm2/nemo_inference_pipelines/triton/model_repo_s2s/voicechat/config.pbtxt delete mode 100755 examples/speechlm2/nemo_inference_pipelines/triton/start_triton.sh diff --git a/examples/speechlm2/nemo_inference_pipelines/triton/client_streaming.py b/examples/speechlm2/nemo_inference_pipelines/triton/client_streaming.py deleted file mode 100644 index 6f95ea16c85d..000000000000 --- a/examples/speechlm2/nemo_inference_pipelines/triton/client_streaming.py +++ /dev/null @@ -1,249 +0,0 @@ -# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Streaming Triton client for the S2S voicechat model. - -Usage: - python client_streaming.py \ - --host localhost --port 8001 \ - --audio_filename /path/to/input.wav \ - --dur_test_audio 30 -""" - -import argparse -import uuid -import random -import sys - -import librosa -import numpy as np -import soundfile as sf -import time -import tritonclient.grpc as grpcclient -from tqdm import tqdm -from tritonclient.utils import * - -# Use Python's built-in logging so this script can run without NeMo installed -import logging -logging.basicConfig(stream=sys.stdout, level=logging.INFO) -logger = logging.getLogger(__name__) - -# Default values -DEFAULT_HOST = "localhost" -DEFAULT_PORT = 8001 -DEFAULT_NUM_FRAMES_PER_CHUNK = 1 -DEFAULT_DUR_TEST_AUDIO = 30 - -parser = argparse.ArgumentParser(description="Streaming client for voicechat model") -parser.add_argument("--host", type=str, default=DEFAULT_HOST, help=f"Triton server host (default: {DEFAULT_HOST})") -parser.add_argument("--port", type=int, default=DEFAULT_PORT, help=f"Triton server port (default: {DEFAULT_PORT})") -parser.add_argument("--num_frames_per_chunk", type=int, default=DEFAULT_NUM_FRAMES_PER_CHUNK, help=f"Number of 80ms frames per inference step (default: {DEFAULT_NUM_FRAMES_PER_CHUNK})") -parser.add_argument("--audio_filename", type=str, required=True, help="Path to input audio file") -parser.add_argument("--dur_test_audio", type=int, default=DEFAULT_DUR_TEST_AUDIO, help=f"Duration of test audio in seconds; audio will be padded or trimmed to this length (default: {DEFAULT_DUR_TEST_AUDIO})") -parser.add_argument("--output_dir", type=str, default=".", help="Directory to save output audio files (default: current directory)") -parser.add_argument("--system_prompt", type=str, default=None, help="System prompt to send to the model on the first request (overrides server default)") -args = parser.parse_args() - -model_name = "voicechat" -audio_file = args.audio_filename - -NUM_FRAMES_PER_CHUNK = args.num_frames_per_chunk -DUR_TEST_AUDIO = args.dur_test_audio -INPUT_CHUNK_SIZE_SAMPLES = int(16000 * 0.08) * NUM_FRAMES_PER_CHUNK # number of samples per input chunk -NUM_CHUNKS_TEST_AUDIO = int(DUR_TEST_AUDIO / (0.08 * NUM_FRAMES_PER_CHUNK)) -print(f"{NUM_CHUNKS_TEST_AUDIO=}") - -times_spend_on_inference = [] - - -def get_audio_as_chunks(audio_file): - audio_signal, sr = librosa.load(audio_file, sr=16000) - audio_signal = np.expand_dims(audio_signal, axis=0) - - padded_len_samples = int(NUM_CHUNKS_TEST_AUDIO * INPUT_CHUNK_SIZE_SAMPLES) - audio_signal_padded = np.zeros((1, padded_len_samples), dtype=np.float32) - - if padded_len_samples > audio_signal.shape[1]: # actually doing padding - audio_signal_padded[:, : audio_signal.shape[1]] = audio_signal - else: # actually need to trim (because audio is longer than maxlen) - audio_signal_padded = audio_signal[:, :padded_len_samples] - - audio_signal_chunks = [ - audio_signal_padded[:, i : i + INPUT_CHUNK_SIZE_SAMPLES] - for i in range(0, audio_signal_padded.shape[1], INPUT_CHUNK_SIZE_SAMPLES) - ] - - return audio_signal_chunks - - -def send_sequence_end(client, sequence_id): - """Send a final request with sequence_end=True to properly clean up the sequence""" - try: - logger.info(f"Sending sequence_end=True for sequence_id={sequence_id}") - - # Send empty audio chunk with sequence_end=True - empty_audio = np.zeros((1, INPUT_CHUNK_SIZE_SAMPLES), dtype=np.float32) - - inputs = [ - grpcclient.InferInput( - "audio_signal", empty_audio.shape, np_to_triton_dtype(empty_audio.dtype) - ), - ] - inputs[0].set_data_from_numpy(empty_audio) - - outputs = [ - grpcclient.InferRequestedOutput("output_text"), - grpcclient.InferRequestedOutput("output_asr_text"), - grpcclient.InferRequestedOutput("output_audio"), - ] - - response = client.infer( - model_name, - inputs, - request_id=str(uuid.uuid4()), - outputs=outputs, - sequence_id=sequence_id, - sequence_start=False, - sequence_end=True, - ) - logger.info("Sequence ended successfully") - - except Exception as e: - logger.error(f"Error ending sequence: {e}") - -with grpcclient.InferenceServerClient(f"{args.host}:{args.port}") as client: - audio_signal_chunks = get_audio_as_chunks(audio_file) - - generated_text = [] - generated_asr_text = [] - generated_audio = [] - - # Generate a numeric sequence ID instead of string UUID to match UINT64 type - sequence_id = random.randint(1, 2**63 - 1) # Generate random uint64 value - - try: - # Always send a prefill request first (zero-length audio, sequence_start=True). - # This initializes the TTS speaker embedding and system prompt for the session. - # If --system_prompt is provided, it is included; otherwise the server uses - # its configured default. - logger.info("Sending prefill request%s", - f" with system_prompt ({len(args.system_prompt)} chars)" if args.system_prompt else "") - empty_audio = np.zeros((1, 0), dtype=np.float32) - prefill_inputs = [ - grpcclient.InferInput( - "audio_signal", empty_audio.shape, np_to_triton_dtype(empty_audio.dtype) - ), - ] - prefill_inputs[0].set_data_from_numpy(empty_audio) - - if args.system_prompt is not None: - prompt_np = np.array([args.system_prompt.encode("utf-8")], dtype=object) - prompt_input = grpcclient.InferInput("system_prompt", prompt_np.shape, "BYTES") - prompt_input.set_data_from_numpy(prompt_np) - prefill_inputs.append(prompt_input) - - prefill_outputs = [ - grpcclient.InferRequestedOutput("output_text"), - grpcclient.InferRequestedOutput("output_asr_text"), - grpcclient.InferRequestedOutput("output_audio"), - ] - - prefill_start = time.time() - client.infer( - model_name, - prefill_inputs, - request_id=str(uuid.uuid4()), - outputs=prefill_outputs, - sequence_id=sequence_id, - sequence_start=True, - sequence_end=False, - ) - logger.info(f"Prefill completed in {time.time() - prefill_start:.3f}s") - - for idx, audio_chunk in tqdm(enumerate(audio_signal_chunks)): - inputs = [ - grpcclient.InferInput( - "audio_signal", audio_chunk.shape, np_to_triton_dtype(audio_chunk.dtype) - ), - ] - - inputs[0].set_data_from_numpy(audio_chunk) - - outputs = [ - grpcclient.InferRequestedOutput("output_text"), - grpcclient.InferRequestedOutput("output_asr_text"), - grpcclient.InferRequestedOutput("output_audio"), - ] - - start_time = time.time() - response = client.infer( - model_name, - inputs, - request_id=str(uuid.uuid4()), - outputs=outputs, - sequence_id=sequence_id, - sequence_start=False, - sequence_end=idx == len(audio_signal_chunks) - 1, - ) - end_time = time.time() - - result = response.get_response() - output_text = response.as_numpy("output_text") - output_asr_text = response.as_numpy("output_asr_text") - output_audio = response.as_numpy("output_audio") - - generated_text.extend([i.decode("utf-8") for i in output_text]) - generated_asr_text.extend([i.decode("utf-8") for i in output_asr_text]) - - if output_audio.shape[1] > 0: - times_spend_on_inference.append(end_time - start_time) - generated_audio.append(output_audio) - - except KeyboardInterrupt: - logger.info("\nKeyboardInterrupt received! Calling send_sequence_end...") - send_sequence_end(client, sequence_id) - logger.info("Sequence cleanup completed. Exiting...") - sys.exit(0) - - logger.info("Agent text: " + "".join([str(i) for i in generated_text])) - logger.info("ASR text (user's speech): " + "".join([str(i) for i in generated_asr_text])) - generated_audio = np.concatenate(generated_audio, axis=1) - - import os - os.makedirs(args.output_dir, exist_ok=True) - - output_audio_path = os.path.join(args.output_dir, "output_audio.wav") - sf.write(output_audio_path, generated_audio.squeeze(0), 22050) - logger.info(f"Agent audio written to {output_audio_path}") - - # Save audio file with both input and output in each channel - # Resample input to 22050 Hz, and pad shorter file to the same length as the longer one - input_audio, sr = librosa.load(audio_file, sr=22050) - generated_audio_1d = generated_audio.squeeze(0) # Convert from [1, T] to [T] - maxlen = max(input_audio.shape[0], generated_audio_1d.shape[0]) - input_audio = np.pad(input_audio, (0, maxlen - input_audio.shape[0]), mode="constant") - generated_audio_1d = np.pad(generated_audio_1d, (0, maxlen - generated_audio_1d.shape[0]), mode="constant") - both_audio = np.column_stack([input_audio, generated_audio_1d]) # Create stereo: [T, 2] - combined_path = os.path.join(args.output_dir, "input_and_output_combined.wav") - sf.write(combined_path, both_audio, 22050) - logger.info(f"Input and output combined audio written to {combined_path}") - - logger.info(f"Average time spend on inference: {np.mean(times_spend_on_inference)}") - logger.info(f"std of time spend on inference: {np.std(times_spend_on_inference)}") - logger.info(f"Median time spend on inference: {np.median(times_spend_on_inference)}") - logger.info(f"Min time spend on inference: {np.min(times_spend_on_inference)}") - logger.info(f"Max time spend on inference: {np.max(times_spend_on_inference)}") - logger.info(f"All times spend on inference: {[round(i, 4) for i in times_spend_on_inference]}") - logger.info(f"Number of chunks: {len(times_spend_on_inference)}") diff --git a/examples/speechlm2/nemo_inference_pipelines/triton/model_repo_s2s/voicechat/1/infer_streaming.py b/examples/speechlm2/nemo_inference_pipelines/triton/model_repo_s2s/voicechat/1/infer_streaming.py deleted file mode 100644 index b0beb18207d4..000000000000 --- a/examples/speechlm2/nemo_inference_pipelines/triton/model_repo_s2s/voicechat/1/infer_streaming.py +++ /dev/null @@ -1,354 +0,0 @@ -# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations -from typing import List, Iterable, Tuple -import os -import numpy as np -import torch - -from nemo.collections.asr.inference.streaming.framing.request import Frame -from nemo.collections.speechlm2.inference.factory.s2s_pipeline_builder import S2SPipelineBuilder -from nemo.collections.speechlm2.inference.streaming.framing.s2s_request_options import S2SRequestOptions - -import triton_python_backend_utils as pb_utils - -from omegaconf import OmegaConf -from nemo.utils import logging -import time - - -class TritonPythonModel: - """Triton Python model for streaming S2S generation. - - This model uses the NeMo S2S pipeline to generate speech from speech input. - Every Python model that is created must have "TritonPythonModel" as the class name. - """ - - def _resolve_env_overrides(self, cfg): - """Resolve ??? placeholders in the config from environment variables. - - This allows start_triton.sh to control model paths and settings via - env vars, while sharing the same s2s_streaming.yaml used by the CLI. - - Env var mapping (cfg key -> env var, default): - s2s.model_path -> S2S_MODEL_PATH (required) - s2s.speaker_reference -> S2S_SPEAKER_REFERENCE (optional) - s2s.speaker_name -> S2S_SPEAKER_NAME (optional) - s2s.engine_type -> S2S_ENGINE_TYPE (default: native) - s2s.deterministic -> S2S_DETERMINISTIC (default: false) - s2s.use_llm_cache -> S2S_USE_LLM_CACHE (default: true) - s2s.use_tts_subword_cache -> S2S_USE_TTS_SUBWORD_CACHE (default: false) - s2s.system_prompt -> S2S_SYSTEM_PROMPT (optional) - s2s.tts_system_prompt -> S2S_TTS_SYSTEM_PROMPT (optional) - streaming.chunk_size_in_secs -> S2S_CHUNK_SIZE_IN_SECS (default: 0.08) - streaming.buffer_size_in_secs -> S2S_BUFFER_SIZE_IN_SECS (default: 5.6) - """ - env_overrides = { - # Required - "s2s.model_path": ("S2S_MODEL_PATH", None), - # Speaker identity (set one or both) - "s2s.speaker_reference": ("S2S_SPEAKER_REFERENCE", None), - "s2s.speaker_name": ("S2S_SPEAKER_NAME", None), - # Engine & precision - "s2s.engine_type": ("S2S_ENGINE_TYPE", "native"), - "s2s.deterministic": ("S2S_DETERMINISTIC", False), - # Cache / speedup flags - "s2s.use_llm_cache": ("S2S_USE_LLM_CACHE", True), - "s2s.use_tts_subword_cache": ("S2S_USE_TTS_SUBWORD_CACHE", False), - # Prompts - "s2s.system_prompt": ("S2S_SYSTEM_PROMPT", None), - "s2s.tts_system_prompt": ("S2S_TTS_SYSTEM_PROMPT", None), - # Streaming - "streaming.chunk_size_in_secs": ("S2S_CHUNK_SIZE_IN_SECS", 0.08), - "streaming.buffer_size_in_secs":("S2S_BUFFER_SIZE_IN_SECS", 5.6), - } - for cfg_key, (env_var, default) in env_overrides.items(): - val = os.environ.get(env_var, "") - if val: - if isinstance(default, bool): - val = val.lower() in ("true", "1", "yes") - elif isinstance(default, float): - val = float(val) - elif isinstance(default, int): - val = int(val) - OmegaConf.update(cfg, cfg_key, val, force_add=True) - elif default is not None: - OmegaConf.update(cfg, cfg_key, default, force_add=True) - - def load_model(self, config_path: str): - """Load the S2S pipeline from a YAML config file. - - Args: - config_path: Path to a shared YAML config file (s2s_streaming.yaml). - Fields marked ??? are resolved from environment variables - exported by start_triton.sh. - """ - cfg = OmegaConf.load(config_path) - self._resolve_env_overrides(cfg) - - self.pipeline = S2SPipelineBuilder.build_pipeline(cfg) - self.pipeline.open_session() - - # Compute chunk size in samples from the pipeline's config - self.chunk_size = int(self.pipeline.chunk_size_in_secs * self.pipeline.input_sample_rate) - - # Track text positions to return only incremental updates - self.text_positions = {} # stream_id -> last_text_length - self.asr_text_positions = {} # stream_id -> last_asr_text_length - - def initialize(self, args): - """`initialize` is called only once when the model is being loaded. - Implementing `initialize` function is optional. This function allows - the model to initialize any state associated with this model. - - Parameters - ---------- - args : dict - Both keys and values are strings. The dictionary keys and values are: - * model_config: A JSON string containing the model configuration - * model_instance_kind: A string containing model instance kind - * model_instance_device_id: A string containing model instance device ID - * model_repository: Model repository path - * model_version: Model version - * model_name: Model name - """ - # Config path: set S2S_TRITON_CONFIG_PATH env var (start_triton.sh does this automatically). - config_path = os.environ.get("S2S_TRITON_CONFIG_PATH") - if not config_path: - raise ValueError( - "S2S_TRITON_CONFIG_PATH environment variable is not set. " - "Use start_triton.sh or set it to the path of s2s_streaming.yaml." - ) - logging.info(f"Loading S2S Triton model from config: {config_path}") - self.load_model(config_path) - - # Warm up the inference engine(s) with a throwaway prefill so the - # first real client request doesn't pay one-time initialization cost. - self.pipeline.warmup() - - def finalize(self) -> None: - """Finalize the model.""" - # Close the session, clear state pool, and empty CUDA cache - self.pipeline.close_session() - torch.cuda.empty_cache() - - def validate_and_convert_audio(self, audio_signal: np.ndarray) -> torch.Tensor: - """Validate that the audio chunk matches the expected size and convert to tensor.""" - if audio_signal.ndim == 2: - audio_signal = audio_signal.flatten() - - if len(audio_signal) != self.chunk_size: - expected_frames = self.pipeline.num_frames_per_chunk - actual_secs = len(audio_signal) / self.pipeline.input_sample_rate - raise ValueError( - f"Audio chunk size mismatch: received {len(audio_signal)} samples ({actual_secs:.3f}s) " - f"but server expects {self.chunk_size} samples " - f"({self.pipeline.chunk_size_in_secs}s = {expected_frames} frame(s)). " - f"Make sure the client's num_frames_per_chunk matches the server's " - f"chunk_size_in_secs={self.pipeline.chunk_size_in_secs}." - ) - - return torch.tensor(audio_signal, dtype=torch.float32) - - def triton_requests_to_frames(self, requests: Iterable) -> List[Frame]: - """ - Convert Triton inference requests into streaming audio Frames. - - Extracts audio data and sequence batching controls (START, END, CORRID) - from each Triton request and wraps them in Frame dataclasses for the - streaming S2S pipeline. - - Since max_batch_size=0, processes one request at a time. - - Returns: - List of Frame objects (one per request) - """ - frames = [] - - for request in requests: - # Get audio input - audio_signal = pb_utils.get_input_tensor_by_name(request, "audio_signal").as_numpy() - - # Extract sequence batching metadata from Triton control inputs - # These are automatically populated when client uses sequence_start/end/id - is_first = False - is_last = False - stream_id = 0 - - try: - start_tensor = pb_utils.get_input_tensor_by_name(request, "START") - if start_tensor is not None: - is_first = bool(start_tensor.as_numpy()[0]) - except Exception: - pass - - try: - end_tensor = pb_utils.get_input_tensor_by_name(request, "END") - if end_tensor is not None: - is_last = bool(end_tensor.as_numpy()[0]) - except Exception: - pass - - try: - corrid_tensor = pb_utils.get_input_tensor_by_name(request, "CORRID") - if corrid_tensor is not None: - stream_id = int(corrid_tensor.as_numpy()[0]) - except Exception: - pass - - # Extract optional per-stream system prompt (sent on the first request) - frame_options = None - if is_first: - system_prompt = None - try: - prompt_tensor = pb_utils.get_input_tensor_by_name(request, "system_prompt") - if prompt_tensor is not None: - raw = prompt_tensor.as_numpy()[0] - system_prompt = raw.decode("utf-8") if isinstance(raw, bytes) else str(raw) - except Exception: - pass - if system_prompt is None: - system_prompt = self.pipeline.system_prompt - frame_options = S2SRequestOptions(system_prompt=system_prompt) - - # Zero-length audio = prefill-only frame; pass through without validation - if audio_signal.size == 0: - samples = torch.empty(0, dtype=torch.float32) - else: - samples = self.validate_and_convert_audio(audio_signal) - - frames.append(Frame( - samples=samples, - stream_id=stream_id, - is_first=is_first, - is_last=is_last, - options=frame_options, - )) - - return frames - - def get_generations(self, frames: List[Frame]) -> List[Tuple]: - """ - Generate speech for the requests. - - Uses StreamingS2SPipeline.generate_step() which updates internal state, - then extracts results from per-stream S2SStreamingState objects. - - Zero-length first frames are prefill-only: generate_step handles them - internally and returns early; this method returns empty results for them. - - Returns a list of tuples, where each tuple contains: - - generated audio tensor - - generated text string (incremental, only new text since last response) - - generated ASR text string (incremental, only new ASR text since last response) - """ - _t_generate_step = time.time() - self.pipeline.generate_step(frames) - _t_generate_step_done = time.time() - - _t_extract = time.time() - generations = [] - - for frame in frames: - stream_id = frame.stream_id - - # Prefill-only frames don't produce audio/text output - if frame.is_first and frame.samples.numel() == 0: - generations.append((torch.empty(1, 0), "", "")) - continue - - state = self.pipeline.get_or_create_state(stream_id) - audio = state.audio_buffer - - full_text = state.output_text_str - full_asr_text = state.output_asr_text_str - - if stream_id not in self.text_positions: - self.text_positions[stream_id] = 0 - last_position = self.text_positions[stream_id] - incremental_text = full_text[last_position:] - self.text_positions[stream_id] = len(full_text) - - if stream_id not in self.asr_text_positions: - self.asr_text_positions[stream_id] = 0 - last_asr_position = self.asr_text_positions[stream_id] - incremental_asr_text = full_asr_text[last_asr_position:] - self.asr_text_positions[stream_id] = len(full_asr_text) - - generations.append((audio, incremental_text, incremental_asr_text)) - - state.clear_audio_buffer() - - if frame.is_last: - self.pipeline.delete_state(stream_id) - if stream_id in self.text_positions: - del self.text_positions[stream_id] - if stream_id in self.asr_text_positions: - del self.asr_text_positions[stream_id] - _t_extract_done = time.time() - - logging.info(f"get_generations breakdown: generate_step={(_t_generate_step_done - _t_generate_step)*1000:.2f}ms, " - f"extract+cleanup={(_t_extract_done - _t_extract)*1000:.2f}ms") - - return generations - - def execute(self, requests: Iterable) -> List[pb_utils.InferenceResponse]: - """Execute the model and return the responses. - - Clients MUST send a prefill request (zero-length audio with - ``sequence_start=True``) before streaming audio. The prefill - initializes the TTS speaker embedding and system prompt for the - session. Sending audio on the first request without a prefill - will produce degraded speaker voice quality. - - Returns: - - output_audio: float32 array of generated audio samples - - output_text: UTF-8 encoded string of generated text (agent's response) - - output_asr_text: UTF-8 encoded string of ASR text (user's transcribed speech) - """ - start_time = time.time() - - _t_to_frames = time.time() - frames = self.triton_requests_to_frames(requests) - _t_to_frames_done = time.time() - - _t_generations = time.time() - generations = self.get_generations(frames) - _t_generations_done = time.time() - - responses = [] - for audio, text, asr_text in generations: - if isinstance(audio, torch.Tensor): - audio_np = audio.detach().cpu().float().numpy() - if audio_np.ndim == 1: - audio_np = audio_np.reshape(1, -1) - else: - audio_np = np.zeros((1, 0), dtype=np.float32) - - text_np = np.array([text.encode('utf-8')], dtype=object) - asr_text_np = np.array([asr_text.encode('utf-8')], dtype=object) - - responses.append(pb_utils.InferenceResponse(output_tensors=[ - pb_utils.Tensor("output_audio", audio_np), - pb_utils.Tensor("output_text", text_np), - pb_utils.Tensor("output_asr_text", asr_text_np), - ])) - - end_time = time.time() - logging.info(f"TritonPythonModel.execute time: {end_time - start_time:.2f} seconds") - logging.info(f"execute() breakdown: triton_requests_to_frames={(_t_to_frames_done - _t_to_frames)*1000:.2f}ms, " - f"get_generations={(_t_generations_done - _t_generations)*1000:.2f}ms") - - return responses diff --git a/examples/speechlm2/nemo_inference_pipelines/triton/model_repo_s2s/voicechat/config.pbtxt b/examples/speechlm2/nemo_inference_pipelines/triton/model_repo_s2s/voicechat/config.pbtxt deleted file mode 100644 index 113b6d96a006..000000000000 --- a/examples/speechlm2/nemo_inference_pipelines/triton/model_repo_s2s/voicechat/config.pbtxt +++ /dev/null @@ -1,90 +0,0 @@ -# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -name: "voicechat" -default_model_filename: "infer_streaming.py" -backend: "python" -max_batch_size: 0 - -input { - name: "audio_signal" - data_type: TYPE_FP32 - dims: [-1, -1] -} - -input { - name: "system_prompt" - data_type: TYPE_STRING - dims: [-1] - optional: true -} - -output { - name: "output_text" - data_type: TYPE_STRING - dims: [-1] -} - -output { - name: "output_asr_text" - data_type: TYPE_STRING - dims: [-1] -} - -output [ - { - name: "output_audio" - data_type: TYPE_FP32 - dims: [-1, -1] - } -] - -sequence_batching { - max_sequence_idle_microseconds: 30000000 - oldest - { - max_candidate_sequences: 1 - } - control_input [ - { - name: "START" - control [ - { - kind: CONTROL_SEQUENCE_START - fp32_false_true: [ 0, 1 ] - } - ] - }, - { - name: "END" - control [ - { - kind: CONTROL_SEQUENCE_END - fp32_false_true: [ 0, 1 ] - } - ] - }, - { - name: "CORRID" - control [ - { - kind: CONTROL_SEQUENCE_CORRID - data_type: TYPE_UINT64 - } - ] - } - ] -} - -instance_group [{ kind: KIND_GPU, gpus: [0] }] diff --git a/examples/speechlm2/nemo_inference_pipelines/triton/start_triton.sh b/examples/speechlm2/nemo_inference_pipelines/triton/start_triton.sh deleted file mode 100755 index d42b035fde48..000000000000 --- a/examples/speechlm2/nemo_inference_pipelines/triton/start_triton.sh +++ /dev/null @@ -1,103 +0,0 @@ -#!/bin/bash -# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Start Triton Inference Server for S2S voicechat model. -# -# Shares the same s2s_streaming.yaml config used by s2s_streaming_infer.py. -# Fields marked ??? in the YAML are resolved from environment variables below. -# -# Usage: -# S2S_MODEL_PATH=/path/to/hf_checkpoint \ -# S2S_SPEAKER_NAME=MySpeaker \ -# ./start_triton.sh -# -# Environment variables (required): -# S2S_MODEL_PATH - Path to the HF-format checkpoint directory -# -# Environment variables (speaker identity — set at least one): -# S2S_SPEAKER_REFERENCE - Path to a speaker reference .wav file -# S2S_SPEAKER_NAME - Registered speaker name from the checkpoint -# -# Environment variables (optional): -# S2S_ENGINE_TYPE - Engine type (default: native) -# S2S_DETERMINISTIC - "true"/"false": deterministic mode (default: false) -# S2S_USE_LLM_CACHE - "true"/"false": LLM KV cache (default: true) -# S2S_USE_TTS_SUBWORD_CACHE - "true"/"false": TTS subword cache (default: false) -# S2S_SYSTEM_PROMPT - LLM system prompt text (default: none) -# S2S_TTS_SYSTEM_PROMPT - TTS system prompt (default: none) -# S2S_CHUNK_SIZE_IN_SECS - Chunk size in seconds, multiple of 0.08 (default: 0.08) -# S2S_BUFFER_SIZE_IN_SECS - Audio buffer size in seconds (default: 5.6) -# S2S_TRITON_CONFIG_PATH - Override the YAML config file path -# MODEL_REPO_DIR - Override the Triton model repository path - -SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" - -# All variables below are exported so they are visible to the Triton Python -# backend (infer_streaming.py reads them via os.environ). - -# ======================== -# Model path (required) -# ======================== -export S2S_MODEL_PATH="${S2S_MODEL_PATH:?Please set S2S_MODEL_PATH to the HF-format checkpoint directory}" - -# ======================== -# Speaker identity (set at least one) -# ======================== -export S2S_SPEAKER_REFERENCE="${S2S_SPEAKER_REFERENCE:-}" -export S2S_SPEAKER_NAME="${S2S_SPEAKER_NAME:-}" -if [ -z "${S2S_SPEAKER_REFERENCE}" ] && [ -z "${S2S_SPEAKER_NAME}" ]; then - echo "ERROR: Set at least one of S2S_SPEAKER_REFERENCE or S2S_SPEAKER_NAME" - exit 1 -fi - -# ======================== -# Optional overrides -# ======================== -export S2S_ENGINE_TYPE="${S2S_ENGINE_TYPE:-native}" -export S2S_DETERMINISTIC="${S2S_DETERMINISTIC:-}" -export S2S_USE_LLM_CACHE="${S2S_USE_LLM_CACHE:-}" -export S2S_USE_TTS_SUBWORD_CACHE="${S2S_USE_TTS_SUBWORD_CACHE:-}" -export S2S_SYSTEM_PROMPT="${S2S_SYSTEM_PROMPT:-}" -export S2S_TTS_SYSTEM_PROMPT="${S2S_TTS_SYSTEM_PROMPT:-}" -export S2S_CHUNK_SIZE_IN_SECS="${S2S_CHUNK_SIZE_IN_SECS:-0.08}" -export S2S_BUFFER_SIZE_IN_SECS="${S2S_BUFFER_SIZE_IN_SECS:-5.6}" -export S2S_TRITON_CONFIG_PATH="${S2S_TRITON_CONFIG_PATH:-${SCRIPT_DIR}/../conf/s2s_streaming.yaml}" -export MODEL_REPO_DIR="${MODEL_REPO_DIR:-${SCRIPT_DIR}/model_repo_s2s}" - - -echo "=== S2S Triton Server ===" -echo " S2S_MODEL_PATH: ${S2S_MODEL_PATH}" -echo " S2S_SPEAKER_REFERENCE: ${S2S_SPEAKER_REFERENCE:-}" -echo " S2S_SPEAKER_NAME: ${S2S_SPEAKER_NAME:-}" -echo " S2S_ENGINE_TYPE: ${S2S_ENGINE_TYPE}" -echo " S2S_DETERMINISTIC: ${S2S_DETERMINISTIC:-}" -echo " S2S_USE_LLM_CACHE: ${S2S_USE_LLM_CACHE:-}" -echo " S2S_USE_TTS_SUBWORD_CACHE: ${S2S_USE_TTS_SUBWORD_CACHE:-}" -echo " S2S_CHUNK_SIZE_IN_SECS: ${S2S_CHUNK_SIZE_IN_SECS}" -echo " S2S_BUFFER_SIZE_IN_SECS: ${S2S_BUFFER_SIZE_IN_SECS}" -echo " S2S_SYSTEM_PROMPT: ${S2S_SYSTEM_PROMPT:-}" -echo " S2S_TTS_SYSTEM_PROMPT: ${S2S_TTS_SYSTEM_PROMPT:-}" -echo " MODEL_REPO_DIR: ${MODEL_REPO_DIR}" -echo " S2S_TRITON_CONFIG_PATH: ${S2S_TRITON_CONFIG_PATH}" -echo "=========================" - -TRITON_BIN="${TRITON_BIN:-/opt/tritonserver/bin/tritonserver}" -if [ ! -x "${TRITON_BIN}" ]; then - echo "ERROR: Triton server not found at ${TRITON_BIN}" - echo " Are you running inside a Triton container?" - exit 1 -fi - -"${TRITON_BIN}" --model-repository="${MODEL_REPO_DIR}" From b7673a40b55a19760c771974633a358f03753a86 Mon Sep 17 00:00:00 2001 From: Elena Rastorgueva Date: Wed, 1 Apr 2026 06:36:06 +0000 Subject: [PATCH 28/40] add missing __init__.py Signed-off-by: Elena Rastorgueva --- .../speechlm2/inference/vllm/scripts/__init__.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) create mode 100644 nemo/collections/speechlm2/inference/vllm/scripts/__init__.py diff --git a/nemo/collections/speechlm2/inference/vllm/scripts/__init__.py b/nemo/collections/speechlm2/inference/vllm/scripts/__init__.py new file mode 100644 index 000000000000..9e3fb699d9f6 --- /dev/null +++ b/nemo/collections/speechlm2/inference/vllm/scripts/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. From df2e3bb3e9870f6ccb158d9bdf6d5f196e51852d Mon Sep 17 00:00:00 2001 From: Elena Rastorgueva Date: Wed, 1 Apr 2026 19:14:56 +0000 Subject: [PATCH 29/40] use built-in type hints (X | None, dict, list) instead of typing imports Signed-off-by: Elena Rastorgueva --- .../inference/model_wrappers/decode_state.py | 18 ++--- .../inference/model_wrappers/model_factory.py | 71 +++++++++---------- .../nemotron_voicechat_inference_wrapper.py | 16 ++--- .../model_wrappers/perception_cache.py | 49 +++++++------ .../pipelines/s2s_pipeline_interface.py | 4 +- .../pipelines/streaming_s2s_pipeline.py | 33 +++++---- .../streaming/state/s2s_context_manager.py | 14 ++-- .../inference/streaming/state/s2s_state.py | 14 ++-- .../speechlm2/inference/utils/audio_data.py | 19 +++-- .../inference/utils/pipeline_utils.py | 28 ++++---- .../scripts/convert_nemotronllm_checkpoint.py | 16 ++--- .../inference/vllm/streaming_llm_engine.py | 19 +++-- .../speechlm2/models/duplex_stt_model.py | 6 +- .../speechlm2/models/nemotron_voicechat.py | 16 ++--- 14 files changed, 151 insertions(+), 172 deletions(-) diff --git a/nemo/collections/speechlm2/inference/model_wrappers/decode_state.py b/nemo/collections/speechlm2/inference/model_wrappers/decode_state.py index 9f34a35a55b2..bde02df57105 100644 --- a/nemo/collections/speechlm2/inference/model_wrappers/decode_state.py +++ b/nemo/collections/speechlm2/inference/model_wrappers/decode_state.py @@ -32,7 +32,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Any, List, Optional, TYPE_CHECKING +from typing import Any, TYPE_CHECKING import torch @@ -51,13 +51,13 @@ class StreamingDecodeState: frame_idx: int gen_text: torch.Tensor gen_asr_text: torch.Tensor - gen_function_text: Optional[torch.Tensor] - input_embeds_history: List[torch.Tensor] + gen_function_text: torch.Tensor | None + input_embeds_history: list[torch.Tensor] llm_cache: Any # DynamicCache or HybridMambaAttentionDynamicCache tts_past_key_values: Any - tts_code: Optional[torch.Tensor] - subword_mask: Optional[torch.Tensor] - perception_cache: Optional["PerceptionCacheState"] = None + tts_code: torch.Tensor | None + subword_mask: torch.Tensor | None + perception_cache: "PerceptionCacheState" | None = None tts_codec_cache: Any = None llm_cache_position_offset: int = 0 @@ -75,6 +75,6 @@ class InferenceStepResult: asr_predicted_text_tokens: torch.Tensor predicted_text_strs: list[str] asr_predicted_text_strs: list[str] - decoded_audio: Optional[torch.Tensor] = None - function_predicted_text_tokens: Optional[torch.Tensor] = None - debug: Optional[dict] = None + decoded_audio: torch.Tensor | None = None + function_predicted_text_tokens: torch.Tensor | None = None + debug: dict | None = None diff --git a/nemo/collections/speechlm2/inference/model_wrappers/model_factory.py b/nemo/collections/speechlm2/inference/model_wrappers/model_factory.py index 4d808d3c16fe..1fcb4517390e 100644 --- a/nemo/collections/speechlm2/inference/model_wrappers/model_factory.py +++ b/nemo/collections/speechlm2/inference/model_wrappers/model_factory.py @@ -33,11 +33,10 @@ """ from abc import ABC, abstractmethod -from typing import Optional, Dict, Any, Union, Set +from typing import Any import math import os import torch -from transformers import DynamicCache from dataclasses import dataclass, fields from nemo.utils import logging @@ -54,7 +53,7 @@ class ModelInterface(ABC): def __init__( self, - special_token_ids: Optional[Set[int]] = None, + special_token_ids: set[int] | None = None, top_p: float = 1.0, repetition_penalty: float = 1.0, temperature: float = 1.0, @@ -82,7 +81,7 @@ def __init__( # Pre-built tensor for special-token filtering in repetition penalty. # Lazily moved to the right device on first use (see _sample_text_token). - self._special_ids_tensor: Optional[torch.Tensor] = ( + self._special_ids_tensor: torch.Tensor | None = ( torch.tensor(sorted(self.special_token_ids), dtype=torch.long) if self.special_token_ids else None ) @@ -198,9 +197,9 @@ def _sample_text_token( def __call__( self, input_embeds: torch.Tensor, - cache: Optional[Any] = None, + cache: Any | None = None, **kwargs - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """ Perform model inference. @@ -218,7 +217,7 @@ def __call__( pass @abstractmethod - def to(self, device_or_dtype: Union[torch.device, torch.dtype]) -> 'ModelInterface': + def to(self, device_or_dtype: torch.device | torch.dtype) -> 'ModelInterface': """Move model to specified device or convert to specified dtype.""" pass @@ -268,9 +267,9 @@ def __init__( gpu_memory_utilization: float = 0.8, trust_remote_code: bool = True, dtype: str = "bfloat16", - engine_path: Optional[str] = None, - pretrained_llm: Optional[str] = None, - special_token_ids: Optional[Set[int]] = None, + engine_path: str | None = None, + pretrained_llm: str | None = None, + special_token_ids: set[int] | None = None, top_p: float = 1.0, repetition_penalty: float = 1.0, temperature: float = 1.0, @@ -359,7 +358,7 @@ def __init__( logging.info("vLLM engine ready!") @staticmethod - def _get_special_token_ids_from_vllm_tokenizer(tokenizer) -> Set[int]: + def _get_special_token_ids_from_vllm_tokenizer(tokenizer) -> set[int]: """ Extract special token IDs from a vLLM tokenizer. Looks for: '' (bos), '' (eos), '' (pad). @@ -400,9 +399,9 @@ def _generate_request_id(self) -> str: def __call__( self, input_embeds: torch.Tensor, - request_id: Optional[str] = "request_id_1", + request_id: str | None = "request_id_1", **kwargs - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """ Perform inference using vLLM streaming engine. @@ -430,10 +429,10 @@ def __call__( async def _async_inference( self, - inputs: Union[torch.Tensor, list[torch.Tensor]], + inputs: torch.Tensor | list[torch.Tensor], request_id: str, **kwargs - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """ Async inference using the streaming engine. @@ -475,10 +474,10 @@ async def _process_inputs_to_outputs( input_embeds: torch.Tensor, request_id: str, decode_steps: int = 1, - prompt_token_ids: Optional[list] = None, - generated_tokens: Optional[torch.Tensor] = None, + prompt_token_ids: list | None = None, + generated_tokens: torch.Tensor | None = None, current_step: int = 0 - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """ Process embeddings sequentially to generate text and ASR tokens. @@ -556,7 +555,7 @@ async def _process_inputs_to_outputs( return ans - def to(self, device_or_dtype: Union[torch.device, torch.dtype]) -> 'VllmLLMModel': + def to(self, device_or_dtype: torch.device | torch.dtype) -> 'VllmLLMModel': """ Move model to specified device or convert to specified dtype. @@ -611,7 +610,7 @@ def restart_request(self, request_id: str) -> bool: self.engine.start_generation(request_id=request_id) ) - def get_request_status(self, request_id: Optional[str] = None) -> Dict[str, Any]: + def get_request_status(self, request_id: str | None = None) -> dict[str, Any]: """ Get status of a specific request or all requests. @@ -637,7 +636,7 @@ def __del__(self): @dataclass class TTSGenerationResult: codes: torch.Tensor # Generated acoustic tokens - past_key_values: Optional[Any] # Updated cache (if applicable) + past_key_values: Any # Updated cache (if applicable) def __getitem__(self, item: str | int): """Allows for accessing attributes by key or index.""" @@ -676,9 +675,9 @@ def _convert_ckpt(self, save_path: str): def __call__( self, - inputs: Optional[Dict[str, torch.Tensor]] = None, - request_id: Optional[str] = None, - prompt_token_ids: Optional[list] = None, + inputs: dict[str, torch.Tensor] | None = None, + request_id: str | None = None, + prompt_token_ids: list | None = None, **kwargs ) -> TTSGenerationResult: """ @@ -729,10 +728,10 @@ def __call__( async def _process_inputs_to_outputs( self, - inputs: Dict[str, torch.Tensor], + inputs: dict[str, torch.Tensor], request_id: str, - prompt_token_ids: Optional[list] = None, - ) -> Dict[str, Any]: + prompt_token_ids: list | None = None, + ) -> dict[str, Any]: """ Process embeddings sequentially to generate text and ASR tokens. @@ -832,7 +831,7 @@ class NativeModel(ModelInterface): def __init__( self, model, - special_token_ids: Optional[Set[int]] = None, + special_token_ids: set[int] | None = None, top_p: float = 1.0, repetition_penalty: float = 1.0, temperature: float = 1.0, @@ -888,13 +887,13 @@ def __init__( def __call__( self, input_embeds: torch.Tensor, - cache: Optional[Any] = None, - cache_position: Optional[torch.Tensor] = None, - generated_tokens: Optional[torch.Tensor] = None, + cache: Any | None = None, + cache_position: torch.Tensor | None = None, + generated_tokens: torch.Tensor | None = None, current_step: int = 0, return_logits: bool = False, **kwargs - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """ Perform inference using the native model. @@ -953,7 +952,7 @@ def __call__( return ans @staticmethod - def _extract_special_token_ids_from_nemo(model) -> Set[int]: + def _extract_special_token_ids_from_nemo(model) -> set[int]: """ Extract special token IDs from NeMo model's tokenizer. @@ -987,7 +986,7 @@ def _extract_special_token_ids_from_nemo(model) -> Set[int]: return special_ids - def to(self, device_or_dtype: Union[torch.device, torch.dtype]) -> 'NativeModel': + def to(self, device_or_dtype: torch.device | torch.dtype) -> 'NativeModel': """Move underlying model to device or convert dtype.""" self.model = self.model.to(device_or_dtype) return self @@ -1025,8 +1024,8 @@ def __getattr__(self, name: str): def create_model( model=None, engine_type: str = "native", - vllm_config: Optional[Dict[str, Any]] = None, - special_token_ids: Optional[Set[int]] = None, + vllm_config: dict[str, Any] | None = None, + special_token_ids: set[int] | None = None, top_p: float = 1.0, repetition_penalty: float = 1.0, temperature: float = 1.0, diff --git a/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py b/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py index 1cec405f283a..a563529a5410 100755 --- a/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py +++ b/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py @@ -17,8 +17,6 @@ import os import time import types -from typing import Optional, Tuple - import torch import torchaudio from omegaconf import OmegaConf, DictConfig @@ -186,7 +184,7 @@ def __init__(self, model_cfg: DictConfig): "use_perception_cudagraph requires use_perception_cache to be enabled. " "Please also set use_perception_cache=True." ) - self.perception_cache_mgr: Optional[PerceptionCacheManager] = None + self.perception_cache_mgr: PerceptionCacheManager | None = None self._use_perception_cudagraph = use_perception_cudagraph self._initialize_model() @@ -385,7 +383,7 @@ def _get_asr_bos_embedding(self) -> torch.Tensor: def _prepare_system_prompt_embeddings( self, system_prompt: str, - ) -> Tuple[Optional[torch.Tensor], int]: + ) -> tuple[torch.Tensor | None, int]: if not system_prompt or not system_prompt.strip(): return None, 0 @@ -546,7 +544,7 @@ def infer_one_step( num_frames_per_chunk: int, state: StreamingDecodeState, *, - request_id: Optional[str] = None, + request_id: str | None = None, has_prompt: bool = False, return_debug: bool = False, ) -> InferenceStepResult: @@ -854,7 +852,7 @@ def _decode_audio( state: StreamingDecodeState, frame_idx: int, num_frames_per_chunk: int, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: """Decode accumulated TTS codes into a waveform. Returns the decoded audio tensor or *None* when ``decode_audio`` @@ -894,8 +892,8 @@ def _run_perception( audio_input: torch.Tensor, frame_idx: int, num_frames_per_chunk: int, - perception_cache: Optional[PerceptionCacheState], - ) -> Tuple[torch.Tensor, Optional[PerceptionCacheState]]: + perception_cache: PerceptionCacheState | None, + ) -> tuple[torch.Tensor, PerceptionCacheState | None]: """Run the perception encoder and return (source_encoded, updated_cache).""" start_perception = time.time() @@ -947,7 +945,7 @@ def _tokens_to_strings(self, token_ids: torch.Tensor) -> list[str]: result.append(_decode_tokens_with_specials(toks, self.tokenizer, keep_pad=False)) return result - def abort_request(self, request_id: Optional[str]) -> bool: + def abort_request(self, request_id: str | None) -> bool: """ Abort an in-flight vLLM streaming request if the backend supports it. """ diff --git a/nemo/collections/speechlm2/inference/model_wrappers/perception_cache.py b/nemo/collections/speechlm2/inference/model_wrappers/perception_cache.py index 32fcc868e93a..f4fed6f69da5 100644 --- a/nemo/collections/speechlm2/inference/model_wrappers/perception_cache.py +++ b/nemo/collections/speechlm2/inference/model_wrappers/perception_cache.py @@ -23,7 +23,6 @@ import copy import time from dataclasses import dataclass -from typing import Optional, Tuple import torch from omegaconf import OmegaConf @@ -38,9 +37,9 @@ class PerceptionCacheState: Holds the cache tensors for the ASR encoder used in the perception module. This enables cache-aware streaming inference without needing the full audio buffer. """ - cache_last_channel: Optional[torch.Tensor] = None - cache_last_time: Optional[torch.Tensor] = None - cache_last_channel_len: Optional[torch.Tensor] = None + cache_last_channel: torch.Tensor | None = None + cache_last_time: torch.Tensor | None = None + cache_last_channel_len: torch.Tensor | None = None def is_initialized(self) -> bool: """Check if the cache has been initialized.""" @@ -55,34 +54,34 @@ class PerceptionCUDAGraphState: Also holds static buffers for inputs/outputs to enable graph replay. """ # CUDA graphs - graph_first: Optional[torch.cuda.CUDAGraph] = None - graph_subsequent: Optional[torch.cuda.CUDAGraph] = None + graph_first: torch.cuda.CUDAGraph | None = None + graph_subsequent: torch.cuda.CUDAGraph | None = None # Static input buffers (for copying data before graph replay) - static_mel_first: Optional[torch.Tensor] = None - static_mel_subsequent: Optional[torch.Tensor] = None - static_mel_len_first: Optional[torch.Tensor] = None - static_mel_len_subsequent: Optional[torch.Tensor] = None + static_mel_first: torch.Tensor | None = None + static_mel_subsequent: torch.Tensor | None = None + static_mel_len_first: torch.Tensor | None = None + static_mel_len_subsequent: torch.Tensor | None = None # Static cache input buffers - static_cache_channel_in: Optional[torch.Tensor] = None - static_cache_time_in: Optional[torch.Tensor] = None - static_cache_channel_len_in: Optional[torch.Tensor] = None + static_cache_channel_in: torch.Tensor | None = None + static_cache_time_in: torch.Tensor | None = None + static_cache_channel_len_in: torch.Tensor | None = None # Static output buffers (results are written here during replay) - static_encoded_first: Optional[torch.Tensor] = None - static_encoded_subsequent: Optional[torch.Tensor] = None - static_encoded_len_first: Optional[torch.Tensor] = None - static_encoded_len_subsequent: Optional[torch.Tensor] = None + static_encoded_first: torch.Tensor | None = None + static_encoded_subsequent: torch.Tensor | None = None + static_encoded_len_first: torch.Tensor | None = None + static_encoded_len_subsequent: torch.Tensor | None = None # Static cache output buffers - SEPARATE for first and subsequent graphs # (each graph writes to its own output tensors during replay) - static_cache_channel_out_first: Optional[torch.Tensor] = None - static_cache_time_out_first: Optional[torch.Tensor] = None - static_cache_channel_len_out_first: Optional[torch.Tensor] = None - static_cache_channel_out_subsequent: Optional[torch.Tensor] = None - static_cache_time_out_subsequent: Optional[torch.Tensor] = None - static_cache_channel_len_out_subsequent: Optional[torch.Tensor] = None + static_cache_channel_out_first: torch.Tensor | None = None + static_cache_time_out_first: torch.Tensor | None = None + static_cache_channel_len_out_first: torch.Tensor | None = None + static_cache_channel_out_subsequent: torch.Tensor | None = None + static_cache_time_out_subsequent: torch.Tensor | None = None + static_cache_channel_len_out_subsequent: torch.Tensor | None = None def is_captured(self) -> bool: """Check if graphs have been captured.""" @@ -108,7 +107,7 @@ def __init__(self, model, device: torch.device, dtype: torch.dtype, use_cudagrap self.subsampling_factor = None self.input_features = None self.sampling_frames = None - self.cudagraph_state: Optional[PerceptionCUDAGraphState] = None + self.cudagraph_state: PerceptionCUDAGraphState | None = None def setup(self) -> bool: """Setup cache-aware streaming for the perception encoder. @@ -330,7 +329,7 @@ def step( frame_idx: int, num_frames_per_chunk: int, perception_cache: PerceptionCacheState, - ) -> Tuple[torch.Tensor, PerceptionCacheState]: + ) -> tuple[torch.Tensor, PerceptionCacheState]: """ Perform cache-aware perception encoding for streaming inference. diff --git a/nemo/collections/speechlm2/inference/pipelines/s2s_pipeline_interface.py b/nemo/collections/speechlm2/inference/pipelines/s2s_pipeline_interface.py index 7fe2b7b6143b..fdd796f36029 100644 --- a/nemo/collections/speechlm2/inference/pipelines/s2s_pipeline_interface.py +++ b/nemo/collections/speechlm2/inference/pipelines/s2s_pipeline_interface.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Any +from typing import Any class S2SPipelineInterface: @@ -28,7 +28,7 @@ class S2SPipelineInterface: def __init__(self) -> None: # Pool that holds per-stream state, keyed by ``stream_id`` - self._state_pool: Dict[int, Any] = {} + self._state_pool: dict[int, Any] = {} # ------------------------------------------------------------------ # State helpers diff --git a/nemo/collections/speechlm2/inference/pipelines/streaming_s2s_pipeline.py b/nemo/collections/speechlm2/inference/pipelines/streaming_s2s_pipeline.py index 5bb97dcbe739..bd7ae118956d 100644 --- a/nemo/collections/speechlm2/inference/pipelines/streaming_s2s_pipeline.py +++ b/nemo/collections/speechlm2/inference/pipelines/streaming_s2s_pipeline.py @@ -18,7 +18,6 @@ import torch import librosa -from typing import List, Optional from torch import Tensor import soundfile as sf from omegaconf import DictConfig @@ -101,7 +100,7 @@ def __init__(self, cfg: DictConfig, s2s_model: NemotronVoicechatInferenceWrapper # System prompt configuration # ------------------------------------------------------------------ s2s_cfg = cfg.get("s2s", {}) - self.system_prompt: Optional[str] = getattr(s2s_cfg, "system_prompt", None) + self.system_prompt: str | None = getattr(s2s_cfg, "system_prompt", None) if self.system_prompt: logging.info(f"System prompt configured: {self.system_prompt[:100]}{'...' if len(self.system_prompt) > 100 else ''}") @@ -152,7 +151,7 @@ def create_state(self) -> S2SStreamingState: # ------------------------------------------------------------------ # Output helpers # ------------------------------------------------------------------ - def log_output(self, frames: List[Frame], audio_wave: Tensor, ready_feats: List[bool], text_pieces: List[str], asr_text_pieces: List[str] = None): + def log_output(self, frames: list[Frame], audio_wave: Tensor, ready_feats: list[bool], text_pieces: list[str], asr_text_pieces: list[str] = None): """Append generated audio waveform and text to per-stream state.""" for idx, frame in enumerate(frames): if not ready_feats[idx]: @@ -177,7 +176,7 @@ def log_output(self, frames: List[Frame], audio_wave: Tensor, ready_feats: List[ state.append_step_output(sample_audio, text=piece, asr_text=asr_piece) - def inner_generate_step(self, frames: List[Frame], buffers: List[Tensor], ready_feats: List[bool]): + def inner_generate_step(self, frames: list[Frame], buffers: list[Tensor], ready_feats: list[bool]): """Generate speech for chunks in *batch* using a shared ContextManager.""" if len(frames) == 0: return @@ -337,7 +336,7 @@ def warmup(self, system_prompt: str | None = None) -> None: logging.info(f"Pipeline warmup complete in {time.time() - t0:.3f}s") - def generate_step(self, frames: List[Frame]): + def generate_step(self, frames: list[Frame]): """Main streaming API similar to *transcribe_step* in recognizers. If the batch contains a single zero-length first frame with a system @@ -371,8 +370,8 @@ def generate_step(self, frames: List[Frame]): # ------------------------------------------------------------------ def _finalize_and_save_finished_streams( self, - frames: List[Frame], - audio_filepaths: List[str], + frames: list[Frame], + audio_filepaths: list[str], saved_paths_by_stream: dict[int, str], ) -> None: """Finalize any streams that ended in this batch and save their outputs.""" @@ -471,9 +470,9 @@ def reset_session(self) -> None: # ------------------------------------------------------------------ def run( self, - audio_filepaths: List[str], - options: List[S2SRequestOptions] | None = None, - progress_bar: Optional[ProgressBar] = None, + audio_filepaths: list[str], + options: list[S2SRequestOptions] | None = None, + progress_bar: ProgressBar | None = None, ) -> PipelineOutput: """Stream all *audio_filepaths* through the pipeline and save outputs. @@ -519,7 +518,7 @@ def run( # run() helpers # ------------------------------------------------------------------ - def _maybe_prefill(self, frames: List[Frame]) -> None: + def _maybe_prefill(self, frames: list[Frame]) -> None: """If the first frame of a new stream carries a system prompt, emit a zero-length prefill frame through ``generate_step`` before inference begins. This is the unified prefill protocol used by both ``run()`` @@ -541,9 +540,9 @@ def _maybe_prefill(self, frames: List[Frame]) -> None: def _apply_padding( self, - frames: List[Frame], + frames: list[Frame], streamer: ContinuousBatchedFrameStreamer, - ) -> tuple[List[Frame], dict[int, float]]: + ) -> tuple[list[Frame], dict[int, float]]: """If padding is configured, intercept last frames so the bufferer and context stay alive for the silence-padding phase. Returns the (possibly modified) frames and a dict mapping ``stream_id`` to the @@ -576,7 +575,7 @@ def _generate_silence_padding( self, pad_targets: dict[int, float], chunk_samples: int, - audio_filepaths: List[str], + audio_filepaths: list[str], saved_paths_by_stream: dict[int, str], ) -> None: """Generate silence-padding frames for streams that need them. @@ -603,7 +602,7 @@ def _generate_silence_padding( def _build_pipeline_output( self, - audio_filepaths: List[str], + audio_filepaths: list[str], saved_paths_by_stream: dict[int, str], ) -> PipelineOutput: """Assemble final ``PipelineOutput`` from accumulated per-stream state. @@ -691,7 +690,7 @@ def _build_pipeline_output( debug_data=debug_data if debug_data else None, ) - def _prefill_system_prompt(self, stream_id: int, system_prompt: str | None = None) -> Optional[torch.Tensor]: + def _prefill_system_prompt(self, stream_id: int, system_prompt: str | None = None) -> torch.Tensor | None: """Prefill the system prompt for a new stream. This prepares the system prompt embeddings and processes them through @@ -711,7 +710,7 @@ def _prefill_system_prompt(self, stream_id: int, system_prompt: str | None = Non Using prefill OUTPUT codes causes audio quality issues (mumbling). Returns: - Optional[torch.Tensor]: The TTS prefill output codes if vLLM EarTTS prefill + torch.Tensor | None: The TTS prefill output codes if vLLM EarTTS prefill happened, None otherwise. These are returned for logging/debugging but should NOT be used to update context.tts_code. """ diff --git a/nemo/collections/speechlm2/inference/streaming/state/s2s_context_manager.py b/nemo/collections/speechlm2/inference/streaming/state/s2s_context_manager.py index 9ea360f8ff1f..4bb0c6dc1e0b 100644 --- a/nemo/collections/speechlm2/inference/streaming/state/s2s_context_manager.py +++ b/nemo/collections/speechlm2/inference/streaming/state/s2s_context_manager.py @@ -13,8 +13,6 @@ # limitations under the License. from queue import Queue -from typing import Any, Dict, List, Optional, Tuple - import torch from nemo.utils import logging @@ -43,12 +41,12 @@ def __init__( def reset(self) -> None: """Reset all bookkeeping for a new streaming session.""" - self.streamidx2slotidx: Dict[int, int] = {} - self.slotidx2streamidx: Dict[int, int] = {} + self.streamidx2slotidx: dict[int, int] = {} + self.slotidx2streamidx: dict[int, int] = {} self.free_slots = Queue(self.num_slots) for i in range(self.num_slots): self.free_slots.put(i) - self.slot_contexts: List[Optional[StreamingDecodeState]] = [None] * self.num_slots + self.slot_contexts: list[StreamingDecodeState | None] = [None] * self.num_slots def _create_context(self) -> StreamingDecodeState: """Allocate a fresh context backed by the realtime inference model.""" @@ -91,7 +89,7 @@ def reset_slot(self, slot_idx: int) -> None: def update_context( self, - stream_ids: List[int], + stream_ids: list[int], step_result: InferenceStepResult, num_frames: int, ) -> None: @@ -130,7 +128,7 @@ def update_context( if context.subword_mask is not None: context.subword_mask[:, start_idx:end_idx] = True - def reset_slots(self, stream_ids: List[int], eos_flags: List[bool]) -> None: + def reset_slots(self, stream_ids: list[int], eos_flags: list[bool]) -> None: """Release contexts for streams that signalled end-of-stream.""" if len(stream_ids) != len(eos_flags): raise ValueError("stream_ids and eos_flags must have the same length") @@ -138,7 +136,7 @@ def reset_slots(self, stream_ids: List[int], eos_flags: List[bool]) -> None: if eos_flag and stream_id in self.streamidx2slotidx: self.reset_slot(self.streamidx2slotidx[stream_id]) - def get_context(self, stream_ids: List[int]) -> Tuple[StreamingDecodeState, Dict[int, int]]: + def get_context(self, stream_ids: list[int]) -> tuple[StreamingDecodeState, dict[int, int]]: """Return the cached context associated with the provided stream ids.""" if len(stream_ids) == 0: return self._create_context(), {} diff --git a/nemo/collections/speechlm2/inference/streaming/state/s2s_state.py b/nemo/collections/speechlm2/inference/streaming/state/s2s_state.py index 76fd9aa8682c..a0fb2a9c5082 100644 --- a/nemo/collections/speechlm2/inference/streaming/state/s2s_state.py +++ b/nemo/collections/speechlm2/inference/streaming/state/s2s_state.py @@ -13,11 +13,9 @@ # limitations under the License. from dataclasses import dataclass, field -from typing import List, Optional import torch from nemo.collections.asr.inference.utils.text_segment import Word -from nemo.utils import logging @dataclass @@ -47,14 +45,14 @@ class S2SStreamingState: # Accumulated ASR (user) text output_asr_text_str: str = "" # Word-level timings for the agent response - output_words: List[Word] = field(default_factory=list) + output_words: list[Word] = field(default_factory=list) # Snapshots of full token-ID tensors, saved from StreamingDecodeState # before the decode context is destroyed at end-of-stream. # Used for post-hoc tokens_to_str conversion. - final_gen_text: Optional[torch.Tensor] = None - final_gen_asr_text: Optional[torch.Tensor] = None - final_gen_function_text: Optional[torch.Tensor] = None + final_gen_text: torch.Tensor | None = None + final_gen_asr_text: torch.Tensor | None = None + final_gen_function_text: torch.Tensor | None = None final_total_frames: int = 0 def __post_init__(self) -> None: @@ -110,7 +108,7 @@ def save_token_tensors( gen_text: torch.Tensor, gen_asr_text: torch.Tensor, total_frames: int, - gen_function_text: Optional[torch.Tensor] = None, + gen_function_text: torch.Tensor | None = None, ) -> None: """Snapshot the full token-ID tensors from the decode context before it is destroyed.""" self.final_gen_text = gen_text[:, :total_frames].clone().cpu() @@ -121,7 +119,7 @@ def save_token_tensors( if gen_function_text is not None else None ) - def get_token_tensors(self) -> Optional[tuple]: + def get_token_tensors(self) -> tuple | None: """Return (gen_text, gen_asr_text, total_frames, gen_function_text) or None if not saved.""" if self.final_gen_text is None: return None diff --git a/nemo/collections/speechlm2/inference/utils/audio_data.py b/nemo/collections/speechlm2/inference/utils/audio_data.py index c73f6bdd6266..623ec36e9cab 100644 --- a/nemo/collections/speechlm2/inference/utils/audio_data.py +++ b/nemo/collections/speechlm2/inference/utils/audio_data.py @@ -14,11 +14,8 @@ """Audio data loading and output serialization for S2S inference scripts.""" -from __future__ import annotations - import json import os -from typing import List import soundfile as sf @@ -31,7 +28,7 @@ def prepare_audio_data( audio_file: str, default_system_prompt: str | None = None, sort_by_duration: bool = True, -) -> tuple[List[str], List[S2SRequestOptions], List[str | None]]: +) -> tuple[list[str], list[S2SRequestOptions], list[str | None]]: """Load audio filepaths and per-stream options from a folder, single file, or manifest. When the input is a JSON manifest, each line may contain:: @@ -48,8 +45,8 @@ def prepare_audio_data( if not os.path.isabs(audio_file): audio_file = os.path.abspath(audio_file) - options: List[S2SRequestOptions] = [] - ground_truths: List[str | None] = [] + options: list[S2SRequestOptions] = [] + ground_truths: list[str | None] = [] if os.path.isdir(audio_file): filepaths = [os.path.join(audio_file, x) for x in os.listdir(audio_file) if x.endswith(".wav")] @@ -86,7 +83,7 @@ def prepare_audio_data( return filepaths, options, ground_truths -def calculate_duration(audio_filepaths: List[str]) -> float: +def calculate_duration(audio_filepaths: list[str]) -> float: """Calculate total duration of the given audio files in seconds.""" total_dur = 0 for audio_filepath in audio_filepaths: @@ -96,7 +93,7 @@ def calculate_duration(audio_filepaths: List[str]) -> float: def calculate_padded_duration( - audio_filepaths: List[str], + audio_filepaths: list[str], pad_audio_to_sec: float | None = None, pad_silence_ratio: float | None = None, pad_audio_by_sec: float | None = None, @@ -118,11 +115,11 @@ def calculate_padded_duration( def dump_output( - audio_filepaths: List[str], + audio_filepaths: list[str], output: PipelineOutput, output_dir: str, - options: List[S2SRequestOptions], - ground_truths: List[str | None], + options: list[S2SRequestOptions], + ground_truths: list[str | None], ) -> None: """Dump inference results to output_processed.json, output_raw.json, and per-file CTM. diff --git a/nemo/collections/speechlm2/inference/utils/pipeline_utils.py b/nemo/collections/speechlm2/inference/utils/pipeline_utils.py index 7dfee3b6ed0f..0425fe1cdd6b 100644 --- a/nemo/collections/speechlm2/inference/utils/pipeline_utils.py +++ b/nemo/collections/speechlm2/inference/utils/pipeline_utils.py @@ -13,8 +13,6 @@ # limitations under the License. import re -from typing import List, Optional - import torch from whisper_normalizer.english import EnglishTextNormalizer @@ -50,19 +48,19 @@ class PipelineOutput: def __init__( self, - texts: Optional[List[str]] = None, - words: Optional[List[List[Word]]] = None, - asr_texts: Optional[List[str]] = None, - texts_with_timestamps: Optional[List[str]] = None, - asr_texts_with_timestamps: Optional[List[str]] = None, - raw_texts: Optional[List[str]] = None, - raw_asr_texts: Optional[List[str]] = None, - token_texts: Optional[List[torch.Tensor | None]] = None, - token_asr_texts: Optional[List[torch.Tensor | None]] = None, - token_function_texts: Optional[List[torch.Tensor | None]] = None, - token_lengths: Optional[List[int | None]] = None, - audio_filepaths: Optional[List[str | None]] = None, - debug_data: Optional[List[list]] = None, + texts: list[str] | None = None, + words: list[list[Word]] | None = None, + asr_texts: list[str] | None = None, + texts_with_timestamps: list[str] | None = None, + asr_texts_with_timestamps: list[str] | None = None, + raw_texts: list[str] | None = None, + raw_asr_texts: list[str] | None = None, + token_texts: list[torch.Tensor | None] | None = None, + token_asr_texts: list[torch.Tensor | None] | None = None, + token_function_texts: list[torch.Tensor | None] | None = None, + token_lengths: list[int | None] | None = None, + audio_filepaths: list[str | None] | None = None, + debug_data: list[list] | None = None, ): if texts is None and words is None: raise ValueError("At least one of the 'texts' or 'words' should be provided.") diff --git a/nemo/collections/speechlm2/inference/vllm/scripts/convert_nemotronllm_checkpoint.py b/nemo/collections/speechlm2/inference/vllm/scripts/convert_nemotronllm_checkpoint.py index 47d2463cebf0..e9f1c53c12e6 100644 --- a/nemo/collections/speechlm2/inference/vllm/scripts/convert_nemotronllm_checkpoint.py +++ b/nemo/collections/speechlm2/inference/vllm/scripts/convert_nemotronllm_checkpoint.py @@ -30,15 +30,13 @@ import json import os from pathlib import Path -from typing import Dict, List, Optional - import torch from safetensors.torch import load_file, save_file from transformers import AutoConfig, AutoTokenizer from nemo.utils import logging -def load_checkpoint(checkpoint_path: str) -> Dict[str, torch.Tensor]: +def load_checkpoint(checkpoint_path: str) -> dict[str, torch.Tensor]: """ Load checkpoint from safetensors or PyTorch format. @@ -67,9 +65,9 @@ def load_checkpoint(checkpoint_path: str) -> Dict[str, torch.Tensor]: def filter_tensors( - state_dict: Dict[str, torch.Tensor], - prefixes_to_keep: List[str] -) -> Dict[str, torch.Tensor]: + state_dict: dict[str, torch.Tensor], + prefixes_to_keep: list[str] +) -> dict[str, torch.Tensor]: """ Filter tensors to keep only those with specified prefixes. @@ -95,9 +93,9 @@ def filter_tensors( def convert_nemo_to_hf_format( checkpoint_path: str, output_dir: str, - config_path: Optional[str] = None, - pretrained_llm: Optional[str] = None, - tensors_to_keep: Optional[List[str]] = None, + config_path: str | None = None, + pretrained_llm: str | None = None, + tensors_to_keep: list[str] | None = None, dtype: str = "float32", ) -> None: """ diff --git a/nemo/collections/speechlm2/inference/vllm/streaming_llm_engine.py b/nemo/collections/speechlm2/inference/vllm/streaming_llm_engine.py index 8384dd112ef7..431714b79fb7 100644 --- a/nemo/collections/speechlm2/inference/vllm/streaming_llm_engine.py +++ b/nemo/collections/speechlm2/inference/vllm/streaming_llm_engine.py @@ -21,14 +21,13 @@ import json import torch import asyncio -from typing import Optional, Dict, Any, AsyncGenerator, Tuple +from typing import Any, AsyncGenerator from dataclasses import dataclass from enum import Enum from vllm.v1.engine.async_llm import AsyncLLM from vllm import SamplingParams from vllm.engine.arg_utils import AsyncEngineArgs -from vllm.config.model import CustomInputSpec from vllm.attention.selector import _cached_get_attn_backend from nemo.utils import logging @@ -44,8 +43,8 @@ class StreamStatus(Enum): class GenerationResult: token_id: int is_finished: bool - custom_outputs: Optional[Dict[str, torch.Tensor]] = None - finish_reason: Optional[str] = None + custom_outputs: dict[str, torch.Tensor] | None = None + finish_reason: str | None = None total_tokens: int = 0 @@ -55,7 +54,7 @@ class RequestState: request_id: str status: StreamStatus generated_tokens: list - generation_iterator: Optional[AsyncGenerator] = None + generation_iterator: AsyncGenerator | None = None class LLMStreamingEngine: @@ -96,10 +95,10 @@ def __init__( self.skip_tokenizer_init = skip_tokenizer_init # Engine state - self.engine: Optional[AsyncLLM] = None + self.engine: AsyncLLM | None = None # Request state tracking - supports multiple concurrent requests - self.requests: Dict[str, RequestState] = {} + self.requests: dict[str, RequestState] = {} # Default sampling parameters default_sampling = { @@ -197,8 +196,8 @@ async def start_generation( return True async def generate_next_token(self, input_tensors: list[torch.Tensor], - prompt_token_ids: Optional[list[int]] = None, - request_id: str = "speech_stream") -> Optional[GenerationResult]: + prompt_token_ids: list[int] | None = None, + request_id: str = "speech_stream") -> GenerationResult | None: """ Generate the next token using the provided input embedding. @@ -369,7 +368,7 @@ async def shutdown(self): # Clear all request states self.requests.clear() - def get_status(self, request_id: Optional[str] = None) -> Dict[str, Any]: + def get_status(self, request_id: str | None = None) -> dict[str, Any]: """Get current status information. Args: diff --git a/nemo/collections/speechlm2/models/duplex_stt_model.py b/nemo/collections/speechlm2/models/duplex_stt_model.py index 3e9607d15456..4d7c322036eb 100644 --- a/nemo/collections/speechlm2/models/duplex_stt_model.py +++ b/nemo/collections/speechlm2/models/duplex_stt_model.py @@ -16,8 +16,6 @@ import re import warnings from pathlib import Path -from typing import Optional, Union - import torch from lightning import LightningModule from omegaconf import DictConfig @@ -138,9 +136,9 @@ def __init__(self, cfg: dict) -> None: def save_pretrained( self, - save_directory: Union[str, Path], + save_directory: str | Path, **kwargs, - ) -> Optional[str]: + ) -> str | None: """Save model and also export LLM artifacts (config + tokenizer) for offline inference.""" result = super().save_pretrained(save_directory, **kwargs) diff --git a/nemo/collections/speechlm2/models/nemotron_voicechat.py b/nemo/collections/speechlm2/models/nemotron_voicechat.py index 0e9e7d2d8d5b..fea99f9f39d8 100644 --- a/nemo/collections/speechlm2/models/nemotron_voicechat.py +++ b/nemo/collections/speechlm2/models/nemotron_voicechat.py @@ -16,8 +16,6 @@ import os import warnings from pathlib import Path -from typing import Optional, Union - import torch from huggingface_hub import CONFIG_NAME from lightning import LightningModule @@ -137,9 +135,9 @@ def __init__(self, cfg: dict) -> None: def save_pretrained( self, - save_directory: Union[str, Path], + save_directory: str | Path, **kwargs, - ) -> Optional[str]: + ) -> str | None: """Save model and export LLM artifacts (tokenizer + perception config) for offline inference.""" result = super().save_pretrained(save_directory, **kwargs) @@ -181,13 +179,13 @@ def _from_pretrained( cls, *, model_id: str, - revision: Optional[str], - cache_dir: Optional[Union[str, Path]], + revision: str | None, + cache_dir: str | Path | None, force_download: bool, - proxies: Optional[dict], - resume_download: Optional[bool], + proxies: dict | None, + resume_download: bool | None, local_files_only: bool, - token: Union[str, bool, None], + token: str | bool | None, map_location: str = "cpu", strict: bool = False, **model_kwargs, From c3c0d7ef033c2d18dcd6018661aa554a73be88da Mon Sep 17 00:00:00 2001 From: Elena Rastorgueva Date: Wed, 1 Apr 2026 21:54:08 +0000 Subject: [PATCH 30/40] use nemo_asr.metrics.wer.word_error_rate for wer calc Signed-off-by: Elena Rastorgueva --- .../s2s_streaming_infer.py | 21 ++++++++----------- 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/examples/speechlm2/nemo_inference_pipelines/s2s_streaming_infer.py b/examples/speechlm2/nemo_inference_pipelines/s2s_streaming_infer.py index b521a5fc318f..18a29bfe2d18 100644 --- a/examples/speechlm2/nemo_inference_pipelines/s2s_streaming_infer.py +++ b/examples/speechlm2/nemo_inference_pipelines/s2s_streaming_infer.py @@ -27,10 +27,9 @@ from time import time import hydra -import torch -from jiwer import wer as compute_wer from omegaconf import DictConfig +from nemo.collections.asr.metrics.wer import word_error_rate from nemo.collections.speechlm2.inference.factory.s2s_pipeline_builder import S2SPipelineBuilder from nemo.collections.speechlm2.inference.utils.audio_data import ( calculate_duration, @@ -68,22 +67,20 @@ def main(cfg: DictConfig): rtfx = data_dur / exec_dur if exec_dur > 0 else float('inf') logging.info(f"RTFX: {rtfx:.2f} ({data_dur:.2f}s / {exec_dur:.2f}s)") - # Compute WER when ground-truth texts are available + # Compute WER when ground-truth texts are available (micro-average, + # matching the offline eval in speechlm2.parts.metrics.asr_cer_wer) asr_texts = output.asr_texts_with_timestamps or [None] * len(audio_filepaths) - wer_scores = [] + all_refs, all_hyps = [], [] for gt, asr_text in zip(ground_truths, asr_texts): if gt and asr_text: cleaned_gt = clean_pred_text(gt) cleaned_pred = clean_pred_text(asr_text) if cleaned_gt.strip() and cleaned_pred.strip(): - wer_scores.append(compute_wer(cleaned_gt, cleaned_pred)) - if wer_scores: - avg_wer = sum(wer_scores) / len(wer_scores) - logging.info( - f"WER: avg={avg_wer:.4f} ({avg_wer * 100:.2f}%), " - f"n={len(wer_scores)}, " - f"min={min(wer_scores):.4f}, max={max(wer_scores):.4f}" - ) + all_refs.append(cleaned_gt) + all_hyps.append(cleaned_pred) + if all_refs: + wer = word_error_rate(hypotheses=all_hyps, references=all_refs) + logging.info(f"WER: {wer:.4f} ({wer * 100:.2f}%), n={len(all_refs)}") output_dir = cfg.get("output_dir", "./generated") dump_output(audio_filepaths, output, output_dir, options, ground_truths) From 770efb48d2e819d8e54cfaad69e037afc21b0d35 Mon Sep 17 00:00:00 2001 From: Elena Rastorgueva Date: Wed, 1 Apr 2026 22:05:19 +0000 Subject: [PATCH 31/40] use SimpleTimer in s2s_streaming_infer.py script Signed-off-by: Elena Rastorgueva --- .../nemo_inference_pipelines/s2s_streaming_infer.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/examples/speechlm2/nemo_inference_pipelines/s2s_streaming_infer.py b/examples/speechlm2/nemo_inference_pipelines/s2s_streaming_infer.py index 18a29bfe2d18..46e5394588ee 100644 --- a/examples/speechlm2/nemo_inference_pipelines/s2s_streaming_infer.py +++ b/examples/speechlm2/nemo_inference_pipelines/s2s_streaming_infer.py @@ -24,8 +24,6 @@ streaming.buffer_size_in_secs=5.6 """ -from time import time - import hydra from omegaconf import DictConfig @@ -39,6 +37,7 @@ ) from nemo.collections.speechlm2.inference.utils.pipeline_utils import clean_pred_text from nemo.utils import logging +from nemo.utils.timers import SimpleTimer @hydra.main(config_path="./conf", config_name="s2s_streaming", version_base=None) @@ -51,9 +50,11 @@ def main(cfg: DictConfig): pipeline = S2SPipelineBuilder.build_pipeline(cfg) - start = time() + timer = SimpleTimer() + timer.start() output = pipeline.run(audio_filepaths, options=options) - exec_dur = time() - start + timer.stop() + exec_dur = timer.total_sec() logging.info(f"Generated {len(audio_filepaths)} files in {exec_dur:.2f}s") # Log RTFX (accounts for padding when configured) From 89f818f6d61a6cfcd04cca89be90fed0044b9d01 Mon Sep 17 00:00:00 2001 From: Elena Rastorgueva Date: Fri, 3 Apr 2026 02:15:32 +0000 Subject: [PATCH 32/40] simplify flow for prefill and use per-stream options, including sampling params Signed-off-by: Elena Rastorgueva --- docs/source/speechlm2/streaming_inference.rst | 87 +++++- .../inference/model_wrappers/model_factory.py | 40 ++- .../nemotron_voicechat_inference_wrapper.py | 18 ++ .../pipelines/s2s_pipeline_interface.py | 15 +- .../pipelines/streaming_s2s_pipeline.py | 261 ++++++++++-------- .../streaming/framing/s2s_request_options.py | 51 +++- .../inference/streaming/state/s2s_state.py | 6 + 7 files changed, 330 insertions(+), 148 deletions(-) diff --git a/docs/source/speechlm2/streaming_inference.rst b/docs/source/speechlm2/streaming_inference.rst index 2313cd553315..d23e53cd1dd8 100644 --- a/docs/source/speechlm2/streaming_inference.rst +++ b/docs/source/speechlm2/streaming_inference.rst @@ -97,11 +97,44 @@ over chunks and calls a single step method: What Happens Inside One Step ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +.. code-block:: text + + generate_step(frames) + │ + ├─ for each frame where is_first=True: + │ │ + │ └─ _init_state(stream_id, options) + │ 1. augment_with_defaults() ← fill None fields from YAML + │ 2. create_state(options) ← pipeline-level state + │ 3. create context_manager ← KV caches, decode slots + │ 4. prefill system prompt ← populate LLM KV cache + │ + ├─ any frames with audio? + │ │ + │ NO → return (server prefill-only request) + │ │ + │ YES ↓ + │ + └─ generate_step_for_frames() + 1. audio buffering + 2. perception encoder + 3. per-frame LLM loop + 4. per-frame TTS + 5. codec decode + 6. state updates + output accumulation + Each call to ``generate_step(frames)`` performs: -1. **Prefill detection** -- A zero-length first frame with a system prompt - triggers ``prefill_for_new_stream()``, which initializes the LLM KV cache - with the system prompt and the TTS speaker embedding. +1. **Stream init on** ``is_first`` -- If a frame has ``is_first=True``, the + private ``_init_state()`` method runs: per-stream options are merged with + pipeline defaults (via ``S2SRequestOptions.augment_with_defaults()``), + a fresh ``S2SStreamingState`` is created, the context manager is + allocated, and the LLM KV cache is prefilled with the system prompt and + TTS speaker embedding. This mirrors ASR's ``init_state()`` called inside + ``transcribe_step()``. If the frame carries no audio (zero-length + samples), the method returns after init — this is the recommended + pattern for latency-sensitive server deployments (see + :ref:`init-and-latency` below). 2. **Audio buffering** -- ``BatchedAudioBufferer`` (reused from ASR infrastructure) maintains a sliding window of ``buffer_size_in_secs``. @@ -188,6 +221,9 @@ S2S Model Settings (``s2s``) * - ``temperature`` - ``0.3`` - Sampling temperature. + * - ``repetition_penalty`` + - ``1.1`` + - Repetition penalty applied to previously generated tokens. * - ``deterministic`` - ``false`` - Force deterministic mode (native engine only). @@ -246,28 +282,59 @@ Server Integration ------------------ The same ``generate_step()`` method used by ``run()`` can be called directly -from a custom server. The zero-length Frame protocol handles prefill: +from a custom server: .. code-block:: python - # 1. Prefill system prompt (zero-length frame) - prefill_frame = Frame( + # 1. Init stream (empty audio so prefill completes before recording) + init_frame = Frame( samples=torch.empty(0), stream_id=stream_id, is_first=True, is_last=False, - options=S2SRequestOptions(system_prompt=prompt), + options=S2SRequestOptions(system_prompt=prompt, top_p=0.9), ) - pipeline.generate_step([prefill_frame]) + pipeline.generate_step([init_frame]) + # -> client can now start recording # 2. Stream audio chunks - for chunk in audio_source: + for i, chunk in enumerate(audio_source): frame = Frame( samples=chunk, stream_id=stream_id, - is_first=(i == 0), is_last=(i == last), + is_first=False, is_last=(i == last), ) pipeline.generate_step([frame]) +Per-stream options (``system_prompt``, ``top_p``, ``temperature``, +``repetition_penalty``) are attached to the ``is_first`` frame via +``S2SRequestOptions``. Any field left as ``None`` falls back to the +pipeline-level YAML default through ``augment_with_defaults()``. + +.. _init-and-latency: + +Init and Latency +^^^^^^^^^^^^^^^^ + +When ``generate_step`` sees ``is_first``, it always runs stream +initialization (context creation, KV-cache prefill). If the frame also +carries audio, inference runs immediately after init in the same call. + +For **latency-sensitive** server deployments (real-time voice chat), +prefill can take hundreds of milliseconds or even multiple seconds. +Clients should send ``is_first`` with **empty audio**, wait for the +response confirming init is done, and only then start recording the +user's microphone. This prevents audio from queuing up during the +expensive prefill phase. + +For **batch/offline** usage (CLI ``run()``), there is no real-time +constraint. The first frame carries both ``is_first`` and real audio, +so init and first-chunk processing happen in one call with no extra +round-trip. + +The pipeline makes no distinction between these cases — it initializes +on ``is_first`` and processes whatever audio is present. The latency +trade-off is entirely the caller's choice. + Batch Size ---------- diff --git a/nemo/collections/speechlm2/inference/model_wrappers/model_factory.py b/nemo/collections/speechlm2/inference/model_wrappers/model_factory.py index 1fcb4517390e..1b9399745e80 100644 --- a/nemo/collections/speechlm2/inference/model_wrappers/model_factory.py +++ b/nemo/collections/speechlm2/inference/model_wrappers/model_factory.py @@ -91,6 +91,7 @@ def _sample_text_token( logits: torch.Tensor, generated_tokens: torch.Tensor, current_step: int, + sampling_params: dict[str, float] | None = None, ) -> torch.Tensor: """ Sample text tokens with optional top-p sampling and repetition penalty. @@ -100,10 +101,17 @@ def _sample_text_token( logits: Logits tensor of shape (B, V) for vocabulary V. generated_tokens: Previously generated tokens of shape (B, T). current_step: Current decoding step (used to slice generated_tokens). + sampling_params: Optional per-request overrides for ``top_p``, + ``temperature``, and ``repetition_penalty``. Missing keys + fall back to ``self.*`` (the pipeline-level defaults). Returns: Sampled token ids of shape (B,). """ + top_p = sampling_params.get("top_p", self.top_p) if sampling_params else self.top_p + temperature = sampling_params.get("temperature", self.temperature) if sampling_params else self.temperature + rep_penalty = sampling_params.get("repetition_penalty", self.repetition_penalty) if sampling_params else self.repetition_penalty + B, V = logits.shape device = logits.device @@ -111,11 +119,11 @@ def _sample_text_token( greedy_tokens = logits.argmax(dim=-1) # (B,) # If no sampling needed (all disabled), return greedy - if self.top_p >= 1.0 and self.repetition_penalty == 1.0 and (self.temperature == 1.0 or self.temperature == 0.0): + if top_p >= 1.0 and rep_penalty == 1.0 and (temperature == 1.0 or temperature == 0.0): return greedy_tokens # temperature=0 means greedy - if self.temperature == 0.0: + if temperature == 0.0: return greedy_tokens # For each batch, if greedy is special token, use greedy; otherwise sample @@ -134,7 +142,7 @@ def _sample_text_token( batch_logits = logits[b].clone() # (V,) # Apply repetition penalty (vectorized, no Python loop) - if self.repetition_penalty != 1.0 and current_step > 0: + if rep_penalty != 1.0 and current_step > 0: unique_prev = generated_tokens[b, :current_step].unique() # Exclude special tokens from penalty if self._special_ids_tensor is not None: @@ -151,13 +159,13 @@ def _sample_text_token( # (same as the standard repetition_penalty convention) batch_logits[unique_prev] = torch.where( prev_logits > 0, - prev_logits / self.repetition_penalty, - prev_logits * self.repetition_penalty, + prev_logits / rep_penalty, + prev_logits * rep_penalty, ) # Apply temperature scaling - if self.temperature != 1.0: - batch_logits = batch_logits / self.temperature + if temperature != 1.0: + batch_logits = batch_logits / temperature # Fall back to greedy if logits are non-finite before top-p # (top-p intentionally introduces -inf, so check must happen first) @@ -173,13 +181,13 @@ def _sample_text_token( continue # Apply top-p (nucleus) sampling - if self.top_p < 1.0: + if top_p < 1.0: sorted_logits, sorted_indices = torch.sort(batch_logits, descending=True) sorted_probs = torch.softmax(sorted_logits, dim=-1) cumulative_probs = torch.cumsum(sorted_probs, dim=-1) # Remove tokens with cumulative prob > top_p, keeping at least one - sorted_indices_to_remove = cumulative_probs > self.top_p + sorted_indices_to_remove = cumulative_probs > top_p # Shift to keep the first token that exceeds threshold sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone() sorted_indices_to_remove[0] = False @@ -476,7 +484,8 @@ async def _process_inputs_to_outputs( decode_steps: int = 1, prompt_token_ids: list | None = None, generated_tokens: torch.Tensor | None = None, - current_step: int = 0 + current_step: int = 0, + sampling_params: dict[str, float] | None = None, ) -> dict[str, Any]: """ Process embeddings sequentially to generate text and ASR tokens. @@ -489,6 +498,7 @@ async def _process_inputs_to_outputs( generated_tokens: Previously generated tokens [batch, num_generated]. Required for repetition_penalty. If None, creates empty tensor. current_step: Current decoding step. Used for repetition penalty. + sampling_params: Optional per-request overrides for sampling. """ if decode_steps == 0: @@ -528,7 +538,10 @@ async def _process_inputs_to_outputs( text_logits = result.custom_outputs["text_logits"] if result else None predicted_token = text_token_ids[-1] - if self.top_p < 1.0 or self.repetition_penalty != 1.0 or (self.temperature != 1.0 and self.temperature != 0.0): + eff_top_p = sampling_params.get("top_p", self.top_p) if sampling_params else self.top_p + eff_temp = sampling_params.get("temperature", self.temperature) if sampling_params else self.temperature + eff_rep = sampling_params.get("repetition_penalty", self.repetition_penalty) if sampling_params else self.repetition_penalty + if eff_top_p < 1.0 or eff_rep != 1.0 or (eff_temp != 1.0 and eff_temp != 0.0): # Use provided generated_tokens or create empty tensor batch_size = text_logits.shape[0] if generated_tokens is None: @@ -541,6 +554,7 @@ async def _process_inputs_to_outputs( logits=text_logits, generated_tokens=gen_tokens, current_step=current_step, + sampling_params=sampling_params, ) ans = { @@ -892,6 +906,7 @@ def __call__( generated_tokens: torch.Tensor | None = None, current_step: int = 0, return_logits: bool = False, + sampling_params: dict[str, float] | None = None, **kwargs ) -> dict[str, Any]: """ @@ -904,6 +919,8 @@ def __call__( generated_tokens: Previously generated tokens [batch, num_generated]. Required for repetition_penalty. If None, creates empty tensor. current_step: Current decoding step. Used for repetition penalty. + sampling_params: Optional per-request overrides for sampling + (top_p, temperature, repetition_penalty). **kwargs: Additional arguments passed to the model Returns: @@ -932,6 +949,7 @@ def __call__( logits=text_logits, generated_tokens=gen_tokens, current_step=current_step, + sampling_params=sampling_params, ) # ASR tokens use greedy decoding (no sampling) diff --git a/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py b/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py index a563529a5410..12703265af26 100755 --- a/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py +++ b/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py @@ -547,12 +547,25 @@ def infer_one_step( request_id: str | None = None, has_prompt: bool = False, return_debug: bool = False, + sampling_params: dict[str, float] | None = None, ) -> InferenceStepResult: """Run one streaming inference step: perception -> LLM -> TTS -> audio decode. All mutable decode state (caches, gen_text, gen_asr_text, code, etc.) is updated **in-place** on *state*. The returned :class:`InferenceStepResult` carries only per-step outputs needed by the pipeline. + + Args: + audio_input (torch.Tensor): Raw audio tensor for this chunk, shape ``(1, samples)``. + num_frames_per_chunk (int): Number of 80 ms frames in this chunk. + state (StreamingDecodeState): Mutable decode state (KV caches, token workspaces, etc.). + request_id (str | None): Unique ID for this stream (used by vLLM engines). + has_prompt (bool): Whether the LLM KV cache already contains a prefilled + system prompt. Affects the first-frame embedding (PAD vs BOS). + return_debug (bool): If True, attach per-step debug info to the result. + sampling_params (dict[str, float] | None): Optional per-stream sampling overrides + (``top_p``, ``temperature``, ``repetition_penalty``). + Keys that are absent fall back to the pipeline-level defaults. """ effective_request_id = request_id or self.request_id frame_idx = state.frame_idx @@ -603,6 +616,7 @@ def infer_one_step( ans = self._run_llm_step( input_emb, state, frame_offset, effective_request_id, current_frame_idx, use_llm_cache, return_debug, new_input_embeds, + sampling_params=sampling_params, ) if return_debug and "text_logits" in ans: @@ -741,6 +755,7 @@ def _run_llm_step( use_llm_cache: bool, return_debug: bool, new_input_embeds: list, + sampling_params: dict[str, float] | None = None, ) -> dict: """Run one LLM forward pass (native cache, vLLM, or full-history). @@ -756,6 +771,7 @@ def _run_llm_step( request_id=request_id, generated_tokens=state.gen_text, current_step=current_frame_idx, + sampling_params=sampling_params, ) else: cache_pos = torch.tensor( @@ -768,6 +784,7 @@ def _run_llm_step( generated_tokens=state.gen_text, current_step=current_frame_idx, return_logits=return_debug, + sampling_params=sampling_params, ) state.llm_cache = ans["cache"] else: @@ -779,6 +796,7 @@ def _run_llm_step( generated_tokens=state.gen_text, current_step=current_frame_idx, return_logits=return_debug, + sampling_params=sampling_params, ) if self._profile_timing: diff --git a/nemo/collections/speechlm2/inference/pipelines/s2s_pipeline_interface.py b/nemo/collections/speechlm2/inference/pipelines/s2s_pipeline_interface.py index fdd796f36029..8948739f205b 100644 --- a/nemo/collections/speechlm2/inference/pipelines/s2s_pipeline_interface.py +++ b/nemo/collections/speechlm2/inference/pipelines/s2s_pipeline_interface.py @@ -12,8 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from typing import Any +from nemo.collections.speechlm2.inference.streaming.framing.s2s_request_options import S2SRequestOptions + class S2SPipelineInterface: """Base class for all streaming S2S pipelines. @@ -42,17 +46,22 @@ def delete_state(self, stream_id: int) -> None: if stream_id in self._state_pool: del self._state_pool[stream_id] - def create_state(self): # noqa: D401 (keep same signature as recognizers) + def create_state(self, options: S2SRequestOptions | None = None): """Create and return a *new*, *empty* state object. + Args: + options: Per-stream request options (system prompt, sampling + overrides, etc.). Stored on the state so they can be + consulted throughout the stream's lifetime. + Must be implemented by concrete pipelines. """ raise NotImplementedError("`create_state()` must be implemented in a subclass.") - def get_or_create_state(self, stream_id: int): + def get_or_create_state(self, stream_id: int, options: S2SRequestOptions | None = None): """Return existing state for *stream_id* or create a new one via :py:meth:`create_state`.""" if stream_id not in self._state_pool: - self._state_pool[stream_id] = self.create_state() + self._state_pool[stream_id] = self.create_state(options) return self._state_pool[stream_id] # ------------------------------------------------------------------ diff --git a/nemo/collections/speechlm2/inference/pipelines/streaming_s2s_pipeline.py b/nemo/collections/speechlm2/inference/pipelines/streaming_s2s_pipeline.py index bd7ae118956d..c473e5713989 100644 --- a/nemo/collections/speechlm2/inference/pipelines/streaming_s2s_pipeline.py +++ b/nemo/collections/speechlm2/inference/pipelines/streaming_s2s_pipeline.py @@ -97,13 +97,17 @@ def __init__(self, cfg: DictConfig, s2s_model: NemotronVoicechatInferenceWrapper ) # ------------------------------------------------------------------ - # System prompt configuration + # System prompt & sampling defaults (from YAML s2s block) # ------------------------------------------------------------------ s2s_cfg = cfg.get("s2s", {}) self.system_prompt: str | None = getattr(s2s_cfg, "system_prompt", None) if self.system_prompt: logging.info(f"System prompt configured: {self.system_prompt[:100]}{'...' if len(self.system_prompt) > 100 else ''}") + self._default_top_p: float | None = getattr(s2s_cfg, "top_p", None) + self._default_temperature: float | None = getattr(s2s_cfg, "temperature", None) + self._default_repetition_penalty: float | None = getattr(s2s_cfg, "repetition_penalty", None) + # Context manager self.context_manager = S2SContextManager( s2s_model=self.s2s_model, @@ -138,25 +142,79 @@ def __init__(self, cfg: DictConfig, s2s_model: NemotronVoicechatInferenceWrapper # ------------------------------------------------------------------ # State helpers # ------------------------------------------------------------------ - def create_state(self) -> S2SStreamingState: - """Create new empty state.""" + def create_state(self, options: S2SRequestOptions | None = None) -> S2SStreamingState: + """Create new empty state with optional per-stream options.""" dtype = getattr(self.s2s_model, "dtype", torch.float32) return S2SStreamingState( device=self.device, dtype=dtype, output_sample_rate=self.output_sample_rate, + options=options or S2SRequestOptions(), + ) + + def _init_state(self, stream_id: int, options: S2SRequestOptions | None = None) -> None: + """Initialize a new stream: resolve defaults, create state, create context, prefill. + + This is the S2S equivalent of ASR's ``init_state()`` in ``BasePipeline``. + Called automatically by :meth:`generate_step` when a frame has + ``is_first=True``. + + The method always runs stream initialization (state creation, + context-manager allocation, KV-cache prefill). If the triggering + frame also carries audio, :meth:`generate_step` will process it + immediately after this method returns. For latency-sensitive + deployments (real-time voice chat), callers should send the first + frame with **empty audio** so that prefill completes before the + user starts speaking — this prevents audio from queuing up during + the expensive prefill phase. + """ + # Normalize empty-string prompts to None so augment_with_defaults + # fills in the YAML default instead of treating "" as "set". + raw_opts = options or S2SRequestOptions() + if raw_opts.system_prompt is not None and not raw_opts.system_prompt.strip(): + raw_opts = S2SRequestOptions( + system_prompt=None, + top_p=raw_opts.top_p, + temperature=raw_opts.temperature, + repetition_penalty=raw_opts.repetition_penalty, + ) + opts = raw_opts.augment_with_defaults( + default_system_prompt=self.system_prompt, + default_top_p=self._default_top_p, + default_temperature=self._default_temperature, + default_repetition_penalty=self._default_repetition_penalty, + ) + self.get_or_create_state(stream_id, options=opts) + + self.context_manager = S2SContextManager( + s2s_model=self.s2s_model, + num_slots=self.batch_size, + max_len=self.max_len, ) + # Prefill can take hundreds of ms, or even tens of seconds if the + # prompt is long and the model is not warmed up. + prompt = opts.system_prompt + start_prefill = time.time() + with torch.no_grad(), torch.inference_mode(): + self._prefill_system_prompt(stream_id, prompt) + torch.cuda.synchronize() + logging.info(f"_init_state: stream_id={stream_id}, prefill={1000*(time.time()-start_prefill):.1f}ms") + + # Will tell generate_step_for_frames whether the KV cache already contains + # a system prompt, so it can choose the right first-frame embedding + # (PAD tokens if prefilled, BOS tokens if not). Consumed and + # cleared on the first audio frame. + self._stream_has_prompt = bool(prompt) + # ------------------------------------------------------------------ # Output helpers # ------------------------------------------------------------------ - def log_output(self, frames: list[Frame], audio_wave: Tensor, ready_feats: list[bool], text_pieces: list[str], asr_text_pieces: list[str] = None): + def log_output(self, frames: list[Frame], audio_wave: Tensor, text_pieces: list[str], asr_text_pieces: list[str] = None): """Append generated audio waveform and text to per-stream state.""" for idx, frame in enumerate(frames): - if not ready_feats[idx]: - continue - state = self.get_or_create_state(frame.stream_id) + state = self.get_state(frame.stream_id) # audio_wave is [B, S]; take sample idx (None when decode_audio=False) sample_audio = audio_wave[idx:idx+1, ...] if audio_wave is not None else None # Determine text piece for this index @@ -176,40 +234,34 @@ def log_output(self, frames: list[Frame], audio_wave: Tensor, ready_feats: list[ state.append_step_output(sample_audio, text=piece, asr_text=asr_piece) - def inner_generate_step(self, frames: list[Frame], buffers: list[Tensor], ready_feats: list[bool]): - """Generate speech for chunks in *batch* using a shared ContextManager.""" + def generate_step_for_frames(self, frames: list[Frame], buffers: list[Tensor]): + """Generate speech for audio Frames using a shared ContextManager. + + This is the S2S equivalent of ASR's ``transcribe_step_for_frames`` + in ``BasePipeline``. Like its ASR counterpart, it is never called + directly — :meth:`generate_step` (the public API, analogous to + ``transcribe_step``) handles stream init and then delegates here + for the actual audio processing. + + Stream initialization (state, context, prefill) is always handled + by :meth:`_init_state` *before* this method is called. + """ if len(frames) == 0: return stream_ids = [f.stream_id for f in frames] eos_flags = [f.is_last for f in frames] - bos_flags = [f.is_first for f in frames] - logging.debug(f"stream_ids={stream_ids} bos_flags={bos_flags} eos_flags={eos_flags}") + logging.debug(f"stream_ids={stream_ids} eos_flags={eos_flags}") if len(frames) != 1: raise NotImplementedError("NemotronVoicechatInferenceWrapper currently supports batch_size == 1") - # If this is the first audio frame and prefill was already done via a - # zero-length prefill frame, skip context init -- it's already set up. - # Otherwise (no system prompt), create a fresh context_manager. - has_prompt = False - if bos_flags[0]: - if self._stream_has_prompt: - logging.debug(f"Prefill already done for stream {stream_ids[0]}, skipping context init") - else: - logging.debug(f"No prefill for stream {stream_ids[0]}, creating fresh context_manager") - self.context_manager = S2SContextManager( - s2s_model=self.s2s_model, - num_slots=self.batch_size, - max_len=self.max_len, - ) - has_prompt = self._stream_has_prompt self._stream_has_prompt = False - + request_id = self._request_id_for_stream(stream_ids[0]) - + context, _ = self.context_manager.get_context(stream_ids) audio_buffer = buffers[0] @@ -217,6 +269,20 @@ def inner_generate_step(self, frames: list[Frame], buffers: list[Tensor], ready_ audio_buffer = audio_buffer.unsqueeze(0) audio_buffer = audio_buffer.to(self.s2s_model.device, dtype=self.s2s_model.dtype) + # Sampling overrides were resolved by _init_state via augment_with_defaults + # and stored on state.options. Build the dict for infer_one_step. + pipeline_state = self.get_state(stream_ids[0]) + if pipeline_state is None: + raise RuntimeError( + f"No state initialized for stream {stream_ids[0]}. " + "Clients must send an is_first=True frame before streaming audio." + ) + sampling_params = { + k: getattr(pipeline_state.options, k) + for k in ("top_p", "temperature", "repetition_penalty") + if getattr(pipeline_state.options, k, None) is not None + } + result = self.s2s_model.infer_one_step( audio_input=audio_buffer, num_frames_per_chunk=self.num_frames_per_chunk, @@ -224,6 +290,7 @@ def inner_generate_step(self, frames: list[Frame], buffers: list[Tensor], ready_ request_id=request_id, has_prompt=has_prompt, return_debug=self.collect_debug, + sampling_params=sampling_params or None, ) if self.collect_debug and result.debug is not None: @@ -259,57 +326,19 @@ def inner_generate_step(self, frames: list[Frame], buffers: list[Tensor], ready_ # It will be cleaned up in close_session() # Log audio and attach text to state - self.log_output(frames, result.decoded_audio, ready_feats, result.predicted_text_strs, result.asr_predicted_text_strs) - - def prefill_for_new_stream(self, stream_id: int, system_prompt: str | None = None) -> bool: - """Prepare the pipeline for a new stream by resetting context and prefilling the system prompt. - - This is the public API for prefill-only calls (e.g. from a server backend) - that need to initialize TTS speaker embeddings and/or inject a system prompt - into the LLM KV cache *without* processing any audio. - - Args: - stream_id: Unique identifier for the new stream. - system_prompt: System prompt text. If *None*, falls back to - the YAML-configured ``self.system_prompt``. - - Returns: - True if a system prompt was prefilled, False otherwise. - """ - t0 = time.time() - if system_prompt is None: - system_prompt = self.system_prompt - - self.context_manager = S2SContextManager( - s2s_model=self.s2s_model, - num_slots=self.batch_size, - max_len=self.max_len, - ) - t_ctx = time.time() - - with torch.no_grad(), torch.inference_mode(): - self._prefill_system_prompt(stream_id, system_prompt) - t_prefill = time.time() - - self._stream_has_prompt = bool(system_prompt) - logging.debug(f"prefill_for_new_stream: context_manager={1000*(t_ctx-t0):.1f}ms, " - f"_prefill_system_prompt={1000*(t_prefill-t_ctx):.1f}ms, " - f"total={1000*(t_prefill-t0):.1f}ms, has_prompt={self._stream_has_prompt}") - return self._stream_has_prompt + self.log_output(frames, result.decoded_audio, result.predicted_text_strs, result.asr_predicted_text_strs) _WARMUP_FALLBACK_PROMPT = "Mock system prompt for warmup." def warmup(self, system_prompt: str | None = None) -> None: - """Run a throwaway prefill cycle to warm up the inference engine. + """Run a throwaway inference cycle to warm up the entire pipeline. - The very first prefill incurs one-time overhead (e.g. CUDA graph - compilation, memory pool allocation, DynamicCache initialization). - Calling this once during startup moves that cost out of the - critical path so the first real client request is fast. - - The method performs a full prefill (TTS speaker embedding + LLM - system prompt), then aborts the request and resets all pipeline - state so the next real stream starts cleanly. + The very first call through each stage incurs one-time overhead + (e.g. CUDA graph compilation, memory pool allocation, + DynamicCache initialization, torch.compile). Sending a silence + frame with ``is_first=True`` exercises the full path — prefill, + perception, LLM decode, TTS, and codec — so the first real + client request is fast. Args: system_prompt: Prompt text to use for warmup. Falls back to @@ -323,47 +352,58 @@ def warmup(self, system_prompt: str | None = None) -> None: logging.info(f"No system prompt configured — using fallback prompt for warmup: \"{prompt}\"") warmup_stream_id = -1 + chunk_samples = int(self.chunk_size_in_secs * self.input_sample_rate) - logging.info("Running pipeline warmup prefill...") + logging.info("Running pipeline warmup (prefill + one silence chunk)...") t0 = time.time() - self.prefill_for_new_stream(warmup_stream_id, prompt) + warmup_frame = Frame( + samples=torch.zeros(chunk_samples), + stream_id=warmup_stream_id, + is_first=True, + is_last=True, + options=S2SRequestOptions(system_prompt=prompt), + ) + self.generate_step([warmup_frame]) - # Tear down the warmup request so the engine is clean for real traffic - self._abort_stream_request(warmup_stream_id) - self.context_manager.reset() + # Tear down everything so the engine is clean for real traffic + self.reset_session() self._stream_has_prompt = False logging.info(f"Pipeline warmup complete in {time.time() - t0:.3f}s") def generate_step(self, frames: list[Frame]): - """Main streaming API similar to *transcribe_step* in recognizers. - - If the batch contains a single zero-length first frame with a system - prompt in ``options``, this is treated as a **prefill-only** request: - the context manager and system prompt are initialized but no audio - inference runs. This is the unified protocol used by both the CLI - (``run()``) and server backends. + """Main streaming API — handles both init and audio processing. + + Mirrors ASR's ``transcribe_step``: on ``is_first`` frames, the + stream is initialized via :meth:`_init_state` (state creation, + context-manager allocation, KV-cache prefill). If the frame also + carries audio, it is processed in the same call. If there is no + audio (e.g. a server prefill-only request), the method returns + after init. + + For latency-sensitive deployments, send the ``is_first`` frame + with **empty audio** so that the expensive prefill completes + before the user starts speaking. For batch/offline usage the + first frame can carry real audio — init and first-chunk + processing simply happen back-to-back in one call. """ - # Detect prefill-only frame: is_first + zero-length audio - if (len(frames) == 1 - and frames[0].is_first - and frames[0].samples.numel() == 0): - opts = frames[0].options - prompt = None - if opts is not None and hasattr(opts, "system_prompt"): - prompt = opts.system_prompt - self.prefill_for_new_stream(frames[0].stream_id, prompt) + # Init phase — like ASR's `if request.is_first: self.init_state(...)` + for frame in frames: + if frame.is_first: + self._init_state(frame.stream_id, frame.options) + + # Audio phase — skip if no audio (e.g. server prefill-only request) + non_empty_frames = [f for f in frames if f.samples.numel() > 0] + if not non_empty_frames: return - buffers, left_paddings = self.bufferer.update(frames) + buffers, left_paddings = self.bufferer.update(non_empty_frames) # This is a workaround for the fact that the audio buffer does left # padding, but the rest of the code requires no padding at all. buffers = [b[lp:] for b, lp in zip(buffers, left_paddings)] - ready_feats = [True] * len(frames) - with torch.no_grad(), torch.inference_mode(): - self.inner_generate_step(frames, buffers, ready_feats) + self.generate_step_for_frames(non_empty_frames, buffers) # ------------------------------------------------------------------ # Finalization helpers @@ -483,7 +523,7 @@ def run( raise ValueError("progress_bar must be an instance of ProgressBar.") if options is None: - options = [S2SRequestOptions(system_prompt=self.system_prompt) for _ in audio_filepaths] + options = [S2SRequestOptions() for _ in audio_filepaths] streamer = ContinuousBatchedFrameStreamer( n_frames_per_stream=1, @@ -504,7 +544,6 @@ def run( self.open_session() for frames in streamer: - self._maybe_prefill(frames) frames, pad_targets = self._apply_padding(frames, streamer) self.generate_step(frames) self._finalize_and_save_finished_streams(frames, audio_filepaths, saved_paths_by_stream) @@ -518,26 +557,6 @@ def run( # run() helpers # ------------------------------------------------------------------ - def _maybe_prefill(self, frames: list[Frame]) -> None: - """If the first frame of a new stream carries a system prompt, emit a - zero-length prefill frame through ``generate_step`` before inference - begins. This is the unified prefill protocol used by both ``run()`` - and server backends. - """ - if (len(frames) == 1 - and frames[0].is_first - and frames[0].options is not None - and hasattr(frames[0].options, "system_prompt") - and frames[0].options.system_prompt): - prefill_frame = Frame( - samples=torch.empty(0), - stream_id=frames[0].stream_id, - is_first=True, - is_last=False, - options=frames[0].options, - ) - self.generate_step([prefill_frame]) - def _apply_padding( self, frames: list[Frame], diff --git a/nemo/collections/speechlm2/inference/streaming/framing/s2s_request_options.py b/nemo/collections/speechlm2/inference/streaming/framing/s2s_request_options.py index 88012f41c941..999985f02610 100644 --- a/nemo/collections/speechlm2/inference/streaming/framing/s2s_request_options.py +++ b/nemo/collections/speechlm2/inference/streaming/framing/s2s_request_options.py @@ -12,16 +12,61 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from dataclasses import dataclass +from typing import Any -@dataclass(slots=True) +@dataclass(frozen=True, slots=True) class S2SRequestOptions: - """Per-stream options for S2S inference. + """Immutable per-stream options for S2S inference. Attached to the first ``Frame`` of each stream via the ``options`` field so that the pipeline can read per-stream configuration at the - start of every new audio stream. + start of every new audio stream. Frozen so that options cannot be + accidentally modified after the stream is initialised. + + All fields default to ``None``, which means "use the pipeline-level + default". Call :meth:`augment_with_defaults` to fill ``None`` fields + with pipeline-level values — this mirrors the + ``ASRRequestOptions.augment_with_defaults()`` pattern used by the + ASR streaming pipelines. """ system_prompt: str | None = None + + top_p: float | None = None # (0, 1] + temperature: float | None = None # >= 0 + repetition_penalty: float | None = None # > 0 + + def __post_init__(self) -> None: + if self.top_p is not None and not (0.0 < self.top_p <= 1.0): + raise ValueError(f"top_p must be in (0, 1], got {self.top_p}") + if self.temperature is not None and self.temperature < 0.0: + raise ValueError(f"temperature must be >= 0, got {self.temperature}") + if self.repetition_penalty is not None and self.repetition_penalty <= 0.0: + raise ValueError(f"repetition_penalty must be > 0, got {self.repetition_penalty}") + + @staticmethod + def _with_default(value: Any, default: Any) -> Any: + """Return *value* when it is not ``None``, otherwise *default*.""" + return default if value is None else value + + def augment_with_defaults( + self, + default_system_prompt: str | None = None, + default_top_p: float | None = None, + default_temperature: float | None = None, + default_repetition_penalty: float | None = None, + ) -> S2SRequestOptions: + """Return a new options instance with ``None`` fields filled from defaults. + + This is the S2S equivalent of ``ASRRequestOptions.augment_with_defaults``. + """ + return S2SRequestOptions( + system_prompt=self._with_default(self.system_prompt, default_system_prompt), + top_p=self._with_default(self.top_p, default_top_p), + temperature=self._with_default(self.temperature, default_temperature), + repetition_penalty=self._with_default(self.repetition_penalty, default_repetition_penalty), + ) diff --git a/nemo/collections/speechlm2/inference/streaming/state/s2s_state.py b/nemo/collections/speechlm2/inference/streaming/state/s2s_state.py index a0fb2a9c5082..c9c300aa3db2 100644 --- a/nemo/collections/speechlm2/inference/streaming/state/s2s_state.py +++ b/nemo/collections/speechlm2/inference/streaming/state/s2s_state.py @@ -12,10 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from dataclasses import dataclass, field import torch from nemo.collections.asr.inference.utils.text_segment import Word +from nemo.collections.speechlm2.inference.streaming.framing.s2s_request_options import S2SRequestOptions @dataclass @@ -37,6 +40,9 @@ class S2SStreamingState: dtype: torch.dtype output_sample_rate: int + # Per-stream request options (system prompt, sampling overrides, etc.) + options: S2SRequestOptions = field(default_factory=S2SRequestOptions) + # Growing audio buffer — shape (1, T), appended each step audio_buffer: torch.Tensor = field(init=False) From 3e9e3e1cf197e9919ce4dfe207b6aaedb3121d0e Mon Sep 17 00:00:00 2001 From: Elena Rastorgueva Date: Fri, 3 Apr 2026 02:18:55 +0000 Subject: [PATCH 33/40] perception_cache: check all three fields in is_initialized Signed-off-by: Elena Rastorgueva --- .../speechlm2/inference/model_wrappers/perception_cache.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/speechlm2/inference/model_wrappers/perception_cache.py b/nemo/collections/speechlm2/inference/model_wrappers/perception_cache.py index f4fed6f69da5..b3250b3131d6 100644 --- a/nemo/collections/speechlm2/inference/model_wrappers/perception_cache.py +++ b/nemo/collections/speechlm2/inference/model_wrappers/perception_cache.py @@ -43,7 +43,7 @@ class PerceptionCacheState: def is_initialized(self) -> bool: """Check if the cache has been initialized.""" - return self.cache_last_channel is not None + return None not in [self.cache_last_channel, self.cache_last_time, self.cache_last_channel_len] @dataclass From 84eeec539d516383568274d376aad0f6b7a745ee Mon Sep 17 00:00:00 2001 From: Elena Rastorgueva Date: Fri, 3 Apr 2026 18:07:56 +0000 Subject: [PATCH 34/40] move silence padding from pipeline run-loop into streamer classes Signed-off-by: Elena Rastorgueva --- .../s2s_streaming_infer.py | 14 +-- .../pipelines/streaming_s2s_pipeline.py | 89 +---------------- .../framing/silence_padded_frame_streamer.py | 75 ++++++++++++++ .../framing/silence_padded_stream.py | 97 +++++++++++++++++++ .../speechlm2/inference/utils/audio_data.py | 20 ++-- 5 files changed, 190 insertions(+), 105 deletions(-) create mode 100644 nemo/collections/speechlm2/inference/streaming/framing/silence_padded_frame_streamer.py create mode 100644 nemo/collections/speechlm2/inference/streaming/framing/silence_padded_stream.py diff --git a/examples/speechlm2/nemo_inference_pipelines/s2s_streaming_infer.py b/examples/speechlm2/nemo_inference_pipelines/s2s_streaming_infer.py index 46e5394588ee..8f1b27565488 100644 --- a/examples/speechlm2/nemo_inference_pipelines/s2s_streaming_infer.py +++ b/examples/speechlm2/nemo_inference_pipelines/s2s_streaming_infer.py @@ -30,8 +30,7 @@ from nemo.collections.asr.metrics.wer import word_error_rate from nemo.collections.speechlm2.inference.factory.s2s_pipeline_builder import S2SPipelineBuilder from nemo.collections.speechlm2.inference.utils.audio_data import ( - calculate_duration, - calculate_padded_duration, + calculate_duration_incl_padding, dump_output, prepare_audio_data, ) @@ -57,14 +56,9 @@ def main(cfg: DictConfig): exec_dur = timer.total_sec() logging.info(f"Generated {len(audio_filepaths)} files in {exec_dur:.2f}s") - # Log RTFX (accounts for padding when configured) - pad_to = cfg.get("pad_audio_to_sec", None) - pad_ratio = cfg.get("pad_silence_ratio", None) - pad_by = cfg.get("pad_audio_by_sec", None) - if pad_to or pad_ratio or pad_by: - data_dur = calculate_padded_duration(audio_filepaths, pad_to, pad_ratio, pad_by) - else: - data_dur = calculate_duration(audio_filepaths) + data_dur = calculate_duration_incl_padding( + audio_filepaths, cfg.get("pad_audio_to_sec"), cfg.get("pad_silence_ratio"), cfg.get("pad_audio_by_sec"), + ) rtfx = data_dur / exec_dur if exec_dur > 0 else float('inf') logging.info(f"RTFX: {rtfx:.2f} ({data_dur:.2f}s / {exec_dur:.2f}s)") diff --git a/nemo/collections/speechlm2/inference/pipelines/streaming_s2s_pipeline.py b/nemo/collections/speechlm2/inference/pipelines/streaming_s2s_pipeline.py index c473e5713989..88600ad44514 100644 --- a/nemo/collections/speechlm2/inference/pipelines/streaming_s2s_pipeline.py +++ b/nemo/collections/speechlm2/inference/pipelines/streaming_s2s_pipeline.py @@ -25,7 +25,6 @@ from nemo.collections.asr.inference.streaming.framing.request import Frame from nemo.collections.asr.inference.utils.enums import RequestType -from nemo.collections.asr.inference.streaming.framing.multi_stream import ContinuousBatchedFrameStreamer from nemo.collections.asr.inference.streaming.buffering.audio_bufferer import BatchedAudioBufferer from nemo.collections.asr.inference.utils.progressbar import ProgressBar from nemo.collections.speechlm2.inference.pipelines.s2s_pipeline_interface import S2SPipelineInterface @@ -34,6 +33,7 @@ from nemo.collections.speechlm2.parts.text_utils import tokens_to_str from nemo.collections.speechlm2.inference.streaming.state.s2s_context_manager import S2SContextManager from nemo.collections.speechlm2.inference.streaming.framing.s2s_request_options import S2SRequestOptions +from nemo.collections.speechlm2.inference.streaming.framing.silence_padded_frame_streamer import SilencePaddedContinuousBatchedFrameStreamer from nemo.collections.speechlm2.inference.utils.pipeline_utils import PipelineOutput from nemo.utils import logging @@ -525,100 +525,31 @@ def run( if options is None: options = [S2SRequestOptions() for _ in audio_filepaths] - streamer = ContinuousBatchedFrameStreamer( + streamer = SilencePaddedContinuousBatchedFrameStreamer( n_frames_per_stream=1, frame_size_in_secs=self.chunk_size_in_secs, sample_rate=self.input_sample_rate, batch_size=self.batch_size, pad_last_frame=True, + pad_to_sec=self.pad_audio_to_sec, + pad_by_sec=self.pad_audio_by_sec, + pad_ratio=self.pad_silence_ratio, ) streamer.set_audio_filepaths(audio_filepaths, options) streamer.set_progress_bar(progress_bar) - # Ensure output directory exists os.makedirs(self.output_dir, exist_ok=True) - - # Track saved paths by stream id to preserve input order saved_paths_by_stream: dict[int, str] = {} - chunk_samples = int(self.chunk_size_in_secs * self.input_sample_rate) self.open_session() for frames in streamer: - frames, pad_targets = self._apply_padding(frames, streamer) self.generate_step(frames) self._finalize_and_save_finished_streams(frames, audio_filepaths, saved_paths_by_stream) - self._generate_silence_padding(pad_targets, chunk_samples, audio_filepaths, saved_paths_by_stream) output = self._build_pipeline_output(audio_filepaths, saved_paths_by_stream) self.close_session() return output - # ------------------------------------------------------------------ - # run() helpers - # ------------------------------------------------------------------ - - def _apply_padding( - self, - frames: list[Frame], - streamer: ContinuousBatchedFrameStreamer, - ) -> tuple[list[Frame], dict[int, float]]: - """If padding is configured, intercept last frames so the bufferer and - context stay alive for the silence-padding phase. Returns the - (possibly modified) frames and a dict mapping ``stream_id`` to the - remaining seconds of silence to append. - """ - pad_targets: dict[int, float] = {} - if not (self.pad_audio_to_sec or self.pad_silence_ratio or self.pad_audio_by_sec): - return frames, pad_targets - - processed_frames = [] - for frame in frames: - if frame.is_last: - elapsed = streamer.elapsed_durations[frame.stream_id] - remaining = self._padding_remaining_secs(elapsed) - if remaining > 0: - processed_frames.append(Frame( - samples=frame.samples, - stream_id=frame.stream_id, - is_first=frame.is_first, - is_last=False, - length=frame.length, - options=frame.options, - )) - pad_targets[frame.stream_id] = remaining - continue - processed_frames.append(frame) - return processed_frames, pad_targets - - def _generate_silence_padding( - self, - pad_targets: dict[int, float], - chunk_samples: int, - audio_filepaths: list[str], - saved_paths_by_stream: dict[int, str], - ) -> None: - """Generate silence-padding frames for streams that need them. - - Must run in the same iteration as the real last frame to avoid the next - stream's setup destroying this stream's context. - """ - for stream_id, remaining_secs in pad_targets.items(): - num_pad_frames = max(1, round(remaining_secs / self.chunk_size_in_secs)) - for i in range(num_pad_frames): - is_last = (i == num_pad_frames - 1) - silence_frame = Frame( - samples=torch.zeros(chunk_samples), - stream_id=stream_id, - is_first=False, - is_last=is_last, - length=chunk_samples, - ) - self.generate_step([silence_frame]) - if is_last: - self._finalize_and_save_finished_streams( - [silence_frame], audio_filepaths, saved_paths_by_stream - ) - def _build_pipeline_output( self, audio_filepaths: list[str], @@ -814,16 +745,6 @@ def _prefill_system_prompt(self, stream_id: int, system_prompt: str | None = Non return tts_output_code - def _padding_remaining_secs(self, elapsed_secs: float) -> float: - """Return how many seconds of silence padding are still needed.""" - if self.pad_audio_to_sec is not None: - return max(0.0, self.pad_audio_to_sec - elapsed_secs) - if self.pad_silence_ratio is not None: - return elapsed_secs * self.pad_silence_ratio - if self.pad_audio_by_sec is not None: - return self.pad_audio_by_sec - return 0.0 - def _request_id_for_stream(self, stream_id: int) -> str: return str(stream_id) diff --git a/nemo/collections/speechlm2/inference/streaming/framing/silence_padded_frame_streamer.py b/nemo/collections/speechlm2/inference/streaming/framing/silence_padded_frame_streamer.py new file mode 100644 index 000000000000..edca88bfd31f --- /dev/null +++ b/nemo/collections/speechlm2/inference/streaming/framing/silence_padded_frame_streamer.py @@ -0,0 +1,75 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo.collections.asr.inference.streaming.framing.mono_stream import MonoStream +from nemo.collections.asr.inference.streaming.framing.multi_stream import ContinuousBatchedFrameStreamer +from nemo.collections.speechlm2.inference.streaming.framing.silence_padded_stream import SilencePaddedStream + + +class SilencePaddedContinuousBatchedFrameStreamer(ContinuousBatchedFrameStreamer): + """``ContinuousBatchedFrameStreamer`` that optionally wraps each + ``MonoStream`` in a :class:`SilencePaddedStream` so extra silence + frames are yielded transparently at the end of each audio file. + + When no padding is configured the behaviour is identical to the base + class. + """ + + def __init__( + self, + *, + pad_to_sec: float | None = None, + pad_by_sec: float | None = None, + pad_ratio: float | None = None, + **kwargs, + ): + super().__init__(**kwargs) + self.pad_to_sec = pad_to_sec + self.pad_by_sec = pad_by_sec + self.pad_ratio = pad_ratio + + @property + def _needs_padding(self) -> bool: + return any(x is not None for x in (self.pad_to_sec, self.pad_by_sec, self.pad_ratio)) + + def add_stream(self) -> None: + if self.stream_id >= self.n_audio_files: + return + + inner = MonoStream( + self.sample_rate, + self.frame_size_in_secs, + stream_id=self.stream_id, + pad_last_frame=self.pad_last_frame, + ) + + if self._needs_padding: + stream = SilencePaddedStream( + inner, + chunk_size_in_secs=self.frame_size_in_secs, + pad_to_sec=self.pad_to_sec, + pad_by_sec=self.pad_by_sec, + pad_ratio=self.pad_ratio, + ) + else: + stream = inner + + audio_filepath = self.audio_filepaths[self.stream_id] + self.sid2filepath[self.stream_id] = audio_filepath + self.elapsed_durations[self.stream_id] = 0.0 + stream.load_audio(audio_filepath, self.options[self.stream_id]) + + self.multi_streamer.add_stream(stream, stream_id=self.stream_id) + self.stream_id += 1 + self.update_progress_bar() diff --git a/nemo/collections/speechlm2/inference/streaming/framing/silence_padded_stream.py b/nemo/collections/speechlm2/inference/streaming/framing/silence_padded_stream.py new file mode 100644 index 000000000000..e5f6d3511efa --- /dev/null +++ b/nemo/collections/speechlm2/inference/streaming/framing/silence_padded_stream.py @@ -0,0 +1,97 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from nemo.collections.asr.inference.streaming.framing.mono_stream import MonoStream +from nemo.collections.asr.inference.streaming.framing.request import Frame +from nemo.collections.asr.inference.streaming.framing.stream import Stream + + +class SilencePaddedStream(Stream): + """Wraps a ``MonoStream`` and appends silence frames after the real audio + to reach a target duration. + + The pipeline's ``run()`` loop sees a single, longer stream — no frame + mutation or side-channel silence injection is needed. ``MultiStream`` + keeps the stream alive until the final silence frame sets ``is_last=True``. + """ + + def __init__( + self, + inner: MonoStream, + chunk_size_in_secs: float, + pad_to_sec: float | None = None, + pad_by_sec: float | None = None, + pad_ratio: float | None = None, + ): + super().__init__(inner.stream_id) + self.inner = inner + self.chunk_size_in_secs = chunk_size_in_secs + self.pad_to_sec = pad_to_sec + self.pad_by_sec = pad_by_sec + self.pad_ratio = pad_ratio + self._inner_exhausted = False + self._silence_frames_remaining = 0 + + def load_audio(self, audio, options=None): + self.inner.load_audio(audio, options) + audio_secs = self.inner.n_samples / self.inner.rate + remaining = self._padding_secs(audio_secs) + self._silence_frames_remaining = ( + max(1, round(remaining / self.chunk_size_in_secs)) if remaining > 0 else 0 + ) + + def _padding_secs(self, elapsed: float) -> float: + if self.pad_to_sec is not None: + return max(0.0, self.pad_to_sec - elapsed) + if self.pad_ratio is not None: + return elapsed * self.pad_ratio + if self.pad_by_sec is not None: + return self.pad_by_sec + return 0.0 + + def __iter__(self): + self.inner.__iter__() + self._inner_exhausted = False + return self + + def __next__(self) -> list[Frame]: + if not self._inner_exhausted: + frames = next(self.inner) + frame = frames[0] + if frame.is_last and self._silence_frames_remaining > 0: + modified = Frame( + samples=frame.samples, + stream_id=frame.stream_id, + is_first=frame.is_first, + is_last=False, + length=frame.length, + options=frame.options, + ) + self._inner_exhausted = True + return [modified] + return frames + + if self._silence_frames_remaining > 0: + self._silence_frames_remaining -= 1 + return [Frame( + samples=torch.zeros(self.inner.frame_size), + stream_id=self.stream_id, + is_first=False, + is_last=(self._silence_frames_remaining == 0), + length=self.inner.frame_size, + )] + + raise StopIteration diff --git a/nemo/collections/speechlm2/inference/utils/audio_data.py b/nemo/collections/speechlm2/inference/utils/audio_data.py index 623ec36e9cab..a054b18b8d8a 100644 --- a/nemo/collections/speechlm2/inference/utils/audio_data.py +++ b/nemo/collections/speechlm2/inference/utils/audio_data.py @@ -83,22 +83,20 @@ def prepare_audio_data( return filepaths, options, ground_truths -def calculate_duration(audio_filepaths: list[str]) -> float: - """Calculate total duration of the given audio files in seconds.""" - total_dur = 0 - for audio_filepath in audio_filepaths: - sound = sf.SoundFile(audio_filepath) - total_dur += sound.frames / sound.samplerate - return total_dur - - -def calculate_padded_duration( +def calculate_duration_incl_padding( audio_filepaths: list[str], pad_audio_to_sec: float | None = None, pad_silence_ratio: float | None = None, pad_audio_by_sec: float | None = None, ) -> float: - """Calculate total duration including silence padding for RTFX reporting.""" + """Calculate total duration of the given audio files in seconds. + + Optionally accounts for silence padding appended after each file. + At most one padding argument may be set; when none are set this + returns the raw audio duration. + """ + if sum(x is not None for x in [pad_audio_to_sec, pad_silence_ratio, pad_audio_by_sec]) > 1: + raise ValueError("Set at most one of: pad_audio_to_sec, pad_silence_ratio, pad_audio_by_sec") total = 0.0 for fp in audio_filepaths: sound = sf.SoundFile(fp) From 800bcc2f259bb705e820f4f6afaad0b6ecf26d5d Mon Sep 17 00:00:00 2001 From: Elena Rastorgueva Date: Fri, 3 Apr 2026 19:22:59 +0000 Subject: [PATCH 35/40] return incremental GenerateStepOutput from generate_step Signed-off-by: Elena Rastorgueva --- docs/source/speechlm2/streaming_inference.rst | 15 ++- .../speechlm2/inference/__init__.py | 2 +- .../pipelines/streaming_s2s_pipeline.py | 127 +++++++++++++----- 3 files changed, 102 insertions(+), 42 deletions(-) diff --git a/docs/source/speechlm2/streaming_inference.rst b/docs/source/speechlm2/streaming_inference.rst index d23e53cd1dd8..153dee7255bd 100644 --- a/docs/source/speechlm2/streaming_inference.rst +++ b/docs/source/speechlm2/streaming_inference.rst @@ -111,7 +111,7 @@ What Happens Inside One Step │ ├─ any frames with audio? │ │ - │ NO → return (server prefill-only request) + │ NO → return empty outputs (server prefill-only request) │ │ │ YES ↓ │ @@ -122,6 +122,7 @@ What Happens Inside One Step 4. per-frame TTS 5. codec decode 6. state updates + output accumulation + 7. return list[GenerateStepOutput] Each call to ``generate_step(frames)`` performs: @@ -282,10 +283,14 @@ Server Integration ------------------ The same ``generate_step()`` method used by ``run()`` can be called directly -from a custom server: +from a custom server. It returns a list of ``GenerateStepOutput`` objects +(one per input frame) carrying the **incremental** audio and text produced +by this step — no need to diff against accumulated state: .. code-block:: python + from nemo.collections.speechlm2.inference import GenerateStepOutput + # 1. Init stream (empty audio so prefill completes before recording) init_frame = Frame( samples=torch.empty(0), @@ -296,14 +301,16 @@ from a custom server: pipeline.generate_step([init_frame]) # -> client can now start recording - # 2. Stream audio chunks + # 2. Stream audio chunks and consume incremental outputs for i, chunk in enumerate(audio_source): frame = Frame( samples=chunk, stream_id=stream_id, is_first=False, is_last=(i == last), ) - pipeline.generate_step([frame]) + outputs = pipeline.generate_step([frame]) + for out in outputs: + send_to_client(out.audio, out.text, out.asr_text) Per-stream options (``system_prompt``, ``top_p``, ``temperature``, ``repetition_penalty``) are attached to the ``is_first`` frame via diff --git a/nemo/collections/speechlm2/inference/__init__.py b/nemo/collections/speechlm2/inference/__init__.py index 03abea1da74d..575d7a95e8bc 100644 --- a/nemo/collections/speechlm2/inference/__init__.py +++ b/nemo/collections/speechlm2/inference/__init__.py @@ -17,6 +17,6 @@ InferenceStepResult, StreamingDecodeState, ) -from nemo.collections.speechlm2.inference.pipelines.streaming_s2s_pipeline import StreamingS2SPipeline +from nemo.collections.speechlm2.inference.pipelines.streaming_s2s_pipeline import GenerateStepOutput, StreamingS2SPipeline from nemo.collections.speechlm2.inference.streaming.framing.s2s_request_options import S2SRequestOptions from nemo.collections.speechlm2.inference.utils.pipeline_utils import PipelineOutput diff --git a/nemo/collections/speechlm2/inference/pipelines/streaming_s2s_pipeline.py b/nemo/collections/speechlm2/inference/pipelines/streaming_s2s_pipeline.py index 88600ad44514..03c418903ac9 100644 --- a/nemo/collections/speechlm2/inference/pipelines/streaming_s2s_pipeline.py +++ b/nemo/collections/speechlm2/inference/pipelines/streaming_s2s_pipeline.py @@ -15,6 +15,7 @@ import copy import os import time +from dataclasses import dataclass import torch import librosa @@ -38,6 +39,25 @@ from nemo.utils import logging +@dataclass +class GenerateStepOutput: + """Output of a single :meth:`StreamingS2SPipeline.generate_step` call + for one stream. + + Analogous to :class:`TranscribeStepOutput` in the ASR pipelines, + this carries the **incremental** (new-this-step) audio and text so + that callers don't have to diff against accumulated state. + + The underlying :class:`S2SStreamingState` still accumulates + everything for batch/offline use. + """ + + stream_id: int + audio: torch.Tensor + text: str = "" + asr_text: str = "" + + class StreamingS2SPipeline(S2SPipelineInterface): """ Streaming S2S pipeline. @@ -208,33 +228,7 @@ def _init_state(self, stream_id: int, options: S2SRequestOptions | None = None) self._stream_has_prompt = bool(prompt) - # ------------------------------------------------------------------ - # Output helpers - # ------------------------------------------------------------------ - def log_output(self, frames: list[Frame], audio_wave: Tensor, text_pieces: list[str], asr_text_pieces: list[str] = None): - """Append generated audio waveform and text to per-stream state.""" - for idx, frame in enumerate(frames): - state = self.get_state(frame.stream_id) - # audio_wave is [B, S]; take sample idx (None when decode_audio=False) - sample_audio = audio_wave[idx:idx+1, ...] if audio_wave is not None else None - # Determine text piece for this index - piece = None - if text_pieces and idx < len(text_pieces): - candidate = text_pieces[idx] - if isinstance(candidate, str) and candidate: - piece = candidate - - # Determine ASR text piece - asr_piece = None - if asr_text_pieces and idx < len(asr_text_pieces): - candidate = asr_text_pieces[idx] - if isinstance(candidate, str) and candidate: - asr_piece = candidate - - state.append_step_output(sample_audio, text=piece, asr_text=asr_piece) - - - def generate_step_for_frames(self, frames: list[Frame], buffers: list[Tensor]): + def generate_step_for_frames(self, frames: list[Frame], buffers: list[Tensor]) -> list[GenerateStepOutput]: """Generate speech for audio Frames using a shared ContextManager. This is the S2S equivalent of ASR's ``transcribe_step_for_frames`` @@ -247,7 +241,7 @@ def generate_step_for_frames(self, frames: list[Frame], buffers: list[Tensor]): by :meth:`_init_state` *before* this method is called. """ if len(frames) == 0: - return + return [] stream_ids = [f.stream_id for f in frames] eos_flags = [f.is_last for f in frames] @@ -325,8 +319,30 @@ def generate_step_for_frames(self, frames: list[Frame], buffers: list[Tensor]): # Note: We keep the state in _state_pool until finalization to save audio # It will be cleaned up in close_session() - # Log audio and attach text to state - self.log_output(frames, result.decoded_audio, result.predicted_text_strs, result.asr_predicted_text_strs) + # Split the batch-level InferenceStepResult into per-frame outputs. + # Two things happen for each frame: + # 1. The incremental audio/text is appended to S2SStreamingState + # (the pipeline-level accumulator that persists across steps — + # used by run() to build the final PipelineOutput). + # 2. A GenerateStepOutput is built with the same incremental data + # and returned to the caller (used by server integrations to + # stream partial results to clients without diffing state). + outputs: list[GenerateStepOutput] = [] + for idx, frame in enumerate(frames): + state = self.get_state(frame.stream_id) + audio = result.decoded_audio[idx:idx+1] if result.decoded_audio is not None else None + text = result.predicted_text_strs[idx] if result.predicted_text_strs else "" + asr_text = result.asr_predicted_text_strs[idx] if result.asr_predicted_text_strs else "" + + state.append_step_output(audio, text=text, asr_text=asr_text) + + outputs.append(GenerateStepOutput( + stream_id=frame.stream_id, + audio=audio if audio is not None else torch.empty(1, 0), + text=text, + asr_text=asr_text, + )) + return outputs _WARMUP_FALLBACK_PROMPT = "Mock system prompt for warmup." @@ -372,15 +388,21 @@ def warmup(self, system_prompt: str | None = None) -> None: logging.info(f"Pipeline warmup complete in {time.time() - t0:.3f}s") - def generate_step(self, frames: list[Frame]): + def generate_step(self, frames: list[Frame]) -> list[GenerateStepOutput]: """Main streaming API — handles both init and audio processing. Mirrors ASR's ``transcribe_step``: on ``is_first`` frames, the stream is initialized via :meth:`_init_state` (state creation, context-manager allocation, KV-cache prefill). If the frame also - carries audio, it is processed in the same call. If there is no - audio (e.g. a server prefill-only request), the method returns - after init. + carries input audio, it is processed in the same call. If there + is no input audio (e.g. a server prefill-only request), the + method returns after init without running inference. + + Returns one :class:`GenerateStepOutput` per input frame carrying + the **incremental** output audio and text produced by this step. + The output audio tensor may be empty when no waveform is produced + (prefill-only frames with no input audio, or when + ``decode_audio=False``). For latency-sensitive deployments, send the ``is_first`` frame with **empty audio** so that the expensive prefill completes @@ -389,21 +411,48 @@ def generate_step(self, frames: list[Frame]): processing simply happen back-to-back in one call. """ # Init phase — like ASR's `if request.is_first: self.init_state(...)` + # Known limitation: _init_state runs prefill synchronously, so with + # batch_size > 1 a long system prompt on one stream will block + # decoding on all other streams in the same batch. for frame in frames: if frame.is_first: self._init_state(frame.stream_id, frame.options) - # Audio phase — skip if no audio (e.g. server prefill-only request) + # Decode phase. + # Although generate_step_for_frames currently enforces batch_size==1, + # this method already handles mixed batches (a mix of prefill-only + # and audio-carrying frames) so it is ready for batch_size > 1. + # We run inference only on the non-empty subset, then stitch the + # outputs back into a list aligned 1:1 with the original *frames*. non_empty_frames = [f for f in frames if f.samples.numel() > 0] if not non_empty_frames: - return + # All frames are prefill-only — nothing to decode. + return [ + GenerateStepOutput(stream_id=f.stream_id, audio=torch.empty(1, 0)) + for f in frames + ] buffers, left_paddings = self.bufferer.update(non_empty_frames) # This is a workaround for the fact that the audio buffer does left # padding, but the rest of the code requires no padding at all. buffers = [b[lp:] for b, lp in zip(buffers, left_paddings)] with torch.no_grad(), torch.inference_mode(): - self.generate_step_for_frames(non_empty_frames, buffers) + step_outputs = self.generate_step_for_frames(non_empty_frames, buffers) + + # Fast path: every frame had audio, so step_outputs is already 1:1. + if len(non_empty_frames) == len(frames): + return step_outputs + + # Slow path (batch_size > 1 with a mix of prefill and audio frames): + # fill in empty outputs for the prefill-only streams. + output_by_stream: dict[int, GenerateStepOutput] = {o.stream_id: o for o in step_outputs} + return [ + output_by_stream.get( + f.stream_id, + GenerateStepOutput(stream_id=f.stream_id, audio=torch.empty(1, 0)), + ) + for f in frames + ] # ------------------------------------------------------------------ # Finalization helpers @@ -543,6 +592,10 @@ def run( self.open_session() for frames in streamer: + # generate_step returns per-step GenerateStepOutput objects + # (useful for server integrations that stream partial results + # to clients). Here we rely on the accumulated state instead, + # which _finalize_and_save_finished_streams reads on is_last. self.generate_step(frames) self._finalize_and_save_finished_streams(frames, audio_filepaths, saved_paths_by_stream) From 88543cfe885faead85d892a34ed82ce8085b1948 Mon Sep 17 00:00:00 2001 From: Elena Rastorgueva Date: Thu, 9 Apr 2026 23:30:07 +0000 Subject: [PATCH 36/40] refactor: split model_factory into backend/ modules; unify vLLM engine classes Signed-off-by: Elena Rastorgueva --- docs/source/speechlm2/streaming_inference.rst | 112 +- .../conf/s2s_streaming.yaml | 2 + .../model_wrappers/backend/__init__.py | 23 + .../model_wrappers/backend/interface.py | 255 ++++ .../backend/pytorch/__init__.py | 13 + .../model_wrappers/backend/pytorch/eartts.py | 134 ++ .../model_wrappers/backend/pytorch/model.py | 235 ++++ .../model_wrappers/backend/vllm/__init__.py | 13 + .../model_wrappers/backend/vllm/base.py | 303 +++++ .../model_wrappers/backend/vllm/eartts.py | 191 +++ .../model_wrappers/backend/vllm/llm.py | 167 +++ .../inference/model_wrappers/factory.py | 146 ++ .../inference/model_wrappers/model_factory.py | 1170 ----------------- .../nemotron_voicechat_inference_wrapper.py | 172 ++- .../pipelines/streaming_s2s_pipeline.py | 47 +- .../inference/vllm/streaming_llm_engine.py | 220 ++-- .../speechlm2/inference/vllm/vllm_patch.py | 59 - 17 files changed, 1805 insertions(+), 1457 deletions(-) create mode 100644 nemo/collections/speechlm2/inference/model_wrappers/backend/__init__.py create mode 100644 nemo/collections/speechlm2/inference/model_wrappers/backend/interface.py create mode 100644 nemo/collections/speechlm2/inference/model_wrappers/backend/pytorch/__init__.py create mode 100644 nemo/collections/speechlm2/inference/model_wrappers/backend/pytorch/eartts.py create mode 100644 nemo/collections/speechlm2/inference/model_wrappers/backend/pytorch/model.py create mode 100644 nemo/collections/speechlm2/inference/model_wrappers/backend/vllm/__init__.py create mode 100644 nemo/collections/speechlm2/inference/model_wrappers/backend/vllm/base.py create mode 100644 nemo/collections/speechlm2/inference/model_wrappers/backend/vllm/eartts.py create mode 100644 nemo/collections/speechlm2/inference/model_wrappers/backend/vllm/llm.py create mode 100644 nemo/collections/speechlm2/inference/model_wrappers/factory.py delete mode 100644 nemo/collections/speechlm2/inference/model_wrappers/model_factory.py delete mode 100644 nemo/collections/speechlm2/inference/vllm/vllm_patch.py diff --git a/docs/source/speechlm2/streaming_inference.rst b/docs/source/speechlm2/streaming_inference.rst index 153dee7255bd..7532511476e5 100644 --- a/docs/source/speechlm2/streaming_inference.rst +++ b/docs/source/speechlm2/streaming_inference.rst @@ -24,7 +24,10 @@ The streaming inference stack has four layers: ▼ Model Wrapper NemotronVoicechatInferenceWrapper │ - infer_one_step() - │ - perception / LLM / TTS / codec + │ - perception + │ - model_llm_interface (PyTorchLLM or VLLMLLM) + │ - model_eartts_interface (PyTorchEarTTS or VLLMEarTTS) + │ - codec decode ▼ Model NemotronVoiceChat - DuplexSTTModel + DuplexEARTTS @@ -176,6 +179,113 @@ The pipeline maintains two separate state objects per stream: ``close_session()`` so the final ``PipelineOutput`` can be assembled. +Inference Backends +^^^^^^^^^^^^^^^^^^ + +NemotronVoiceChat has two inference components that each need a backend: + +- **LLM** (DuplexSTT backbone) -- takes audio embeddings from the perception + encoder and predicts text tokens, ASR tokens, and optional function-call + tokens at each frame. +- **TTS** (EarTTS) -- takes the predicted text token and produces audio codec + codes (RVQ acoustic tokens). + +Each component can run on **native PyTorch** or **vLLM**, selected by the +``engine_type`` config value: + +.. list-table:: + :header-rows: 1 + :widths: 35 30 30 + + * - ``engine_type`` + - LLM backend + - TTS backend + * - ``native`` + - ``PyTorchLLM`` + - ``PyTorchEarTTS`` + * - ``vllm_llm`` + - ``VLLMLLM`` + - ``PyTorchEarTTS`` + * - ``vllm_eartts`` + - ``PyTorchLLM`` + - ``VLLMEarTTS`` + * - ``vllm_llm_vllm_eartts`` + - ``VLLMLLM`` + - ``VLLMEarTTS`` + +All four backend classes implement the same ``ModelInterface`` ABC (defined in +``inference.model_wrappers.backend.interface``), so the inference wrapper +(``NemotronVoicechatInferenceWrapper``) can treat them uniformly via two +attributes: + +- ``model_llm_interface`` -- the LLM backend +- ``model_eartts_interface`` -- the TTS backend + +The backend classes live under ``inference.model_wrappers.backend/``: + +.. code-block:: text + + backend/ + interface.py # ModelInterface ABC + shared sampling + pytorch/ + model.py # PyTorchLLM (wraps DuplexSTT forward pass) + eartts.py # PyTorchEarTTS (wraps DuplexEARTTS.infer_codes_one_step) + vllm/ + base.py # VLLMModelBase (engine lifecycle, async loop) + llm.py # VLLMLLM (DuplexSTT via vLLM) + eartts.py # VLLMEarTTS (EarTTS via vLLM) + factory.py # create_model() dispatches to the right class + +``ModelInterface`` provides shared utilities used by the LLM backends: +top-p (nucleus) sampling, repetition penalty, and temperature scaling. +These are applied **post-hoc** on the returned logits -- vLLM internally +runs with ``skip_sampling=True`` and greedy decoding. + +Each backend also exposes lifecycle methods that the wrapper calls uniformly: + +- ``prefill_prompt(embeddings, ...)`` -- Warm up KV cache (native) or + prefill the vLLM engine with system-prompt embeddings before streaming. +- ``compile()`` -- Apply ``torch.compile`` to the TTS backbone (native + only; no-op for vLLM). +- ``setup_subword_cache(cfg)`` -- Enable the TTS subword embedding cache + (native only; no-op for vLLM). + +The ``factory.create_model()`` function is the single entry point that +dispatches to the correct class based on a per-component ``engine_type`` +string (``native_llm``, ``native_eartts``, ``vllm_llm``, ``vllm_eartts``). + +vLLM Integration Details +"""""""""""""""""""""""" + +When ``engine_type`` includes ``vllm``, the pipeline loads vLLM engines +**in-process** alongside the native PyTorch components -- there is no +disaggregated multi-server setup. Each vLLM component runs as an +``AsyncLLM`` engine in the same Python process, sharing GPU memory with the +native perception encoder and codec decoder. + +The vLLM engines manage their own KV caches via PagedAttention. Both +``VLLMLLM`` and ``VLLMEarTTS`` inherit from ``VLLMModelBase``, which wraps a +``CustomInputAsyncVLLMEngine`` (defined in ``inference.vllm.streaming_llm_engine``) +and provides: + +- An internal ``asyncio`` event loop for blocking synchronous calls +- Request lifecycle management (start, abort, restart) +- Automatic checkpoint conversion to vLLM format on first use + +``CustomInputAsyncVLLMEngine`` is a thin wrapper around vLLM's ``AsyncLLM`` +that adds support for custom input tensor specifications (multi-tensor +inputs like audio embeddings, subword IDs, speaker latents). The +``engine_kind`` parameter (``"llm"`` or ``"eartts"``) selects EarTTS-specific +runtime settings (TRITON attention backend, TF32 precision, guidance scale) +without introducing inheritance between TTS and LLM engine classes. + +This requires a custom vLLM fork with NemotronVoiceChat model support: + +.. code-block:: bash + + pip install git+https://github.com/vklimkov-nvidia/vllm@vklimkov/voicechat + + Configuration ------------- diff --git a/examples/speechlm2/nemo_inference_pipelines/conf/s2s_streaming.yaml b/examples/speechlm2/nemo_inference_pipelines/conf/s2s_streaming.yaml index ac50445ef97c..ef1177642a93 100644 --- a/examples/speechlm2/nemo_inference_pipelines/conf/s2s_streaming.yaml +++ b/examples/speechlm2/nemo_inference_pipelines/conf/s2s_streaming.yaml @@ -33,6 +33,7 @@ s2s: vllm_llm_config: model_path: ${s2s.model_path} # Inherits from s2s.model_path max_model_len: 8192 # Maximum sequence length for vLLM + max_num_batched_tokens: 768 # Max tokens per forward pass (prefill chunk size) gpu_memory_utilization: 0.35 # GPU memory utilization (0.0-1.0) dtype: bfloat16 # Data type for vLLM inference engine_path: null # Path to vLLM engine (null = auto-convert if needed) @@ -41,6 +42,7 @@ s2s: vllm_tts_config: model_path: ${s2s.vllm_llm_config.model_path} # Inherits from s2s.model_path max_model_len: ${s2s.vllm_llm_config.max_model_len} + max_num_batched_tokens: ${s2s.vllm_llm_config.max_num_batched_tokens} gpu_memory_utilization: ${s2s.vllm_llm_config.gpu_memory_utilization} dtype: float32 # EarTTS requires float32 for proper audio quality (bfloat16 causes hallucinations) engine_path: null diff --git a/nemo/collections/speechlm2/inference/model_wrappers/backend/__init__.py b/nemo/collections/speechlm2/inference/model_wrappers/backend/__init__.py new file mode 100644 index 000000000000..111b94eccc7b --- /dev/null +++ b/nemo/collections/speechlm2/inference/model_wrappers/backend/__init__.py @@ -0,0 +1,23 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo.collections.speechlm2.inference.model_wrappers.backend.interface import ModelInterface +from nemo.collections.speechlm2.inference.model_wrappers.backend.pytorch.model import PyTorchLLM +from nemo.collections.speechlm2.inference.model_wrappers.backend.pytorch.eartts import PyTorchEarTTS, TTSGenerationResult + +try: + from nemo.collections.speechlm2.inference.model_wrappers.backend.vllm.llm import VLLMLLM + from nemo.collections.speechlm2.inference.model_wrappers.backend.vllm.eartts import VLLMEarTTS +except ImportError: + pass diff --git a/nemo/collections/speechlm2/inference/model_wrappers/backend/interface.py b/nemo/collections/speechlm2/inference/model_wrappers/backend/interface.py new file mode 100644 index 000000000000..ba69eb34740a --- /dev/null +++ b/nemo/collections/speechlm2/inference/model_wrappers/backend/interface.py @@ -0,0 +1,255 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Base interface for S2S model inference backends. + +NemotronVoiceChat has two main inference components that can each run on +either native PyTorch or vLLM: + +- **LLM** (inside DuplexSTT): takes audio embeddings, produces text and ASR + tokens. Wrapped as ``model_llm_interface`` in the inference wrapper. +- **TTS** (EarTTS): takes text tokens, produces audio codec codes. + Wrapped as ``model_eartts_interface`` in the inference wrapper. + +This module defines the abstract ``ModelInterface`` that both backends +implement, with shared sampling utilities (top-p, repetition penalty, +temperature). +""" + +from abc import ABC, abstractmethod +from typing import Any +import math +import torch + +from nemo.utils import logging + + +class ModelInterface(ABC): + """ + Abstract base class for LLM and TTS inference backends. + + Concrete implementations wrap either the LLM component (DuplexSTT backbone + that predicts text/ASR tokens from audio embeddings) or the TTS component + (EarTTS that generates audio codec codes from text tokens). Each component + can run on native PyTorch or vLLM. + + Provides shared sampling utilities (top-p, repetition penalty, temperature) + and lifecycle methods (compile, prefill, subword cache) so the inference + wrapper can treat all backends uniformly. + """ + + def __init__( + self, + special_token_ids: set[int] | None = None, + top_p: float = 1.0, + repetition_penalty: float = 1.0, + temperature: float = 1.0, + ): + """ + Initialize base interface with sampling parameters. + + Args: + special_token_ids: Set of special token IDs (pad, eos, bos) that should bypass sampling. + These tokens will use greedy decoding and won't be penalized. + top_p: Top-p (nucleus) sampling threshold. 1.0 disables it (greedy). Default: 1.0 + repetition_penalty: Penalty for repeated tokens. 1.0 disables it. Default: 1.0 + temperature: Temperature for sampling. 1.0 = no change, <1.0 = sharper, >1.0 = flatter. + 0.0 = greedy (argmax). Default: 1.0 + """ + if not math.isfinite(temperature): + raise ValueError(f"temperature must be finite, got {temperature}") + if temperature < 0.0: + raise ValueError(f"temperature must be >= 0.0, got {temperature}") + + self.special_token_ids = special_token_ids if special_token_ids is not None else set() + self.top_p = top_p + self.repetition_penalty = repetition_penalty + self.temperature = temperature + + # Pre-built tensor for special-token filtering in repetition penalty. + # Lazily moved to the right device on first use (see _sample_text_token). + self._special_ids_tensor: torch.Tensor | None = ( + torch.tensor(sorted(self.special_token_ids), dtype=torch.long) + if self.special_token_ids else None + ) + + def _sample_text_token( + self, + logits: torch.Tensor, + generated_tokens: torch.Tensor, + current_step: int, + sampling_params: dict[str, float] | None = None, + ) -> torch.Tensor: + """ + Sample text tokens with optional top-p sampling and repetition penalty. + Special tokens (pad, eos, bos) bypass sampling - if they have highest probability, return them directly. + + Args: + logits: Logits tensor of shape (B, V) for vocabulary V. + generated_tokens: Previously generated tokens of shape (B, T). + current_step: Current decoding step (used to slice generated_tokens). + sampling_params: Optional per-request overrides for ``top_p``, + ``temperature``, and ``repetition_penalty``. Missing keys + fall back to ``self.*`` (the pipeline-level defaults). + + Returns: + Sampled token ids of shape (B,). + """ + top_p = sampling_params.get("top_p", self.top_p) if sampling_params else self.top_p + temperature = sampling_params.get("temperature", self.temperature) if sampling_params else self.temperature + rep_penalty = sampling_params.get("repetition_penalty", self.repetition_penalty) if sampling_params else self.repetition_penalty + + B, V = logits.shape + device = logits.device + + # First check greedy tokens (on original logits) + greedy_tokens = logits.argmax(dim=-1) # (B,) + + # If no sampling needed (all disabled), return greedy + if top_p >= 1.0 and rep_penalty == 1.0 and (temperature == 1.0 or temperature == 0.0): + return greedy_tokens + + # temperature=0 means greedy + if temperature == 0.0: + return greedy_tokens + + # For each batch, if greedy is special token, use greedy; otherwise sample + sampled_tokens = greedy_tokens.clone() + + # Ensure cached special-token tensor is on the right device (once). + if self._special_ids_tensor is not None and self._special_ids_tensor.device != device: + self._special_ids_tensor = self._special_ids_tensor.to(device) + + for b in range(B): + # If greedy token is a special token, keep it (no sampling) + if greedy_tokens[b].item() in self.special_token_ids: + continue + + # Not a special token - apply repetition penalty and sampling + batch_logits = logits[b].clone() # (V,) + + # Apply repetition penalty (vectorized, no Python loop) + if rep_penalty != 1.0 and current_step > 0: + unique_prev = generated_tokens[b, :current_step].unique() + # Exclude special tokens from penalty + if self._special_ids_tensor is not None: + ids_t = self._special_ids_tensor + if ids_t.device != unique_prev.device: + ids_t = ids_t.to(unique_prev.device) + unique_prev = unique_prev[~torch.isin(unique_prev, ids_t)] + + if unique_prev.numel() > 0: + if unique_prev.device != batch_logits.device: + unique_prev = unique_prev.to(batch_logits.device) + prev_logits = batch_logits[unique_prev] + # Positive logits are divided, negative logits are multiplied + # (same as the standard repetition_penalty convention) + batch_logits[unique_prev] = torch.where( + prev_logits > 0, + prev_logits / rep_penalty, + prev_logits * rep_penalty, + ) + + # Apply temperature scaling + if temperature != 1.0: + batch_logits = batch_logits / temperature + + # Fall back to greedy if logits are non-finite before top-p + # (top-p intentionally introduces -inf, so check must happen first) + if not torch.isfinite(batch_logits).all(): + logging.warning( + f"_sample_text_token: logits contain NaN or inf at step {current_step}, batch {b}: " + f"nan={batch_logits.isnan().sum().item()}, " + f"inf={batch_logits.isinf().sum().item()}, " + f"min={batch_logits[~batch_logits.isnan()].min().item() if not batch_logits.isnan().all() else 'all_nan'}, " + f"max={batch_logits[~batch_logits.isnan()].max().item() if not batch_logits.isnan().all() else 'all_nan'}" + ) + sampled_tokens[b] = greedy_tokens[b] + continue + + # Apply top-p (nucleus) sampling + if top_p < 1.0: + sorted_logits, sorted_indices = torch.sort(batch_logits, descending=True) + sorted_probs = torch.softmax(sorted_logits, dim=-1) + cumulative_probs = torch.cumsum(sorted_probs, dim=-1) + + # Remove tokens with cumulative prob > top_p, keeping at least one + sorted_indices_to_remove = cumulative_probs > top_p + # Shift to keep the first token that exceeds threshold + sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone() + sorted_indices_to_remove[0] = False + + # Set to -inf + indices_to_remove = sorted_indices[sorted_indices_to_remove] + batch_logits[indices_to_remove] = float('-inf') + + probs = torch.softmax(batch_logits, dim=-1) + sampled_tokens[b] = torch.multinomial(probs, num_samples=1).item() + + return sampled_tokens + + @abstractmethod + def __call__( + self, + input_embeds: torch.Tensor, + cache: Any | None = None, + **kwargs + ) -> dict[str, Any]: + """ + Perform model inference. + + Args: + input_embeds: Input embeddings tensor of shape [batch, seq_len, hidden_dim] + cache: Optional cache object (e.g., DynamicCache for transformers) + **kwargs: Additional model-specific arguments + + Returns: + Dictionary containing: + - 'text_logits': Logits for text generation [batch, seq_len, vocab_size] + - 'cache': Updated cache object (if cache was provided) + - Additional model-specific outputs + """ + pass + + @abstractmethod + def to(self, device_or_dtype: torch.device | torch.dtype) -> 'ModelInterface': + """Move model to specified device or convert to specified dtype.""" + pass + + @abstractmethod + def eval(self) -> 'ModelInterface': + """Set model to evaluation mode.""" + pass + + @property + @abstractmethod + def device(self) -> torch.device: + """Get the device of the model.""" + pass + + def compile(self, **kwargs) -> None: + """Apply torch.compile optimizations. No-op by default; override in subclasses.""" + pass + + def setup_subword_cache(self, cfg) -> None: + """Enable TTS subword embedding cache. No-op by default; override in subclasses.""" + pass + + def prefill_prompt(self, embeddings, **kwargs): + """Prefill the model with prompt embeddings before streaming begins. + + Override in subclasses to implement engine-specific prefill logic. + """ + raise NotImplementedError(f"{type(self).__name__} does not implement prefill_prompt") diff --git a/nemo/collections/speechlm2/inference/model_wrappers/backend/pytorch/__init__.py b/nemo/collections/speechlm2/inference/model_wrappers/backend/pytorch/__init__.py new file mode 100644 index 000000000000..9e3fb699d9f6 --- /dev/null +++ b/nemo/collections/speechlm2/inference/model_wrappers/backend/pytorch/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/speechlm2/inference/model_wrappers/backend/pytorch/eartts.py b/nemo/collections/speechlm2/inference/model_wrappers/backend/pytorch/eartts.py new file mode 100644 index 000000000000..fe4c4f6df4f9 --- /dev/null +++ b/nemo/collections/speechlm2/inference/model_wrappers/backend/pytorch/eartts.py @@ -0,0 +1,134 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Native PyTorch backend for the TTS (EarTTS) component of NemotronVoiceChat. + +Wraps the DuplexEARTTS model for direct PyTorch inference, providing the +same ``ModelInterface`` API as ``VLLMEarTTS`` so the inference wrapper can +treat both backends uniformly. + +Used as ``model_eartts_interface`` in the inference wrapper when +``engine_type="native_eartts"``. +""" + +from typing import Any +from dataclasses import dataclass, fields +import torch + +from nemo.utils import logging +from nemo.collections.speechlm2.inference.model_wrappers.backend.interface import ModelInterface + + +@dataclass +class TTSGenerationResult: + """Result from a single TTS generation step (shared by PyTorch and vLLM backends).""" + codes: torch.Tensor # Generated acoustic tokens + past_key_values: Any # Updated cache (if applicable) + + def __getitem__(self, item: str | int): + """Allows for accessing attributes by key or index.""" + if isinstance(item, str): + return getattr(self, item) + else: + return getattr(self, fields(self)[item].name) + + +class PyTorchEarTTS(ModelInterface): + """ + Native PyTorch backend for the TTS (EarTTS) component. + + Wraps ``DuplexEARTTS.infer_codes_one_step()`` for per-frame audio codec + generation, conforming to the ModelInterface contract. + """ + + def __init__(self, tts_model): + """ + Args: + tts_model: A ``DuplexEARTTS`` instance (``NemotronVoiceChat.tts_model``). + """ + super().__init__() + self.tts_model = tts_model + + def __call__(self, inputs: dict, **kwargs) -> TTSGenerationResult: + """ + Run one TTS code-generation step via ``infer_codes_one_step``. + + Args: + inputs: Keyword arguments for ``DuplexEARTTS.infer_codes_one_step`` + (current_subword_id, prev_subword_id, current_subword_mask, + prev_audio_tokens, past_key_values, guidance_enabled, etc.) + + Returns: + TTSGenerationResult with generated codes and updated cache. + """ + codes, cache = self.tts_model.infer_codes_one_step(**inputs) + return TTSGenerationResult(codes=codes, past_key_values=cache) + + def prefill_prompt(self, init_inputs, prompt_token_ids=None, request_id=None, **kwargs): + """Prefill TTS with speaker embedding / warmup inputs. + + For native PyTorch, this calls the inner ``tts_model`` directly + (the actual EarTTS nn.Module, not the DuplexEARTTS wrapper). + + Args: + init_inputs: Dict of initial TTS inputs from ``get_init_inputs()``. + prompt_token_ids: Unused for native (vLLM-only parameter). + request_id: Unused for native (vLLM-only parameter). + + Returns: + Model outputs (with ``past_key_values`` and ``codes``). + """ + return self.tts_model.tts_model(**init_inputs) + + def compile(self, **kwargs) -> None: + """Apply torch.compile to the TTS backbone if available.""" + tts_backbone = getattr(self.tts_model, 'tts_model', None) + if tts_backbone is not None and hasattr(tts_backbone, 'backbone'): + mode = kwargs.get('mode', 'default') + logging.info(f"Compiling TTS backbone with torch.compile(mode='{mode}')...") + tts_backbone.backbone = torch.compile(tts_backbone.backbone, mode=mode) + logging.info(" TTS backbone compiled") + + def setup_subword_cache(self, cfg) -> None: + """Enable TTS subword embedding cache from config flags.""" + from omegaconf import OmegaConf + + tts_inner = getattr(self.tts_model, 'tts_model', None) + if tts_inner is None or not hasattr(tts_inner, 'config'): + return + if bool(cfg.get("use_tts_subword_cache", False)): + OmegaConf.update(tts_inner.config, "use_tts_subword_cache", True) + logging.info("TTS speedup enabled: use_tts_subword_cache") + embed_subword = getattr(tts_inner, 'embed_subword', None) + if embed_subword is not None and hasattr(embed_subword, 'use_tts_subword_cache'): + embed_subword.use_tts_subword_cache = True + + def to(self, device_or_dtype: torch.device | torch.dtype) -> 'PyTorchEarTTS': + """Move underlying TTS model to device or convert dtype.""" + self.tts_model = self.tts_model.to(device_or_dtype) + return self + + def eval(self) -> 'PyTorchEarTTS': + """Set underlying TTS model to eval mode.""" + self.tts_model.eval() + return self + + @property + def device(self) -> torch.device: + """Get device of the underlying TTS model.""" + try: + return next(self.tts_model.parameters()).device + except StopIteration: + return torch.device('cpu') diff --git a/nemo/collections/speechlm2/inference/model_wrappers/backend/pytorch/model.py b/nemo/collections/speechlm2/inference/model_wrappers/backend/pytorch/model.py new file mode 100644 index 000000000000..d45a1683cb43 --- /dev/null +++ b/nemo/collections/speechlm2/inference/model_wrappers/backend/pytorch/model.py @@ -0,0 +1,235 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Native PyTorch backend for the LLM component of NemotronVoiceChat. + +Wraps the DuplexSTT model (which contains the Nemotron LLM backbone) for +direct PyTorch inference with top-p sampling and repetition penalty support. + +Used as ``model_llm_interface`` in the inference wrapper when +``engine_type="native_llm"``. +""" + +from typing import Any +import torch + +from nemo.utils import logging +from nemo.collections.speechlm2.inference.model_wrappers.backend.interface import ModelInterface + + +class PyTorchLLM(ModelInterface): + """ + Native PyTorch backend for the LLM (DuplexSTT) component. + + Wraps the DuplexSTT model's forward pass (``stt_model()``) to produce + text/ASR token predictions, conforming to the ModelInterface contract. + Supports top-p sampling and repetition penalty. + """ + + def __init__( + self, + model, + special_token_ids: set[int] | None = None, + top_p: float = 1.0, + repetition_penalty: float = 1.0, + temperature: float = 1.0, + ): + """ + Initialize with an existing model. + + Args: + model: The DuplexS2SExternalSpeechDecoderModel instance + special_token_ids: Set of special token IDs (pad, eos, bos) that should bypass sampling. + These tokens will use greedy decoding and won't be penalized. + If None, will try to extract from model.tokenizer for tokens: + '' (bos), '' (eos), '' (pad). + You can also manually provide: {tokenizer.pad_token_id, tokenizer.eos_token_id, tokenizer.bos_token_id} + top_p: Top-p (nucleus) sampling threshold. 1.0 disables it (greedy). Default: 1.0 + repetition_penalty: Penalty for repeated tokens. 1.0 disables it. Default: 1.0 + Recommended value when enabling: 1.2 + temperature: Temperature for sampling. 1.0 = no change, <1.0 = sharper, >1.0 = flatter. + 0.0 = greedy (argmax). Default: 1.0 + """ + DEFAULT_SPECIAL_TOKEN_IDS = {1, 2, 12} + + if special_token_ids is None: + special_token_ids = self._extract_special_token_ids_from_nemo(model) + if not special_token_ids: + special_token_ids = DEFAULT_SPECIAL_TOKEN_IDS + + super().__init__( + special_token_ids=special_token_ids, + top_p=top_p, + repetition_penalty=repetition_penalty, + temperature=temperature, + ) + + self.model = model + + logging.debug(f"Special token IDs: {self.special_token_ids}") + + sampling_active = top_p < 1.0 or repetition_penalty != 1.0 or (temperature != 1.0 and temperature != 0.0) + if sampling_active and not self.special_token_ids: + import warnings + warnings.warn( + "Sampling is enabled but special_token_ids is empty. " + "Could not auto-extract from model.tokenizer. " + "Please provide special_token_ids manually to ensure special tokens use greedy decoding. " + "Otherwise, EOS tokens may be randomly sampled and generation may not stop properly!" + ) + + def __call__( + self, + input_embeds: torch.Tensor, + cache: Any | None = None, + cache_position: torch.Tensor | None = None, + generated_tokens: torch.Tensor | None = None, + current_step: int = 0, + return_logits: bool = False, + sampling_params: dict[str, float] | None = None, + **kwargs + ) -> dict[str, Any]: + """ + Perform inference using the native model. + + Args: + input_embeds: Input embeddings [batch, seq_len, hidden_dim] + cache: Optional DynamicCache or HybridMambaAttentionDynamicCache + cache_position: Optional position tensor for Nemotron models + generated_tokens: Previously generated tokens [batch, num_generated]. + Required for repetition_penalty. If None, creates empty tensor. + current_step: Current decoding step. Used for repetition penalty. + sampling_params: Optional per-request overrides for sampling + (top_p, temperature, repetition_penalty). + **kwargs: Additional arguments passed to the model + + Returns: + Dictionary with 'predicted_token', 'asr_predicted_token', and 'cache' + """ + result = self.model.stt_model(input_embeds, cache=cache, cache_position=cache_position, **kwargs) + + if not isinstance(result, dict): + raise TypeError(f"Model returned {type(result)}, expected dict") + + if 'text_logits' not in result: + raise KeyError("Model output must contain 'text_logits' key") + + text_logits = result["text_logits"][:, -1] # [batch, vocab_size] + batch_size = text_logits.shape[0] + + if generated_tokens is None: + gen_tokens = torch.empty(batch_size, 0, device=text_logits.device, dtype=torch.long) + else: + gen_tokens = generated_tokens + + predicted_token = self._sample_text_token( + logits=text_logits, + generated_tokens=gen_tokens, + current_step=current_step, + sampling_params=sampling_params, + ) + + # ASR tokens use greedy decoding (no sampling) + asr_predicted_token = result["asr_logits"][:, -1].argmax(dim=-1) + + ans = { + "predicted_token": predicted_token, + "asr_predicted_token": asr_predicted_token, + "cache": result.get("cache", None), + } + if return_logits: + ans["text_logits"] = result["text_logits"] + ans["asr_logits"] = result.get("asr_logits") + if "function_logits" in result: + ans["function_logits"] = result["function_logits"] + if "function_logits" in result: + ans["function_predicted_token"] = result["function_logits"][:, -1].argmax(dim=-1) + return ans + + @staticmethod + def _extract_special_token_ids_from_nemo(model) -> set[int]: + """ + Extract special token IDs from NeMo model's tokenizer. + + NeMo tokenizer uses bos_token, eos_token, pad_token (not bos_token_id). + Then converts token strings to IDs using token_to_id method. + + Args: + model: The DuplexS2SExternalSpeechDecoderModel instance + + Returns: + Set of special token IDs, or empty set if extraction fails + """ + special_ids = set() + try: + tokenizer = model.stt_model.tokenizer + except AttributeError: + logging.debug("Cannot extract special token IDs: model has no stt_model.tokenizer") + return special_ids + + for attr in ('bos_token', 'eos_token', 'pad_token'): + token = getattr(tokenizer, attr, None) + if token is not None and hasattr(tokenizer, 'token_to_id'): + tid = tokenizer.token_to_id(token) + if tid is not None and isinstance(tid, int): + special_ids.add(tid) + + return special_ids + + def to(self, device_or_dtype: torch.device | torch.dtype) -> 'PyTorchLLM': + """Move underlying model to device or convert dtype.""" + self.model = self.model.to(device_or_dtype) + return self + + def eval(self) -> 'PyTorchLLM': + """Set underlying model to eval mode.""" + self.model.eval() + return self + + @property + def device(self) -> torch.device: + """Get device of the underlying model.""" + try: + return next(self.model.parameters()).device + except StopIteration: + return torch.device('cpu') + + def prefill_prompt(self, embeddings, cache=None, cache_position=None, **kwargs): + """Prefill the native LLM with prompt embeddings to warm up the KV cache. + + Args: + embeddings: Prompt embeddings [batch, seq_len, hidden_dim]. + cache: KV cache object to update in-place. + cache_position: Position tensor for the prompt tokens. + + Returns: + Dictionary with updated 'cache'. + """ + result = self.model.stt_model(embeddings, cache=cache, cache_position=cache_position, **kwargs) + if not isinstance(result, dict): + raise TypeError(f"Model returned {type(result)}, expected dict") + return {"cache": result.get("cache", cache)} + + def __getattr__(self, name: str): + """ + Delegate attribute access to the underlying model. + + This allows transparent access to model attributes like + perception, tokenizer, etc. + """ + if name in ('model', '__dict__', '__class__'): + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") + + return getattr(self.model, name) diff --git a/nemo/collections/speechlm2/inference/model_wrappers/backend/vllm/__init__.py b/nemo/collections/speechlm2/inference/model_wrappers/backend/vllm/__init__.py new file mode 100644 index 000000000000..9e3fb699d9f6 --- /dev/null +++ b/nemo/collections/speechlm2/inference/model_wrappers/backend/vllm/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/speechlm2/inference/model_wrappers/backend/vllm/base.py b/nemo/collections/speechlm2/inference/model_wrappers/backend/vllm/base.py new file mode 100644 index 000000000000..80a967857cd6 --- /dev/null +++ b/nemo/collections/speechlm2/inference/model_wrappers/backend/vllm/base.py @@ -0,0 +1,303 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Base class for vLLM-backed inference backends. + +Provides shared vLLM engine lifecycle management (init, abort, restart, +shutdown) used by both the LLM backend (``VLLMLLM`` -- DuplexSTT text +token prediction) and the TTS backend (``VLLMEarTTS`` -- EarTTS audio +codec generation). +""" + +from abc import abstractmethod +from typing import Any +import os +import torch + +from nemo.utils import logging +from nemo.collections.speechlm2.inference.model_wrappers.backend.interface import ModelInterface + + +class VLLMModelBase(ModelInterface): + """ + Base class for vLLM-backed model interfaces. + + Wraps a CustomInputAsyncVLLMEngine to provide streaming inference while + conforming to the ModelInterface contract. Supports two usage modes: + + 1. **Blocking component** (default): Call the model synchronously via + ``__call__``. The async engine runs on an internal event loop. + This is the mode used by the S2S streaming inference pipeline + (``StreamingS2SPipeline`` / ``NemotronVoicechatInferenceWrapper``). + + 2. **Async standalone server**: Use ``_async_inference()`` directly + from your own async event loop for concurrent multi-stream serving + (e.g., a WebSocket server handling multiple streams concurrently). + + Subclasses must implement: + - ``_convert_ckpt(save_path)``: checkpoint conversion to vLLM format + - ``__call__(...)``: synchronous inference entry point + - ``_process_inputs_to_outputs(...)``: async core inference logic + + Requires vLLM from https://github.com/vklimkov-nvidia/vllm (branch vklimkov/voicechat). + """ + + def __init__( + self, + model_path: str, + max_model_len: int = 1024, + max_num_batched_tokens: int = 768, + gpu_memory_utilization: float = 0.8, + trust_remote_code: bool = True, + dtype: str = "bfloat16", + engine_path: str | None = None, + pretrained_llm: str | None = None, + special_token_ids: set[int] | None = None, + top_p: float = 1.0, + repetition_penalty: float = 1.0, + temperature: float = 1.0, + model_type: str = "llm", + **sampling_kwargs + ): + """ + Initialize vLLM model with engine creation, event loop setup, and warm-up. + + Args: + model_path: Path to the vLLM-compatible model checkpoint + max_model_len: Maximum sequence length + max_num_batched_tokens: Maximum tokens per vLLM forward pass. + Controls prefill chunk size and max concurrent decode streams. + gpu_memory_utilization: GPU memory utilization ratio (0.0-1.0) + trust_remote_code: Whether to trust remote code in model + dtype: Data type for embeddings (e.g., "bfloat16", "float16") + engine_path: Optional path to pre-converted vLLM model + pretrained_llm: Optional path to pretrained LLM for conversion + special_token_ids: Set of special token IDs (pad, eos, bos) that bypass + sampling and always use greedy decoding. + top_p: Top-p (nucleus) sampling threshold. Default: 1.0 (disabled). + repetition_penalty: Penalty for repeated tokens. Default: 1.0 (disabled). + temperature: Sampling temperature. Default: 1.0 (no scaling). + model_type: Type of model for vLLM engine ("llm", "eartts", etc.) + **sampling_kwargs: Additional vLLM sampling parameters. + + Note: + vLLM internally runs greedy decoding (temperature=0, ignore_eos=True). + Text sampling (top_p, repetition_penalty, temperature) is applied + post-hoc by ``ModelInterface._sample_text_token`` on the logits + returned by vLLM, not by vLLM's own sampler. + """ + super().__init__( + special_token_ids=special_token_ids, + top_p=top_p, + repetition_penalty=repetition_penalty, + temperature=temperature, + ) + + import asyncio + from nemo.collections.speechlm2.inference.vllm.streaming_llm_engine import create_engine + + self.model_path = model_path + self.pretrained_llm = pretrained_llm + self._dtype = dtype + self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Force greedy decoding in vLLM by setting temperature=0 if not specified + if 'temperature' not in sampling_kwargs: + sampling_kwargs['temperature'] = 0.0 + + if engine_path is None: + dir_name = os.path.basename(os.path.normpath(model_path)) + engine_path = "/tmp/" + dir_name + f"_vllm_converted_{model_type}" + if os.path.exists(engine_path): + logging.info(f"Found existing vLLM converted model at {engine_path}") + else: + self._convert_ckpt(save_path=engine_path) + + self.engine = create_engine( + engine_type=model_type, + model_path=engine_path, + max_model_len=max_model_len, + max_num_batched_tokens=max_num_batched_tokens, + gpu_memory_utilization=gpu_memory_utilization, + trust_remote_code=trust_remote_code, + dtype=dtype, + **sampling_kwargs + ) + self._request_counter = 0 + + try: + self._loop = asyncio.get_event_loop() + except RuntimeError: + self._loop = asyncio.new_event_loop() + asyncio.set_event_loop(self._loop) + + logging.info("Initializing vLLM engine (this may take a moment)...") + self._loop.run_until_complete(self.engine.initialize()) + + if self.engine.engine.tokenizer is not None and not self.special_token_ids: + self.special_token_ids = self._get_special_token_ids_from_vllm_tokenizer(self.engine.engine.tokenizer) + + logging.debug(f"Special token IDs: {self.special_token_ids}") + logging.info("vLLM engine ready!") + + @abstractmethod + def _convert_ckpt(self, save_path: str): + """Convert existing checkpoint to vLLM format and save.""" + pass + + @staticmethod + def _get_special_token_ids_from_vllm_tokenizer(tokenizer) -> set[int]: + """ + Extract special token IDs from a vLLM tokenizer. + Looks for: '' (bos), '' (eos), '' (pad). + + Args: + tokenizer: A vLLM CachedTokenizer instance. + + Returns: + Set of special token IDs. + """ + special_ids = set() + for token in ('', '', ''): + try: + tid = tokenizer.convert_tokens_to_ids(token) + if isinstance(tid, int): + special_ids.add(tid) + except (KeyError, AttributeError): + logging.debug(f"Token '{token}' not found in vLLM tokenizer, skipping") + return special_ids + + def _generate_request_id(self) -> str: + """Generate a unique request ID.""" + self._request_counter += 1 + return f"vllm_request_{self._request_counter}" + + async def _async_inference( + self, + inputs: torch.Tensor | list[torch.Tensor] | dict, + request_id: str, + **kwargs + ) -> dict[str, Any]: + """ + Async inference using the streaming engine. + + Checks request status (starting or restarting as needed) and + delegates to the subclass ``_process_inputs_to_outputs``. + + Args: + inputs: Model inputs (tensor for LLM, dict for EarTTS) + request_id: Unique request identifier + **kwargs: Passed through to ``_process_inputs_to_outputs`` + + Returns: + Dictionary with model-specific outputs + """ + from nemo.collections.speechlm2.inference.vllm.streaming_llm_engine import StreamStatus + + if request_id not in self.engine.requests: + await self.engine.start_generation(request_id=request_id) + else: + request_state = self.engine.requests[request_id] + if request_state.status in (StreamStatus.FINISHED, StreamStatus.ABORTED): + logging.warning( + f"Request {request_id} was {request_state.status.value}. " + f"Generated {len(request_state.generated_tokens)} tokens before stopping. " + "Cleaning up and restarting..." + ) + try: + await self.engine.abort_generation(request_id) + except Exception: + pass + await self.engine.start_generation(request_id=request_id) + + return await self._process_inputs_to_outputs(inputs, request_id, **kwargs) + + @abstractmethod + async def _process_inputs_to_outputs(self, inputs, request_id: str, **kwargs) -> dict[str, Any]: + """Process model inputs and return outputs. Subclasses must implement.""" + pass + + def to(self, device_or_dtype: torch.device | torch.dtype) -> 'VLLMModelBase': + """ + Move model to specified device or convert to specified dtype. + + Note: vLLM manages device placement internally, this is for compatibility. + """ + if isinstance(device_or_dtype, torch.device): + self._device = device_or_dtype + return self + + def eval(self) -> 'VLLMModelBase': + """Set model to evaluation mode (vLLM is always in eval mode).""" + return self + + @property + def device(self) -> torch.device: + """Get the device of the model.""" + return self._device + + def abort_request(self, request_id: str) -> bool: + """ + Abort a specific generation request. + + Args: + request_id: Request identifier to abort + + Returns: + bool: True if abort was successful + """ + return self._loop.run_until_complete( + self.engine.abort_generation(request_id) + ) + + def restart_request(self, request_id: str) -> bool: + """ + Restart a finished or aborted generation request. + + Args: + request_id: Request identifier to restart + + Returns: + bool: True if restart was successful + """ + if request_id in self.engine.requests: + self.abort_request(request_id) + + return self._loop.run_until_complete( + self.engine.start_generation(request_id=request_id) + ) + + def get_request_status(self, request_id: str | None = None) -> dict[str, Any]: + """ + Get status of a specific request or all requests. + + Args: + request_id: Optional request ID. If None, returns all requests. + + Returns: + Status dictionary + """ + return self.engine.get_status(request_id) + + def shutdown(self): + """Shutdown the vLLM engine and cleanup resources.""" + self._loop.run_until_complete(self.engine.shutdown()) + + def __del__(self): + """Cleanup on deletion.""" + try: + self.shutdown() + except Exception: + pass diff --git a/nemo/collections/speechlm2/inference/model_wrappers/backend/vllm/eartts.py b/nemo/collections/speechlm2/inference/model_wrappers/backend/vllm/eartts.py new file mode 100644 index 000000000000..051388e7dda6 --- /dev/null +++ b/nemo/collections/speechlm2/inference/model_wrappers/backend/vllm/eartts.py @@ -0,0 +1,191 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +vLLM backend for the TTS (EarTTS) component of NemotronVoiceChat. + +EarTTS generates audio codec codes from text tokens: given the current +subword ID, mask, previous audio codes, and speaker latent, it predicts +the next frame of RVQ acoustic tokens. This module wraps it in a +CustomInputAsyncVLLMEngine for accelerated inference. + +Used as ``model_eartts_interface`` in the inference wrapper. +""" + +from typing import Any +import os +import torch + +from nemo.utils import logging +from nemo.collections.speechlm2.inference.model_wrappers.backend.vllm.base import VLLMModelBase +from nemo.collections.speechlm2.inference.model_wrappers.backend.pytorch.eartts import TTSGenerationResult + + +class VLLMEarTTS(VLLMModelBase): + """ + vLLM backend for the TTS (EarTTS) component. + + Accepts dictionary inputs with codes, subword IDs, masks, and speaker + latents, and returns ``TTSGenerationResult`` with generated acoustic tokens. + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._speaker_latent_dim = None + logging.info("VLLMEarTTS initialized with EARTTS-specific settings.") + + def _convert_ckpt(self, save_path: str): + """Convert EARTTS checkpoint to vLLM format.""" + from nemo.collections.speechlm2.inference.vllm.scripts.convert_eartts_checkpoint import convert + ckpt_dir = os.path.normpath(self.model_path) + config_file = os.path.join(ckpt_dir, "config.json") + model_ckpt = os.path.join(ckpt_dir, "model.safetensors") + convert(save_path, config_file, model_ckpt) + + def __call__( + self, + inputs: dict[str, torch.Tensor] | None = None, + request_id: str | None = None, + prompt_token_ids: list | None = None, + **kwargs + ) -> TTSGenerationResult: + """ + Perform TTS inference using vLLM streaming engine. + + Supports two calling conventions: + 1. model(inputs_dict, request_id="id") - pass dict as first positional arg + 2. model(**inputs_dict) - unpack dict as keyword arguments + + Args: + inputs: Optional dict of model inputs (if None, uses **kwargs) + request_id: Optional request identifier + prompt_token_ids: Optional list of prompt token IDs for prefill + **kwargs: Model inputs as keyword arguments (used if inputs is None) + + Returns: + TTSGenerationResult containing generated acoustic tokens and cache + """ + if inputs is not None: + input_dict = inputs + else: + if request_id is None: + request_id = kwargs.pop('request_id', None) + input_dict = kwargs + + if request_id is None: + request_id = 'tts_request_id_1' + + result = self._loop.run_until_complete( + self._async_inference(input_dict, request_id, prompt_token_ids=prompt_token_ids) + ) + + return result + + async def _process_inputs_to_outputs( + self, + inputs: dict[str, torch.Tensor], + request_id: str, + prompt_token_ids: list | None = None, + ) -> dict[str, Any]: + """ + Process tensor inputs to generate acoustic tokens via vLLM engine. + + Args: + inputs: Dictionary with code, context_hidden_state, subword_ids, + subword_mask, and optional non_prompt_mask / audio_prompt_lantent. + request_id: Request identifier. + prompt_token_ids: Optional prompt token IDs for prefill. + + Returns: + TTSGenerationResult with generated acoustic tokens. + """ + + assert inputs["context_hidden_state"] is None, "EARTTS vllm model does not support context_hidden_state input" + + codes = inputs["code"].squeeze(0) # T x 31 + if codes.shape[0] > 1: + # In prefill stage, shift acoustic tokens for vLLM to replicate + # the NeMo logic for teacher-forced input construction. + codes = torch.nn.functional.pad(codes[:-1], [0, 0, 1, 0]) + input_tensors = [ + codes, + inputs["subword_ids"].squeeze(0), + inputs["subword_mask"].squeeze(0), + ] + + if "non_prompt_mask" in inputs: + # Apply edge detection to match native model's BOS placement logic: + # BOS should only be applied at the FIRST position where non_prompt_mask is True + non_prompt_mask = inputs["non_prompt_mask"].squeeze(0) # T + padded_prev = torch.nn.functional.pad(non_prompt_mask[:-1], [1, 0], value=False) + bos_mask = (non_prompt_mask & (~padded_prev)).to(dtype=getattr(torch, self._dtype)) + input_tensors.append(bos_mask) + + else: + current_subword_id = input_tensors[1] + # Use a tiny epsilon instead of exact 0 so the vLLM model's + # (bos_mask == 0) check is False during decoding. This prevents + # use_audio_prompt_frozen_projection from incorrectly applying the + # speaker-prompt projection to every decoding step. The epsilon is + # small enough that bos_mask * bos_emb remains negligible. + bos_mask = torch.full_like(current_subword_id, 1e-20, dtype=getattr(torch, self._dtype)) + input_tensors.append(bos_mask) + + # Pass speaker_latent: the pre-extracted speaker embedding. + # During prefill with speaker_name: audio_prompt_lantent is [1, T, hidden_size] + # During decode or speaker_reference: pass zeros so the model falls back + # to computing the latent from acoustic tokens. + if "audio_prompt_lantent" in inputs and inputs["audio_prompt_lantent"] is not None: + speaker_latent = inputs["audio_prompt_lantent"].squeeze(0) # T x hidden_size + self._speaker_latent_dim = speaker_latent.shape[-1] + input_tensors.append(speaker_latent.to(dtype=getattr(torch, self._dtype))) + else: + if self._speaker_latent_dim is None: + import json as _json + dir_name = os.path.basename(os.path.normpath(self.model_path)) + converted_config_path = os.path.join("/tmp", dir_name + "_vllm_converted_eartts", "config.json") + if os.path.exists(converted_config_path): + with open(converted_config_path) as _f: + self._speaker_latent_dim = _json.load(_f)["hidden_size"] + else: + raise RuntimeError( + f"Cannot determine speaker_latent_dim: converted config not found at {converted_config_path}. " + "Run a prefill with audio_prompt_lantent first, or ensure the converted checkpoint exists." + ) + num_tokens = codes.shape[0] + speaker_latent = torch.zeros(num_tokens, self._speaker_latent_dim, dtype=getattr(torch, self._dtype)) + input_tensors.append(speaker_latent) + + result = await self.engine.generate_next_token(input_tensors, prompt_token_ids=prompt_token_ids, request_id=request_id) + acoustic_tokens = result.custom_outputs["acoustic_tokens"] # T x 31 + step_acoustic_tokens = acoustic_tokens[-1:] # 1 x 31 + return TTSGenerationResult( + codes=step_acoustic_tokens.unsqueeze(0).cuda(), # Add batch dim back: 1 x 1 x 31 + past_key_values=None # vLLM manages cache internally + ) + + def prefill_prompt(self, init_inputs, prompt_token_ids, request_id: str, **kwargs): + """Prefill vLLM EarTTS engine with speaker embedding context. + + Args: + init_inputs: Dict of initial TTS inputs (codes, subword_ids, etc.). + prompt_token_ids: List of prompt token IDs for the speaker embedding. + request_id: Unique request identifier. + + Returns: + TTSGenerationResult with codes from the prefill step, or None. + """ + import copy + inputs_copy = copy.deepcopy(init_inputs) + return self(inputs_copy, request_id=request_id, prompt_token_ids=prompt_token_ids) diff --git a/nemo/collections/speechlm2/inference/model_wrappers/backend/vllm/llm.py b/nemo/collections/speechlm2/inference/model_wrappers/backend/vllm/llm.py new file mode 100644 index 000000000000..f778bb2c8458 --- /dev/null +++ b/nemo/collections/speechlm2/inference/model_wrappers/backend/vllm/llm.py @@ -0,0 +1,167 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +vLLM backend for the LLM component of NemotronVoiceChat. + +The LLM lives inside DuplexSTT: it takes audio frame embeddings (from the +perception encoder) and produces text tokens, ASR tokens, and optional +function-call tokens at each step. This module wraps it in a +CustomInputAsyncVLLMEngine for accelerated inference. + +Used as ``model_llm_interface`` in the inference wrapper. +""" + +from typing import Any +import torch + +from nemo.utils import logging +from nemo.collections.speechlm2.inference.model_wrappers.backend.vllm.base import VLLMModelBase + + +class VLLMLLM(VLLMModelBase): + """ + vLLM backend for the LLM (DuplexSTT) component. + + Accepts audio frame embeddings, runs them through the vLLM streaming + engine one step at a time, and returns text/ASR token predictions with + optional post-hoc sampling (top-p, repetition penalty). + """ + + def _convert_ckpt(self, save_path: str): + """Convert existing DuplexSTT checkpoint to vLLM-compatible HF format.""" + from nemo.collections.speechlm2.inference.vllm.scripts.convert_nemotronllm_checkpoint import convert_nemo_to_hf_format + + convert_nemo_to_hf_format( + checkpoint_path=self.model_path, + output_dir=save_path, + pretrained_llm=self.pretrained_llm, + dtype=self._dtype + ) + logging.info(f"Converted model saved to {save_path}") + + def __call__( + self, + input_embeds: torch.Tensor, + request_id: str | None = "request_id_1", + **kwargs + ) -> dict[str, Any]: + """ + Perform inference using vLLM streaming engine. + + Args: + input_embeds: Input embeddings [batch, seq_len, hidden_dim] + request_id: Unique request identifier for this generation + **kwargs: Additional arguments (decode_steps, generated_tokens, etc.) + + Returns: + Dictionary containing predicted_token, asr_predicted_token, cache, + is_finished, and request_id. + """ + result = self._loop.run_until_complete( + self._async_inference(input_embeds, request_id, **kwargs) + ) + return result + + async def _process_inputs_to_outputs( + self, + input_embeds: torch.Tensor, + request_id: str, + decode_steps: int = 1, + prompt_token_ids: list | None = None, + generated_tokens: torch.Tensor | None = None, + current_step: int = 0, + sampling_params: dict[str, float] | None = None, + ) -> dict[str, Any]: + """ + Process embeddings sequentially to generate text and ASR tokens. + + Args: + input_embeds: Input embeddings [batch, seq_len, hidden_dim] + request_id: Request identifier + decode_steps: Number of decoding steps to perform; 0 means prefill only + prompt_token_ids: Optional list of prompt token IDs for prefill + generated_tokens: Previously generated tokens [batch, num_generated]. + Required for repetition_penalty. If None, creates empty tensor. + current_step: Current decoding step. Used for repetition penalty. + sampling_params: Optional per-request overrides for sampling. + """ + + if decode_steps == 0: + input_embeds = input_embeds.flatten(0, 1) # [seq_len, hidden_dim] + result = await self.engine.generate_next_token([input_embeds], + prompt_token_ids, + request_id=request_id) + return True if result is not None else False + + text_token_ids = [] + asr_token_ids = [] + result = None + for i in range(decode_steps): + single_embed = input_embeds[:, i:i+1, :].squeeze(1) # [batch, hidden_dim] + + result = await self.engine.generate_next_token([single_embed], request_id=request_id) + if result is None: + break + + text_token_ids.append(result.token_id) + asr_token_ids.append(result.custom_outputs["asr_tokens"]) + + if result.is_finished: + break + + assert len(text_token_ids) <= decode_steps, "Generated more tokens than input embeddings" + is_finished = False + if text_token_ids: + is_finished = len(text_token_ids) < decode_steps or (result and result.is_finished) + + text_logits = result.custom_outputs["text_logits"] if result else None + + batch_size = text_logits.shape[0] + if generated_tokens is None: + gen_tokens = torch.empty(batch_size, 0, device=text_logits.device, dtype=torch.long) + else: + gen_tokens = generated_tokens + + predicted_token = self._sample_text_token( + logits=text_logits, + generated_tokens=gen_tokens, + current_step=current_step, + sampling_params=sampling_params, + ) + + ans = { + "predicted_token": predicted_token, + "asr_predicted_token": asr_token_ids[-1], + "cache": None, # vLLM manages cache internally + "is_finished": is_finished, + "request_id": request_id + } + if result and result.custom_outputs and "function_tokens" in result.custom_outputs: + ans["function_predicted_token"] = result.custom_outputs["function_tokens"] + return ans + + def prefill_prompt(self, embeddings: torch.Tensor, request_id: str, **kwargs) -> bool: + """Prefill vLLM LLM engine with prompt embeddings in a single bulk step. + + Args: + embeddings: Prompt embeddings [batch, seq_len, hidden_dim]. + request_id: Unique request identifier. + + Returns: + True if prefill succeeded. + """ + return self._loop.run_until_complete( + self._async_inference(embeddings, request_id, decode_steps=0) + ) diff --git a/nemo/collections/speechlm2/inference/model_wrappers/factory.py b/nemo/collections/speechlm2/inference/model_wrappers/factory.py new file mode 100644 index 000000000000..b67d9aa523d0 --- /dev/null +++ b/nemo/collections/speechlm2/inference/model_wrappers/factory.py @@ -0,0 +1,146 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Factory for creating LLM and TTS inference backends. + +NemotronVoiceChat has two components that can each be backed by native +PyTorch or vLLM. This factory returns the right backend for a given +``engine_type``: + +- ``"native_llm"`` -- wraps the PyTorch model directly (LLM component) +- ``"native_eartts"`` -- wraps the PyTorch DuplexEARTTS model (TTS component) +- ``"vllm_llm"`` -- vLLM engine for the LLM (DuplexSTT) component +- ``"vllm_eartts"`` -- vLLM engine for the TTS (EarTTS) component + +Usage: + from nemo.collections.speechlm2.inference.model_wrappers.factory import create_model + + llm = create_model(engine_type="native_llm", model=voicechat_model) + tts = create_model(engine_type="native_eartts", model=voicechat_model.tts_model) +""" + +from typing import Any + +from nemo.collections.speechlm2.inference.model_wrappers.backend.interface import ModelInterface + + +def create_model( + engine_type: str, + model=None, + vllm_config: dict[str, Any] | None = None, + special_token_ids: set[int] | None = None, + top_p: float = 1.0, + repetition_penalty: float = 1.0, + temperature: float = 1.0, + **kwargs +) -> ModelInterface: + """ + Factory function to create a single inference backend for one component. + + Each call creates a backend for one specific component (LLM or TTS) on + one specific runtime (native PyTorch or vLLM). The ``engine_type`` + must be one of the four ``{backend}_{component}`` combinations. + + Note: the user-facing config uses a *combined* ``engine_type`` like + ``"native"`` or ``"vllm_llm_vllm_eartts"`` which the wrapper + translates into two ``create_model`` calls (one for LLM, one for TTS). + + Args: + engine_type: One of "native_llm", "native_eartts", "vllm_llm", "vllm_eartts" + model: The PyTorch model to wrap. Required for native backends + (NemotronVoiceChat for LLM, DuplexEARTTS for TTS). Not used + by vLLM backends, which load their own engine from ``vllm_config``. + vllm_config: Configuration dict for vLLM engines (required for "vllm*") + special_token_ids: Set of special token IDs (pad, eos, bos) that should bypass sampling. + If None (default), will auto-extract from model.tokenizer for tokens: + '' (bos), '' (eos), '' (pad). + top_p: Top-p (nucleus) sampling threshold. 1.0 disables it (greedy). Default: 1.0 + repetition_penalty: Penalty for repeated tokens. 1.0 disables it. Default: 1.0 + temperature: Temperature for sampling. 1.0 = no change, 0.0 = greedy. Default: 1.0 + **kwargs: Additional arguments passed to the backend constructor + + Returns: + A ModelInterface instance ready for inference. + + Example: + >>> # Native PyTorch LLM with greedy decoding + >>> llm = create_model(engine_type="native_llm", model=voicechat_model) + >>> + >>> # Native PyTorch EarTTS + >>> tts = create_model(engine_type="native_eartts", model=voicechat_model.tts_model) + >>> + >>> # vLLM LLM engine + >>> llm = create_model(engine_type="vllm_llm", vllm_config={...}) + >>> + >>> # vLLM EarTTS engine + >>> tts = create_model(engine_type="vllm_eartts", vllm_config={...}) + """ + engine_type = engine_type.lower() + + if engine_type == "native_llm": + from nemo.collections.speechlm2.inference.model_wrappers.backend.pytorch.model import PyTorchLLM + + if model is None: + raise ValueError("model must be provided for native engine") + return PyTorchLLM( + model=model, + special_token_ids=special_token_ids, + top_p=top_p, + repetition_penalty=repetition_penalty, + temperature=temperature, + ) + + elif engine_type == "native_eartts": + from nemo.collections.speechlm2.inference.model_wrappers.backend.pytorch.eartts import PyTorchEarTTS + + if model is None: + raise ValueError("model (DuplexEARTTS instance) must be provided for native EarTTS engine") + return PyTorchEarTTS(tts_model=model) + + elif engine_type == "vllm_eartts": + from nemo.collections.speechlm2.inference.model_wrappers.backend.vllm.eartts import VLLMEarTTS + + if vllm_config is None: + raise ValueError("vllm_config must be provided for vLLM EARTTS engine") + return VLLMEarTTS( + **vllm_config, + model_type="eartts", + special_token_ids=special_token_ids, + top_p=top_p, + repetition_penalty=repetition_penalty, + temperature=temperature, + **kwargs + ) + + elif engine_type == "vllm_llm": + from nemo.collections.speechlm2.inference.model_wrappers.backend.vllm.llm import VLLMLLM + + if vllm_config is None: + raise ValueError("vllm_config must be provided for vLLM engine") + return VLLMLLM( + **vllm_config, + model_type="llm", + special_token_ids=special_token_ids, + top_p=top_p, + repetition_penalty=repetition_penalty, + temperature=temperature, + **kwargs + ) + + else: + raise ValueError( + f"Unknown engine_type: {engine_type}. " + f"Supported types: 'native_llm', 'native_eartts', 'vllm_llm', 'vllm_eartts'" + ) diff --git a/nemo/collections/speechlm2/inference/model_wrappers/model_factory.py b/nemo/collections/speechlm2/inference/model_wrappers/model_factory.py deleted file mode 100644 index 1b9399745e80..000000000000 --- a/nemo/collections/speechlm2/inference/model_wrappers/model_factory.py +++ /dev/null @@ -1,1170 +0,0 @@ -# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Model Interface for S2S Inference - -This module provides an abstract interface for model inference engines, -allowing seamless swapping between different implementations (e.g., native PyTorch, vLLM) -without modifying the inference code. - -Usage Example: - from nemo.collections.speechlm2.inference.model_wrappers.model_factory import create_model - - # Create interface (automatically wraps existing model) - model_interface = create_model( - model=your_model, - engine_type="native" # or "vllm" - ) - - # Use the interface exactly as you would use self.model() - ans = model_interface(input_embeds, cache=cache) -""" - -from abc import ABC, abstractmethod -from typing import Any -import math -import os -import torch -from dataclasses import dataclass, fields - -from nemo.utils import logging - -class ModelInterface(ABC): - """ - Base interface for model inference engines with shared sampling utilities. - - This interface defines the contract that all model implementations must follow, - ensuring consistent behavior across different engine types. It also provides - concrete implementations of sampling methods (top-p, repetition penalty) that - can be shared across all engines. - """ - - def __init__( - self, - special_token_ids: set[int] | None = None, - top_p: float = 1.0, - repetition_penalty: float = 1.0, - temperature: float = 1.0, - ): - """ - Initialize base interface with sampling parameters. - - Args: - special_token_ids: Set of special token IDs (pad, eos, bos) that should bypass sampling. - These tokens will use greedy decoding and won't be penalized. - top_p: Top-p (nucleus) sampling threshold. 1.0 disables it (greedy). Default: 1.0 - repetition_penalty: Penalty for repeated tokens. 1.0 disables it. Default: 1.0 - temperature: Temperature for sampling. 1.0 = no change, <1.0 = sharper, >1.0 = flatter. - 0.0 = greedy (argmax). Default: 1.0 - """ - if not math.isfinite(temperature): - raise ValueError(f"temperature must be finite, got {temperature}") - if temperature < 0.0: - raise ValueError(f"temperature must be >= 0.0, got {temperature}") - - self.special_token_ids = special_token_ids if special_token_ids is not None else set() - self.top_p = top_p - self.repetition_penalty = repetition_penalty - self.temperature = temperature - - # Pre-built tensor for special-token filtering in repetition penalty. - # Lazily moved to the right device on first use (see _sample_text_token). - self._special_ids_tensor: torch.Tensor | None = ( - torch.tensor(sorted(self.special_token_ids), dtype=torch.long) - if self.special_token_ids else None - ) - - def _sample_text_token( - self, - logits: torch.Tensor, - generated_tokens: torch.Tensor, - current_step: int, - sampling_params: dict[str, float] | None = None, - ) -> torch.Tensor: - """ - Sample text tokens with optional top-p sampling and repetition penalty. - Special tokens (pad, eos, bos) bypass sampling - if they have highest probability, return them directly. - - Args: - logits: Logits tensor of shape (B, V) for vocabulary V. - generated_tokens: Previously generated tokens of shape (B, T). - current_step: Current decoding step (used to slice generated_tokens). - sampling_params: Optional per-request overrides for ``top_p``, - ``temperature``, and ``repetition_penalty``. Missing keys - fall back to ``self.*`` (the pipeline-level defaults). - - Returns: - Sampled token ids of shape (B,). - """ - top_p = sampling_params.get("top_p", self.top_p) if sampling_params else self.top_p - temperature = sampling_params.get("temperature", self.temperature) if sampling_params else self.temperature - rep_penalty = sampling_params.get("repetition_penalty", self.repetition_penalty) if sampling_params else self.repetition_penalty - - B, V = logits.shape - device = logits.device - - # First check greedy tokens (on original logits) - greedy_tokens = logits.argmax(dim=-1) # (B,) - - # If no sampling needed (all disabled), return greedy - if top_p >= 1.0 and rep_penalty == 1.0 and (temperature == 1.0 or temperature == 0.0): - return greedy_tokens - - # temperature=0 means greedy - if temperature == 0.0: - return greedy_tokens - - # For each batch, if greedy is special token, use greedy; otherwise sample - sampled_tokens = greedy_tokens.clone() - - # Ensure cached special-token tensor is on the right device (once). - if self._special_ids_tensor is not None and self._special_ids_tensor.device != device: - self._special_ids_tensor = self._special_ids_tensor.to(device) - - for b in range(B): - # If greedy token is a special token, keep it (no sampling) - if greedy_tokens[b].item() in self.special_token_ids: - continue - - # Not a special token - apply repetition penalty and sampling - batch_logits = logits[b].clone() # (V,) - - # Apply repetition penalty (vectorized, no Python loop) - if rep_penalty != 1.0 and current_step > 0: - unique_prev = generated_tokens[b, :current_step].unique() - # Exclude special tokens from penalty - if self._special_ids_tensor is not None: - ids_t = self._special_ids_tensor - if ids_t.device != unique_prev.device: - ids_t = ids_t.to(unique_prev.device) - unique_prev = unique_prev[~torch.isin(unique_prev, ids_t)] - - if unique_prev.numel() > 0: - if unique_prev.device != batch_logits.device: - unique_prev = unique_prev.to(batch_logits.device) - prev_logits = batch_logits[unique_prev] - # Positive logits are divided, negative logits are multiplied - # (same as the standard repetition_penalty convention) - batch_logits[unique_prev] = torch.where( - prev_logits > 0, - prev_logits / rep_penalty, - prev_logits * rep_penalty, - ) - - # Apply temperature scaling - if temperature != 1.0: - batch_logits = batch_logits / temperature - - # Fall back to greedy if logits are non-finite before top-p - # (top-p intentionally introduces -inf, so check must happen first) - if not torch.isfinite(batch_logits).all(): - logging.warning( - f"_sample_text_token: logits contain NaN or inf at step {current_step}, batch {b}: " - f"nan={batch_logits.isnan().sum().item()}, " - f"inf={batch_logits.isinf().sum().item()}, " - f"min={batch_logits[~batch_logits.isnan()].min().item() if not batch_logits.isnan().all() else 'all_nan'}, " - f"max={batch_logits[~batch_logits.isnan()].max().item() if not batch_logits.isnan().all() else 'all_nan'}" - ) - sampled_tokens[b] = greedy_tokens[b] - continue - - # Apply top-p (nucleus) sampling - if top_p < 1.0: - sorted_logits, sorted_indices = torch.sort(batch_logits, descending=True) - sorted_probs = torch.softmax(sorted_logits, dim=-1) - cumulative_probs = torch.cumsum(sorted_probs, dim=-1) - - # Remove tokens with cumulative prob > top_p, keeping at least one - sorted_indices_to_remove = cumulative_probs > top_p - # Shift to keep the first token that exceeds threshold - sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone() - sorted_indices_to_remove[0] = False - - # Set to -inf - indices_to_remove = sorted_indices[sorted_indices_to_remove] - batch_logits[indices_to_remove] = float('-inf') - - probs = torch.softmax(batch_logits, dim=-1) - sampled_tokens[b] = torch.multinomial(probs, num_samples=1).item() - - return sampled_tokens - - @abstractmethod - def __call__( - self, - input_embeds: torch.Tensor, - cache: Any | None = None, - **kwargs - ) -> dict[str, Any]: - """ - Perform model inference. - - Args: - input_embeds: Input embeddings tensor of shape [batch, seq_len, hidden_dim] - cache: Optional cache object (e.g., DynamicCache for transformers) - **kwargs: Additional model-specific arguments - - Returns: - Dictionary containing: - - 'text_logits': Logits for text generation [batch, seq_len, vocab_size] - - 'cache': Updated cache object (if cache was provided) - - Additional model-specific outputs - """ - pass - - @abstractmethod - def to(self, device_or_dtype: torch.device | torch.dtype) -> 'ModelInterface': - """Move model to specified device or convert to specified dtype.""" - pass - - @abstractmethod - def eval(self) -> 'ModelInterface': - """Set model to evaluation mode.""" - pass - - @property - @abstractmethod - def device(self) -> torch.device: - """Get the device of the model.""" - pass - - -class VllmLLMModel(ModelInterface): - """ - vLLM-based model interface using LLMStreamingEngine. - - - This wraps the LLMStreamingEngine to provide async streaming inference - while conforming to the ModelInterface contract. Supports multiple concurrent - requests sharing a single engine instance. - - model = VllmLLMModel(...) - - async def process_stream(embeds, stream_id): - # Use the async engine directly - result = await model._async_inference(embeds, f"stream_{stream_id}", seq_len) - return result - - # Run multiple streams concurrently in same event loop - async def main(): - results = await asyncio.gather( - process_stream(embeds1, 1), - process_stream(embeds2, 2), - process_stream(embeds3, 3) - ) - - asyncio.run(main()) - """ - - def __init__( - self, - model_path: str, - max_model_len: int = 1024, - gpu_memory_utilization: float = 0.8, - trust_remote_code: bool = True, - dtype: str = "bfloat16", - engine_path: str | None = None, - pretrained_llm: str | None = None, - special_token_ids: set[int] | None = None, - top_p: float = 1.0, - repetition_penalty: float = 1.0, - temperature: float = 1.0, - model_type: str = "llm", - **sampling_kwargs - ): - """ - Initialize vLLM model interface with LLMStreamingEngine. - - Args: - model_path: Path to the vLLM-compatible model checkpoint - max_model_len: Maximum sequence length - gpu_memory_utilization: GPU memory utilization ratio (0.0-1.0) - trust_remote_code: Whether to trust remote code in model - dtype: Data type for embeddings (e.g., "bfloat16", "float16") - engine_path: Optional path to pre-converted vLLM model - pretrained_llm: Optional path to pretrained LLM for conversion - special_token_ids: Set of special token IDs (for potential post-processing) - top_p: Top-p sampling (currently vLLM uses greedy decoding) - repetition_penalty: Repetition penalty (currently not used by vLLM engine) - temperature: Temperature for sampling. Applied in _sample_text_token, not in vLLM engine. - model_type: Type of model for vLLM engine ("llm", "chatglm", etc.) - **sampling_kwargs: Additional sampling parameters passed to vLLM engine. - By default, vLLM uses greedy decoding (temperature=0) - """ - # Initialize base class with sampling parameters - super().__init__( - special_token_ids=special_token_ids, - top_p=top_p, - repetition_penalty=repetition_penalty, - temperature=temperature, - ) - - import asyncio - from nemo.collections.speechlm2.inference.vllm.streaming_llm_engine import LLMStreamingEngine - - self.model_path = model_path - self.pretrained_llm = pretrained_llm - self._dtype = dtype - self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - # Force greedy decoding in vLLM by setting temperature=0 if not specified - if 'temperature' not in sampling_kwargs: - sampling_kwargs['temperature'] = 0.0 - - if engine_path is None: - # convert model to vLLM format if needed - dir_name = os.path.basename(os.path.normpath(model_path)) - engine_path = "/tmp/" + dir_name + f"_vllm_converted_{model_type}" - if os.path.exists(engine_path): - logging.info(f"Found existing vLLM converted model at {engine_path}") - else: - self._convert_ckpt( - save_path=engine_path - ) - - from nemo.collections.speechlm2.inference.vllm.streaming_llm_engine import create_engine - # Initialize the streaming engine - self.engine = create_engine( - engine_type=model_type, - model_path=engine_path, - max_model_len=max_model_len, - gpu_memory_utilization=gpu_memory_utilization, - trust_remote_code=trust_remote_code, - dtype=dtype, - **sampling_kwargs - ) - # Track request counter - self._request_counter = 0 - - # Get or create event loop - try: - self._loop = asyncio.get_event_loop() - except RuntimeError: - self._loop = asyncio.new_event_loop() - asyncio.set_event_loop(self._loop) - - # Initialize engine immediately to avoid first-call latency - logging.info("Initializing vLLM engine (this may take a moment)...") - self._loop.run_until_complete(self.engine.initialize()) - - if self.engine.engine.tokenizer is not None and not self.special_token_ids: - self.special_token_ids = self._get_special_token_ids_from_vllm_tokenizer(self.engine.engine.tokenizer) - - logging.debug(f"Special token IDs: {self.special_token_ids}") - logging.info("vLLM engine ready!") - - @staticmethod - def _get_special_token_ids_from_vllm_tokenizer(tokenizer) -> set[int]: - """ - Extract special token IDs from a vLLM tokenizer. - Looks for: '' (bos), '' (eos), '' (pad). - - Args: - tokenizer: A vLLM CachedTokenizer instance. - - Returns: - Set of special token IDs. - """ - special_ids = set() - for token in ('', '', ''): - try: - tid = tokenizer.convert_tokens_to_ids(token) - if isinstance(tid, int): - special_ids.add(tid) - except Exception: - pass - return special_ids - - def _convert_ckpt(self, save_path: str): - """Convert existing checkpoint to vLLM format and save.""" - from nemo.collections.speechlm2.inference.vllm.scripts.convert_nemotronllm_checkpoint import convert_nemo_to_hf_format - - convert_nemo_to_hf_format( - checkpoint_path=self.model_path, - output_dir=save_path, - pretrained_llm=self.pretrained_llm, - dtype=self._dtype - ) - logging.info(f"Converted model saved to {save_path}") - - def _generate_request_id(self) -> str: - """Generate a unique request ID.""" - self._request_counter += 1 - return f"vllm_request_{self._request_counter}" - - def __call__( - self, - input_embeds: torch.Tensor, - request_id: str | None = "request_id_1", - **kwargs - ) -> dict[str, Any]: - """ - Perform inference using vLLM streaming engine. - - Args: - inputs: - cache: Optional cache object (currently not used for streaming) - generated_tokens: Optional tensor of generated tokens - current_step: Current decoding step - request_id: Unique request identifier for this generation - **kwargs: Additional model-specific arguments - - Returns: - Dictionary containing: - - predicted_token: Last generated text token - - asr_predicted_token: Last generated ASR token - - cache: None (vLLM manages cache internally) - - is_finished: Whether generation is complete - - request_id: The request identifier - """ - # Run async inference - result = self._loop.run_until_complete( - self._async_inference(input_embeds, request_id, **kwargs) - ) - return result - - async def _async_inference( - self, - inputs: torch.Tensor | list[torch.Tensor], - request_id: str, - **kwargs - ) -> dict[str, Any]: - """ - Async inference using the streaming engine. - - Args: - input_embeds: Input embeddings [batch, seq_len, hidden_dim] - request_id: Unique request identifier - seq_len: Number of decoding steps to perform - - Returns: - Dictionary with text_logits and other outputs - """ - # Check request status and restart if needed - from nemo.collections.speechlm2.inference.vllm.streaming_llm_engine import StreamStatus - - if request_id not in self.engine.requests: - await self.engine.start_generation(request_id=request_id) - else: - # Check if request is finished and needs restart - request_state = self.engine.requests[request_id] - if request_state.status in (StreamStatus.FINISHED, StreamStatus.ABORTED): - logging.warning( - f"Request {request_id} was {request_state.status.value}. " - f"Generated {len(request_state.generated_tokens)} tokens before stopping. " - "Cleaning up and restarting..." - ) - # Try to abort cleanly first - try: - await self.engine.abort_generation(request_id) - except Exception: - pass - # Start fresh - await self.engine.start_generation(request_id=request_id) - - # Process embeddings to generate tokens - return await self._process_inputs_to_outputs(inputs, request_id, **kwargs) - - async def _process_inputs_to_outputs( - self, - input_embeds: torch.Tensor, - request_id: str, - decode_steps: int = 1, - prompt_token_ids: list | None = None, - generated_tokens: torch.Tensor | None = None, - current_step: int = 0, - sampling_params: dict[str, float] | None = None, - ) -> dict[str, Any]: - """ - Process embeddings sequentially to generate text and ASR tokens. - - Args: - input_embeds: Input embeddings [batch, seq_len, hidden_dim] - request_id: Request identifier - decode_steps: Number of decoding steps to perform; decode steps = 0 means prefill - prompt_token_ids: Optional list of prompt token IDs for prefill - generated_tokens: Previously generated tokens [batch, num_generated]. - Required for repetition_penalty. If None, creates empty tensor. - current_step: Current decoding step. Used for repetition penalty. - sampling_params: Optional per-request overrides for sampling. - """ - - if decode_steps == 0: - # prefill only, no token generation - input_embeds = input_embeds.flatten(0, 1) # [seq_len, hidden_dim] - result = await self.engine.generate_next_token([input_embeds], - prompt_token_ids, - request_id=request_id) - return True if result is not None else False - - # Process each embedding in sequence - text_token_ids = [] - asr_token_ids = [] - result = None - for i in range(decode_steps): - # Extract single embedding [1, hidden_dim] - single_embed = input_embeds[:, i:i+1, :].squeeze(1) # [batch, hidden_dim] - - # Generate next token - result = await self.engine.generate_next_token([single_embed], request_id=request_id) - if result is None: - # No token generated (finished or error) - break - - text_token_ids.append(result.token_id) - asr_token_ids.append(result.custom_outputs["asr_tokens"]) # Assuming custom_outputs contains asr tokens - - if result.is_finished: - break - - assert len(text_token_ids) <= decode_steps, "Generated more tokens than input embeddings" - # Handle case when no tokens were generated - is_finished = False - if text_token_ids: - is_finished = len(text_token_ids) < decode_steps or (result and result.is_finished) - - text_logits = result.custom_outputs["text_logits"] if result else None - - predicted_token = text_token_ids[-1] - eff_top_p = sampling_params.get("top_p", self.top_p) if sampling_params else self.top_p - eff_temp = sampling_params.get("temperature", self.temperature) if sampling_params else self.temperature - eff_rep = sampling_params.get("repetition_penalty", self.repetition_penalty) if sampling_params else self.repetition_penalty - if eff_top_p < 1.0 or eff_rep != 1.0 or (eff_temp != 1.0 and eff_temp != 0.0): - # Use provided generated_tokens or create empty tensor - batch_size = text_logits.shape[0] - if generated_tokens is None: - gen_tokens = torch.empty(batch_size, 0, device=text_logits.device, dtype=torch.long) - else: - gen_tokens = generated_tokens - - # Apply sampling with top-p and repetition penalty - predicted_token = self._sample_text_token( - logits=text_logits, - generated_tokens=gen_tokens, - current_step=current_step, - sampling_params=sampling_params, - ) - - ans = { - "predicted_token": predicted_token, - "asr_predicted_token": asr_token_ids[-1], - "cache": None, # vLLM manages cache internally - "is_finished": is_finished, - "request_id": request_id - } - if result and result.custom_outputs and "function_tokens" in result.custom_outputs: - ans["function_predicted_token"] = result.custom_outputs["function_tokens"] - return ans - - - def to(self, device_or_dtype: torch.device | torch.dtype) -> 'VllmLLMModel': - """ - Move model to specified device or convert to specified dtype. - - Note: vLLM manages device placement internally, this is for compatibility. - """ - if isinstance(device_or_dtype, torch.device): - self._device = device_or_dtype - elif isinstance(device_or_dtype, torch.dtype): - # dtype conversion not directly supported, update config - pass - return self - - def eval(self) -> 'VllmLLMModel': - """Set model to evaluation mode (vLLM is always in eval mode).""" - return self - - @property - def device(self) -> torch.device: - """Get the device of the model.""" - return self._device - - def abort_request(self, request_id: str) -> bool: - """ - Abort a specific generation request. - - Args: - request_id: Request identifier to abort - - Returns: - bool: True if abort was successful - """ - return self._loop.run_until_complete( - self.engine.abort_generation(request_id) - ) - - def restart_request(self, request_id: str) -> bool: - """ - Restart a finished or aborted generation request. - - Args: - request_id: Request identifier to restart - - Returns: - bool: True if restart was successful - """ - # First abort if active - if request_id in self.engine.requests: - self.abort_request(request_id) - - # Start new generation - return self._loop.run_until_complete( - self.engine.start_generation(request_id=request_id) - ) - - def get_request_status(self, request_id: str | None = None) -> dict[str, Any]: - """ - Get status of a specific request or all requests. - - Args: - request_id: Optional request ID. If None, returns all requests. - - Returns: - Status dictionary - """ - return self.engine.get_status(request_id) - - def shutdown(self): - """Shutdown the vLLM engine and cleanup resources.""" - self._loop.run_until_complete(self.engine.shutdown()) - - def __del__(self): - """Cleanup on deletion.""" - try: - self.shutdown() - except Exception: - pass - -@dataclass -class TTSGenerationResult: - codes: torch.Tensor # Generated acoustic tokens - past_key_values: Any # Updated cache (if applicable) - - def __getitem__(self, item: str | int): - """Allows for accessing attributes by key or index.""" - if isinstance(item, str): - return getattr(self, item) - else: - # Access fields in the order they are defined in the dataclass - return getattr(self, fields(self)[item].name) - - -class VllmEARTTSModel(VllmLLMModel): - """ - vLLM-based model interface specialized for EARTTS models. - - Inherits from VllmLLMModel and sets EARTTS-specific configurations. - """ - - def __init__(self, **kwargs): - """ - Initialize vLLM EARTTS model interface. - - Args: - **kwargs: Arguments passed to the VllmLLMModel constructor - """ - super().__init__(**kwargs) - self._speaker_latent_dim = None - logging.info("VllmEARTTSModel initialized with EARTTS-specific settings.") - - def _convert_ckpt(self, save_path: str): - """Convert EARTTS checkpoint to vLLM format.""" - from nemo.collections.speechlm2.inference.vllm.scripts.convert_eartts_checkpoint import convert - ckpt_dir = os.path.normpath(self.model_path) - config_file = os.path.join(ckpt_dir, "config.json") - model_ckpt = os.path.join(ckpt_dir, "model.safetensors") - convert(save_path, config_file, model_ckpt) - - def __call__( - self, - inputs: dict[str, torch.Tensor] | None = None, - request_id: str | None = None, - prompt_token_ids: list | None = None, - **kwargs - ) -> TTSGenerationResult: - """ - Perform TTS inference using vLLM streaming engine. - - Supports two calling conventions: - 1. model(inputs_dict, request_id="id") - pass dict as first positional arg - 2. model(**inputs_dict) - unpack dict as keyword arguments - - Args: - inputs: Optional dict of model inputs (if None, uses **kwargs) - request_id: Optional request identifier - **kwargs: Model inputs as keyword arguments (used if inputs is None): - - code: prev_audio_tokens - - context_hidden_state: context_hidden_state (must be None) - - subword_ids: current_subword_id - - subword_mask: current_subword_mask - - past_key_values: past_key_values - - use_cache: True - - guidance_enabled: guidance_enabled - - generation_config: generation_config - - ignore_eos_flag_stop: ignore_eos_flag_stop - - Returns: - TTSGenerationResult containing generated acoustic tokens and cache - """ - # Handle both calling conventions - if inputs is not None: - # Called as model(inputs_dict, request_id="id") - input_dict = inputs - else: - # Called as model(**inputs_dict) - # Extract request_id from kwargs if present - if request_id is None: - request_id = kwargs.pop('request_id', None) - input_dict = kwargs - - # Use default request_id if still None - if request_id is None: - request_id = 'tts_request_id_1' - - # Run async inference - result = self._loop.run_until_complete( - self._async_inference(input_dict, request_id, prompt_token_ids=prompt_token_ids) - ) - - return result - - async def _process_inputs_to_outputs( - self, - inputs: dict[str, torch.Tensor], - request_id: str, - prompt_token_ids: list | None = None, - ) -> dict[str, Any]: - """ - Process embeddings sequentially to generate text and ASR tokens. - - Args: - inputs = { - "code": prev_audio_tokens, - "context_hidden_state": context_hidden_state, - "subword_ids": current_subword_id, - "subword_mask": current_subword_mask, - "past_key_values": past_key_values, - "use_cache": True, - "guidance_enabled": guidance_enabled, - "generation_config": generation_config, - "ignore_eos_flag_stop": ignore_eos_flag_stop, - } - Returns: - step_acoustic_tokens: Generated acoustic tokens for the current step - cache: None (vLLM manages cache internally) - """ - - assert inputs["context_hidden_state"] is None, "EARTTS vllm model does not support context_hidden_state input" - - codes = inputs["code"].squeeze(0) # T x 31 - if codes.shape[0] > 1: - # in prefill stage, we needto shift acoustic tokens for vllm, - # replicating the NeMo logic from here: - # https://github.com/erastorgueva-nv/NeMo/blob/duplex-realtime-inference/nemo/collections/speechlm2/modules/ear_tts_model.py#L1357 - codes = torch.nn.functional.pad(codes[:-1], [0, 0, 1, 0]) - input_tensors = [ - codes, - inputs["subword_ids"].squeeze(0), - inputs["subword_mask"].squeeze(0), - ] - - if "non_prompt_mask" in inputs: - # Apply edge detection to match native model's BOS placement logic: - # BOS should only be applied at the FIRST position where non_prompt_mask is True - non_prompt_mask = inputs["non_prompt_mask"].squeeze(0) # T - # Compute edge: positions where mask is True AND previous position is False - padded_prev = torch.nn.functional.pad(non_prompt_mask[:-1], [1, 0], value=False) - bos_mask = (non_prompt_mask & (~padded_prev)).to(dtype=getattr(torch, self._dtype)) - input_tensors.append(bos_mask) - - - else: - current_subword_id = input_tensors[1] - # Use a tiny epsilon instead of exact 0 so the vLLM model's - # (bos_mask == 0) check is False during decoding. This prevents - # use_audio_prompt_frozen_projection from incorrectly applying the - # speaker-prompt projection to every decoding step. The epsilon is - # small enough that bos_mask * bos_emb remains negligible. - bos_mask = torch.full_like(current_subword_id, 1e-20, dtype=getattr(torch, self._dtype)) - input_tensors.append(bos_mask) - - # Pass speaker_latent: the pre-extracted speaker embedding. - # During prefill with speaker_name: audio_prompt_lantent is [1, T, hidden_size] - # During decode or speaker_reference: pass zeros so the model falls back - # to computing the latent from acoustic tokens. - if "audio_prompt_lantent" in inputs and inputs["audio_prompt_lantent"] is not None: - speaker_latent = inputs["audio_prompt_lantent"].squeeze(0) # T x hidden_size - self._speaker_latent_dim = speaker_latent.shape[-1] - input_tensors.append(speaker_latent.to(dtype=getattr(torch, self._dtype))) - else: - if self._speaker_latent_dim is None: - # Read hidden_size from the converted model config - import json as _json - dir_name = os.path.basename(os.path.normpath(self.model_path)) - converted_config_path = os.path.join("/tmp", dir_name + "_vllm_converted_eartts", "config.json") - if os.path.exists(converted_config_path): - with open(converted_config_path) as _f: - self._speaker_latent_dim = _json.load(_f)["hidden_size"] - else: - raise RuntimeError( - f"Cannot determine speaker_latent_dim: converted config not found at {converted_config_path}. " - "Run a prefill with audio_prompt_lantent first, or ensure the converted checkpoint exists." - ) - num_tokens = codes.shape[0] - speaker_latent = torch.zeros(num_tokens, self._speaker_latent_dim, dtype=getattr(torch, self._dtype)) - input_tensors.append(speaker_latent) - - result = await self.engine.generate_next_token(input_tensors, prompt_token_ids=prompt_token_ids, request_id=request_id) - acoustic_tokens = result.custom_outputs["acoustic_tokens"] # T x 31 - step_acoustic_tokens = acoustic_tokens[-1:] # 1 x 31 - return TTSGenerationResult( - codes=step_acoustic_tokens.unsqueeze(0).cuda(), # Add batch dim back: 1 x 1 x 31 - past_key_values=None # vLLM manages cache internally - ) - -class NativeModel(ModelInterface): - """ - Native PyTorch model interface. - - This wraps the existing DuplexS2SExternalSpeechDecoderModel to conform - to the ModelInterface contract. Supports top-k, top-p sampling and repetition penalty. - """ - - def __init__( - self, - model, - special_token_ids: set[int] | None = None, - top_p: float = 1.0, - repetition_penalty: float = 1.0, - temperature: float = 1.0, - ): - """ - Initialize with an existing model. - - Args: - model: The DuplexS2SExternalSpeechDecoderModel instance - special_token_ids: Set of special token IDs (pad, eos, bos) that should bypass sampling. - These tokens will use greedy decoding and won't be penalized. - If None, will try to extract from model.tokenizer for tokens: - '' (bos), '' (eos), '' (pad). - You can also manually provide: {tokenizer.pad_token_id, tokenizer.eos_token_id, tokenizer.bos_token_id} - top_p: Top-p (nucleus) sampling threshold. 1.0 disables it (greedy). Default: 1.0 - repetition_penalty: Penalty for repeated tokens. 1.0 disables it. Default: 1.0 - Recommended value when enabling: 1.2 - temperature: Temperature for sampling. 1.0 = no change, <1.0 = sharper, >1.0 = flatter. - 0.0 = greedy (argmax). Default: 1.0 - """ - # Default special token IDs: bos=1, eos=2, pad=12 - DEFAULT_SPECIAL_TOKEN_IDS = {1, 2, 12} - - # Try to extract special token IDs from model if not provided - if special_token_ids is None: - special_token_ids = self._extract_special_token_ids_from_nemo(model) - # Fallback to default if extraction failed - if not special_token_ids: - special_token_ids = DEFAULT_SPECIAL_TOKEN_IDS - # Initialize base class with sampling parameters - super().__init__( - special_token_ids=special_token_ids, - top_p=top_p, - repetition_penalty=repetition_penalty, - temperature=temperature, - ) - - self.model = model - - logging.debug(f"Special token IDs: {self.special_token_ids}") - - # Validate: if sampling is enabled, special_token_ids should be set - sampling_active = top_p < 1.0 or repetition_penalty != 1.0 or (temperature != 1.0 and temperature != 0.0) - if sampling_active and not self.special_token_ids: - import warnings - warnings.warn( - "Sampling is enabled but special_token_ids is empty. " - "Could not auto-extract from model.tokenizer. " - "Please provide special_token_ids manually to ensure special tokens use greedy decoding. " - "Otherwise, EOS tokens may be randomly sampled and generation may not stop properly!" - ) - - def __call__( - self, - input_embeds: torch.Tensor, - cache: Any | None = None, - cache_position: torch.Tensor | None = None, - generated_tokens: torch.Tensor | None = None, - current_step: int = 0, - return_logits: bool = False, - sampling_params: dict[str, float] | None = None, - **kwargs - ) -> dict[str, Any]: - """ - Perform inference using the native model. - - Args: - input_embeds: Input embeddings [batch, seq_len, hidden_dim] - cache: Optional DynamicCache or HybridMambaAttentionDynamicCache - cache_position: Optional position tensor for Nemotron models - generated_tokens: Previously generated tokens [batch, num_generated]. - Required for repetition_penalty. If None, creates empty tensor. - current_step: Current decoding step. Used for repetition penalty. - sampling_params: Optional per-request overrides for sampling - (top_p, temperature, repetition_penalty). - **kwargs: Additional arguments passed to the model - - Returns: - Dictionary with 'predicted_token', 'asr_predicted_token', and 'cache' - """ - result = self.model.stt_model(input_embeds, cache=cache, cache_position=cache_position, **kwargs) - - # Ensure consistent return format - if not isinstance(result, dict): - raise TypeError(f"Model returned {type(result)}, expected dict") - - if 'text_logits' not in result: - raise KeyError("Model output must contain 'text_logits' key") - - text_logits = result["text_logits"][:, -1] # [batch, vocab_size] - batch_size = text_logits.shape[0] - - # Use provided generated_tokens or create empty tensor - if generated_tokens is None: - gen_tokens = torch.empty(batch_size, 0, device=text_logits.device, dtype=torch.long) - else: - gen_tokens = generated_tokens - - # Apply sampling with top-p and repetition penalty - predicted_token = self._sample_text_token( - logits=text_logits, - generated_tokens=gen_tokens, - current_step=current_step, - sampling_params=sampling_params, - ) - - # ASR tokens use greedy decoding (no sampling) - asr_predicted_token = result["asr_logits"][:, -1].argmax(dim=-1) - - ans = { - "predicted_token": predicted_token, - "asr_predicted_token": asr_predicted_token, - "cache": result.get("cache", None), - } - if return_logits: - ans["text_logits"] = result["text_logits"] - ans["asr_logits"] = result.get("asr_logits") - if "function_logits" in result: - ans["function_logits"] = result["function_logits"] - if "function_logits" in result: - ans["function_predicted_token"] = result["function_logits"][:, -1].argmax(dim=-1) - return ans - - @staticmethod - def _extract_special_token_ids_from_nemo(model) -> set[int]: - """ - Extract special token IDs from NeMo model's tokenizer. - - NeMo tokenizer uses bos_token, eos_token, pad_token (not bos_token_id). - Then converts token strings to IDs using token_to_id method. - - Args: - model: The DuplexS2SExternalSpeechDecoderModel instance - - Returns: - Set of special token IDs, or empty set if extraction fails - """ - special_ids = set() - try: - tokenizer = model.stt_model.tokenizer - - # Get token strings (NeMo uses bos_token, not bos_token_id) - bos_token = getattr(tokenizer, 'bos_token', None) - eos_token = getattr(tokenizer, 'eos_token', None) - pad_token = getattr(tokenizer, 'pad_token', None) - - # Convert token strings to IDs - if hasattr(tokenizer, 'token_to_id'): - for token in [bos_token, eos_token, pad_token]: - if token is not None: - tid = tokenizer.token_to_id(token) - if tid is not None and isinstance(tid, int): - special_ids.add(tid) - except Exception as e: - pass # Return empty set on failure - - return special_ids - - def to(self, device_or_dtype: torch.device | torch.dtype) -> 'NativeModel': - """Move underlying model to device or convert dtype.""" - self.model = self.model.to(device_or_dtype) - return self - - def eval(self) -> 'NativeModel': - """Set underlying model to eval mode.""" - self.model.eval() - return self - - @property - def device(self) -> torch.device: - """Get device of the underlying model.""" - # Try to get device from model parameters - try: - return next(self.model.parameters()).device - except StopIteration: - # No parameters, return CPU - return torch.device('cpu') - - def __getattr__(self, name: str): - """ - Delegate attribute access to the underlying model. - - This allows transparent access to model attributes like - perception, tokenizer, etc. - """ - # Avoid infinite recursion for special attributes - if name in ('model', '__dict__', '__class__'): - raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") - - # Delegate to wrapped model - return getattr(self.model, name) - - -def create_model( - model=None, - engine_type: str = "native", - vllm_config: dict[str, Any] | None = None, - special_token_ids: set[int] | None = None, - top_p: float = 1.0, - repetition_penalty: float = 1.0, - temperature: float = 1.0, - **kwargs -) -> ModelInterface: - """ - Factory function to create appropriate model interface. - - This is the main entry point for creating model interfaces. - - Args: - model: The base model to wrap (required for "native" engine, optional for "vllm") - engine_type: Type of engine ("native", "vllm") - vllm_config: Configuration dict for vLLM engines (required for "vllm") - special_token_ids: Set of special token IDs (pad, eos, bos) that should bypass sampling. - If None (default), will auto-extract from model.tokenizer for tokens: - '' (bos), '' (eos), '' (pad). - You can manually provide: {tokenizer.pad_token_id, tokenizer.eos_token_id, tokenizer.bos_token_id} - top_p: Top-p (nucleus) sampling threshold. 1.0 disables it (greedy). Default: 1.0 - repetition_penalty: Penalty for repeated tokens. 1.0 disables it. Default: 1.0 - temperature: Temperature for sampling. 1.0 = no change, 0.0 = greedy. Default: 1.0 - **kwargs: Additional arguments passed to the interface constructor - - Returns: - ModelInterface instance - - Example: - >>> # Use native PyTorch model with greedy decoding (default) - >>> interface = create_model(model, engine_type="native") - >>> - >>> # Use native with top-p sampling (special_token_ids auto-extracted from model.tokenizer) - >>> # Auto-extracts IDs for: '', '', '' - >>> interface = create_model( - >>> model, - >>> engine_type="native", - >>> top_p=0.9 - >>> ) - >>> - >>> # Use native with top-p and repetition penalty (auto-extract special tokens) - >>> interface = create_model( - >>> model, - >>> engine_type="native", - >>> top_p=0.9, - >>> repetition_penalty=1.2 - >>> ) - >>> - >>> # Manually provide special_token_ids (if auto-extraction fails or you want custom tokens) - >>> special_ids = { - >>> tokenizer.pad_token_id, - >>> tokenizer.eos_token_id, - >>> tokenizer.bos_token_id - >>> } - >>> interface = create_model( - >>> model, - >>> engine_type="native", - >>> special_token_ids=special_ids, - >>> top_p=0.9, - >>> repetition_penalty=1.2 - >>> ) - >>> - >>> # Use vLLM with streaming engine - >>> vllm_cfg = { - >>> "model_path": "/path/to/vllm/checkpoint", - >>> "max_model_len": 10240, - >>> "gpu_memory_utilization": 0.8, - >>> "dtype": "bfloat16" - >>> } - >>> interface = create_model( - >>> engine_type="vllm", - >>> vllm_config=vllm_cfg - >>> ) - >>> - >>> # Perform inference - >>> result = interface(input_embeds, cache=cache) - >>> - >>> # For repetition penalty, pass generated_tokens and current_step - >>> result = interface(input_embeds, cache=cache, generated_tokens=prev_tokens, current_step=step) - """ - engine_type = engine_type.lower() - - if engine_type == "native": - if model is None: - raise ValueError("model must be provided for native engine") - return NativeModel( - model=model, - special_token_ids=special_token_ids, - top_p=top_p, - repetition_penalty=repetition_penalty, - temperature=temperature, - ) - - elif engine_type == "vllm_eartts": - if vllm_config is None: - raise ValueError("vllm_config must be provided for vLLM EARTTS engine") - # VllmEARTTSModel for TTS inference - return VllmEARTTSModel( - **vllm_config, - model_type="eartts", - special_token_ids=special_token_ids, - top_p=top_p, - repetition_penalty=repetition_penalty, - temperature=temperature, - **kwargs - ) - - elif engine_type.startswith("vllm"): - if vllm_config is None: - raise ValueError("vllm_config must be provided for vLLM engine") - # VllmLLMModel doesn't need the PyTorch model, only the config - return VllmLLMModel( - **vllm_config, - model_type="llm", - special_token_ids=special_token_ids, - top_p=top_p, - repetition_penalty=repetition_penalty, - temperature=temperature, - **kwargs - ) - - else: - raise ValueError( - f"Unknown engine_type: {engine_type}. " - f"Supported types: 'native', 'vllm', 'vllm_llm', 'vllm_eartts'" - ) diff --git a/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py b/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py index 12703265af26..cfd3ae4e9789 100755 --- a/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py +++ b/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py @@ -16,7 +16,6 @@ import gc import os import time -import types import torch import torchaudio from omegaconf import OmegaConf, DictConfig @@ -28,7 +27,7 @@ from nemo.collections.speechlm2.parts.precision import fp32_precision from nemo.collections.audio.parts.utils.transforms import resample from nemo.collections.speechlm2.modules.ear_tts_vae_codec import CausalConv1dCache -from nemo.collections.speechlm2.inference.model_wrappers.model_factory import create_model +from nemo.collections.speechlm2.inference.model_wrappers.factory import create_model from nemo.collections.speechlm2.inference.model_wrappers.perception_cache import ( PerceptionCacheState, PerceptionCacheManager, @@ -158,7 +157,15 @@ def __init__(self, model_cfg: DictConfig): self.model_llm_interface = None self.tokenizer = None - # vLLM configuration + # Engine configuration. + # engine_type is a user-facing config value that selects which combination + # of backends to use for the LLM and TTS components: + # "native" -> native_llm + native_eartts + # "vllm_llm" -> vllm_llm + native_eartts + # "vllm_eartts" -> native_llm + vllm_eartts + # "vllm_llm_vllm_eartts" -> vllm_llm + vllm_eartts + # The factory (create_model) uses the specific {backend}_{component} + # names: native_llm, native_eartts, vllm_llm, vllm_eartts. self.engine_type = model_cfg.get("engine_type", "native") self.use_vllm_llm = "vllm_llm" in self.engine_type.lower() self.use_vllm_eartts = "vllm_eartts" in self.engine_type.lower() @@ -202,38 +209,25 @@ def _initialize_model(self): ) logging.info(f"NemotronVoiceChat initialized in {time.time() - start_model_init:.1f}s") - if self.use_vllm_eartts: - # Use object.__setattr__ to bypass PyTorch's module registration - # since VllmEARTTSModel is not a torch.nn.Module - del self.model.tts_model.tts_model - gc.collect() - torch.cuda.empty_cache() - torch.cuda.synchronize() - object.__setattr__( - self.model.tts_model, - 'tts_model', - create_model( - model=self.model_path, - engine_type="vllm_eartts", - vllm_config=self.vllm_tts_config) - ) - from nemo.collections.speechlm2.inference.vllm.vllm_patch import patched_infer_codes_one_step - self.model.tts_model.infer_codes_one_step = types.MethodType(patched_infer_codes_one_step, self.model.tts_model) - - # If using vLLM for LLM, delete native LLM BEFORE moving to device to save memory + # Delete unused native components BEFORE moving to GPU to save memory if self.use_vllm_llm: - logging.info("Deleting native LLM before GPU transfer (will use vLLM instead)...") + logging.info("Deleting native LLM (will use vLLM instead)...") if hasattr(self.model.stt_model, 'llm') and self.model.stt_model.llm is not None: - # Delete all submodules of LLM to free memory for name, child in list(self.model.stt_model.llm.named_children()): delattr(self.model.stt_model.llm, name) del self.model.stt_model.llm self.model.stt_model.llm = None + + if self.use_vllm_eartts: + logging.info("Deleting native TTS (will use vLLM instead)...") + del self.model.tts_model.tts_model + + if self.use_vllm_llm or self.use_vllm_eartts: + # Free memory from deleted components before GPU transfer and vLLM engine creation gc.collect() torch.cuda.empty_cache() - logging.info(" Native LLM deleted") - # Setup model + # Setup model on device self.model.to(self.device) self.model.eval() @@ -249,24 +243,6 @@ def _initialize_model(self): self.model.stt_model.function_head = self.model.stt_model.function_head.to(self.dtype) logging.info("function_head converted to %s", self.dtype) - # torch.compile for native TTS backbone - use_tts_torch_compile = bool(self.model_cfg.get("use_tts_torch_compile", False)) - if use_tts_torch_compile and not self.use_vllm_eartts and hasattr(self.model, 'tts_model'): - tts_backbone = getattr(self.model.tts_model, 'tts_model', None) - if tts_backbone is not None and hasattr(tts_backbone, 'backbone'): - logging.info("Compiling TTS backbone with torch.compile(mode='default')...") - tts_backbone.backbone = torch.compile(tts_backbone.backbone, mode="default") - logging.info(" TTS backbone compiled") - - # Inject TTS speedup flags into the TTS model config so ear_tts_model.py can read them - tts_inner = getattr(self.model.tts_model, 'tts_model', None) if hasattr(self.model, 'tts_model') else None - if tts_inner is not None and hasattr(tts_inner, 'config'): - if bool(self.model_cfg.get("use_tts_subword_cache", False)): - OmegaConf.update(tts_inner.config, "use_tts_subword_cache", True) - logging.info("TTS speedup enabled: use_tts_subword_cache") - if hasattr(tts_inner, 'embed_subword') and tts_inner.embed_subword is not None and hasattr(tts_inner.embed_subword, 'use_tts_subword_cache'): - tts_inner.embed_subword.use_tts_subword_cache = True - self.tokenizer = self.model.stt_model.tokenizer # Allow overrides from wrapper config into the model config (e.g. logit boosts). @@ -285,17 +261,12 @@ def _initialize_model(self): boost_values = {k: self.model.stt_model.cfg.get(k, None) for k in _BOOST_KEYS} logging.info(f"Inference logit boosts: {boost_values}") - # Wrap model with appropriate interface (Native or vLLM) + # Create LLM backend if self.use_vllm_llm: - logging.info("Wrapping model with VllmLLMModel interface...") + logging.info("Creating VLLMLLM backend...") if self.vllm_llm_config is None: raise ValueError("vllm_llm_config must be provided when engine_type contains 'vllm_llm'") - # LLM already deleted above, just ensure cleanup - gc.collect() - torch.cuda.empty_cache() - torch.cuda.synchronize() - # Set logit boosts as env vars BEFORE creating the vLLM engine, # so they are inherited by the forked worker process. The modified # nemotron_h.py reads VLLM_ASR_BOOST_ and @@ -325,26 +296,28 @@ def _initialize_model(self): os.environ[env_key] = str(float(val)) logging.info(f"Set env {env_key}={val} (from {cfg_key})") - self.model_llm_interface = create_model( - model=self.model_path, - engine_type="vllm_llm", - vllm_config=self.vllm_llm_config, - top_p=self.top_p, - repetition_penalty=self.repetition_penalty, - temperature=self.temperature, - ) + self.model_llm_interface = create_model( + model=self.model_path if self.use_vllm_llm else self.model, + engine_type="vllm_llm" if self.use_vllm_llm else "native_llm", + vllm_config=self.vllm_llm_config if self.use_vllm_llm else None, + top_p=self.top_p, + repetition_penalty=self.repetition_penalty, + temperature=self.temperature, + ) + logging.info(f"LLM backend: {type(self.model_llm_interface).__name__}") - logging.info("VllmLLMModel interface created") - else: - logging.info("Wrapping model with NativeModel interface...") - self.model_llm_interface = create_model( - model=self.model, - engine_type="native", - top_p=self.top_p, - repetition_penalty=self.repetition_penalty, - temperature=self.temperature, - ) - logging.info("NativeModel interface created") + # Create TTS backend + self.model_eartts_interface = create_model( + model=self.model_path if self.use_vllm_eartts else self.model.tts_model, + engine_type="vllm_eartts" if self.use_vllm_eartts else "native_eartts", + vllm_config=self.vllm_tts_config if self.use_vllm_eartts else None, + ) + logging.info(f"TTS backend: {type(self.model_eartts_interface).__name__}") + + # torch.compile and subword cache (no-ops for vLLM, delegated to TTS backend) + if bool(self.model_cfg.get("use_tts_torch_compile", False)): + self.model_eartts_interface.compile() + self.model_eartts_interface.setup_subword_cache(self.model_cfg) # Get TTS info if hasattr(self.model, 'tts_model'): @@ -491,15 +464,14 @@ def _prepare_tts_initial_state(self): if self.use_vllm_eartts: self.tts_prompt_token_ids = init_inputs["subword_ids"].squeeze().cpu().numpy().tolist() self.tts_init_inputs = init_inputs - outputs = self.model.tts_model.tts_model( - self.tts_init_inputs, - request_id="tts_system_prompt_prefill_request", - prompt_token_ids=self.tts_prompt_token_ids - ) - # abort this request - self.model.tts_model.tts_model.abort_request("tts_system_prompt_prefill_request") - else: - outputs = self.model.tts_model.tts_model(**init_inputs) + outputs = self.model_eartts_interface.prefill_prompt( + init_inputs, + prompt_token_ids=getattr(self, 'tts_prompt_token_ids', None), + request_id="tts_warmup", + ) + if self.use_vllm_eartts: + # Abort warmup request so the engine is clean for actual streaming + self.model_eartts_interface.abort_request("tts_warmup") code = init_inputs["code"][:, -1:] @@ -832,20 +804,32 @@ def _run_tts_step( raise RuntimeError("generation_config is not initialized. Ensure TTS warmup ran successfully.") start_tts_model = time.time() - inputs = { - "current_subword_id": current_subword_id, - "prev_subword_id": prev_subword_id, - "current_subword_mask": current_subword_mask, - "prev_audio_tokens": state.tts_code, - "past_key_values": state.tts_past_key_values, - "guidance_enabled": True, - "generation_config": self.generation_config, - "ignore_eos_flag_stop": True, - } if self.use_vllm_eartts: - inputs["request_id"] = request_id - - state.tts_code, state.tts_past_key_values = self.model.tts_model.infer_codes_one_step(**inputs) + tts_inputs = { + "code": state.tts_code, + "context_hidden_state": None, + "subword_ids": current_subword_id, + "subword_mask": current_subword_mask, + "past_key_values": state.tts_past_key_values, + "use_cache": True, + "guidance_enabled": True, + "generation_config": self.generation_config, + "ignore_eos_flag_stop": True, + } + else: + tts_inputs = { + "current_subword_id": current_subword_id, + "prev_subword_id": prev_subword_id, + "current_subword_mask": current_subword_mask, + "prev_audio_tokens": state.tts_code, + "past_key_values": state.tts_past_key_values, + "guidance_enabled": True, + "generation_config": self.generation_config, + "ignore_eos_flag_stop": True, + } + result = self.model_eartts_interface(tts_inputs, request_id=request_id) + state.tts_code = result.codes + state.tts_past_key_values = result.past_key_values if self._profile_timing: torch.cuda.synchronize() @@ -984,8 +968,8 @@ def abort_request(self, request_id: str | None) -> bool: logging.warning(f"Failed to abort LLM request {request_id}: {exc}") # Abort EarTTS if applicable - if self.use_vllm_eartts: - abort_fn = getattr(self.model.tts_model.tts_model, "abort_request", None) + if self.use_vllm_eartts and self.model_eartts_interface is not None: + abort_fn = getattr(self.model_eartts_interface, "abort_request", None) if callable(abort_fn): try: if abort_fn(request_id): diff --git a/nemo/collections/speechlm2/inference/pipelines/streaming_s2s_pipeline.py b/nemo/collections/speechlm2/inference/pipelines/streaming_s2s_pipeline.py index 03c418903ac9..5875ddbf3fd8 100644 --- a/nemo/collections/speechlm2/inference/pipelines/streaming_s2s_pipeline.py +++ b/nemo/collections/speechlm2/inference/pipelines/streaming_s2s_pipeline.py @@ -721,21 +721,18 @@ def _prefill_system_prompt(self, stream_id: int, system_prompt: str | None = Non engine_type = getattr(self.s2s_model, "engine_type", "native") tts_output_code = None - # Prefill TTS with speaker embedding when using vLLM EarTTS - # This initializes the vLLM TTS engine with the speaker context via prompt_token_ids + # Prefill TTS with speaker embedding via model_eartts_interface use_vllm_eartts = "vllm_eartts" in engine_type.lower() if use_vllm_eartts: + eartts = self.s2s_model.model_eartts_interface tts_init_inputs = getattr(self.s2s_model, "tts_init_inputs", None) tts_prompt_token_ids = getattr(self.s2s_model, "tts_prompt_token_ids", None) if tts_init_inputs is not None and tts_prompt_token_ids is not None: logging.info(f"Prefilling TTS speaker embedding for stream {stream_id}...") start_tts_prefill = time.time() with torch.no_grad(): - tts_inputs_copy = copy.deepcopy(tts_init_inputs) - tts_result = self.s2s_model.model.tts_model.tts_model( - tts_inputs_copy, - request_id=request_id, - prompt_token_ids=tts_prompt_token_ids + tts_result = eartts.prefill_prompt( + tts_init_inputs, tts_prompt_token_ids, request_id, ) # Capture the generated codes to sync context with vLLM state if hasattr(tts_result, 'codes') and tts_result.codes is not None: @@ -744,58 +741,48 @@ def _prefill_system_prompt(self, stream_id: int, system_prompt: str | None = Non logging.info(f"TTS speaker embedding prefilled in {time.time() - start_tts_prefill:.3f}s") else: logging.warning("TTS init inputs not available, skipping TTS prefill") - + if not system_prompt: return tts_output_code - + + # Prefill LLM with system prompt via model_llm_interface logging.info(f"Prefilling system prompt for stream {stream_id}...") start_get_prompt_embeddings = time.time() prompt_embedded, prompt_len = self.s2s_model._prepare_system_prompt_embeddings(system_prompt) logging.debug(f"Time taken to get prompt embeddings: {time.time() - start_get_prompt_embeddings:.3f}s") - + if prompt_embedded is None: logging.warning("System prompt embedding returned None, skipping prefill") return tts_output_code - - # Check if using vLLM for LLM (matches vllm_llm, vllm_llm_vllm_eartts, etc.) + use_vllm_llm = "vllm_llm" in engine_type.lower() - + llm = self.s2s_model.model_llm_interface + if use_vllm_llm: - # For vLLM LLM: prefill all prompt embeddings in one shot - # (decode_steps=0 triggers a single bulk prefill in the vLLM engine) logging.info(f"Prefilling {prompt_len} prompt embeddings for vLLM LLM...") start_prefill = time.time() with torch.no_grad(): - _ = self.s2s_model.model_llm_interface( - prompt_embedded, - request_id=request_id, - decode_steps=0, - prompt_token_ids=None, - ) + llm.prefill_prompt(prompt_embedded, request_id=request_id) logging.info(f"System prompt prefilled ({prompt_len} tokens) in {time.time() - start_prefill:.3f}s") - else: context, _ = self.context_manager.get_context([stream_id]) if context.llm_cache is not None: - # Native cache mode: process prompt through LLM to update KV cache + # Native cache mode: process prompt through LLM to warm up KV cache with torch.no_grad(): cache_pos = torch.arange(prompt_len, device=self.s2s_model.device) - llm_cache = context.llm_cache - ans = self.s2s_model.model_llm_interface( + ans = llm.prefill_prompt( prompt_embedded, - cache=llm_cache, + cache=context.llm_cache, cache_position=cache_pos, - generated_tokens=None, - current_step=0 ) - context.llm_cache = ans.get("cache", llm_cache) + context.llm_cache = ans.get("cache", context.llm_cache) context.llm_cache_position_offset = prompt_len logging.info(f"System prompt processed, cache updated ({prompt_len} tokens, offset={prompt_len})") else: for t in range(prompt_len): context.input_embeds_history.append(prompt_embedded[:, t:t+1, :]) logging.info(f"Added {prompt_len} prompt embeddings to input_embeds_history") - + return tts_output_code def _request_id_for_stream(self, stream_id: int) -> str: diff --git a/nemo/collections/speechlm2/inference/vllm/streaming_llm_engine.py b/nemo/collections/speechlm2/inference/vllm/streaming_llm_engine.py index 431714b79fb7..3c570953acf1 100644 --- a/nemo/collections/speechlm2/inference/vllm/streaming_llm_engine.py +++ b/nemo/collections/speechlm2/inference/vllm/streaming_llm_engine.py @@ -13,15 +13,18 @@ # limitations under the License. """ -Speech Streaming Engine Wrapper -A clean wrapper for streaming speech-to-speech generation with custom embeddings. +Async vLLM engine wrapper for models that use custom input tensors. + +Wraps vLLM's ``AsyncLLM`` for streaming step-by-step inference with +``custom_inputs`` (used by both DuplexSTT/LLM and EarTTS backends). +``engine_kind`` selects EarTTS-only runtime settings (guidance scale, attention +backend env); it does not imply an inheritance relationship between TTS and LLM. """ import os import json import torch -import asyncio -from typing import Any, AsyncGenerator +from typing import Any, AsyncGenerator, Literal from dataclasses import dataclass from enum import Enum @@ -57,42 +60,61 @@ class RequestState: generation_iterator: AsyncGenerator | None = None -class LLMStreamingEngine: +class CustomInputAsyncVLLMEngine: """ - A wrapper for vLLM AsyncLLM engine that enables: - - Easy initialization with speech model configuration - - Start/stop streaming with custom embeddings - - Generate one token at a time - - Abort ongoing generation + Wrapper for vLLM ``AsyncLLM`` with custom input tensor specifications. + + KV cache is managed entirely inside the AsyncLLM engine -- callers do not + allocate, pass, or update cache objects. The engine uses PagedAttention to + manage GPU memory automatically. Per-request state (generated tokens, + generation iterator, status) is tracked in ``self.requests`` dicts. + + Provides: + - Start/stop streaming with custom embedding inputs + - Generate one token at a time via ``generate_next_token`` + - Abort / restart individual requests by ``request_id`` """ def __init__( self, - model_path: str = "/ws/ckpt/converted", + model_path: str, max_model_len: int = 10240, + max_num_batched_tokens: int = 768, gpu_memory_utilization: float = 0.8, trust_remote_code: bool = True, dtype: str = "bfloat16", skip_tokenizer_init: bool = False, + engine_kind: Literal["llm", "eartts"] = "llm", **sampling_kwargs ): """ - Initialize the Speech Streaming Engine. + Initialize the async vLLM engine wrapper. Args: - model_path: Path to the speech model (default: "/ws/ckpt/converted") + model_path: Path to the vLLM-compatible model checkpoint (required) max_model_len: Maximum sequence length (default: 10240) + max_num_batched_tokens: Maximum tokens processed per forward pass. + Controls prefill chunk size and max concurrent decode streams. + Default: 768. gpu_memory_utilization: GPU memory utilization ratio (default: 0.8) trust_remote_code: Whether to trust remote code (default: True) dtype: Data type for embeddings (default: "bfloat16") - **sampling_kwargs: Additional sampling parameters (max_tokens, temperature, top_p, top_k, seed, stop, stop_token_ids, ignore_eos) + engine_kind: ``"llm"`` for DuplexSTT-style models; ``"eartts"`` applies + EarTTS-specific ``guidance_scale`` and attention/runtime env during init. + **sampling_kwargs: Additional vLLM sampling parameters. + Note: vLLM is configured for greedy decoding internally + (temperature=0, ignore_eos=True). Actual text sampling + (top-p, repetition penalty) is applied post-hoc by + ModelInterface._sample_text_token, not by vLLM. """ self.model_path = model_path self.max_model_len = max_model_len + self.max_num_batched_tokens = max_num_batched_tokens self.gpu_memory_utilization = gpu_memory_utilization self.trust_remote_code = trust_remote_code self.dtype = dtype self.skip_tokenizer_init = skip_tokenizer_init + self.engine_kind = engine_kind # Engine state self.engine: AsyncLLM | None = None @@ -100,21 +122,46 @@ def __init__( # Request state tracking - supports multiple concurrent requests self.requests: dict[str, RequestState] = {} - # Default sampling parameters + # vLLM sampling is disabled (skip_sampling=True) because we handle + # token selection ourselves for both the LLM and EarTTS paths. + # For LLM: text sampling (top-p, repetition penalty) is applied + # post-hoc by ModelInterface._sample_text_token on the returned logits. + # For EarTTS: acoustic tokens are produced by the model's own forward + # pass (RVQ codebook prediction), not by vLLM's sampler. default_sampling = { "max_tokens": 100000, # Set very high to prevent stopping - use abort to stop explicitly - "temperature": 0.0, - "top_p": 0.9, - "top_k": 50, - "seed": None, - "stop": [], - "stop_token_ids": [], + "skip_sampling": True, "ignore_eos": True, } default_sampling.update(sampling_kwargs) self.sampling_params = SamplingParams(**default_sampling) - logging.info(f"LLMStreamingEngine initialized for model: {model_path}") + if self.engine_kind == "eartts": + guidance_scale = self._read_guidance_scale_from_config() + self.sampling_params.guidance_scale = guidance_scale + logging.info( + f"CustomInputAsyncVLLMEngine initialized (engine_kind=eartts, " + f"guidance_scale={guidance_scale}, model={model_path})" + ) + else: + logging.info( + f"CustomInputAsyncVLLMEngine initialized (engine_kind=llm, model={model_path})" + ) + + def _read_guidance_scale_from_config(self) -> float: + """Read guidance_scale from the converted vLLM model's config.json.""" + config_path = os.path.join(self.model_path, "config.json") + if os.path.isfile(config_path): + with open(config_path, "r") as f: + cfg = json.load(f) + value = cfg.get("guidance_scale", None) + if value is not None: + logging.info(f"Read guidance_scale={value} from {config_path}") + return float(value) + logging.warning( + f"guidance_scale not found in {config_path}, using default 0.5." + ) + return 0.5 async def initialize(self): """Initialize the vLLM engine with custom input specifications.""" @@ -124,31 +171,45 @@ async def initialize(self): logging.info("Initializing vLLM engine...") - # Create engine arguments - - engine_args = AsyncEngineArgs( - model=self.model_path, - max_model_len=self.max_model_len, - max_num_batched_tokens=768, - gpu_memory_utilization=self.gpu_memory_utilization, - trust_remote_code=self.trust_remote_code, - mamba_ssm_cache_dtype="float32", - dtype=self.dtype, - skip_tokenizer_init=self.skip_tokenizer_init, - enable_prefix_caching=False - ) - - # please custom input/output specs in model config file - # Create engine config and add custom input specs - vllm_config = engine_args.create_engine_config() - self.custom_input_specs = vllm_config.model_config.custom_input_specs + eartts_env = self.engine_kind == "eartts" + if eartts_env: + # Force TRITON_ATTN backend for EarTTS + os.environ["VLLM_ATTENTION_BACKEND"] = "TRITON_ATTN" + # TF32 matmul precision to match TTS training ("medium"). + # torch.set_float32_matmul_precision is process-local and does NOT + # propagate to vLLM's spawned worker processes; this CUDA-level env + # var is inherited by child processes. + os.environ["NVIDIA_TF32_OVERRIDE"] = "1" + _cached_get_attn_backend.cache_clear() - # Initialize the engine - self.engine = AsyncLLM.from_vllm_config(vllm_config) - - logging.info("Engine initialized with custom input specs:") - for spec in self.custom_input_specs: - logging.info(f" - {spec}") + try: + engine_args = AsyncEngineArgs( + model=self.model_path, + max_model_len=self.max_model_len, + max_num_batched_tokens=self.max_num_batched_tokens, + gpu_memory_utilization=self.gpu_memory_utilization, + trust_remote_code=self.trust_remote_code, + mamba_ssm_cache_dtype="float32", + dtype=self.dtype, + skip_tokenizer_init=self.skip_tokenizer_init, + enable_prefix_caching=False + ) + + # Custom input/output specs are defined in the model's config file + vllm_config = engine_args.create_engine_config() + self.custom_input_specs = vllm_config.model_config.custom_input_specs + + # Initialize the engine + self.engine = AsyncLLM.from_vllm_config(vllm_config) + + logging.info("Engine initialized with custom input specs:") + for spec in self.custom_input_specs: + logging.info(f" - {spec}") + finally: + if eartts_env: + os.environ.pop("VLLM_ATTENTION_BACKEND", None) + os.environ.pop("NVIDIA_TF32_OVERRIDE", None) + _cached_get_attn_backend.cache_clear() def _get_safe_prompt_tokens(self, length: int = 10) -> list[int]: """Generate safe prompt tokens that won't cause immediate EOS.""" @@ -229,7 +290,7 @@ async def generate_next_token(self, input_tensors: list[torch.Tensor], input_dtype = spec.dtype if input_dtype is None: input_dtype = "float32" # Default dtype - if spec.dim !=None and spec.dim != input_tensors[i].shape[-1]: + if spec.dim is not None and spec.dim != input_tensors[i].shape[-1]: raise ValueError(f"Input tensor dimension mismatch for {spec.name}: expected {spec.dim}, got {input_tensors[i].shape[-1]}") custom_inputs[spec.name] = input_tensors[i].to(dtype=getattr(torch, input_dtype)).cpu() max_length = max(max_length, input_tensors[i].shape[0]) @@ -412,68 +473,21 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): await self.shutdown() -class EARTTSStreamingEngine(LLMStreamingEngine): - """ - A specialized streaming engine for EARTTS models. - Inherits from LLMStreamingEngine and sets EARTTS-specific configurations. - """ - def __init__(self, **kwargs): - super().__init__(**kwargs) - guidance_scale = self._read_guidance_scale_from_config() - default_sampling = { - "max_tokens": 100000, # Set very high to prevent stopping - use abort to stop explicitly - "temperature": 0.0, - "skip_sampling": True, - "ignore_eos": True, - "guidance_scale": guidance_scale, - } - self.sampling_params = SamplingParams(**default_sampling) - logging.info(f"EARTTSStreamingEngine initialized (guidance_scale={guidance_scale}).") - - def _read_guidance_scale_from_config(self) -> float: - """Read guidance_scale from the converted vLLM model's config.json.""" - config_path = os.path.join(self.model_path, "config.json") - if os.path.isfile(config_path): - with open(config_path, "r") as f: - cfg = json.load(f) - value = cfg.get("guidance_scale", None) - if value is not None: - logging.info(f"Read guidance_scale={value} from {config_path}") - return float(value) - logging.warning( - f"guidance_scale not found in {config_path}, using default 0.5. " - ) - return 0.5 - - async def initialize(self): - # Force TRITON_ATTN backend for EarTTS - os.environ["VLLM_ATTENTION_BACKEND"] = "TRITON_ATTN" - # TF32 matmul precision to match TTS training ("medium"). - # torch.set_float32_matmul_precision is process-local and does NOT - # propagate to vLLM's spawned worker processes; this CUDA-level env - # var is inherited by child processes. - os.environ["NVIDIA_TF32_OVERRIDE"] = "1" - _cached_get_attn_backend.cache_clear() - await super().initialize() - os.environ.pop("VLLM_ATTENTION_BACKEND", None) - os.environ.pop("NVIDIA_TF32_OVERRIDE", None) - _cached_get_attn_backend.cache_clear() - - -def create_engine(engine_type: str = "llm", **kwargs) -> LLMStreamingEngine: +def create_engine(engine_type: str = "llm", **kwargs) -> CustomInputAsyncVLLMEngine: """ - Factory function to create a streaming engine instance. + Factory function to create a CustomInputAsyncVLLMEngine instance. Args: - engine_type: Type of the engine ("eartts" or "llm", default: "llm") - **kwargs: Additional arguments for engine initialization (model_path, max_model_len, gpu_memory_utilization, trust_remote_code, dtype, and sampling parameters) + engine_type: ``"llm"`` or ``"eartts"`` (maps to ``engine_kind``). + **kwargs: Passed to the engine (model_path, max_model_len, etc.). Returns: - An instance of LLMStreamingEngine or its subclass + A configured ``CustomInputAsyncVLLMEngine``. """ if engine_type == "eartts": - return EARTTSStreamingEngine(**kwargs) + return CustomInputAsyncVLLMEngine(engine_kind="eartts", **kwargs) elif engine_type == "llm": - return LLMStreamingEngine(**kwargs) + return CustomInputAsyncVLLMEngine(engine_kind="llm", **kwargs) else: - raise ValueError(f"Unsupported engine_type: {engine_type}") \ No newline at end of file + raise ValueError(f"Unsupported engine_type: {engine_type}") + diff --git a/nemo/collections/speechlm2/inference/vllm/vllm_patch.py b/nemo/collections/speechlm2/inference/vllm/vllm_patch.py deleted file mode 100644 index 35322235e586..000000000000 --- a/nemo/collections/speechlm2/inference/vllm/vllm_patch.py +++ /dev/null @@ -1,59 +0,0 @@ -# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch - -@torch.no_grad() -def patched_infer_codes_one_step( - self, - current_subword_id, - prev_subword_id, - current_subword_mask, - prev_audio_tokens, - past_key_values, - guidance_enabled=True, - generation_config=None, - ignore_eos_flag_stop=True, - request_id=None, # change signature to include request_id -): - if self.cfg.tts_config.context_hidden_size is not None: - # get context_hidden_state it is always one step behind current_subword_id - # for the first step uses the last step from warmup - context_hidden_state = self.embed_tokens(prev_subword_id) - else: - context_hidden_state = None - - # force silence as next token - if self.cfg.get('inference_force_speech_silence_on_eos', True): - silence_codes = self.codec_silence_tokens.view(1, 1, -1).expand(prev_audio_tokens.shape) - prev_audio_tokens = torch.where( - current_subword_id.unsqueeze(-1) == self.text_eos_id, - silence_codes, # silence - prev_audio_tokens, # keep original - ) - # get subword_ids - inputs = { - "code": prev_audio_tokens, - "context_hidden_state": context_hidden_state, - "subword_ids": current_subword_id, - "subword_mask": current_subword_mask, - "past_key_values": past_key_values, - "use_cache": True, - "guidance_enabled": guidance_enabled, - "generation_config": generation_config, - "ignore_eos_flag_stop": ignore_eos_flag_stop, - "request_id": request_id, # pass request_id to the model - } - outputs = self.tts_model(**inputs) - return outputs["codes"], outputs["past_key_values"] \ No newline at end of file From e0db2ca560c0714c027ae4e23f52b3074ec8d49b Mon Sep 17 00:00:00 2001 From: Elena Rastorgueva Date: Fri, 10 Apr 2026 00:09:55 +0000 Subject: [PATCH 37/40] address CodeQL errors Signed-off-by: Elena Rastorgueva --- .../inference/model_wrappers/backend/__init__.py | 2 +- .../inference/model_wrappers/backend/vllm/base.py | 8 ++++++++ .../model_wrappers/backend/vllm/eartts.py | 2 +- .../inference/model_wrappers/backend/vllm/llm.py | 2 +- .../inference/model_wrappers/perception_cache.py | 1 - .../inference/pipelines/streaming_s2s_pipeline.py | 15 ++++++++------- .../speechlm2/models/nemotron_voicechat.py | 2 +- .../speechlm2/modules/ear_tts_model.py | 1 + 8 files changed, 21 insertions(+), 12 deletions(-) diff --git a/nemo/collections/speechlm2/inference/model_wrappers/backend/__init__.py b/nemo/collections/speechlm2/inference/model_wrappers/backend/__init__.py index 111b94eccc7b..ee2a19e2abd3 100644 --- a/nemo/collections/speechlm2/inference/model_wrappers/backend/__init__.py +++ b/nemo/collections/speechlm2/inference/model_wrappers/backend/__init__.py @@ -20,4 +20,4 @@ from nemo.collections.speechlm2.inference.model_wrappers.backend.vllm.llm import VLLMLLM from nemo.collections.speechlm2.inference.model_wrappers.backend.vllm.eartts import VLLMEarTTS except ImportError: - pass + pass # vLLM is an optional dependency diff --git a/nemo/collections/speechlm2/inference/model_wrappers/backend/vllm/base.py b/nemo/collections/speechlm2/inference/model_wrappers/backend/vllm/base.py index 80a967857cd6..7f917eca2af0 100644 --- a/nemo/collections/speechlm2/inference/model_wrappers/backend/vllm/base.py +++ b/nemo/collections/speechlm2/inference/model_wrappers/backend/vllm/base.py @@ -219,6 +219,10 @@ async def _async_inference( try: await self.engine.abort_generation(request_id) except Exception: + # The request already finished/aborted; abort_generation + # is just releasing engine-side resources. If the engine + # already purged it, the call may raise -- harmless since + # we start a fresh generation immediately after. pass await self.engine.start_generation(request_id=request_id) @@ -300,4 +304,8 @@ def __del__(self): try: self.shutdown() except Exception: + # __del__ may run during interpreter shutdown when globals + # (self._loop, self.engine, asyncio) are already torn down. + # Nothing useful to do with the error; suppress to avoid + # noisy tracebacks on exit. pass diff --git a/nemo/collections/speechlm2/inference/model_wrappers/backend/vllm/eartts.py b/nemo/collections/speechlm2/inference/model_wrappers/backend/vllm/eartts.py index 051388e7dda6..2f344b0aa9d6 100644 --- a/nemo/collections/speechlm2/inference/model_wrappers/backend/vllm/eartts.py +++ b/nemo/collections/speechlm2/inference/model_wrappers/backend/vllm/eartts.py @@ -175,7 +175,7 @@ async def _process_inputs_to_outputs( past_key_values=None # vLLM manages cache internally ) - def prefill_prompt(self, init_inputs, prompt_token_ids, request_id: str, **kwargs): + def prefill_prompt(self, init_inputs, prompt_token_ids=None, request_id=None, **kwargs): """Prefill vLLM EarTTS engine with speaker embedding context. Args: diff --git a/nemo/collections/speechlm2/inference/model_wrappers/backend/vllm/llm.py b/nemo/collections/speechlm2/inference/model_wrappers/backend/vllm/llm.py index f778bb2c8458..c9d8db7c5303 100644 --- a/nemo/collections/speechlm2/inference/model_wrappers/backend/vllm/llm.py +++ b/nemo/collections/speechlm2/inference/model_wrappers/backend/vllm/llm.py @@ -152,7 +152,7 @@ async def _process_inputs_to_outputs( ans["function_predicted_token"] = result.custom_outputs["function_tokens"] return ans - def prefill_prompt(self, embeddings: torch.Tensor, request_id: str, **kwargs) -> bool: + def prefill_prompt(self, embeddings: torch.Tensor, request_id: str = None, **kwargs) -> bool: """Prefill vLLM LLM engine with prompt embeddings in a single bulk step. Args: diff --git a/nemo/collections/speechlm2/inference/model_wrappers/perception_cache.py b/nemo/collections/speechlm2/inference/model_wrappers/perception_cache.py index b3250b3131d6..146521ed36ce 100644 --- a/nemo/collections/speechlm2/inference/model_wrappers/perception_cache.py +++ b/nemo/collections/speechlm2/inference/model_wrappers/perception_cache.py @@ -21,7 +21,6 @@ """ import copy -import time from dataclasses import dataclass import torch diff --git a/nemo/collections/speechlm2/inference/pipelines/streaming_s2s_pipeline.py b/nemo/collections/speechlm2/inference/pipelines/streaming_s2s_pipeline.py index 5875ddbf3fd8..af07de0681c4 100644 --- a/nemo/collections/speechlm2/inference/pipelines/streaming_s2s_pipeline.py +++ b/nemo/collections/speechlm2/inference/pipelines/streaming_s2s_pipeline.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy import os import time from dataclasses import dataclass @@ -525,16 +524,16 @@ def _finalize_and_save_finished_streams( try: with open(os.path.join(txt_dir, f"{base}.txt"), "w", encoding="utf-8") as f: f.write(text_out) - except Exception: - pass + except OSError: + logging.warning(f"Failed to write text output for {base}") asr_text_out = state.output_asr_text_str if isinstance(asr_text_out, str) and asr_text_out: try: with open(os.path.join(txt_dir, f"{base}_asr.txt"), "w", encoding="utf-8") as f: f.write(asr_text_out) - except Exception: - pass + except OSError: + logging.warning(f"Failed to write ASR text output for {base}") saved_paths_by_stream[stream_id] = out_path # Keep state in _state_pool until _build_pipeline_output; @@ -660,7 +659,7 @@ def _build_pipeline_output( if gen_function_text is not None: fc_text = tokens_to_str(gen_function_text, lengths, tokenizer=tokenizer, pad_id=pad_id, eval_text_turn_taking=False)[0] fc_text_raw = tokens_to_str(gen_function_text, lengths, tokenizer=tokenizer, pad_id=pad_id, keep_pad=True)[0] - logging.info(f"Function calling channel: {fc_text}") + logging.info(f"Function calling channel: {fc_text}, fc_text_raw: {fc_text_raw}") else: token_texts.append(None) token_asr_texts.append(None) @@ -732,7 +731,9 @@ def _prefill_system_prompt(self, stream_id: int, system_prompt: str | None = Non start_tts_prefill = time.time() with torch.no_grad(): tts_result = eartts.prefill_prompt( - tts_init_inputs, tts_prompt_token_ids, request_id, + tts_init_inputs, + prompt_token_ids=tts_prompt_token_ids, + request_id=request_id, ) # Capture the generated codes to sync context with vLLM state if hasattr(tts_result, 'codes') and tts_result.codes is not None: diff --git a/nemo/collections/speechlm2/models/nemotron_voicechat.py b/nemo/collections/speechlm2/models/nemotron_voicechat.py index fea99f9f39d8..26fcb2925ebd 100644 --- a/nemo/collections/speechlm2/models/nemotron_voicechat.py +++ b/nemo/collections/speechlm2/models/nemotron_voicechat.py @@ -235,7 +235,7 @@ def _from_pretrained( tts_model_cfg['pretrained_model'] = None tts_model_cfg['pretrained_codec_model'] = None except (KeyError, TypeError): - pass + logging.warning("Could not nullify pretrained TTS/codec paths in nested TTS config") # Instantiate the empty model skeleton model = cls(model_kwargs['cfg']) diff --git a/nemo/collections/speechlm2/modules/ear_tts_model.py b/nemo/collections/speechlm2/modules/ear_tts_model.py index 4be39de8889c..a32ca5740811 100644 --- a/nemo/collections/speechlm2/modules/ear_tts_model.py +++ b/nemo/collections/speechlm2/modules/ear_tts_model.py @@ -999,6 +999,7 @@ def forward(self, subword_ids: Tensor, subword_mask: Tensor | None = None) -> Te # Cache results for future lookups if not self.training and self.use_tts_subword_cache: + valid_ids = torch.masked_select(subword_ids, subword_mask).tolist() valid_embeds = subword_embeds[subword_mask].detach() for idx, sid in enumerate(valid_ids): self._inference_cache[sid] = valid_embeds[idx] From 65d14e6918a16972d4303ec48108d8719ec43308 Mon Sep 17 00:00:00 2001 From: Elena Rastorgueva Date: Fri, 10 Apr 2026 01:43:02 +0000 Subject: [PATCH 38/40] Clean up debug/logging: logger pattern, keep logits on GPU, per-frame logs to debug Signed-off-by: Elena Rastorgueva --- docs/source/speechlm2/streaming_inference.rst | 14 ++++- .../inference/model_wrappers/decode_state.py | 57 +++++++++++++++++++ .../nemotron_voicechat_inference_wrapper.py | 36 +++++------- .../pipelines/streaming_s2s_pipeline.py | 10 +++- .../speechlm2/models/nemotron_voicechat.py | 10 ++-- 5 files changed, 96 insertions(+), 31 deletions(-) diff --git a/docs/source/speechlm2/streaming_inference.rst b/docs/source/speechlm2/streaming_inference.rst index 7532511476e5..6bcfdcb8ddb3 100644 --- a/docs/source/speechlm2/streaming_inference.rst +++ b/docs/source/speechlm2/streaming_inference.rst @@ -89,9 +89,19 @@ over chunks and calls a single step method: pipeline.open_session() for frames in streamer: - pipeline.generate_step(frames) + # Each call returns partial results for this chunk only + step_outputs = pipeline.generate_step(frames) + for out in step_outputs: + # out.text / out.asr_text: new tokens from this step + # out.audio: newly decoded audio for this step + print(f"[stream {out.stream_id}] agent: {out.text} user: {out.asr_text}") pipeline.close_session() - return PipelineOutput(...) + return PipelineOutput(...) # aggregated final results + +Each ``generate_step()`` call returns a list of ``GenerateStepOutput`` carrying +the partial text, ASR text, and audio produced by that single chunk. The +``PipelineOutput`` returned after ``close_session()`` carries the aggregated +results for the entire session. ``generate_step()`` is the unified entry point used by **both** the batch ``run()`` method and server deployments. diff --git a/nemo/collections/speechlm2/inference/model_wrappers/decode_state.py b/nemo/collections/speechlm2/inference/model_wrappers/decode_state.py index bde02df57105..30513447579b 100644 --- a/nemo/collections/speechlm2/inference/model_wrappers/decode_state.py +++ b/nemo/collections/speechlm2/inference/model_wrappers/decode_state.py @@ -78,3 +78,60 @@ class InferenceStepResult: decoded_audio: torch.Tensor | None = None function_predicted_text_tokens: torch.Tensor | None = None debug: dict | None = None + + +class IntermediateResultLogger: + """Records per-frame debug data (logits, embeddings, indices) during inference. + + Tensors are kept on their original device until :meth:`build_debug_dict` + is called, which performs a single bulk copy to CPU. + """ + + def __init__(self): + self.text_logits: list[torch.Tensor] = [] + self.asr_logits: list[torch.Tensor] = [] + self.input_embeds: list[torch.Tensor] = [] + self.selected_frame_indices: list[int] = [] + + def log_input_embeds(self, emb: torch.Tensor): + self.input_embeds.append(emb.detach()) + + def log_text_logits(self, logits: torch.Tensor): + self.text_logits.append(logits.detach()) + + def log_asr_logits(self, logits: torch.Tensor | None): + if logits is not None: + self.asr_logits.append(logits.detach()) + + def log_selected_frame_index(self, idx: int): + self.selected_frame_indices.append(idx) + + def build_debug_dict(self, source_encoded: torch.Tensor, gen_text: torch.Tensor, gen_asr_text: torch.Tensor | None) -> dict: + return { + "source_encoded": source_encoded.detach().cpu(), + "selected_frame_indices": self.selected_frame_indices, + "input_embeds": torch.cat(self.input_embeds, dim=1).cpu() if self.input_embeds else None, + "gen_text": gen_text.detach().cpu(), + "gen_asr": gen_asr_text.detach().cpu() if gen_asr_text is not None else None, + "text_logits": torch.stack(self.text_logits, dim=1).cpu() if self.text_logits else None, + "asr_logits": torch.stack(self.asr_logits, dim=1).cpu() if self.asr_logits else None, + } + + +class NullIntermediateResultLogger: + """No-op stand-in for :class:`IntermediateResultLogger`.""" + + def log_input_embeds(self, emb): + pass + + def log_text_logits(self, logits): + pass + + def log_asr_logits(self, logits): + pass + + def log_selected_frame_index(self, idx): + pass + + def build_debug_dict(self, *args, **kwargs): + return None diff --git a/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py b/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py index cfd3ae4e9789..693bc28bbf63 100755 --- a/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py +++ b/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py @@ -34,6 +34,8 @@ ) from nemo.collections.speechlm2.inference.model_wrappers.decode_state import ( InferenceStepResult, + IntermediateResultLogger, + NullIntermediateResultLogger, StreamingDecodeState, ) from nemo.collections.speechlm2.parts.text_utils import _decode_tokens_with_specials @@ -552,7 +554,7 @@ def infer_one_step( if self.model.stt_model.function_head is not None: function_predicted_tokens = torch.empty((B, num_frames_per_chunk), dtype=state.gen_text.dtype, device=state.gen_text.device) - debug_text_logits, debug_asr_logits, debug_input_embeds, selected_frame_indices = [], [], [], [] + debug_logger = IntermediateResultLogger() if return_debug else NullIntermediateResultLogger() # --- Stage 1: Perception --- source_encoded, state.perception_cache = self._run_perception( @@ -576,14 +578,13 @@ def infer_one_step( for frame_offset in range(num_frames_per_chunk): current_frame_idx = frame_idx + frame_offset current_frame_index = min(base_frame_index + frame_offset, source_encoded.shape[1] - 1) - selected_frame_indices.append(current_frame_index) + debug_logger.log_selected_frame_index(current_frame_index) frame_embedding = source_encoded[:, current_frame_index:current_frame_index + 1, :] input_emb = self._build_input_embedding( frame_embedding, current_frame_idx, state, has_prompt, ) - if return_debug: - debug_input_embeds.append(input_emb.detach().cpu()) + debug_logger.log_input_embeds(input_emb) ans = self._run_llm_step( input_emb, state, frame_offset, effective_request_id, @@ -591,10 +592,11 @@ def infer_one_step( sampling_params=sampling_params, ) - if return_debug and "text_logits" in ans: - debug_text_logits.append(ans["text_logits"][:, -1].detach().cpu()) - if return_debug and "asr_logits" in ans and ans["asr_logits"] is not None: - debug_asr_logits.append(ans["asr_logits"][:, -1].detach().cpu()) + if "text_logits" in ans: + debug_logger.log_text_logits(ans["text_logits"][:, -1]) + asr_logits = ans.get("asr_logits") + if asr_logits is not None: + debug_logger.log_asr_logits(asr_logits[:, -1]) state.gen_text[:, current_frame_idx] = ans["predicted_token"] predicted_tokens[:, frame_offset] = ans["predicted_token"] @@ -623,8 +625,8 @@ def infer_one_step( predicted_text_strs = self._tokens_to_strings(predicted_tokens) asr_predicted_text_strs = self._tokens_to_strings(asr_predicted_tokens) - logging.info(f'frame {frame_idx}: USER asr: {asr_predicted_text_strs}') - logging.info(f'frame {frame_idx}: AGENT txt: {predicted_text_strs}') + logging.debug(f'frame {frame_idx}: USER asr: {asr_predicted_text_strs}') + logging.debug(f'frame {frame_idx}: AGENT txt: {predicted_text_strs}') # --- Update remaining state fields --- if not use_llm_cache: @@ -637,17 +639,7 @@ def infer_one_step( time_for_one_step = time.time() - start_time_one_step logging.info(f'frame {frame_idx}: Time taken for one step: {time_for_one_step:.3f}s') - debug = None - if return_debug: - debug = { - "source_encoded": source_encoded.detach().cpu(), - "selected_frame_indices": selected_frame_indices, - "input_embeds": torch.cat(debug_input_embeds, dim=1) if debug_input_embeds else None, - "gen_text": state.gen_text.detach().cpu(), - "gen_asr": state.gen_asr_text.detach().cpu() if state.gen_asr_text is not None else None, - "text_logits": torch.stack(debug_text_logits, dim=1) if debug_text_logits else None, - "asr_logits": torch.stack(debug_asr_logits, dim=1) if debug_asr_logits else None, - } + debug = debug_logger.build_debug_dict(source_encoded, state.gen_text, state.gen_asr_text) return InferenceStepResult( predicted_text_tokens=predicted_tokens, @@ -863,7 +855,7 @@ def _decode_audio( if not self.decode_audio or not new_codes_for_decode: return None - logging.info(f"Decoding audio for {frame_idx}-th frame ({num_frames_per_chunk=})") + logging.debug(f"Decoding audio for {frame_idx}-th frame ({num_frames_per_chunk=})") start_time_decode = time.time() with fp32_precision(), torch.no_grad(): diff --git a/nemo/collections/speechlm2/inference/pipelines/streaming_s2s_pipeline.py b/nemo/collections/speechlm2/inference/pipelines/streaming_s2s_pipeline.py index af07de0681c4..da7a414a5604 100644 --- a/nemo/collections/speechlm2/inference/pipelines/streaming_s2s_pipeline.py +++ b/nemo/collections/speechlm2/inference/pipelines/streaming_s2s_pipeline.py @@ -309,10 +309,16 @@ def generate_step_for_frames(self, frames: list[Frame], buffers: list[Tensor]) - self.context_manager.reset_slots(stream_ids, eos_flags) - # Explicitly clean up bufferer and state for finished streams + # Log summary and clean up finished streams for stream_id, eos_flag in zip(stream_ids, eos_flags): if eos_flag: - logging.debug(f"Ending stream {stream_id} - cleaning up bufferer and context") + state = self.get_state(stream_id) + audio_sec = state.audio_buffer.shape[-1] / self.output_sample_rate if self.output_sample_rate > 0 else 0 + logging.info( + f"Stream {stream_id} finished: {state.final_total_frames} frames, " + f"{audio_sec:.1f}s audio, " + f"agent: {state.output_text_str!r}, user: {state.output_asr_text_str!r}" + ) self.bufferer.rm_bufferer(stream_id) self._abort_stream_request(stream_id) # Note: We keep the state in _state_pool until finalization to save audio diff --git a/nemo/collections/speechlm2/models/nemotron_voicechat.py b/nemo/collections/speechlm2/models/nemotron_voicechat.py index 26fcb2925ebd..760e44b6df42 100644 --- a/nemo/collections/speechlm2/models/nemotron_voicechat.py +++ b/nemo/collections/speechlm2/models/nemotron_voicechat.py @@ -646,8 +646,8 @@ def offline_inference( T = inference_state["T"] if return_logits: - _text_logits = [ans["text_logits"][:, -1].detach().cpu()] - _asr_logits = [ans["asr_logits"][:, -1].detach().cpu()] if "asr_logits" in ans else [] + _text_logits = [ans["text_logits"][:, -1].detach()] + _asr_logits = [ans["asr_logits"][:, -1].detach()] if "asr_logits" in ans else [] # if speaker_name is provided uses it, if not uses the speaker_audio provided, if speaker_audio is None load it from inference_speaker_reference if speaker_audio is None: @@ -699,9 +699,9 @@ def offline_inference( ans = self.stt_model.streaming_inference._step_inference(t, inference_state, ans) if return_logits: - _text_logits.append(ans["text_logits"][:, -1].detach().cpu()) + _text_logits.append(ans["text_logits"][:, -1].detach()) if "asr_logits" in ans: - _asr_logits.append(ans["asr_logits"][:, -1].detach().cpu()) + _asr_logits.append(ans["asr_logits"][:, -1].detach()) # do one step inference on Duplex TTS model # current subword id is always seem @@ -738,7 +738,7 @@ def offline_inference( audio_pred = torch.cat([audio_pred, audio_pred_i], dim=1) audio_pred_len += audio_pred_i_len - logging.info(f"Autoregressive inference step: {t} of {T} !") + logging.debug(f"Autoregressive inference step: {t} of {T} !") # Trim back to local length if padded if self._use_fsdp and T > inference_state["T_local"]: From c3af8fa79ac4ce838b252efe2a37fb2d6cc0b434 Mon Sep 17 00:00:00 2001 From: Elena Rastorgueva Date: Tue, 14 Apr 2026 00:46:54 +0000 Subject: [PATCH 39/40] Add per-step progress bar, timing summary, and pad-visible logging Signed-off-by: Elena Rastorgueva --- .../s2s_streaming_infer.py | 17 ++++- .../inference/model_wrappers/decode_state.py | 55 +++++++++++++- .../nemotron_voicechat_inference_wrapper.py | 47 +++++------- .../pipelines/streaming_s2s_pipeline.py | 52 ++++++++++--- .../speechlm2/inference/utils/audio_data.py | 24 +++--- .../inference/utils/stepprogressbar.py | 73 +++++++++++++++++++ ...est_nemotron_voicechat_pipeline_nocrash.py | 5 ++ 7 files changed, 216 insertions(+), 57 deletions(-) create mode 100644 nemo/collections/speechlm2/inference/utils/stepprogressbar.py diff --git a/examples/speechlm2/nemo_inference_pipelines/s2s_streaming_infer.py b/examples/speechlm2/nemo_inference_pipelines/s2s_streaming_infer.py index 8f1b27565488..63ec63d91bdc 100644 --- a/examples/speechlm2/nemo_inference_pipelines/s2s_streaming_infer.py +++ b/examples/speechlm2/nemo_inference_pipelines/s2s_streaming_infer.py @@ -29,8 +29,9 @@ from nemo.collections.asr.metrics.wer import word_error_rate from nemo.collections.speechlm2.inference.factory.s2s_pipeline_builder import S2SPipelineBuilder +from nemo.collections.speechlm2.inference.utils.stepprogressbar import StepProgressBar from nemo.collections.speechlm2.inference.utils.audio_data import ( - calculate_duration_incl_padding, + calculate_durations_incl_padding, dump_output, prepare_audio_data, ) @@ -49,16 +50,24 @@ def main(cfg: DictConfig): pipeline = S2SPipelineBuilder.build_pipeline(cfg) + progress_bar = StepProgressBar.from_audio_filepaths( + audio_filepaths, + chunk_size_in_secs=pipeline.chunk_size_in_secs, + pad_audio_to_sec=cfg.get("pad_audio_to_sec"), + pad_silence_ratio=cfg.get("pad_silence_ratio"), + pad_audio_by_sec=cfg.get("pad_audio_by_sec"), + ) + timer = SimpleTimer() timer.start() - output = pipeline.run(audio_filepaths, options=options) + output = pipeline.run(audio_filepaths, options=options, progress_bar=progress_bar) timer.stop() exec_dur = timer.total_sec() logging.info(f"Generated {len(audio_filepaths)} files in {exec_dur:.2f}s") - data_dur = calculate_duration_incl_padding( + data_dur = sum(calculate_durations_incl_padding( audio_filepaths, cfg.get("pad_audio_to_sec"), cfg.get("pad_silence_ratio"), cfg.get("pad_audio_by_sec"), - ) + )) rtfx = data_dur / exec_dur if exec_dur > 0 else float('inf') logging.info(f"RTFX: {rtfx:.2f} ({data_dur:.2f}s / {exec_dur:.2f}s)") diff --git a/nemo/collections/speechlm2/inference/model_wrappers/decode_state.py b/nemo/collections/speechlm2/inference/model_wrappers/decode_state.py index 30513447579b..77f2840b9352 100644 --- a/nemo/collections/speechlm2/inference/model_wrappers/decode_state.py +++ b/nemo/collections/speechlm2/inference/model_wrappers/decode_state.py @@ -31,15 +31,67 @@ from __future__ import annotations -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Any, TYPE_CHECKING import torch +from nemo.utils import logging +from nemo.utils.timers import NamedTimer + if TYPE_CHECKING: from nemo.collections.speechlm2.inference.model_wrappers.perception_cache import PerceptionCacheState +class TimingSummary(NamedTimer): + """Accumulates per-stage wall-clock times across inference steps. + + Extends :class:`~nemo.utils.timers.NamedTimer` (``sync_cuda=True``) + with a :meth:`log_summary` that prints a compact min/mean/max table + at ``logging.info`` level once a stream finishes. + + Usage:: + + timing.start("perception") + # ... run perception ... + timing.stop("perception") + + timing.log_summary(label="Stream 0", chunk_ms=240) + """ + + def __init__(self): + super().__init__(reduction="none", sync_cuda=True) + + def log_summary(self, label: str = "Timing", chunk_ms: float | None = None) -> None: + header = f"{label} timing" + if chunk_ms is not None: + header += f" (chunk={chunk_ms:.0f}ms)" + parts = [] + for name, data in self.timers.items(): + times = data.get("dt", []) + if not times: + continue + mean_ms = sum(times) / len(times) * 1000 + min_ms = min(times) * 1000 + max_ms = max(times) * 1000 + parts.append(f"{name}: mean={mean_ms:.1f}ms min={min_ms:.1f}ms max={max_ms:.1f}ms") + if parts: + logging.info(f"{header}:\n " + "\n ".join(parts)) + + +class NullTimingSummary: + """No-op stand-in for :class:`TimingSummary`.""" + + def start(self, name: str = "") -> None: + pass + + def stop(self, name: str = "") -> None: + pass + + def log_summary(self, label: str = "Timing", chunk_ms: float | None = None) -> None: + pass + + @dataclass class StreamingDecodeState: """Per-stream model-level decode state for streaming S2S inference. @@ -60,6 +112,7 @@ class StreamingDecodeState: perception_cache: "PerceptionCacheState" | None = None tts_codec_cache: Any = None llm_cache_position_offset: int = 0 + timing: TimingSummary | NullTimingSummary = field(default_factory=NullTimingSummary) @dataclass diff --git a/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py b/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py index 693bc28bbf63..c5049fcd665f 100755 --- a/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py +++ b/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py @@ -36,7 +36,9 @@ InferenceStepResult, IntermediateResultLogger, NullIntermediateResultLogger, + NullTimingSummary, StreamingDecodeState, + TimingSummary, ) from nemo.collections.speechlm2.parts.text_utils import _decode_tokens_with_specials @@ -144,9 +146,10 @@ def __init__(self, model_cfg: DictConfig): logging.info(f"Precision (effective): float32_matmul_precision={torch.get_float32_matmul_precision()}, cudnn.allow_tf32={torch.backends.cudnn.allow_tf32}, cuda.matmul.allow_tf32={torch.backends.cuda.matmul.allow_tf32}") logging.info("=" * 70) - # Profiling: when True, insert torch.cuda.synchronize() around each - # stage for accurate per-stage wall-clock timing. Disabled by default - # to avoid unnecessary GPU stalls in production. + # Profiling: when True, a TimingSummary (extending NamedTimer with + # sync_cuda=True) is attached to each decode state, recording + # per-stage wall-clock times. Disabled by default to avoid + # unnecessary GPU stalls in production. self._profile_timing = bool(model_cfg.get("profile_timing", False)) # Cached TTS helpers populated during initialization/warmup @@ -510,6 +513,7 @@ def create_decode_state(self, max_len: int) -> StreamingDecodeState: perception_cache=perception_cache, tts_codec_cache=tts_codec_cache, llm_cache_position_offset=0, + timing=TimingSummary() if self._profile_timing else NullTimingSummary(), ) def infer_one_step( @@ -544,7 +548,7 @@ def infer_one_step( effective_request_id = request_id or self.request_id frame_idx = state.frame_idx - start_time_one_step = time.time() + state.timing.start("total_step") use_llm_cache = state.llm_cache is not None B = state.gen_text.shape[0] @@ -557,9 +561,11 @@ def infer_one_step( debug_logger = IntermediateResultLogger() if return_debug else NullIntermediateResultLogger() # --- Stage 1: Perception --- + state.timing.start("perception") source_encoded, state.perception_cache = self._run_perception( audio_input, frame_idx, num_frames_per_chunk, state.perception_cache, ) + state.timing.stop("perception") total_encoded_frames = source_encoded.shape[1] if self.use_perception_cache and state.perception_cache is not None and state.perception_cache.is_initialized(): # With cache: we get exactly num_frames_per_chunk output frames @@ -634,10 +640,7 @@ def infer_one_step( if use_llm_cache: state.llm_cache_position_offset += num_frames_per_chunk - if self._profile_timing: - torch.cuda.synchronize() - time_for_one_step = time.time() - start_time_one_step - logging.info(f'frame {frame_idx}: Time taken for one step: {time_for_one_step:.3f}s') + state.timing.stop("total_step") debug = debug_logger.build_debug_dict(source_encoded, state.gen_text, state.gen_asr_text) @@ -726,7 +729,7 @@ def _run_llm_step( Updates ``state.llm_cache`` in-place for cached paths. For the no-cache fallback, appends to *new_input_embeds* (list, mutated). """ - start_stt_model = time.time() + state.timing.start("stt_model") if use_llm_cache or self.use_vllm_llm: if self.use_vllm_llm: @@ -763,10 +766,7 @@ def _run_llm_step( sampling_params=sampling_params, ) - if self._profile_timing: - torch.cuda.synchronize() - time_stt_model = time.time() - start_stt_model - logging.info(f"Time taken for stt_model: {time_stt_model:.3f}s") + state.timing.stop("stt_model") return ans @@ -795,7 +795,7 @@ def _run_tts_step( if self.generation_config is None: raise RuntimeError("generation_config is not initialized. Ensure TTS warmup ran successfully.") - start_tts_model = time.time() + state.timing.start("tts_model") if self.use_vllm_eartts: tts_inputs = { "code": state.tts_code, @@ -823,10 +823,7 @@ def _run_tts_step( state.tts_code = result.codes state.tts_past_key_values = result.past_key_values - if self._profile_timing: - torch.cuda.synchronize() - time_tts_model = time.time() - start_tts_model - logging.info(f"Time taken for tts_model: {time_tts_model:.3f}s") + state.timing.stop("tts_model") new_code = state.tts_code.clone() @@ -857,7 +854,7 @@ def _decode_audio( logging.debug(f"Decoding audio for {frame_idx}-th frame ({num_frames_per_chunk=})") - start_time_decode = time.time() + state.timing.start("audio_codec") with fp32_precision(), torch.no_grad(): new_codes_tensor = torch.cat(new_codes_for_decode, dim=1) if hasattr(self.model.tts_model, '_control_codes'): @@ -874,10 +871,7 @@ def _decode_audio( new_codes_tensor, new_code_len, cache=state.tts_codec_cache, ) - if self._profile_timing: - torch.cuda.synchronize() - time_audio_codec = time.time() - start_time_decode - logging.info(f"Time taken for audio_codec: {time_audio_codec:.3f}s") + state.timing.stop("audio_codec") return decoded_audio @@ -889,8 +883,6 @@ def _run_perception( perception_cache: PerceptionCacheState | None, ) -> tuple[torch.Tensor, PerceptionCacheState | None]: """Run the perception encoder and return (source_encoded, updated_cache).""" - start_perception = time.time() - if self.use_perception_cache and perception_cache is not None and perception_cache.is_initialized(): source_encoded, perception_cache = self.perception_cache_mgr.step( audio_input=audio_input, @@ -906,11 +898,6 @@ def _run_perception( return_encoder_emb=True, ) - if self._profile_timing: - torch.cuda.synchronize() - time_perception = time.time() - start_perception - logging.info(f"Time taken for perception: {time_perception:.3f}s") - source_encoded = source_encoded.to(self.dtype) return source_encoded, perception_cache diff --git a/nemo/collections/speechlm2/inference/pipelines/streaming_s2s_pipeline.py b/nemo/collections/speechlm2/inference/pipelines/streaming_s2s_pipeline.py index da7a414a5604..7fe64526694a 100644 --- a/nemo/collections/speechlm2/inference/pipelines/streaming_s2s_pipeline.py +++ b/nemo/collections/speechlm2/inference/pipelines/streaming_s2s_pipeline.py @@ -12,23 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. +import math import os import time from dataclasses import dataclass +import soundfile as sf import torch import librosa from torch import Tensor -import soundfile as sf from omegaconf import DictConfig -import math from nemo.collections.asr.inference.streaming.framing.request import Frame from nemo.collections.asr.inference.utils.enums import RequestType from nemo.collections.asr.inference.streaming.buffering.audio_bufferer import BatchedAudioBufferer -from nemo.collections.asr.inference.utils.progressbar import ProgressBar +from nemo.collections.speechlm2.inference.utils.stepprogressbar import StepProgressBar from nemo.collections.speechlm2.inference.pipelines.s2s_pipeline_interface import S2SPipelineInterface from nemo.collections.speechlm2.inference.streaming.state.s2s_state import S2SStreamingState +from nemo.collections.speechlm2.inference.model_wrappers.decode_state import NullTimingSummary from nemo.collections.speechlm2.inference.model_wrappers.nemotron_voicechat_inference_wrapper import NemotronVoicechatInferenceWrapper from nemo.collections.speechlm2.parts.text_utils import tokens_to_str from nemo.collections.speechlm2.inference.streaming.state.s2s_context_manager import S2SContextManager @@ -295,8 +296,9 @@ def generate_step_for_frames(self, frames: list[Frame], buffers: list[Tensor]) - # Persist updated cache & clean finished streams self.context_manager.update_context(stream_ids, result, self.num_frames_per_chunk) - # Save full token tensors to state before the context is destroyed, - # so we can run tokens_to_str post-hoc. + # Save token tensors and timing from the decode context before it is + # destroyed by reset_slots. + timing_by_stream: dict[int, object] = {} for stream_id, eos_flag in zip(stream_ids, eos_flags): if eos_flag: ctx = self.context_manager.slot_contexts[ @@ -306,6 +308,7 @@ def generate_step_for_frames(self, frames: list[Frame], buffers: list[Tensor]) - state = self.get_or_create_state(stream_id) state.save_token_tensors(ctx.gen_text, ctx.gen_asr_text, ctx.frame_idx, gen_function_text=ctx.gen_function_text) + timing_by_stream[stream_id] = ctx.timing self.context_manager.reset_slots(stream_ids, eos_flags) @@ -319,6 +322,27 @@ def generate_step_for_frames(self, frames: list[Frame], buffers: list[Tensor]) - f"{audio_sec:.1f}s audio, " f"agent: {state.output_text_str!r}, user: {state.output_asr_text_str!r}" ) + + # Compact pad-visible summary (· replaces ) + token_data = state.get_token_tensors() + if token_data is not None: + gen_text, gen_asr_text, total_frames, _ = token_data + tokenizer = self.s2s_model.tokenizer + pad_id = self.s2s_model.model.stt_model.text_pad_id + lengths = torch.tensor([total_frames], dtype=torch.long) + raw_agent = tokens_to_str(gen_text, lengths, tokenizer=tokenizer, pad_id=pad_id, keep_pad=True)[0] + raw_user = tokens_to_str(gen_asr_text, lengths, tokenizer=tokenizer, pad_id=pad_id, keep_pad=True)[0] + compact_agent = raw_agent.replace('', '·') + compact_user = raw_user.replace('', '·') + logging.info(f"Stream {stream_id} agent (with padding): {compact_agent}") + logging.info(f"Stream {stream_id} user (with padding): {compact_user}") + + # Timing summary (no-op when profile_timing is off) + timing_by_stream.get(stream_id, NullTimingSummary()).log_summary( + label=f"Stream {stream_id}", + chunk_ms=self.chunk_size_in_secs * 1000, + ) + self.bufferer.rm_bufferer(stream_id) self._abort_stream_request(stream_id) # Note: We keep the state in _state_pool until finalization to save audio @@ -566,15 +590,19 @@ def run( self, audio_filepaths: list[str], options: list[S2SRequestOptions] | None = None, - progress_bar: ProgressBar | None = None, + progress_bar: StepProgressBar | None = None, ) -> PipelineOutput: """Stream all *audio_filepaths* through the pipeline and save outputs. Saves one generated ``.wav`` per input under ``self.output_dir`` and returns their paths in ``PipelineOutput.texts``. + + Args: + audio_filepaths: Paths to input audio files. + options: Per-stream request options (system prompt, sampling, etc.). + progress_bar: Optional :class:`StepProgressBar` for per-step + progress with per-stream postfix. """ - if progress_bar and not isinstance(progress_bar, ProgressBar): - raise ValueError("progress_bar must be an instance of ProgressBar.") if options is None: options = [S2SRequestOptions() for _ in audio_filepaths] @@ -590,7 +618,6 @@ def run( pad_ratio=self.pad_silence_ratio, ) streamer.set_audio_filepaths(audio_filepaths, options) - streamer.set_progress_bar(progress_bar) os.makedirs(self.output_dir, exist_ok=True) saved_paths_by_stream: dict[int, str] = {} @@ -604,6 +631,13 @@ def run( self.generate_step(frames) self._finalize_and_save_finished_streams(frames, audio_filepaths, saved_paths_by_stream) + if progress_bar is not None: + for f in frames: + progress_bar.step(f.stream_id) + + if progress_bar is not None: + progress_bar.finish() + output = self._build_pipeline_output(audio_filepaths, saved_paths_by_stream) self.close_session() return output diff --git a/nemo/collections/speechlm2/inference/utils/audio_data.py b/nemo/collections/speechlm2/inference/utils/audio_data.py index a054b18b8d8a..2135e97dfd17 100644 --- a/nemo/collections/speechlm2/inference/utils/audio_data.py +++ b/nemo/collections/speechlm2/inference/utils/audio_data.py @@ -83,33 +83,31 @@ def prepare_audio_data( return filepaths, options, ground_truths -def calculate_duration_incl_padding( +def calculate_durations_incl_padding( audio_filepaths: list[str], pad_audio_to_sec: float | None = None, pad_silence_ratio: float | None = None, pad_audio_by_sec: float | None = None, -) -> float: - """Calculate total duration of the given audio files in seconds. +) -> list[float]: + """Return per-file durations in seconds, accounting for silence padding. - Optionally accounts for silence padding appended after each file. At most one padding argument may be set; when none are set this - returns the raw audio duration. + returns the raw audio durations. """ if sum(x is not None for x in [pad_audio_to_sec, pad_silence_ratio, pad_audio_by_sec]) > 1: raise ValueError("Set at most one of: pad_audio_to_sec, pad_silence_ratio, pad_audio_by_sec") - total = 0.0 + durations = [] for fp in audio_filepaths: sound = sf.SoundFile(fp) - orig = sound.frames / sound.samplerate + dur = sound.frames / sound.samplerate if pad_audio_to_sec is not None: - total += max(orig, pad_audio_to_sec) + dur = max(dur, pad_audio_to_sec) elif pad_silence_ratio is not None: - total += orig * (1 + pad_silence_ratio) + dur *= (1 + pad_silence_ratio) elif pad_audio_by_sec is not None: - total += orig + pad_audio_by_sec - else: - total += orig - return total + dur += pad_audio_by_sec + durations.append(dur) + return durations def dump_output( diff --git a/nemo/collections/speechlm2/inference/utils/stepprogressbar.py b/nemo/collections/speechlm2/inference/utils/stepprogressbar.py new file mode 100644 index 000000000000..a21b25fc8039 --- /dev/null +++ b/nemo/collections/speechlm2/inference/utils/stepprogressbar.py @@ -0,0 +1,73 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Per-inference-step progress bar for S2S streaming pipelines.""" + +from __future__ import annotations + +import math + +from tqdm import tqdm + +from nemo.collections.speechlm2.inference.utils.audio_data import calculate_durations_incl_padding + + +class StepProgressBar: + """Tracks per-step inference progress across one or more streams. + + Each call to :meth:`step` advances the bar by one and updates the + per-stream postfix (e.g. ``stream 2: 45/127``). + + Create via :meth:`from_audio_filepaths`. + """ + + def __init__(self, total_steps: int, steps_per_stream: dict[int, int] | None = None): + self._bar = tqdm(total=total_steps, desc="Inference", unit="step", dynamic_ncols=True) + self._steps_per_stream = steps_per_stream or {} + self._stream_progress: dict[int, int] = {} + + def step(self, stream_id: int) -> None: + """Record one inference step for *stream_id* and advance the bar.""" + self._stream_progress[stream_id] = self._stream_progress.get(stream_id, 0) + 1 + stream_total = self._steps_per_stream.get(stream_id) + if stream_total is not None: + self._bar.set_postfix_str( + f"stream {stream_id}: {self._stream_progress[stream_id]}/{stream_total}", + refresh=False, + ) + self._bar.update(1) + + def finish(self) -> None: + """Close the underlying tqdm bar.""" + self._bar.close() + + @classmethod + def from_audio_filepaths( + cls, + audio_filepaths: list[str], + chunk_size_in_secs: float, + pad_audio_to_sec: float | None = None, + pad_silence_ratio: float | None = None, + pad_audio_by_sec: float | None = None, + ) -> StepProgressBar: + durations = calculate_durations_incl_padding( + audio_filepaths, pad_audio_to_sec, pad_silence_ratio, pad_audio_by_sec, + ) + steps_per_stream = { + idx: math.ceil(dur / chunk_size_in_secs) for idx, dur in enumerate(durations) + } + return cls( + total_steps=sum(steps_per_stream.values()), + steps_per_stream=steps_per_stream, + ) diff --git a/tests/collections/speechlm2/test_nemotron_voicechat_pipeline_nocrash.py b/tests/collections/speechlm2/test_nemotron_voicechat_pipeline_nocrash.py index 18ab6b4861b0..26fe0ba0f115 100644 --- a/tests/collections/speechlm2/test_nemotron_voicechat_pipeline_nocrash.py +++ b/tests/collections/speechlm2/test_nemotron_voicechat_pipeline_nocrash.py @@ -130,6 +130,11 @@ def _build_no_crash_pipeline( {}, id="deterministic", ), + pytest.param( + {"profile_timing": True}, + {}, + id="profile_timing", + ), ] # --------------------------------------------------------------------------- From 9c3326d5071694006efbb65f2db650df447c6b5d Mon Sep 17 00:00:00 2001 From: Elena Rastorgueva Date: Tue, 14 Apr 2026 19:17:08 +0000 Subject: [PATCH 40/40] refactor pipeline test builders to accept overrides as positional args; test padding Signed-off-by: Elena Rastorgueva --- ...est_nemotron_voicechat_pipeline_nocrash.py | 142 ++++++++++-------- ...test_nemotron_voicechat_pipeline_parity.py | 78 ++++++---- 2 files changed, 123 insertions(+), 97 deletions(-) diff --git a/tests/collections/speechlm2/test_nemotron_voicechat_pipeline_nocrash.py b/tests/collections/speechlm2/test_nemotron_voicechat_pipeline_nocrash.py index 26fe0ba0f115..8d9a7d1db068 100644 --- a/tests/collections/speechlm2/test_nemotron_voicechat_pipeline_nocrash.py +++ b/tests/collections/speechlm2/test_nemotron_voicechat_pipeline_nocrash.py @@ -35,6 +35,7 @@ from nemo.collections.speechlm2.inference.factory.s2s_pipeline_builder import S2SPipelineBuilder from nemo.collections.speechlm2.inference.pipelines.streaming_s2s_pipeline import StreamingS2SPipeline +from nemo.collections.speechlm2.inference.utils.stepprogressbar import StepProgressBar _CONF_YAML = os.path.join( os.path.dirname(__file__), @@ -42,24 +43,10 @@ ) _MOCK_SYSTEM_PROMPT = "This is a mock prompt for the test" -# --------------------------------------------------------------------------- -# Helper -# --------------------------------------------------------------------------- - - -def _build_no_crash_pipeline( - model_path: str, - audio_path: str, - output_dir: str, - *, - s2s_overrides: dict[str, Any] | None = None, - streaming_overrides: dict[str, Any] | None = None, -) -> StreamingS2SPipeline: - """Build a :class:`StreamingS2SPipeline` with custom overrides for no-crash testing.""" - cfg = OmegaConf.load(_CONF_YAML) - - s2s_cfg: dict[str, Any] = { - "model_path": model_path, +# Safe defaults so tests run quickly with the tiny model. +# Individual tests override specific keys via OmegaConf.merge. +_TEST_DEFAULTS = { + "s2s": { "engine_type": "native", "compute_dtype": "float32", "deterministic": False, @@ -71,68 +58,82 @@ def _build_no_crash_pipeline( "top_p": 1.0, "repetition_penalty": 1.0, "temperature": 1.0, - } - streaming_cfg: dict[str, Any] = { + }, + "streaming": { "chunk_size_in_secs": 0.08, "buffer_size_in_secs": 71 * 0.08, - } - - if s2s_overrides: - s2s_cfg.update(s2s_overrides) - if streaming_overrides: - streaming_cfg.update(streaming_overrides) - - overrides = { - "audio_file": audio_path, - "output_dir": output_dir, - "s2s": s2s_cfg, - "streaming": streaming_cfg, - } - cfg = OmegaConf.merge(cfg, OmegaConf.create(overrides)) + }, +} + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _build_no_crash_pipeline( + model_path: str, + audio_path: str, + output_dir: str, + *overrides: dict[str, Any], +) -> StreamingS2SPipeline: + """Build a :class:`StreamingS2SPipeline` for no-crash testing. + + Loads the YAML base config, applies ``_TEST_DEFAULTS``, then merges + each dict in *overrides* on top (in order). Nested dicts are + recursively merged by OmegaConf, so ``{"s2s": {"top_p": 0.9}}`` + overrides only that key while keeping all other s2s defaults. + """ + cfg = OmegaConf.load(_CONF_YAML) + cfg = OmegaConf.merge( + cfg, + _TEST_DEFAULTS, + {"audio_file": audio_path, "output_dir": output_dir, "s2s": {"model_path": model_path}}, + ) + for overrides in overrides: + if overrides: + cfg = OmegaConf.merge(cfg, overrides) return S2SPipelineBuilder.build_pipeline(cfg) # --------------------------------------------------------------------------- -# Parametrized configs +# Parametrized configs — each entry is a single overrides dict # --------------------------------------------------------------------------- # Text-only configs (decode_audio=False): minimal STT-path smoke checks. -# Most config variations are folded into the audio tests below. _TEXT_CONFIGS = [ - pytest.param({}, {}, id="baseline"), + pytest.param({}, id="baseline"), pytest.param( - {"use_llm_cache": True, "use_perception_cache": True}, - {}, + {"s2s": {"use_llm_cache": True, "use_perception_cache": True}}, id="both_caches", ), + pytest.param({"pad_audio_by_sec": 2}, id="pad_by_sec"), ] # Audio configs (decode_audio=True): exercises the full STT + TTS pipeline. _AUDIO_CONFIGS = [ - pytest.param({}, {}, id="baseline"), + pytest.param({}, id="baseline"), pytest.param( - {"use_llm_cache": True, "use_perception_cache": True, "system_prompt": _MOCK_SYSTEM_PROMPT}, - {"chunk_size_in_secs": 0.24}, - id="both_caches_prompt_multiframe", + {"s2s": {"use_llm_cache": True, "use_perception_cache": True, "system_prompt": _MOCK_SYSTEM_PROMPT}, + "streaming": {"chunk_size_in_secs": 0.24}, + "pad_audio_to_sec": 5}, + id="both_caches_prompt_multiframe_pad_to_sec", ), pytest.param( - {"use_llm_cache": True, "top_p": 0.9, "temperature": 0.7, "repetition_penalty": 1.1}, - {}, - id="sampling", + {"s2s": {"use_llm_cache": True, "top_p": 0.9, "temperature": 0.7, "repetition_penalty": 1.1}, + "pad_silence_ratio": 0.5}, + id="sampling_pad_silence_ratio", ), pytest.param( - {"use_tts_subword_cache": True, "use_tts_torch_compile": True}, - {}, - id="tts_optimizations", + {"s2s": {"use_tts_subword_cache": True, "use_tts_torch_compile": True}, + "pad_audio_by_sec": 2}, + id="tts_optimizations_pad_by_sec", ), pytest.param( - {"deterministic": True, "temperature": 0.0}, - {}, + {"s2s": {"deterministic": True, "temperature": 0.0}}, id="deterministic", ), pytest.param( - {"profile_timing": True}, - {}, + {"s2s": {"profile_timing": True}}, id="profile_timing", ), ] @@ -143,33 +144,42 @@ def _build_no_crash_pipeline( @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU") -@pytest.mark.parametrize("s2s_overrides,streaming_overrides", _TEXT_CONFIGS) -def test_pipeline_no_crash_tiny_model(tiny_model_artifacts, s2s_overrides, streaming_overrides): +@pytest.mark.parametrize("overrides", _TEXT_CONFIGS) +def test_pipeline_no_crash_tiny_model(tiny_model_artifacts, overrides): """Run the streaming pipeline with various configs and verify it doesn't crash.""" model_dir, audio_path, _ = tiny_model_artifacts output_dir = tempfile.mkdtemp(prefix="no-crash-text-") - pipeline = _build_no_crash_pipeline( - model_dir, audio_path, output_dir, - s2s_overrides=s2s_overrides, streaming_overrides=streaming_overrides, + pipeline = _build_no_crash_pipeline(model_dir, audio_path, output_dir, overrides) + progress_bar = StepProgressBar.from_audio_filepaths( + [audio_path], + chunk_size_in_secs=pipeline.chunk_size_in_secs, + pad_audio_to_sec=pipeline.pad_audio_to_sec, + pad_silence_ratio=pipeline.pad_silence_ratio, + pad_audio_by_sec=pipeline.pad_audio_by_sec, ) - result = pipeline.run([audio_path]) + result = pipeline.run([audio_path], progress_bar=progress_bar) assert result is not None @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU") -@pytest.mark.parametrize("s2s_overrides,streaming_overrides", _AUDIO_CONFIGS) -def test_pipeline_no_crash_tiny_model_decode_audio(tiny_model_artifacts, s2s_overrides, streaming_overrides): +@pytest.mark.parametrize("overrides", _AUDIO_CONFIGS) +def test_pipeline_no_crash_tiny_model_decode_audio(tiny_model_artifacts, overrides): """Run the streaming pipeline with decode_audio=True and verify it doesn't crash.""" model_dir, audio_path, speaker_ref_path = tiny_model_artifacts output_dir = tempfile.mkdtemp(prefix="no-crash-audio-") - audio_overrides = {"decode_audio": True, "speaker_reference": speaker_ref_path} - audio_overrides.update(s2s_overrides) - pipeline = _build_no_crash_pipeline( model_dir, audio_path, output_dir, - s2s_overrides=audio_overrides, streaming_overrides=streaming_overrides, + {"s2s": {"decode_audio": True, "speaker_reference": speaker_ref_path}}, + overrides, + ) + progress_bar = StepProgressBar.from_audio_filepaths( + [audio_path], + chunk_size_in_secs=pipeline.chunk_size_in_secs, + pad_audio_to_sec=pipeline.pad_audio_to_sec, + pad_silence_ratio=pipeline.pad_silence_ratio, + pad_audio_by_sec=pipeline.pad_audio_by_sec, ) - result = pipeline.run([audio_path]) + result = pipeline.run([audio_path], progress_bar=progress_bar) assert result is not None diff --git a/tests/collections/speechlm2/test_nemotron_voicechat_pipeline_parity.py b/tests/collections/speechlm2/test_nemotron_voicechat_pipeline_parity.py index b4f61941c4d7..217af74f5a01 100644 --- a/tests/collections/speechlm2/test_nemotron_voicechat_pipeline_parity.py +++ b/tests/collections/speechlm2/test_nemotron_voicechat_pipeline_parity.py @@ -226,18 +226,37 @@ def assert_parity( assert not failures, "Parity failed:\n " + "\n ".join(failures) +# Parity requires deterministic, float32, no caches, greedy decoding. +_PARITY_DEFAULTS = { + "s2s": { + "engine_type": "native", + "compute_dtype": "float32", + "deterministic": True, + "decode_audio": False, + "use_perception_cache": False, + "use_perception_cudagraph": False, + "use_llm_cache": False, + "top_p": 1.0, + "repetition_penalty": 1.0, + "temperature": 1.0, + }, +} + + def _build_parity_pipeline( model_path: str, audio_path: str, output_dir: str, - *, - speaker_name: str | None = None, - system_prompt: str | None = None, + *overrides: dict[str, Any], ) -> StreamingS2SPipeline: """Build a :class:`StreamingS2SPipeline` configured for strict parity testing. - Loads ``s2s_streaming.yaml`` as the base config and applies - parity-specific overrides (deterministic, float32, no caches, greedy). + Loads ``s2s_streaming.yaml`` as the base config, applies + ``_PARITY_DEFAULTS``, then merges each dict in *overrides* + on top (in order). + + The chunk size is set to cover the full audio in one step so that + offline and incremental paths see identical input. """ import librosa @@ -246,31 +265,22 @@ def _build_parity_pipeline( chunk_secs = total_frames * FRAME_SIZE_SAMPLES / SAMPLE_RATE cfg = OmegaConf.load(_CONF_YAML) - overrides = { - "audio_file": audio_path, - "output_dir": output_dir, - "s2s": { - "model_path": model_path, - "engine_type": "native", - "compute_dtype": "float32", - "deterministic": True, - "decode_audio": False, - "use_perception_cache": False, - "use_perception_cudagraph": False, - "use_llm_cache": False, - "system_prompt": system_prompt, - "top_p": 1.0, - "repetition_penalty": 1.0, - "temperature": 1.0, + cfg = OmegaConf.merge( + cfg, + _PARITY_DEFAULTS, + { + "audio_file": audio_path, + "output_dir": output_dir, + "s2s": {"model_path": model_path}, + "streaming": { + "chunk_size_in_secs": chunk_secs, + "buffer_size_in_secs": max(71 * 0.08, chunk_secs), + }, }, - "streaming": { - "chunk_size_in_secs": chunk_secs, - "buffer_size_in_secs": max(71 * 0.08, chunk_secs), - }, - } - if speaker_name: - overrides["s2s"]["speaker_name"] = speaker_name - cfg = OmegaConf.merge(cfg, OmegaConf.create(overrides)) + ) + for overrides in overrides: + if overrides: + cfg = OmegaConf.merge(cfg, overrides) return S2SPipelineBuilder.build_pipeline(cfg) @@ -289,7 +299,10 @@ def test_parity_tiny_model(tiny_model_artifacts): model_dir, audio_path, _ = tiny_model_artifacts output_dir = tempfile.mkdtemp(prefix="parity-tiny-") - pipeline = _build_parity_pipeline(model_dir, audio_path, output_dir, system_prompt=_MOCK_SYSTEM_PROMPT) + pipeline = _build_parity_pipeline( + model_dir, audio_path, output_dir, + {"s2s": {"system_prompt": _MOCK_SYSTEM_PROMPT}}, + ) report = run_parity_check(pipeline, audio_path, system_prompt=_MOCK_SYSTEM_PROMPT) assert_parity(report, strict=True, atol=0.0) @@ -319,9 +332,12 @@ def test_parity_real_checkpoint(): audio = os.environ.get("PARITY_AUDIO_PATH") or _FORCE_ALIGN_AUDIO speaker = os.environ.get("PARITY_SPEAKER_NAME") + parity_overrides: dict[str, Any] = {"s2s": {"system_prompt": _MOCK_SYSTEM_PROMPT}} + if speaker: + parity_overrides["s2s"]["speaker_name"] = speaker pipeline = _build_parity_pipeline( ckpt, audio, tempfile.mkdtemp(prefix="parity-"), - speaker_name=speaker, system_prompt=_MOCK_SYSTEM_PROMPT, + parity_overrides, ) report = run_parity_check(pipeline, audio, system_prompt=_MOCK_SYSTEM_PROMPT) assert_parity(report, strict=True, atol=0.0)