diff --git a/docs/source/speechlm2/intro.rst b/docs/source/speechlm2/intro.rst index 94f0426d575b..ff4158002963 100644 --- a/docs/source/speechlm2/intro.rst +++ b/docs/source/speechlm2/intro.rst @@ -246,7 +246,35 @@ You can evaluate and run full-duplex inference using the `NemotronVoiceChat` pip print(f"Agent response: {generated_text}") # generated_speech can now be saved or played (sampled at model.target_sample_rate) - + +NemotronVoiceChat Streaming Inference +************************************* + +For real-time, chunk-by-chunk inference (as opposed to the offline mode shown +above), use the Streaming S2S Pipeline: + +.. code-block:: python + + from nemo.collections.speechlm2.inference import S2SPipelineBuilder + + pipeline = S2SPipelineBuilder.build_pipeline(cfg) + output = pipeline.run(audio_filepaths, options=options) + +Or from the command line: + +.. code-block:: bash + + python examples/speechlm2/nemo_inference_pipelines/s2s_streaming_infer.py \ + audio_file=/path/to/audio \ + s2s.model_path=/path/to/checkpoint \ + s2s.speaker_name="" \ + s2s.engine_type=native \ + s2s.system_prompt="You are a helpful assistant." \ + streaming.chunk_size_in_secs=0.24 \ + streaming.buffer_size_in_secs=1.68 + +See :doc:`streaming_inference` for full details on configuration, architecture, +and server integration. Training a Model ---------------- @@ -341,3 +369,4 @@ For more information, see additional sections in the SpeechLM2 docs: datasets configs training_and_scaling + streaming_inference diff --git a/docs/source/speechlm2/streaming_inference.rst b/docs/source/speechlm2/streaming_inference.rst new file mode 100644 index 000000000000..6bcfdcb8ddb3 --- /dev/null +++ b/docs/source/speechlm2/streaming_inference.rst @@ -0,0 +1,498 @@ +Streaming Inference +=================== + +The speechlm2 collection provides a streaming inference pipeline for +NemotronVoiceChat that processes audio in real time, chunk by chunk, and +produces both text and speech output incrementally. The pipeline follows the +same methodology as the NeMo ASR Inference Pipelines (see +``nemo.collections.asr.inference``). + +Overview +-------- + +The streaming inference stack has four layers: + +.. code-block:: text + + Entry Script s2s_streaming_infer.py (Hydra) + │ + ▼ + Pipeline StreamingS2SPipeline + │ - audio buffering + │ - state management + │ - file I/O + ▼ + Model Wrapper NemotronVoicechatInferenceWrapper + │ - infer_one_step() + │ - perception + │ - model_llm_interface (PyTorchLLM or VLLMLLM) + │ - model_eartts_interface (PyTorchEarTTS or VLLMEarTTS) + │ - codec decode + ▼ + Model NemotronVoiceChat + - DuplexSTTModel + DuplexEARTTS + +Quick Start +----------- + +Batch Inference from a Script +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The simplest way to run streaming inference is with the provided Hydra script: + +.. code-block:: bash + + python examples/speechlm2/nemo_inference_pipelines/s2s_streaming_infer.py \ + --config-path=examples/speechlm2/nemo_inference_pipelines/conf \ + --config-name=s2s_streaming \ + audio_file=/path/to/audio_or_directory_or_manifest.json \ + output_dir=./generated \ + s2s.model_path=/path/to/checkpoint \ + s2s.speaker_name="" \ + s2s.engine_type=native \ + s2s.system_prompt="You are a helpful assistant." \ + streaming.chunk_size_in_secs=0.24 \ + streaming.buffer_size_in_secs=1.68 + +This will: + +1. Load the NemotronVoiceChat checkpoint. +2. Stream each audio file through the pipeline in chunks. +3. Save generated ``.wav``, stereo (input+output), and ``.txt`` files under + ``output_dir``. + +Programmatic Usage +^^^^^^^^^^^^^^^^^^ + +.. code-block:: python + + from nemo.collections.speechlm2.inference import S2SPipelineBuilder + + pipeline = S2SPipelineBuilder.build_pipeline(cfg) + output = pipeline.run(audio_filepaths, options=options) + + # output.texts -- generated agent text per file + # output.asr_texts -- recognized user text per file + # output.audio_filepaths -- paths to generated .wav files + + +Architecture +------------ + +The Core Loop +^^^^^^^^^^^^^ + +Like the ASR pipeline's ``BasePipeline.run()``, the S2S pipeline iterates +over chunks and calls a single step method: + +.. code-block:: python + + pipeline.open_session() + for frames in streamer: + # Each call returns partial results for this chunk only + step_outputs = pipeline.generate_step(frames) + for out in step_outputs: + # out.text / out.asr_text: new tokens from this step + # out.audio: newly decoded audio for this step + print(f"[stream {out.stream_id}] agent: {out.text} user: {out.asr_text}") + pipeline.close_session() + return PipelineOutput(...) # aggregated final results + +Each ``generate_step()`` call returns a list of ``GenerateStepOutput`` carrying +the partial text, ASR text, and audio produced by that single chunk. The +``PipelineOutput`` returned after ``close_session()`` carries the aggregated +results for the entire session. + +``generate_step()`` is the unified entry point used by **both** the batch +``run()`` method and server deployments. + + +What Happens Inside One Step +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. code-block:: text + + generate_step(frames) + │ + ├─ for each frame where is_first=True: + │ │ + │ └─ _init_state(stream_id, options) + │ 1. augment_with_defaults() ← fill None fields from YAML + │ 2. create_state(options) ← pipeline-level state + │ 3. create context_manager ← KV caches, decode slots + │ 4. prefill system prompt ← populate LLM KV cache + │ + ├─ any frames with audio? + │ │ + │ NO → return empty outputs (server prefill-only request) + │ │ + │ YES ↓ + │ + └─ generate_step_for_frames() + 1. audio buffering + 2. perception encoder + 3. per-frame LLM loop + 4. per-frame TTS + 5. codec decode + 6. state updates + output accumulation + 7. return list[GenerateStepOutput] + +Each call to ``generate_step(frames)`` performs: + +1. **Stream init on** ``is_first`` -- If a frame has ``is_first=True``, the + private ``_init_state()`` method runs: per-stream options are merged with + pipeline defaults (via ``S2SRequestOptions.augment_with_defaults()``), + a fresh ``S2SStreamingState`` is created, the context manager is + allocated, and the LLM KV cache is prefilled with the system prompt and + TTS speaker embedding. This mirrors ASR's ``init_state()`` called inside + ``transcribe_step()``. If the frame carries no audio (zero-length + samples), the method returns after init — this is the recommended + pattern for latency-sensitive server deployments (see + :ref:`init-and-latency` below). + +2. **Audio buffering** -- ``BatchedAudioBufferer`` (reused from ASR + infrastructure) maintains a sliding window of ``buffer_size_in_secs``. + +3. **Model inference** via ``infer_one_step(audio_buffer, state)``: + + a. **Perception** -- The audio buffer is encoded by the streaming + FastConformer encoder into frame embeddings. + b. **Per-frame LLM loop** -- For each of the ``num_frames_per_chunk`` + frames, the pipeline builds an input embedding (user audio + + previous-step text/ASR tokens), runs it through the LLM, and obtains + predicted text and ASR tokens. + c. **Per-frame TTS** -- Each predicted text token is fed into the EarTTS + model to produce audio codec codes. + d. **Codec decode** -- The accumulated codes are decoded into a waveform. + +4. **State updates** -- The context manager advances ``frame_idx`` and + updates the subword mask. + +5. **Output accumulation** -- Decoded audio and text are appended to the + per-stream ``S2SStreamingState``. + + +Two Kinds of State +^^^^^^^^^^^^^^^^^^ + +The pipeline maintains two separate state objects per stream: + +**StreamingDecodeState** (model level) + Lives in ``S2SContextManager`` slots. Contains LLM KV cache, TTS KV + cache, perception cache, codec cache, token workspaces (``gen_text``, + ``gen_asr_text``), and ``frame_idx``. Created by the wrapper, mutated + in-place by ``infer_one_step()``, destroyed at end-of-stream. + +**S2SStreamingState** (pipeline level) + Lives in the pipeline's ``_state_pool``. Accumulates generated audio + chunks, text strings, and word timings across steps. Kept alive until + ``close_session()`` so the final ``PipelineOutput`` can be assembled. + + +Inference Backends +^^^^^^^^^^^^^^^^^^ + +NemotronVoiceChat has two inference components that each need a backend: + +- **LLM** (DuplexSTT backbone) -- takes audio embeddings from the perception + encoder and predicts text tokens, ASR tokens, and optional function-call + tokens at each frame. +- **TTS** (EarTTS) -- takes the predicted text token and produces audio codec + codes (RVQ acoustic tokens). + +Each component can run on **native PyTorch** or **vLLM**, selected by the +``engine_type`` config value: + +.. list-table:: + :header-rows: 1 + :widths: 35 30 30 + + * - ``engine_type`` + - LLM backend + - TTS backend + * - ``native`` + - ``PyTorchLLM`` + - ``PyTorchEarTTS`` + * - ``vllm_llm`` + - ``VLLMLLM`` + - ``PyTorchEarTTS`` + * - ``vllm_eartts`` + - ``PyTorchLLM`` + - ``VLLMEarTTS`` + * - ``vllm_llm_vllm_eartts`` + - ``VLLMLLM`` + - ``VLLMEarTTS`` + +All four backend classes implement the same ``ModelInterface`` ABC (defined in +``inference.model_wrappers.backend.interface``), so the inference wrapper +(``NemotronVoicechatInferenceWrapper``) can treat them uniformly via two +attributes: + +- ``model_llm_interface`` -- the LLM backend +- ``model_eartts_interface`` -- the TTS backend + +The backend classes live under ``inference.model_wrappers.backend/``: + +.. code-block:: text + + backend/ + interface.py # ModelInterface ABC + shared sampling + pytorch/ + model.py # PyTorchLLM (wraps DuplexSTT forward pass) + eartts.py # PyTorchEarTTS (wraps DuplexEARTTS.infer_codes_one_step) + vllm/ + base.py # VLLMModelBase (engine lifecycle, async loop) + llm.py # VLLMLLM (DuplexSTT via vLLM) + eartts.py # VLLMEarTTS (EarTTS via vLLM) + factory.py # create_model() dispatches to the right class + +``ModelInterface`` provides shared utilities used by the LLM backends: +top-p (nucleus) sampling, repetition penalty, and temperature scaling. +These are applied **post-hoc** on the returned logits -- vLLM internally +runs with ``skip_sampling=True`` and greedy decoding. + +Each backend also exposes lifecycle methods that the wrapper calls uniformly: + +- ``prefill_prompt(embeddings, ...)`` -- Warm up KV cache (native) or + prefill the vLLM engine with system-prompt embeddings before streaming. +- ``compile()`` -- Apply ``torch.compile`` to the TTS backbone (native + only; no-op for vLLM). +- ``setup_subword_cache(cfg)`` -- Enable the TTS subword embedding cache + (native only; no-op for vLLM). + +The ``factory.create_model()`` function is the single entry point that +dispatches to the correct class based on a per-component ``engine_type`` +string (``native_llm``, ``native_eartts``, ``vllm_llm``, ``vllm_eartts``). + +vLLM Integration Details +"""""""""""""""""""""""" + +When ``engine_type`` includes ``vllm``, the pipeline loads vLLM engines +**in-process** alongside the native PyTorch components -- there is no +disaggregated multi-server setup. Each vLLM component runs as an +``AsyncLLM`` engine in the same Python process, sharing GPU memory with the +native perception encoder and codec decoder. + +The vLLM engines manage their own KV caches via PagedAttention. Both +``VLLMLLM`` and ``VLLMEarTTS`` inherit from ``VLLMModelBase``, which wraps a +``CustomInputAsyncVLLMEngine`` (defined in ``inference.vllm.streaming_llm_engine``) +and provides: + +- An internal ``asyncio`` event loop for blocking synchronous calls +- Request lifecycle management (start, abort, restart) +- Automatic checkpoint conversion to vLLM format on first use + +``CustomInputAsyncVLLMEngine`` is a thin wrapper around vLLM's ``AsyncLLM`` +that adds support for custom input tensor specifications (multi-tensor +inputs like audio embeddings, subword IDs, speaker latents). The +``engine_kind`` parameter (``"llm"`` or ``"eartts"``) selects EarTTS-specific +runtime settings (TRITON attention backend, TF32 precision, guidance scale) +without introducing inheritance between TTS and LLM engine classes. + +This requires a custom vLLM fork with NemotronVoiceChat model support: + +.. code-block:: bash + + pip install git+https://github.com/vklimkov-nvidia/vllm@vklimkov/voicechat + + +Configuration +------------- + +The streaming inference configuration is defined in +``examples/speechlm2/nemo_inference_pipelines/conf/s2s_streaming.yaml``. + +Key configuration groups: + +S2S Model Settings (``s2s``) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. list-table:: + :header-rows: 1 + :widths: 30 15 55 + + * - Parameter + - Default + - Description + * - ``model_path`` + - (required) + - Path to the NemotronVoiceChat HuggingFace checkpoint. + * - ``engine_type`` + - (required) + - ``native``, ``vllm_llm``, ``vllm_eartts``, or + ``vllm_llm_vllm_eartts``. + * - ``speaker_name`` + - ``null`` + - Registered speaker name (must match a speaker in the checkpoint). + * - ``system_prompt`` + - (required) + - Text injected into the LLM KV cache before audio streaming begins. + * - ``compute_dtype`` + - ``bfloat16`` + - Precision for LLM/embedding layers. + * - ``use_perception_cache`` + - ``true`` + - Cache-aware streaming for the perception encoder. + * - ``use_llm_cache`` + - ``true`` + - Use KV cache for incremental LLM decoding. + * - ``top_p`` + - ``0.5`` + - Top-p sampling threshold. + * - ``temperature`` + - ``0.3`` + - Sampling temperature. + * - ``repetition_penalty`` + - ``1.1`` + - Repetition penalty applied to previously generated tokens. + * - ``deterministic`` + - ``false`` + - Force deterministic mode (native engine only). + * - ``profile_timing`` + - ``false`` + - Insert ``torch.cuda.synchronize()`` around each stage for accurate + per-stage timing. Disabled by default to avoid GPU stalls. + +Streaming Settings (``streaming``) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. list-table:: + :header-rows: 1 + :widths: 30 15 55 + + * - Parameter + - Default + - Description + * - ``chunk_size_in_secs`` + - (required) + - Audio processed per inference step. Must be a multiple of 0.08 s. + * - ``buffer_size_in_secs`` + - (required) + - Sliding-window size passed to the perception encoder. + * - ``batch_size`` + - ``1`` + - Number of concurrent streams (currently only 1 is supported). + * - ``max_len`` + - ``8192`` + - Maximum number of frames per stream. + +Padding Settings (top-level) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +At most one of these may be set: + +.. list-table:: + :header-rows: 1 + :widths: 30 15 55 + + * - Parameter + - Default + - Description + * - ``pad_audio_to_sec`` + - ``null`` + - Pad each input to a fixed duration. + * - ``pad_silence_ratio`` + - ``null`` + - Append silence equal to this fraction of the original duration. + * - ``pad_audio_by_sec`` + - ``null`` + - Append a fixed number of extra seconds of silence. + + +Server Integration +------------------ + +The same ``generate_step()`` method used by ``run()`` can be called directly +from a custom server. It returns a list of ``GenerateStepOutput`` objects +(one per input frame) carrying the **incremental** audio and text produced +by this step — no need to diff against accumulated state: + +.. code-block:: python + + from nemo.collections.speechlm2.inference import GenerateStepOutput + + # 1. Init stream (empty audio so prefill completes before recording) + init_frame = Frame( + samples=torch.empty(0), + stream_id=stream_id, + is_first=True, is_last=False, + options=S2SRequestOptions(system_prompt=prompt, top_p=0.9), + ) + pipeline.generate_step([init_frame]) + # -> client can now start recording + + # 2. Stream audio chunks and consume incremental outputs + for i, chunk in enumerate(audio_source): + frame = Frame( + samples=chunk, + stream_id=stream_id, + is_first=False, is_last=(i == last), + ) + outputs = pipeline.generate_step([frame]) + for out in outputs: + send_to_client(out.audio, out.text, out.asr_text) + +Per-stream options (``system_prompt``, ``top_p``, ``temperature``, +``repetition_penalty``) are attached to the ``is_first`` frame via +``S2SRequestOptions``. Any field left as ``None`` falls back to the +pipeline-level YAML default through ``augment_with_defaults()``. + +.. _init-and-latency: + +Init and Latency +^^^^^^^^^^^^^^^^ + +When ``generate_step`` sees ``is_first``, it always runs stream +initialization (context creation, KV-cache prefill). If the frame also +carries audio, inference runs immediately after init in the same call. + +For **latency-sensitive** server deployments (real-time voice chat), +prefill can take hundreds of milliseconds or even multiple seconds. +Clients should send ``is_first`` with **empty audio**, wait for the +response confirming init is done, and only then start recording the +user's microphone. This prevents audio from queuing up during the +expensive prefill phase. + +For **batch/offline** usage (CLI ``run()``), there is no real-time +constraint. The first frame carries both ``is_first`` and real audio, +so init and first-chunk processing happen in one call with no extra +round-trip. + +The pipeline makes no distinction between these cases — it initializes +on ``is_first`` and processes whatever audio is present. The latency +trade-off is entirely the caller's choice. + + +Batch Size +---------- + +The pipeline currently supports ``batch_size=1`` (one stream at a time). + + +File Layout +----------- + +.. code-block:: text + + nemo/collections/speechlm2/inference/ + ├── __init__.py # Public exports + ├── factory/ + │ └── s2s_pipeline_builder.py # S2SPipelineBuilder + ├── pipelines/ + │ ├── s2s_pipeline_interface.py # Base: _state_pool, sessions + │ └── streaming_s2s_pipeline.py # StreamingS2SPipeline + ├── model_wrappers/ + │ ├── decode_state.py # StreamingDecodeState, InferenceStepResult + │ ├── nemotron_voicechat_inference_wrapper.py + │ ├── model_factory.py # Native / vLLM model interfaces + │ └── perception_cache.py # Perception cache + CUDA graphs + ├── streaming/ + │ ├── framing/ + │ │ └── s2s_request_options.py # S2SRequestOptions + │ └── state/ + │ ├── s2s_state.py # S2SStreamingState + │ └── s2s_context_manager.py # Slot-based decode-state lifecycle + ├── utils/ + │ ├── pipeline_utils.py # PipelineOutput, text helpers + │ └── audio_data.py # Manifest / folder loading + └── vllm/ # Optional vLLM engine backend diff --git a/examples/speechlm2/nemo_inference_pipelines/conf/s2s_streaming.yaml b/examples/speechlm2/nemo_inference_pipelines/conf/s2s_streaming.yaml new file mode 100644 index 000000000000..ef1177642a93 --- /dev/null +++ b/examples/speechlm2/nemo_inference_pipelines/conf/s2s_streaming.yaml @@ -0,0 +1,109 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Runtime args (overridden via CLI) +# audio_file: "/path/to/your/audio_or_directory" +audio_file: ??? +output_dir: ./generated + +# Input audio padding (set at most one; null = disabled) +pad_audio_to_sec: null # Pad each audio to this fixed duration (seconds) +pad_silence_ratio: null # Append silence = ratio * original duration (e.g. 0.2 = 20%) +pad_audio_by_sec: null # Append this fixed number of extra seconds of silence + +pipeline_type: s2s_streaming + +# S2S model block +s2s: + model_path: ??? + speaker_reference: null + speaker_name: null + engine_type: ??? # Engine type: 'native' or 'vllm_llm' or 'vllm_eartts' or 'vllm_llm_vllm_eartts' + vllm_llm_config: + model_path: ${s2s.model_path} # Inherits from s2s.model_path + max_model_len: 8192 # Maximum sequence length for vLLM + max_num_batched_tokens: 768 # Max tokens per forward pass (prefill chunk size) + gpu_memory_utilization: 0.35 # GPU memory utilization (0.0-1.0) + dtype: bfloat16 # Data type for vLLM inference + engine_path: null # Path to vLLM engine (null = auto-convert if needed) + pretrained_llm: ${s2s.model_path} + + vllm_tts_config: + model_path: ${s2s.vllm_llm_config.model_path} # Inherits from s2s.model_path + max_model_len: ${s2s.vllm_llm_config.max_model_len} + max_num_batched_tokens: ${s2s.vllm_llm_config.max_num_batched_tokens} + gpu_memory_utilization: ${s2s.vllm_llm_config.gpu_memory_utilization} + dtype: float32 # EarTTS requires float32 for proper audio quality (bfloat16 causes hallucinations) + engine_path: null + pretrained_llm: null + skip_tokenizer_init: true + + device: cuda + # ======================== + # Device Configuration + # ======================== + device_id: 0 # GPU device ID + compute_dtype: bfloat16 # Compute precision: 'bfloat16' for Ampere+, + # 'float16' for older GPUs + # 'float32' + # ======================== + # Inference settings + # ======================== + use_perception_cache: true # Enable cache-aware streaming for perception encoder + use_perception_cudagraph: true # Enable CUDA graph-accelerated perception encoder + use_llm_cache: true # Use KV cache for the STT LLM (DynamicCache or HybridMambaAttentionDynamicCache) + + # TTS speedup flags (default to false; enable to speed up native inference) + use_tts_torch_compile: false # Compile TTS backbone with torch.compile (mode='default') + use_tts_subword_cache: false # Cache CharAwareSubwordEncoder embeddings (skip backbone for repeated tokens) + + # sampling parameters. if set all to 1.0, it will be greedy decoding. + top_p: 0.5 + repetition_penalty: 1.1 + temperature: 0.3 + force_turn_taking: true + force_turn_taking_threshold: 40 + force_turn_taking_pad_window: 25 + + # Inference logit boosts (applied to model logits at inference time) + inference_user_pad_boost: 0.8 # Boost ASR pad logit + inference_user_bos_boost: null # Boost ASR BOS logit + inference_user_eos_boost: null # Boost ASR EOS logit + + system_prompt: ??? + tts_system_prompt: null # TTS system prompt - conditions TTS generation style + # Requires a checkpoint trained with individual TTS prompts + + # ======================== + # Precision & determinism + # ======================== + # These defaults match what was used during model training. + matmul_precision: medium # Matrix multiplication precision: highest, high, medium + allow_tf32: true # Allow TF32 for cuDNN and CUDA matmul (Ampere+ GPUs). + # Set to false for stricter float32 precision. + # Deterministic inference (native engine only). Ensures identical results across + # runs by disabling FlashAttention and forcing deterministic CUDA algorithms. + # Trade-offs: slower inference, might produce worse results than non-deterministic mode, + # since non-deterministic mode was used in training. + deterministic: false + +streaming: + input_sample_rate: 16000 # Audio sample rate in Hz + output_sample_rate: 22050 # Audio sample rate in Hz + batch_size: 1 # Number of audio frames per batch + att_context_size: [70,0] # Attention context size: [70,13],[70,6],[70,2],[70,0] + chunk_size_in_secs: ??? # Needs to be multiple of 80ms + buffer_size_in_secs: ??? # Audio buffer size in seconds (larger = more context, better quality) + request_type: frame # Type of request: frame, only frame is supported for cache-aware streaming + max_len: 8192 diff --git a/examples/speechlm2/nemo_inference_pipelines/s2s_streaming_infer.py b/examples/speechlm2/nemo_inference_pipelines/s2s_streaming_infer.py new file mode 100644 index 000000000000..63ec63d91bdc --- /dev/null +++ b/examples/speechlm2/nemo_inference_pipelines/s2s_streaming_infer.py @@ -0,0 +1,95 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +S2S Streaming Inference Client + +Usage: + python s2s_streaming_infer.py \ + audio_file=/path/to/audio_or_directory \ + s2s.model_path=/path/to/eartts_ckpt \ + s2s.speaker_reference=/path/to/speaker.wav \ + streaming.chunk_size_in_secs=0.08 \ + streaming.buffer_size_in_secs=5.6 +""" + +import hydra +from omegaconf import DictConfig + +from nemo.collections.asr.metrics.wer import word_error_rate +from nemo.collections.speechlm2.inference.factory.s2s_pipeline_builder import S2SPipelineBuilder +from nemo.collections.speechlm2.inference.utils.stepprogressbar import StepProgressBar +from nemo.collections.speechlm2.inference.utils.audio_data import ( + calculate_durations_incl_padding, + dump_output, + prepare_audio_data, +) +from nemo.collections.speechlm2.inference.utils.pipeline_utils import clean_pred_text +from nemo.utils import logging +from nemo.utils.timers import SimpleTimer + + +@hydra.main(config_path="./conf", config_name="s2s_streaming", version_base=None) +def main(cfg: DictConfig): + default_system_prompt = cfg.get("s2s", {}).get("system_prompt", None) + audio_filepaths, options, ground_truths = prepare_audio_data( + cfg.audio_file, default_system_prompt=default_system_prompt, sort_by_duration=False, + ) + logging.info(f"Found {len(audio_filepaths)} audio files to generate") + + pipeline = S2SPipelineBuilder.build_pipeline(cfg) + + progress_bar = StepProgressBar.from_audio_filepaths( + audio_filepaths, + chunk_size_in_secs=pipeline.chunk_size_in_secs, + pad_audio_to_sec=cfg.get("pad_audio_to_sec"), + pad_silence_ratio=cfg.get("pad_silence_ratio"), + pad_audio_by_sec=cfg.get("pad_audio_by_sec"), + ) + + timer = SimpleTimer() + timer.start() + output = pipeline.run(audio_filepaths, options=options, progress_bar=progress_bar) + timer.stop() + exec_dur = timer.total_sec() + logging.info(f"Generated {len(audio_filepaths)} files in {exec_dur:.2f}s") + + data_dur = sum(calculate_durations_incl_padding( + audio_filepaths, cfg.get("pad_audio_to_sec"), cfg.get("pad_silence_ratio"), cfg.get("pad_audio_by_sec"), + )) + rtfx = data_dur / exec_dur if exec_dur > 0 else float('inf') + logging.info(f"RTFX: {rtfx:.2f} ({data_dur:.2f}s / {exec_dur:.2f}s)") + + # Compute WER when ground-truth texts are available (micro-average, + # matching the offline eval in speechlm2.parts.metrics.asr_cer_wer) + asr_texts = output.asr_texts_with_timestamps or [None] * len(audio_filepaths) + all_refs, all_hyps = [], [] + for gt, asr_text in zip(ground_truths, asr_texts): + if gt and asr_text: + cleaned_gt = clean_pred_text(gt) + cleaned_pred = clean_pred_text(asr_text) + if cleaned_gt.strip() and cleaned_pred.strip(): + all_refs.append(cleaned_gt) + all_hyps.append(cleaned_pred) + if all_refs: + wer = word_error_rate(hypotheses=all_hyps, references=all_refs) + logging.info(f"WER: {wer:.4f} ({wer * 100:.2f}%), n={len(all_refs)}") + + output_dir = cfg.get("output_dir", "./generated") + dump_output(audio_filepaths, output, output_dir, options, ground_truths) + logging.info(f"Transcriptions written to {output_dir}/output_processed.json and {output_dir}/output_raw.json") + + +if __name__ == "__main__": + main() diff --git a/nemo/collections/speechlm2/inference/__init__.py b/nemo/collections/speechlm2/inference/__init__.py new file mode 100644 index 000000000000..575d7a95e8bc --- /dev/null +++ b/nemo/collections/speechlm2/inference/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo.collections.speechlm2.inference.factory.s2s_pipeline_builder import S2SPipelineBuilder +from nemo.collections.speechlm2.inference.model_wrappers.decode_state import ( + InferenceStepResult, + StreamingDecodeState, +) +from nemo.collections.speechlm2.inference.pipelines.streaming_s2s_pipeline import GenerateStepOutput, StreamingS2SPipeline +from nemo.collections.speechlm2.inference.streaming.framing.s2s_request_options import S2SRequestOptions +from nemo.collections.speechlm2.inference.utils.pipeline_utils import PipelineOutput diff --git a/nemo/collections/speechlm2/inference/factory/__init__.py b/nemo/collections/speechlm2/inference/factory/__init__.py new file mode 100644 index 000000000000..9e3fb699d9f6 --- /dev/null +++ b/nemo/collections/speechlm2/inference/factory/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/speechlm2/inference/factory/s2s_pipeline_builder.py b/nemo/collections/speechlm2/inference/factory/s2s_pipeline_builder.py new file mode 100644 index 000000000000..c943ba8b14d8 --- /dev/null +++ b/nemo/collections/speechlm2/inference/factory/s2s_pipeline_builder.py @@ -0,0 +1,47 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from omegaconf.dictconfig import DictConfig + +from nemo.utils import logging as logger +from nemo.collections.speechlm2.inference.pipelines.streaming_s2s_pipeline import StreamingS2SPipeline +from nemo.collections.speechlm2.inference.model_wrappers.nemotron_voicechat_inference_wrapper import NemotronVoicechatInferenceWrapper + + +class S2SPipelineBuilder: + """Factory that builds a streaming S2S pipeline.""" + + @classmethod + def build_pipeline( + cls, + cfg: DictConfig + ) -> StreamingS2SPipeline: + """ + Build the streaming S2S pipeline based on the config. + Args: + cfg: (DictConfig) Config + Returns: + Returns StreamingS2SPipeline object + """ + s2s_model = NemotronVoicechatInferenceWrapper(model_cfg=cfg.s2s) + + logger.info(f"S2S model `{cfg.s2s.model_path}` loaded") + + pipeline = StreamingS2SPipeline( + cfg, + s2s_model, + ) + logger.info(f"`{type(pipeline).__name__}` pipeline loaded") + return pipeline + diff --git a/nemo/collections/speechlm2/inference/model_wrappers/__init__.py b/nemo/collections/speechlm2/inference/model_wrappers/__init__.py new file mode 100644 index 000000000000..9e3fb699d9f6 --- /dev/null +++ b/nemo/collections/speechlm2/inference/model_wrappers/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/speechlm2/inference/model_wrappers/backend/__init__.py b/nemo/collections/speechlm2/inference/model_wrappers/backend/__init__.py new file mode 100644 index 000000000000..ee2a19e2abd3 --- /dev/null +++ b/nemo/collections/speechlm2/inference/model_wrappers/backend/__init__.py @@ -0,0 +1,23 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo.collections.speechlm2.inference.model_wrappers.backend.interface import ModelInterface +from nemo.collections.speechlm2.inference.model_wrappers.backend.pytorch.model import PyTorchLLM +from nemo.collections.speechlm2.inference.model_wrappers.backend.pytorch.eartts import PyTorchEarTTS, TTSGenerationResult + +try: + from nemo.collections.speechlm2.inference.model_wrappers.backend.vllm.llm import VLLMLLM + from nemo.collections.speechlm2.inference.model_wrappers.backend.vllm.eartts import VLLMEarTTS +except ImportError: + pass # vLLM is an optional dependency diff --git a/nemo/collections/speechlm2/inference/model_wrappers/backend/interface.py b/nemo/collections/speechlm2/inference/model_wrappers/backend/interface.py new file mode 100644 index 000000000000..ba69eb34740a --- /dev/null +++ b/nemo/collections/speechlm2/inference/model_wrappers/backend/interface.py @@ -0,0 +1,255 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Base interface for S2S model inference backends. + +NemotronVoiceChat has two main inference components that can each run on +either native PyTorch or vLLM: + +- **LLM** (inside DuplexSTT): takes audio embeddings, produces text and ASR + tokens. Wrapped as ``model_llm_interface`` in the inference wrapper. +- **TTS** (EarTTS): takes text tokens, produces audio codec codes. + Wrapped as ``model_eartts_interface`` in the inference wrapper. + +This module defines the abstract ``ModelInterface`` that both backends +implement, with shared sampling utilities (top-p, repetition penalty, +temperature). +""" + +from abc import ABC, abstractmethod +from typing import Any +import math +import torch + +from nemo.utils import logging + + +class ModelInterface(ABC): + """ + Abstract base class for LLM and TTS inference backends. + + Concrete implementations wrap either the LLM component (DuplexSTT backbone + that predicts text/ASR tokens from audio embeddings) or the TTS component + (EarTTS that generates audio codec codes from text tokens). Each component + can run on native PyTorch or vLLM. + + Provides shared sampling utilities (top-p, repetition penalty, temperature) + and lifecycle methods (compile, prefill, subword cache) so the inference + wrapper can treat all backends uniformly. + """ + + def __init__( + self, + special_token_ids: set[int] | None = None, + top_p: float = 1.0, + repetition_penalty: float = 1.0, + temperature: float = 1.0, + ): + """ + Initialize base interface with sampling parameters. + + Args: + special_token_ids: Set of special token IDs (pad, eos, bos) that should bypass sampling. + These tokens will use greedy decoding and won't be penalized. + top_p: Top-p (nucleus) sampling threshold. 1.0 disables it (greedy). Default: 1.0 + repetition_penalty: Penalty for repeated tokens. 1.0 disables it. Default: 1.0 + temperature: Temperature for sampling. 1.0 = no change, <1.0 = sharper, >1.0 = flatter. + 0.0 = greedy (argmax). Default: 1.0 + """ + if not math.isfinite(temperature): + raise ValueError(f"temperature must be finite, got {temperature}") + if temperature < 0.0: + raise ValueError(f"temperature must be >= 0.0, got {temperature}") + + self.special_token_ids = special_token_ids if special_token_ids is not None else set() + self.top_p = top_p + self.repetition_penalty = repetition_penalty + self.temperature = temperature + + # Pre-built tensor for special-token filtering in repetition penalty. + # Lazily moved to the right device on first use (see _sample_text_token). + self._special_ids_tensor: torch.Tensor | None = ( + torch.tensor(sorted(self.special_token_ids), dtype=torch.long) + if self.special_token_ids else None + ) + + def _sample_text_token( + self, + logits: torch.Tensor, + generated_tokens: torch.Tensor, + current_step: int, + sampling_params: dict[str, float] | None = None, + ) -> torch.Tensor: + """ + Sample text tokens with optional top-p sampling and repetition penalty. + Special tokens (pad, eos, bos) bypass sampling - if they have highest probability, return them directly. + + Args: + logits: Logits tensor of shape (B, V) for vocabulary V. + generated_tokens: Previously generated tokens of shape (B, T). + current_step: Current decoding step (used to slice generated_tokens). + sampling_params: Optional per-request overrides for ``top_p``, + ``temperature``, and ``repetition_penalty``. Missing keys + fall back to ``self.*`` (the pipeline-level defaults). + + Returns: + Sampled token ids of shape (B,). + """ + top_p = sampling_params.get("top_p", self.top_p) if sampling_params else self.top_p + temperature = sampling_params.get("temperature", self.temperature) if sampling_params else self.temperature + rep_penalty = sampling_params.get("repetition_penalty", self.repetition_penalty) if sampling_params else self.repetition_penalty + + B, V = logits.shape + device = logits.device + + # First check greedy tokens (on original logits) + greedy_tokens = logits.argmax(dim=-1) # (B,) + + # If no sampling needed (all disabled), return greedy + if top_p >= 1.0 and rep_penalty == 1.0 and (temperature == 1.0 or temperature == 0.0): + return greedy_tokens + + # temperature=0 means greedy + if temperature == 0.0: + return greedy_tokens + + # For each batch, if greedy is special token, use greedy; otherwise sample + sampled_tokens = greedy_tokens.clone() + + # Ensure cached special-token tensor is on the right device (once). + if self._special_ids_tensor is not None and self._special_ids_tensor.device != device: + self._special_ids_tensor = self._special_ids_tensor.to(device) + + for b in range(B): + # If greedy token is a special token, keep it (no sampling) + if greedy_tokens[b].item() in self.special_token_ids: + continue + + # Not a special token - apply repetition penalty and sampling + batch_logits = logits[b].clone() # (V,) + + # Apply repetition penalty (vectorized, no Python loop) + if rep_penalty != 1.0 and current_step > 0: + unique_prev = generated_tokens[b, :current_step].unique() + # Exclude special tokens from penalty + if self._special_ids_tensor is not None: + ids_t = self._special_ids_tensor + if ids_t.device != unique_prev.device: + ids_t = ids_t.to(unique_prev.device) + unique_prev = unique_prev[~torch.isin(unique_prev, ids_t)] + + if unique_prev.numel() > 0: + if unique_prev.device != batch_logits.device: + unique_prev = unique_prev.to(batch_logits.device) + prev_logits = batch_logits[unique_prev] + # Positive logits are divided, negative logits are multiplied + # (same as the standard repetition_penalty convention) + batch_logits[unique_prev] = torch.where( + prev_logits > 0, + prev_logits / rep_penalty, + prev_logits * rep_penalty, + ) + + # Apply temperature scaling + if temperature != 1.0: + batch_logits = batch_logits / temperature + + # Fall back to greedy if logits are non-finite before top-p + # (top-p intentionally introduces -inf, so check must happen first) + if not torch.isfinite(batch_logits).all(): + logging.warning( + f"_sample_text_token: logits contain NaN or inf at step {current_step}, batch {b}: " + f"nan={batch_logits.isnan().sum().item()}, " + f"inf={batch_logits.isinf().sum().item()}, " + f"min={batch_logits[~batch_logits.isnan()].min().item() if not batch_logits.isnan().all() else 'all_nan'}, " + f"max={batch_logits[~batch_logits.isnan()].max().item() if not batch_logits.isnan().all() else 'all_nan'}" + ) + sampled_tokens[b] = greedy_tokens[b] + continue + + # Apply top-p (nucleus) sampling + if top_p < 1.0: + sorted_logits, sorted_indices = torch.sort(batch_logits, descending=True) + sorted_probs = torch.softmax(sorted_logits, dim=-1) + cumulative_probs = torch.cumsum(sorted_probs, dim=-1) + + # Remove tokens with cumulative prob > top_p, keeping at least one + sorted_indices_to_remove = cumulative_probs > top_p + # Shift to keep the first token that exceeds threshold + sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone() + sorted_indices_to_remove[0] = False + + # Set to -inf + indices_to_remove = sorted_indices[sorted_indices_to_remove] + batch_logits[indices_to_remove] = float('-inf') + + probs = torch.softmax(batch_logits, dim=-1) + sampled_tokens[b] = torch.multinomial(probs, num_samples=1).item() + + return sampled_tokens + + @abstractmethod + def __call__( + self, + input_embeds: torch.Tensor, + cache: Any | None = None, + **kwargs + ) -> dict[str, Any]: + """ + Perform model inference. + + Args: + input_embeds: Input embeddings tensor of shape [batch, seq_len, hidden_dim] + cache: Optional cache object (e.g., DynamicCache for transformers) + **kwargs: Additional model-specific arguments + + Returns: + Dictionary containing: + - 'text_logits': Logits for text generation [batch, seq_len, vocab_size] + - 'cache': Updated cache object (if cache was provided) + - Additional model-specific outputs + """ + pass + + @abstractmethod + def to(self, device_or_dtype: torch.device | torch.dtype) -> 'ModelInterface': + """Move model to specified device or convert to specified dtype.""" + pass + + @abstractmethod + def eval(self) -> 'ModelInterface': + """Set model to evaluation mode.""" + pass + + @property + @abstractmethod + def device(self) -> torch.device: + """Get the device of the model.""" + pass + + def compile(self, **kwargs) -> None: + """Apply torch.compile optimizations. No-op by default; override in subclasses.""" + pass + + def setup_subword_cache(self, cfg) -> None: + """Enable TTS subword embedding cache. No-op by default; override in subclasses.""" + pass + + def prefill_prompt(self, embeddings, **kwargs): + """Prefill the model with prompt embeddings before streaming begins. + + Override in subclasses to implement engine-specific prefill logic. + """ + raise NotImplementedError(f"{type(self).__name__} does not implement prefill_prompt") diff --git a/nemo/collections/speechlm2/inference/model_wrappers/backend/pytorch/__init__.py b/nemo/collections/speechlm2/inference/model_wrappers/backend/pytorch/__init__.py new file mode 100644 index 000000000000..9e3fb699d9f6 --- /dev/null +++ b/nemo/collections/speechlm2/inference/model_wrappers/backend/pytorch/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/speechlm2/inference/model_wrappers/backend/pytorch/eartts.py b/nemo/collections/speechlm2/inference/model_wrappers/backend/pytorch/eartts.py new file mode 100644 index 000000000000..fe4c4f6df4f9 --- /dev/null +++ b/nemo/collections/speechlm2/inference/model_wrappers/backend/pytorch/eartts.py @@ -0,0 +1,134 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Native PyTorch backend for the TTS (EarTTS) component of NemotronVoiceChat. + +Wraps the DuplexEARTTS model for direct PyTorch inference, providing the +same ``ModelInterface`` API as ``VLLMEarTTS`` so the inference wrapper can +treat both backends uniformly. + +Used as ``model_eartts_interface`` in the inference wrapper when +``engine_type="native_eartts"``. +""" + +from typing import Any +from dataclasses import dataclass, fields +import torch + +from nemo.utils import logging +from nemo.collections.speechlm2.inference.model_wrappers.backend.interface import ModelInterface + + +@dataclass +class TTSGenerationResult: + """Result from a single TTS generation step (shared by PyTorch and vLLM backends).""" + codes: torch.Tensor # Generated acoustic tokens + past_key_values: Any # Updated cache (if applicable) + + def __getitem__(self, item: str | int): + """Allows for accessing attributes by key or index.""" + if isinstance(item, str): + return getattr(self, item) + else: + return getattr(self, fields(self)[item].name) + + +class PyTorchEarTTS(ModelInterface): + """ + Native PyTorch backend for the TTS (EarTTS) component. + + Wraps ``DuplexEARTTS.infer_codes_one_step()`` for per-frame audio codec + generation, conforming to the ModelInterface contract. + """ + + def __init__(self, tts_model): + """ + Args: + tts_model: A ``DuplexEARTTS`` instance (``NemotronVoiceChat.tts_model``). + """ + super().__init__() + self.tts_model = tts_model + + def __call__(self, inputs: dict, **kwargs) -> TTSGenerationResult: + """ + Run one TTS code-generation step via ``infer_codes_one_step``. + + Args: + inputs: Keyword arguments for ``DuplexEARTTS.infer_codes_one_step`` + (current_subword_id, prev_subword_id, current_subword_mask, + prev_audio_tokens, past_key_values, guidance_enabled, etc.) + + Returns: + TTSGenerationResult with generated codes and updated cache. + """ + codes, cache = self.tts_model.infer_codes_one_step(**inputs) + return TTSGenerationResult(codes=codes, past_key_values=cache) + + def prefill_prompt(self, init_inputs, prompt_token_ids=None, request_id=None, **kwargs): + """Prefill TTS with speaker embedding / warmup inputs. + + For native PyTorch, this calls the inner ``tts_model`` directly + (the actual EarTTS nn.Module, not the DuplexEARTTS wrapper). + + Args: + init_inputs: Dict of initial TTS inputs from ``get_init_inputs()``. + prompt_token_ids: Unused for native (vLLM-only parameter). + request_id: Unused for native (vLLM-only parameter). + + Returns: + Model outputs (with ``past_key_values`` and ``codes``). + """ + return self.tts_model.tts_model(**init_inputs) + + def compile(self, **kwargs) -> None: + """Apply torch.compile to the TTS backbone if available.""" + tts_backbone = getattr(self.tts_model, 'tts_model', None) + if tts_backbone is not None and hasattr(tts_backbone, 'backbone'): + mode = kwargs.get('mode', 'default') + logging.info(f"Compiling TTS backbone with torch.compile(mode='{mode}')...") + tts_backbone.backbone = torch.compile(tts_backbone.backbone, mode=mode) + logging.info(" TTS backbone compiled") + + def setup_subword_cache(self, cfg) -> None: + """Enable TTS subword embedding cache from config flags.""" + from omegaconf import OmegaConf + + tts_inner = getattr(self.tts_model, 'tts_model', None) + if tts_inner is None or not hasattr(tts_inner, 'config'): + return + if bool(cfg.get("use_tts_subword_cache", False)): + OmegaConf.update(tts_inner.config, "use_tts_subword_cache", True) + logging.info("TTS speedup enabled: use_tts_subword_cache") + embed_subword = getattr(tts_inner, 'embed_subword', None) + if embed_subword is not None and hasattr(embed_subword, 'use_tts_subword_cache'): + embed_subword.use_tts_subword_cache = True + + def to(self, device_or_dtype: torch.device | torch.dtype) -> 'PyTorchEarTTS': + """Move underlying TTS model to device or convert dtype.""" + self.tts_model = self.tts_model.to(device_or_dtype) + return self + + def eval(self) -> 'PyTorchEarTTS': + """Set underlying TTS model to eval mode.""" + self.tts_model.eval() + return self + + @property + def device(self) -> torch.device: + """Get device of the underlying TTS model.""" + try: + return next(self.tts_model.parameters()).device + except StopIteration: + return torch.device('cpu') diff --git a/nemo/collections/speechlm2/inference/model_wrappers/backend/pytorch/model.py b/nemo/collections/speechlm2/inference/model_wrappers/backend/pytorch/model.py new file mode 100644 index 000000000000..d45a1683cb43 --- /dev/null +++ b/nemo/collections/speechlm2/inference/model_wrappers/backend/pytorch/model.py @@ -0,0 +1,235 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Native PyTorch backend for the LLM component of NemotronVoiceChat. + +Wraps the DuplexSTT model (which contains the Nemotron LLM backbone) for +direct PyTorch inference with top-p sampling and repetition penalty support. + +Used as ``model_llm_interface`` in the inference wrapper when +``engine_type="native_llm"``. +""" + +from typing import Any +import torch + +from nemo.utils import logging +from nemo.collections.speechlm2.inference.model_wrappers.backend.interface import ModelInterface + + +class PyTorchLLM(ModelInterface): + """ + Native PyTorch backend for the LLM (DuplexSTT) component. + + Wraps the DuplexSTT model's forward pass (``stt_model()``) to produce + text/ASR token predictions, conforming to the ModelInterface contract. + Supports top-p sampling and repetition penalty. + """ + + def __init__( + self, + model, + special_token_ids: set[int] | None = None, + top_p: float = 1.0, + repetition_penalty: float = 1.0, + temperature: float = 1.0, + ): + """ + Initialize with an existing model. + + Args: + model: The DuplexS2SExternalSpeechDecoderModel instance + special_token_ids: Set of special token IDs (pad, eos, bos) that should bypass sampling. + These tokens will use greedy decoding and won't be penalized. + If None, will try to extract from model.tokenizer for tokens: + '' (bos), '' (eos), '' (pad). + You can also manually provide: {tokenizer.pad_token_id, tokenizer.eos_token_id, tokenizer.bos_token_id} + top_p: Top-p (nucleus) sampling threshold. 1.0 disables it (greedy). Default: 1.0 + repetition_penalty: Penalty for repeated tokens. 1.0 disables it. Default: 1.0 + Recommended value when enabling: 1.2 + temperature: Temperature for sampling. 1.0 = no change, <1.0 = sharper, >1.0 = flatter. + 0.0 = greedy (argmax). Default: 1.0 + """ + DEFAULT_SPECIAL_TOKEN_IDS = {1, 2, 12} + + if special_token_ids is None: + special_token_ids = self._extract_special_token_ids_from_nemo(model) + if not special_token_ids: + special_token_ids = DEFAULT_SPECIAL_TOKEN_IDS + + super().__init__( + special_token_ids=special_token_ids, + top_p=top_p, + repetition_penalty=repetition_penalty, + temperature=temperature, + ) + + self.model = model + + logging.debug(f"Special token IDs: {self.special_token_ids}") + + sampling_active = top_p < 1.0 or repetition_penalty != 1.0 or (temperature != 1.0 and temperature != 0.0) + if sampling_active and not self.special_token_ids: + import warnings + warnings.warn( + "Sampling is enabled but special_token_ids is empty. " + "Could not auto-extract from model.tokenizer. " + "Please provide special_token_ids manually to ensure special tokens use greedy decoding. " + "Otherwise, EOS tokens may be randomly sampled and generation may not stop properly!" + ) + + def __call__( + self, + input_embeds: torch.Tensor, + cache: Any | None = None, + cache_position: torch.Tensor | None = None, + generated_tokens: torch.Tensor | None = None, + current_step: int = 0, + return_logits: bool = False, + sampling_params: dict[str, float] | None = None, + **kwargs + ) -> dict[str, Any]: + """ + Perform inference using the native model. + + Args: + input_embeds: Input embeddings [batch, seq_len, hidden_dim] + cache: Optional DynamicCache or HybridMambaAttentionDynamicCache + cache_position: Optional position tensor for Nemotron models + generated_tokens: Previously generated tokens [batch, num_generated]. + Required for repetition_penalty. If None, creates empty tensor. + current_step: Current decoding step. Used for repetition penalty. + sampling_params: Optional per-request overrides for sampling + (top_p, temperature, repetition_penalty). + **kwargs: Additional arguments passed to the model + + Returns: + Dictionary with 'predicted_token', 'asr_predicted_token', and 'cache' + """ + result = self.model.stt_model(input_embeds, cache=cache, cache_position=cache_position, **kwargs) + + if not isinstance(result, dict): + raise TypeError(f"Model returned {type(result)}, expected dict") + + if 'text_logits' not in result: + raise KeyError("Model output must contain 'text_logits' key") + + text_logits = result["text_logits"][:, -1] # [batch, vocab_size] + batch_size = text_logits.shape[0] + + if generated_tokens is None: + gen_tokens = torch.empty(batch_size, 0, device=text_logits.device, dtype=torch.long) + else: + gen_tokens = generated_tokens + + predicted_token = self._sample_text_token( + logits=text_logits, + generated_tokens=gen_tokens, + current_step=current_step, + sampling_params=sampling_params, + ) + + # ASR tokens use greedy decoding (no sampling) + asr_predicted_token = result["asr_logits"][:, -1].argmax(dim=-1) + + ans = { + "predicted_token": predicted_token, + "asr_predicted_token": asr_predicted_token, + "cache": result.get("cache", None), + } + if return_logits: + ans["text_logits"] = result["text_logits"] + ans["asr_logits"] = result.get("asr_logits") + if "function_logits" in result: + ans["function_logits"] = result["function_logits"] + if "function_logits" in result: + ans["function_predicted_token"] = result["function_logits"][:, -1].argmax(dim=-1) + return ans + + @staticmethod + def _extract_special_token_ids_from_nemo(model) -> set[int]: + """ + Extract special token IDs from NeMo model's tokenizer. + + NeMo tokenizer uses bos_token, eos_token, pad_token (not bos_token_id). + Then converts token strings to IDs using token_to_id method. + + Args: + model: The DuplexS2SExternalSpeechDecoderModel instance + + Returns: + Set of special token IDs, or empty set if extraction fails + """ + special_ids = set() + try: + tokenizer = model.stt_model.tokenizer + except AttributeError: + logging.debug("Cannot extract special token IDs: model has no stt_model.tokenizer") + return special_ids + + for attr in ('bos_token', 'eos_token', 'pad_token'): + token = getattr(tokenizer, attr, None) + if token is not None and hasattr(tokenizer, 'token_to_id'): + tid = tokenizer.token_to_id(token) + if tid is not None and isinstance(tid, int): + special_ids.add(tid) + + return special_ids + + def to(self, device_or_dtype: torch.device | torch.dtype) -> 'PyTorchLLM': + """Move underlying model to device or convert dtype.""" + self.model = self.model.to(device_or_dtype) + return self + + def eval(self) -> 'PyTorchLLM': + """Set underlying model to eval mode.""" + self.model.eval() + return self + + @property + def device(self) -> torch.device: + """Get device of the underlying model.""" + try: + return next(self.model.parameters()).device + except StopIteration: + return torch.device('cpu') + + def prefill_prompt(self, embeddings, cache=None, cache_position=None, **kwargs): + """Prefill the native LLM with prompt embeddings to warm up the KV cache. + + Args: + embeddings: Prompt embeddings [batch, seq_len, hidden_dim]. + cache: KV cache object to update in-place. + cache_position: Position tensor for the prompt tokens. + + Returns: + Dictionary with updated 'cache'. + """ + result = self.model.stt_model(embeddings, cache=cache, cache_position=cache_position, **kwargs) + if not isinstance(result, dict): + raise TypeError(f"Model returned {type(result)}, expected dict") + return {"cache": result.get("cache", cache)} + + def __getattr__(self, name: str): + """ + Delegate attribute access to the underlying model. + + This allows transparent access to model attributes like + perception, tokenizer, etc. + """ + if name in ('model', '__dict__', '__class__'): + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") + + return getattr(self.model, name) diff --git a/nemo/collections/speechlm2/inference/model_wrappers/backend/vllm/__init__.py b/nemo/collections/speechlm2/inference/model_wrappers/backend/vllm/__init__.py new file mode 100644 index 000000000000..9e3fb699d9f6 --- /dev/null +++ b/nemo/collections/speechlm2/inference/model_wrappers/backend/vllm/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/speechlm2/inference/model_wrappers/backend/vllm/base.py b/nemo/collections/speechlm2/inference/model_wrappers/backend/vllm/base.py new file mode 100644 index 000000000000..7f917eca2af0 --- /dev/null +++ b/nemo/collections/speechlm2/inference/model_wrappers/backend/vllm/base.py @@ -0,0 +1,311 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Base class for vLLM-backed inference backends. + +Provides shared vLLM engine lifecycle management (init, abort, restart, +shutdown) used by both the LLM backend (``VLLMLLM`` -- DuplexSTT text +token prediction) and the TTS backend (``VLLMEarTTS`` -- EarTTS audio +codec generation). +""" + +from abc import abstractmethod +from typing import Any +import os +import torch + +from nemo.utils import logging +from nemo.collections.speechlm2.inference.model_wrappers.backend.interface import ModelInterface + + +class VLLMModelBase(ModelInterface): + """ + Base class for vLLM-backed model interfaces. + + Wraps a CustomInputAsyncVLLMEngine to provide streaming inference while + conforming to the ModelInterface contract. Supports two usage modes: + + 1. **Blocking component** (default): Call the model synchronously via + ``__call__``. The async engine runs on an internal event loop. + This is the mode used by the S2S streaming inference pipeline + (``StreamingS2SPipeline`` / ``NemotronVoicechatInferenceWrapper``). + + 2. **Async standalone server**: Use ``_async_inference()`` directly + from your own async event loop for concurrent multi-stream serving + (e.g., a WebSocket server handling multiple streams concurrently). + + Subclasses must implement: + - ``_convert_ckpt(save_path)``: checkpoint conversion to vLLM format + - ``__call__(...)``: synchronous inference entry point + - ``_process_inputs_to_outputs(...)``: async core inference logic + + Requires vLLM from https://github.com/vklimkov-nvidia/vllm (branch vklimkov/voicechat). + """ + + def __init__( + self, + model_path: str, + max_model_len: int = 1024, + max_num_batched_tokens: int = 768, + gpu_memory_utilization: float = 0.8, + trust_remote_code: bool = True, + dtype: str = "bfloat16", + engine_path: str | None = None, + pretrained_llm: str | None = None, + special_token_ids: set[int] | None = None, + top_p: float = 1.0, + repetition_penalty: float = 1.0, + temperature: float = 1.0, + model_type: str = "llm", + **sampling_kwargs + ): + """ + Initialize vLLM model with engine creation, event loop setup, and warm-up. + + Args: + model_path: Path to the vLLM-compatible model checkpoint + max_model_len: Maximum sequence length + max_num_batched_tokens: Maximum tokens per vLLM forward pass. + Controls prefill chunk size and max concurrent decode streams. + gpu_memory_utilization: GPU memory utilization ratio (0.0-1.0) + trust_remote_code: Whether to trust remote code in model + dtype: Data type for embeddings (e.g., "bfloat16", "float16") + engine_path: Optional path to pre-converted vLLM model + pretrained_llm: Optional path to pretrained LLM for conversion + special_token_ids: Set of special token IDs (pad, eos, bos) that bypass + sampling and always use greedy decoding. + top_p: Top-p (nucleus) sampling threshold. Default: 1.0 (disabled). + repetition_penalty: Penalty for repeated tokens. Default: 1.0 (disabled). + temperature: Sampling temperature. Default: 1.0 (no scaling). + model_type: Type of model for vLLM engine ("llm", "eartts", etc.) + **sampling_kwargs: Additional vLLM sampling parameters. + + Note: + vLLM internally runs greedy decoding (temperature=0, ignore_eos=True). + Text sampling (top_p, repetition_penalty, temperature) is applied + post-hoc by ``ModelInterface._sample_text_token`` on the logits + returned by vLLM, not by vLLM's own sampler. + """ + super().__init__( + special_token_ids=special_token_ids, + top_p=top_p, + repetition_penalty=repetition_penalty, + temperature=temperature, + ) + + import asyncio + from nemo.collections.speechlm2.inference.vllm.streaming_llm_engine import create_engine + + self.model_path = model_path + self.pretrained_llm = pretrained_llm + self._dtype = dtype + self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Force greedy decoding in vLLM by setting temperature=0 if not specified + if 'temperature' not in sampling_kwargs: + sampling_kwargs['temperature'] = 0.0 + + if engine_path is None: + dir_name = os.path.basename(os.path.normpath(model_path)) + engine_path = "/tmp/" + dir_name + f"_vllm_converted_{model_type}" + if os.path.exists(engine_path): + logging.info(f"Found existing vLLM converted model at {engine_path}") + else: + self._convert_ckpt(save_path=engine_path) + + self.engine = create_engine( + engine_type=model_type, + model_path=engine_path, + max_model_len=max_model_len, + max_num_batched_tokens=max_num_batched_tokens, + gpu_memory_utilization=gpu_memory_utilization, + trust_remote_code=trust_remote_code, + dtype=dtype, + **sampling_kwargs + ) + self._request_counter = 0 + + try: + self._loop = asyncio.get_event_loop() + except RuntimeError: + self._loop = asyncio.new_event_loop() + asyncio.set_event_loop(self._loop) + + logging.info("Initializing vLLM engine (this may take a moment)...") + self._loop.run_until_complete(self.engine.initialize()) + + if self.engine.engine.tokenizer is not None and not self.special_token_ids: + self.special_token_ids = self._get_special_token_ids_from_vllm_tokenizer(self.engine.engine.tokenizer) + + logging.debug(f"Special token IDs: {self.special_token_ids}") + logging.info("vLLM engine ready!") + + @abstractmethod + def _convert_ckpt(self, save_path: str): + """Convert existing checkpoint to vLLM format and save.""" + pass + + @staticmethod + def _get_special_token_ids_from_vllm_tokenizer(tokenizer) -> set[int]: + """ + Extract special token IDs from a vLLM tokenizer. + Looks for: '' (bos), '' (eos), '' (pad). + + Args: + tokenizer: A vLLM CachedTokenizer instance. + + Returns: + Set of special token IDs. + """ + special_ids = set() + for token in ('', '', ''): + try: + tid = tokenizer.convert_tokens_to_ids(token) + if isinstance(tid, int): + special_ids.add(tid) + except (KeyError, AttributeError): + logging.debug(f"Token '{token}' not found in vLLM tokenizer, skipping") + return special_ids + + def _generate_request_id(self) -> str: + """Generate a unique request ID.""" + self._request_counter += 1 + return f"vllm_request_{self._request_counter}" + + async def _async_inference( + self, + inputs: torch.Tensor | list[torch.Tensor] | dict, + request_id: str, + **kwargs + ) -> dict[str, Any]: + """ + Async inference using the streaming engine. + + Checks request status (starting or restarting as needed) and + delegates to the subclass ``_process_inputs_to_outputs``. + + Args: + inputs: Model inputs (tensor for LLM, dict for EarTTS) + request_id: Unique request identifier + **kwargs: Passed through to ``_process_inputs_to_outputs`` + + Returns: + Dictionary with model-specific outputs + """ + from nemo.collections.speechlm2.inference.vllm.streaming_llm_engine import StreamStatus + + if request_id not in self.engine.requests: + await self.engine.start_generation(request_id=request_id) + else: + request_state = self.engine.requests[request_id] + if request_state.status in (StreamStatus.FINISHED, StreamStatus.ABORTED): + logging.warning( + f"Request {request_id} was {request_state.status.value}. " + f"Generated {len(request_state.generated_tokens)} tokens before stopping. " + "Cleaning up and restarting..." + ) + try: + await self.engine.abort_generation(request_id) + except Exception: + # The request already finished/aborted; abort_generation + # is just releasing engine-side resources. If the engine + # already purged it, the call may raise -- harmless since + # we start a fresh generation immediately after. + pass + await self.engine.start_generation(request_id=request_id) + + return await self._process_inputs_to_outputs(inputs, request_id, **kwargs) + + @abstractmethod + async def _process_inputs_to_outputs(self, inputs, request_id: str, **kwargs) -> dict[str, Any]: + """Process model inputs and return outputs. Subclasses must implement.""" + pass + + def to(self, device_or_dtype: torch.device | torch.dtype) -> 'VLLMModelBase': + """ + Move model to specified device or convert to specified dtype. + + Note: vLLM manages device placement internally, this is for compatibility. + """ + if isinstance(device_or_dtype, torch.device): + self._device = device_or_dtype + return self + + def eval(self) -> 'VLLMModelBase': + """Set model to evaluation mode (vLLM is always in eval mode).""" + return self + + @property + def device(self) -> torch.device: + """Get the device of the model.""" + return self._device + + def abort_request(self, request_id: str) -> bool: + """ + Abort a specific generation request. + + Args: + request_id: Request identifier to abort + + Returns: + bool: True if abort was successful + """ + return self._loop.run_until_complete( + self.engine.abort_generation(request_id) + ) + + def restart_request(self, request_id: str) -> bool: + """ + Restart a finished or aborted generation request. + + Args: + request_id: Request identifier to restart + + Returns: + bool: True if restart was successful + """ + if request_id in self.engine.requests: + self.abort_request(request_id) + + return self._loop.run_until_complete( + self.engine.start_generation(request_id=request_id) + ) + + def get_request_status(self, request_id: str | None = None) -> dict[str, Any]: + """ + Get status of a specific request or all requests. + + Args: + request_id: Optional request ID. If None, returns all requests. + + Returns: + Status dictionary + """ + return self.engine.get_status(request_id) + + def shutdown(self): + """Shutdown the vLLM engine and cleanup resources.""" + self._loop.run_until_complete(self.engine.shutdown()) + + def __del__(self): + """Cleanup on deletion.""" + try: + self.shutdown() + except Exception: + # __del__ may run during interpreter shutdown when globals + # (self._loop, self.engine, asyncio) are already torn down. + # Nothing useful to do with the error; suppress to avoid + # noisy tracebacks on exit. + pass diff --git a/nemo/collections/speechlm2/inference/model_wrappers/backend/vllm/eartts.py b/nemo/collections/speechlm2/inference/model_wrappers/backend/vllm/eartts.py new file mode 100644 index 000000000000..2f344b0aa9d6 --- /dev/null +++ b/nemo/collections/speechlm2/inference/model_wrappers/backend/vllm/eartts.py @@ -0,0 +1,191 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +vLLM backend for the TTS (EarTTS) component of NemotronVoiceChat. + +EarTTS generates audio codec codes from text tokens: given the current +subword ID, mask, previous audio codes, and speaker latent, it predicts +the next frame of RVQ acoustic tokens. This module wraps it in a +CustomInputAsyncVLLMEngine for accelerated inference. + +Used as ``model_eartts_interface`` in the inference wrapper. +""" + +from typing import Any +import os +import torch + +from nemo.utils import logging +from nemo.collections.speechlm2.inference.model_wrappers.backend.vllm.base import VLLMModelBase +from nemo.collections.speechlm2.inference.model_wrappers.backend.pytorch.eartts import TTSGenerationResult + + +class VLLMEarTTS(VLLMModelBase): + """ + vLLM backend for the TTS (EarTTS) component. + + Accepts dictionary inputs with codes, subword IDs, masks, and speaker + latents, and returns ``TTSGenerationResult`` with generated acoustic tokens. + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._speaker_latent_dim = None + logging.info("VLLMEarTTS initialized with EARTTS-specific settings.") + + def _convert_ckpt(self, save_path: str): + """Convert EARTTS checkpoint to vLLM format.""" + from nemo.collections.speechlm2.inference.vllm.scripts.convert_eartts_checkpoint import convert + ckpt_dir = os.path.normpath(self.model_path) + config_file = os.path.join(ckpt_dir, "config.json") + model_ckpt = os.path.join(ckpt_dir, "model.safetensors") + convert(save_path, config_file, model_ckpt) + + def __call__( + self, + inputs: dict[str, torch.Tensor] | None = None, + request_id: str | None = None, + prompt_token_ids: list | None = None, + **kwargs + ) -> TTSGenerationResult: + """ + Perform TTS inference using vLLM streaming engine. + + Supports two calling conventions: + 1. model(inputs_dict, request_id="id") - pass dict as first positional arg + 2. model(**inputs_dict) - unpack dict as keyword arguments + + Args: + inputs: Optional dict of model inputs (if None, uses **kwargs) + request_id: Optional request identifier + prompt_token_ids: Optional list of prompt token IDs for prefill + **kwargs: Model inputs as keyword arguments (used if inputs is None) + + Returns: + TTSGenerationResult containing generated acoustic tokens and cache + """ + if inputs is not None: + input_dict = inputs + else: + if request_id is None: + request_id = kwargs.pop('request_id', None) + input_dict = kwargs + + if request_id is None: + request_id = 'tts_request_id_1' + + result = self._loop.run_until_complete( + self._async_inference(input_dict, request_id, prompt_token_ids=prompt_token_ids) + ) + + return result + + async def _process_inputs_to_outputs( + self, + inputs: dict[str, torch.Tensor], + request_id: str, + prompt_token_ids: list | None = None, + ) -> dict[str, Any]: + """ + Process tensor inputs to generate acoustic tokens via vLLM engine. + + Args: + inputs: Dictionary with code, context_hidden_state, subword_ids, + subword_mask, and optional non_prompt_mask / audio_prompt_lantent. + request_id: Request identifier. + prompt_token_ids: Optional prompt token IDs for prefill. + + Returns: + TTSGenerationResult with generated acoustic tokens. + """ + + assert inputs["context_hidden_state"] is None, "EARTTS vllm model does not support context_hidden_state input" + + codes = inputs["code"].squeeze(0) # T x 31 + if codes.shape[0] > 1: + # In prefill stage, shift acoustic tokens for vLLM to replicate + # the NeMo logic for teacher-forced input construction. + codes = torch.nn.functional.pad(codes[:-1], [0, 0, 1, 0]) + input_tensors = [ + codes, + inputs["subword_ids"].squeeze(0), + inputs["subword_mask"].squeeze(0), + ] + + if "non_prompt_mask" in inputs: + # Apply edge detection to match native model's BOS placement logic: + # BOS should only be applied at the FIRST position where non_prompt_mask is True + non_prompt_mask = inputs["non_prompt_mask"].squeeze(0) # T + padded_prev = torch.nn.functional.pad(non_prompt_mask[:-1], [1, 0], value=False) + bos_mask = (non_prompt_mask & (~padded_prev)).to(dtype=getattr(torch, self._dtype)) + input_tensors.append(bos_mask) + + else: + current_subword_id = input_tensors[1] + # Use a tiny epsilon instead of exact 0 so the vLLM model's + # (bos_mask == 0) check is False during decoding. This prevents + # use_audio_prompt_frozen_projection from incorrectly applying the + # speaker-prompt projection to every decoding step. The epsilon is + # small enough that bos_mask * bos_emb remains negligible. + bos_mask = torch.full_like(current_subword_id, 1e-20, dtype=getattr(torch, self._dtype)) + input_tensors.append(bos_mask) + + # Pass speaker_latent: the pre-extracted speaker embedding. + # During prefill with speaker_name: audio_prompt_lantent is [1, T, hidden_size] + # During decode or speaker_reference: pass zeros so the model falls back + # to computing the latent from acoustic tokens. + if "audio_prompt_lantent" in inputs and inputs["audio_prompt_lantent"] is not None: + speaker_latent = inputs["audio_prompt_lantent"].squeeze(0) # T x hidden_size + self._speaker_latent_dim = speaker_latent.shape[-1] + input_tensors.append(speaker_latent.to(dtype=getattr(torch, self._dtype))) + else: + if self._speaker_latent_dim is None: + import json as _json + dir_name = os.path.basename(os.path.normpath(self.model_path)) + converted_config_path = os.path.join("/tmp", dir_name + "_vllm_converted_eartts", "config.json") + if os.path.exists(converted_config_path): + with open(converted_config_path) as _f: + self._speaker_latent_dim = _json.load(_f)["hidden_size"] + else: + raise RuntimeError( + f"Cannot determine speaker_latent_dim: converted config not found at {converted_config_path}. " + "Run a prefill with audio_prompt_lantent first, or ensure the converted checkpoint exists." + ) + num_tokens = codes.shape[0] + speaker_latent = torch.zeros(num_tokens, self._speaker_latent_dim, dtype=getattr(torch, self._dtype)) + input_tensors.append(speaker_latent) + + result = await self.engine.generate_next_token(input_tensors, prompt_token_ids=prompt_token_ids, request_id=request_id) + acoustic_tokens = result.custom_outputs["acoustic_tokens"] # T x 31 + step_acoustic_tokens = acoustic_tokens[-1:] # 1 x 31 + return TTSGenerationResult( + codes=step_acoustic_tokens.unsqueeze(0).cuda(), # Add batch dim back: 1 x 1 x 31 + past_key_values=None # vLLM manages cache internally + ) + + def prefill_prompt(self, init_inputs, prompt_token_ids=None, request_id=None, **kwargs): + """Prefill vLLM EarTTS engine with speaker embedding context. + + Args: + init_inputs: Dict of initial TTS inputs (codes, subword_ids, etc.). + prompt_token_ids: List of prompt token IDs for the speaker embedding. + request_id: Unique request identifier. + + Returns: + TTSGenerationResult with codes from the prefill step, or None. + """ + import copy + inputs_copy = copy.deepcopy(init_inputs) + return self(inputs_copy, request_id=request_id, prompt_token_ids=prompt_token_ids) diff --git a/nemo/collections/speechlm2/inference/model_wrappers/backend/vllm/llm.py b/nemo/collections/speechlm2/inference/model_wrappers/backend/vllm/llm.py new file mode 100644 index 000000000000..c9d8db7c5303 --- /dev/null +++ b/nemo/collections/speechlm2/inference/model_wrappers/backend/vllm/llm.py @@ -0,0 +1,167 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +vLLM backend for the LLM component of NemotronVoiceChat. + +The LLM lives inside DuplexSTT: it takes audio frame embeddings (from the +perception encoder) and produces text tokens, ASR tokens, and optional +function-call tokens at each step. This module wraps it in a +CustomInputAsyncVLLMEngine for accelerated inference. + +Used as ``model_llm_interface`` in the inference wrapper. +""" + +from typing import Any +import torch + +from nemo.utils import logging +from nemo.collections.speechlm2.inference.model_wrappers.backend.vllm.base import VLLMModelBase + + +class VLLMLLM(VLLMModelBase): + """ + vLLM backend for the LLM (DuplexSTT) component. + + Accepts audio frame embeddings, runs them through the vLLM streaming + engine one step at a time, and returns text/ASR token predictions with + optional post-hoc sampling (top-p, repetition penalty). + """ + + def _convert_ckpt(self, save_path: str): + """Convert existing DuplexSTT checkpoint to vLLM-compatible HF format.""" + from nemo.collections.speechlm2.inference.vllm.scripts.convert_nemotronllm_checkpoint import convert_nemo_to_hf_format + + convert_nemo_to_hf_format( + checkpoint_path=self.model_path, + output_dir=save_path, + pretrained_llm=self.pretrained_llm, + dtype=self._dtype + ) + logging.info(f"Converted model saved to {save_path}") + + def __call__( + self, + input_embeds: torch.Tensor, + request_id: str | None = "request_id_1", + **kwargs + ) -> dict[str, Any]: + """ + Perform inference using vLLM streaming engine. + + Args: + input_embeds: Input embeddings [batch, seq_len, hidden_dim] + request_id: Unique request identifier for this generation + **kwargs: Additional arguments (decode_steps, generated_tokens, etc.) + + Returns: + Dictionary containing predicted_token, asr_predicted_token, cache, + is_finished, and request_id. + """ + result = self._loop.run_until_complete( + self._async_inference(input_embeds, request_id, **kwargs) + ) + return result + + async def _process_inputs_to_outputs( + self, + input_embeds: torch.Tensor, + request_id: str, + decode_steps: int = 1, + prompt_token_ids: list | None = None, + generated_tokens: torch.Tensor | None = None, + current_step: int = 0, + sampling_params: dict[str, float] | None = None, + ) -> dict[str, Any]: + """ + Process embeddings sequentially to generate text and ASR tokens. + + Args: + input_embeds: Input embeddings [batch, seq_len, hidden_dim] + request_id: Request identifier + decode_steps: Number of decoding steps to perform; 0 means prefill only + prompt_token_ids: Optional list of prompt token IDs for prefill + generated_tokens: Previously generated tokens [batch, num_generated]. + Required for repetition_penalty. If None, creates empty tensor. + current_step: Current decoding step. Used for repetition penalty. + sampling_params: Optional per-request overrides for sampling. + """ + + if decode_steps == 0: + input_embeds = input_embeds.flatten(0, 1) # [seq_len, hidden_dim] + result = await self.engine.generate_next_token([input_embeds], + prompt_token_ids, + request_id=request_id) + return True if result is not None else False + + text_token_ids = [] + asr_token_ids = [] + result = None + for i in range(decode_steps): + single_embed = input_embeds[:, i:i+1, :].squeeze(1) # [batch, hidden_dim] + + result = await self.engine.generate_next_token([single_embed], request_id=request_id) + if result is None: + break + + text_token_ids.append(result.token_id) + asr_token_ids.append(result.custom_outputs["asr_tokens"]) + + if result.is_finished: + break + + assert len(text_token_ids) <= decode_steps, "Generated more tokens than input embeddings" + is_finished = False + if text_token_ids: + is_finished = len(text_token_ids) < decode_steps or (result and result.is_finished) + + text_logits = result.custom_outputs["text_logits"] if result else None + + batch_size = text_logits.shape[0] + if generated_tokens is None: + gen_tokens = torch.empty(batch_size, 0, device=text_logits.device, dtype=torch.long) + else: + gen_tokens = generated_tokens + + predicted_token = self._sample_text_token( + logits=text_logits, + generated_tokens=gen_tokens, + current_step=current_step, + sampling_params=sampling_params, + ) + + ans = { + "predicted_token": predicted_token, + "asr_predicted_token": asr_token_ids[-1], + "cache": None, # vLLM manages cache internally + "is_finished": is_finished, + "request_id": request_id + } + if result and result.custom_outputs and "function_tokens" in result.custom_outputs: + ans["function_predicted_token"] = result.custom_outputs["function_tokens"] + return ans + + def prefill_prompt(self, embeddings: torch.Tensor, request_id: str = None, **kwargs) -> bool: + """Prefill vLLM LLM engine with prompt embeddings in a single bulk step. + + Args: + embeddings: Prompt embeddings [batch, seq_len, hidden_dim]. + request_id: Unique request identifier. + + Returns: + True if prefill succeeded. + """ + return self._loop.run_until_complete( + self._async_inference(embeddings, request_id, decode_steps=0) + ) diff --git a/nemo/collections/speechlm2/inference/model_wrappers/decode_state.py b/nemo/collections/speechlm2/inference/model_wrappers/decode_state.py new file mode 100644 index 000000000000..77f2840b9352 --- /dev/null +++ b/nemo/collections/speechlm2/inference/model_wrappers/decode_state.py @@ -0,0 +1,190 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""State and result types for streaming S2S inference. + +These dataclasses define the *model wrapper's* interface contract: + +* :class:`StreamingDecodeState` — mutable per-stream decode state + (KV caches, token workspaces, perception/codec caches). Created by + the wrapper, mutated in-place by ``infer_one_step``, and held between + steps by the pipeline's context manager. + +* :class:`InferenceStepResult` — immutable per-step outputs returned + by ``infer_one_step`` (predicted tokens, text strings, audio). + +Defined here (in ``model_wrappers/``) because the wrapper is the +component that knows what state it needs. The context manager and +pipeline import from here. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, TYPE_CHECKING + +import torch + +from nemo.utils import logging +from nemo.utils.timers import NamedTimer + +if TYPE_CHECKING: + from nemo.collections.speechlm2.inference.model_wrappers.perception_cache import PerceptionCacheState + + +class TimingSummary(NamedTimer): + """Accumulates per-stage wall-clock times across inference steps. + + Extends :class:`~nemo.utils.timers.NamedTimer` (``sync_cuda=True``) + with a :meth:`log_summary` that prints a compact min/mean/max table + at ``logging.info`` level once a stream finishes. + + Usage:: + + timing.start("perception") + # ... run perception ... + timing.stop("perception") + + timing.log_summary(label="Stream 0", chunk_ms=240) + """ + + def __init__(self): + super().__init__(reduction="none", sync_cuda=True) + + def log_summary(self, label: str = "Timing", chunk_ms: float | None = None) -> None: + header = f"{label} timing" + if chunk_ms is not None: + header += f" (chunk={chunk_ms:.0f}ms)" + parts = [] + for name, data in self.timers.items(): + times = data.get("dt", []) + if not times: + continue + mean_ms = sum(times) / len(times) * 1000 + min_ms = min(times) * 1000 + max_ms = max(times) * 1000 + parts.append(f"{name}: mean={mean_ms:.1f}ms min={min_ms:.1f}ms max={max_ms:.1f}ms") + if parts: + logging.info(f"{header}:\n " + "\n ".join(parts)) + + +class NullTimingSummary: + """No-op stand-in for :class:`TimingSummary`.""" + + def start(self, name: str = "") -> None: + pass + + def stop(self, name: str = "") -> None: + pass + + def log_summary(self, label: str = "Timing", chunk_ms: float | None = None) -> None: + pass + + +@dataclass +class StreamingDecodeState: + """Per-stream model-level decode state for streaming S2S inference. + + Holds KV caches, token workspaces, perception cache, and codec state + that persist across inference steps within a single stream. + """ + + frame_idx: int + gen_text: torch.Tensor + gen_asr_text: torch.Tensor + gen_function_text: torch.Tensor | None + input_embeds_history: list[torch.Tensor] + llm_cache: Any # DynamicCache or HybridMambaAttentionDynamicCache + tts_past_key_values: Any + tts_code: torch.Tensor | None + subword_mask: torch.Tensor | None + perception_cache: "PerceptionCacheState" | None = None + tts_codec_cache: Any = None + llm_cache_position_offset: int = 0 + timing: TimingSummary | NullTimingSummary = field(default_factory=NullTimingSummary) + + +@dataclass +class InferenceStepResult: + """Output from a single ``infer_one_step`` call. + + State mutations (caches, token workspaces, frame_idx) are applied + in-place on :class:`StreamingDecodeState`. This dataclass carries + only the per-step *outputs* needed by the pipeline. + """ + + predicted_text_tokens: torch.Tensor + asr_predicted_text_tokens: torch.Tensor + predicted_text_strs: list[str] + asr_predicted_text_strs: list[str] + decoded_audio: torch.Tensor | None = None + function_predicted_text_tokens: torch.Tensor | None = None + debug: dict | None = None + + +class IntermediateResultLogger: + """Records per-frame debug data (logits, embeddings, indices) during inference. + + Tensors are kept on their original device until :meth:`build_debug_dict` + is called, which performs a single bulk copy to CPU. + """ + + def __init__(self): + self.text_logits: list[torch.Tensor] = [] + self.asr_logits: list[torch.Tensor] = [] + self.input_embeds: list[torch.Tensor] = [] + self.selected_frame_indices: list[int] = [] + + def log_input_embeds(self, emb: torch.Tensor): + self.input_embeds.append(emb.detach()) + + def log_text_logits(self, logits: torch.Tensor): + self.text_logits.append(logits.detach()) + + def log_asr_logits(self, logits: torch.Tensor | None): + if logits is not None: + self.asr_logits.append(logits.detach()) + + def log_selected_frame_index(self, idx: int): + self.selected_frame_indices.append(idx) + + def build_debug_dict(self, source_encoded: torch.Tensor, gen_text: torch.Tensor, gen_asr_text: torch.Tensor | None) -> dict: + return { + "source_encoded": source_encoded.detach().cpu(), + "selected_frame_indices": self.selected_frame_indices, + "input_embeds": torch.cat(self.input_embeds, dim=1).cpu() if self.input_embeds else None, + "gen_text": gen_text.detach().cpu(), + "gen_asr": gen_asr_text.detach().cpu() if gen_asr_text is not None else None, + "text_logits": torch.stack(self.text_logits, dim=1).cpu() if self.text_logits else None, + "asr_logits": torch.stack(self.asr_logits, dim=1).cpu() if self.asr_logits else None, + } + + +class NullIntermediateResultLogger: + """No-op stand-in for :class:`IntermediateResultLogger`.""" + + def log_input_embeds(self, emb): + pass + + def log_text_logits(self, logits): + pass + + def log_asr_logits(self, logits): + pass + + def log_selected_frame_index(self, idx): + pass + + def build_debug_dict(self, *args, **kwargs): + return None diff --git a/nemo/collections/speechlm2/inference/model_wrappers/factory.py b/nemo/collections/speechlm2/inference/model_wrappers/factory.py new file mode 100644 index 000000000000..b67d9aa523d0 --- /dev/null +++ b/nemo/collections/speechlm2/inference/model_wrappers/factory.py @@ -0,0 +1,146 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Factory for creating LLM and TTS inference backends. + +NemotronVoiceChat has two components that can each be backed by native +PyTorch or vLLM. This factory returns the right backend for a given +``engine_type``: + +- ``"native_llm"`` -- wraps the PyTorch model directly (LLM component) +- ``"native_eartts"`` -- wraps the PyTorch DuplexEARTTS model (TTS component) +- ``"vllm_llm"`` -- vLLM engine for the LLM (DuplexSTT) component +- ``"vllm_eartts"`` -- vLLM engine for the TTS (EarTTS) component + +Usage: + from nemo.collections.speechlm2.inference.model_wrappers.factory import create_model + + llm = create_model(engine_type="native_llm", model=voicechat_model) + tts = create_model(engine_type="native_eartts", model=voicechat_model.tts_model) +""" + +from typing import Any + +from nemo.collections.speechlm2.inference.model_wrappers.backend.interface import ModelInterface + + +def create_model( + engine_type: str, + model=None, + vllm_config: dict[str, Any] | None = None, + special_token_ids: set[int] | None = None, + top_p: float = 1.0, + repetition_penalty: float = 1.0, + temperature: float = 1.0, + **kwargs +) -> ModelInterface: + """ + Factory function to create a single inference backend for one component. + + Each call creates a backend for one specific component (LLM or TTS) on + one specific runtime (native PyTorch or vLLM). The ``engine_type`` + must be one of the four ``{backend}_{component}`` combinations. + + Note: the user-facing config uses a *combined* ``engine_type`` like + ``"native"`` or ``"vllm_llm_vllm_eartts"`` which the wrapper + translates into two ``create_model`` calls (one for LLM, one for TTS). + + Args: + engine_type: One of "native_llm", "native_eartts", "vllm_llm", "vllm_eartts" + model: The PyTorch model to wrap. Required for native backends + (NemotronVoiceChat for LLM, DuplexEARTTS for TTS). Not used + by vLLM backends, which load their own engine from ``vllm_config``. + vllm_config: Configuration dict for vLLM engines (required for "vllm*") + special_token_ids: Set of special token IDs (pad, eos, bos) that should bypass sampling. + If None (default), will auto-extract from model.tokenizer for tokens: + '' (bos), '' (eos), '' (pad). + top_p: Top-p (nucleus) sampling threshold. 1.0 disables it (greedy). Default: 1.0 + repetition_penalty: Penalty for repeated tokens. 1.0 disables it. Default: 1.0 + temperature: Temperature for sampling. 1.0 = no change, 0.0 = greedy. Default: 1.0 + **kwargs: Additional arguments passed to the backend constructor + + Returns: + A ModelInterface instance ready for inference. + + Example: + >>> # Native PyTorch LLM with greedy decoding + >>> llm = create_model(engine_type="native_llm", model=voicechat_model) + >>> + >>> # Native PyTorch EarTTS + >>> tts = create_model(engine_type="native_eartts", model=voicechat_model.tts_model) + >>> + >>> # vLLM LLM engine + >>> llm = create_model(engine_type="vllm_llm", vllm_config={...}) + >>> + >>> # vLLM EarTTS engine + >>> tts = create_model(engine_type="vllm_eartts", vllm_config={...}) + """ + engine_type = engine_type.lower() + + if engine_type == "native_llm": + from nemo.collections.speechlm2.inference.model_wrappers.backend.pytorch.model import PyTorchLLM + + if model is None: + raise ValueError("model must be provided for native engine") + return PyTorchLLM( + model=model, + special_token_ids=special_token_ids, + top_p=top_p, + repetition_penalty=repetition_penalty, + temperature=temperature, + ) + + elif engine_type == "native_eartts": + from nemo.collections.speechlm2.inference.model_wrappers.backend.pytorch.eartts import PyTorchEarTTS + + if model is None: + raise ValueError("model (DuplexEARTTS instance) must be provided for native EarTTS engine") + return PyTorchEarTTS(tts_model=model) + + elif engine_type == "vllm_eartts": + from nemo.collections.speechlm2.inference.model_wrappers.backend.vllm.eartts import VLLMEarTTS + + if vllm_config is None: + raise ValueError("vllm_config must be provided for vLLM EARTTS engine") + return VLLMEarTTS( + **vllm_config, + model_type="eartts", + special_token_ids=special_token_ids, + top_p=top_p, + repetition_penalty=repetition_penalty, + temperature=temperature, + **kwargs + ) + + elif engine_type == "vllm_llm": + from nemo.collections.speechlm2.inference.model_wrappers.backend.vllm.llm import VLLMLLM + + if vllm_config is None: + raise ValueError("vllm_config must be provided for vLLM engine") + return VLLMLLM( + **vllm_config, + model_type="llm", + special_token_ids=special_token_ids, + top_p=top_p, + repetition_penalty=repetition_penalty, + temperature=temperature, + **kwargs + ) + + else: + raise ValueError( + f"Unknown engine_type: {engine_type}. " + f"Supported types: 'native_llm', 'native_eartts', 'vllm_llm', 'vllm_eartts'" + ) diff --git a/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py b/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py new file mode 100755 index 000000000000..c5049fcd665f --- /dev/null +++ b/nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py @@ -0,0 +1,1005 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import gc +import os +import time +import torch +import torchaudio +from omegaconf import OmegaConf, DictConfig + +from nemo.utils import logging, str_to_dtype +from transformers import DynamicCache + +from nemo.collections.speechlm2.models.nemotron_voicechat import NemotronVoiceChat +from nemo.collections.speechlm2.parts.precision import fp32_precision +from nemo.collections.audio.parts.utils.transforms import resample +from nemo.collections.speechlm2.modules.ear_tts_vae_codec import CausalConv1dCache +from nemo.collections.speechlm2.inference.model_wrappers.factory import create_model +from nemo.collections.speechlm2.inference.model_wrappers.perception_cache import ( + PerceptionCacheState, + PerceptionCacheManager, +) +from nemo.collections.speechlm2.inference.model_wrappers.decode_state import ( + InferenceStepResult, + IntermediateResultLogger, + NullIntermediateResultLogger, + NullTimingSummary, + StreamingDecodeState, + TimingSummary, +) +from nemo.collections.speechlm2.parts.text_utils import _decode_tokens_with_specials + + +# --- Configuration --- +DEFAULT_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +# --- Streaming Parameters --- +SAMPLE_RATE = 16000 +FRAME_SIZE_SEC = 0.08 # 80ms per frame +FRAME_SIZE_SAMPLES = int(SAMPLE_RATE * FRAME_SIZE_SEC) # 1280 samples + +TTS_SAMPLE_RATE = 22050 + + +class NemotronVoicechatInferenceWrapper: + """ + Inference wrapper for NemotronVoiceChat models. + Uses a sliding window buffer and processes audio frame by frame. + """ + + def __init__(self, model_cfg: DictConfig): + """ + Initialize the model for realtime streaming inference. + + Args: + model_cfg (DictConfig): Configuration describing the model paths and inference parameters. + """ + if model_cfg is None: + raise ValueError("model_cfg must be provided") + if not isinstance(model_cfg, DictConfig): + model_cfg = OmegaConf.create(model_cfg) + + # Precision settings (applied here so they take effect before model loading) + allow_tf32 = bool(model_cfg.get("allow_tf32", True)) + torch.backends.cudnn.allow_tf32 = allow_tf32 + torch.backends.cuda.matmul.allow_tf32 = allow_tf32 + matmul_precision = str(model_cfg.get("matmul_precision", "medium")) + torch.set_float32_matmul_precision(matmul_precision) + + # Deterministic mode: guarantees identical *text* outputs (from the STT/LLM + # heads) across runs for the same inputs, even when sampling is enabled + # (top_p < 1, temperature != 1, repetition_penalty != 1). This works + # because we fix the global PyTorch RNG seeds and force all CUDA ops + # to use deterministic algorithm implementations. + # Not compatible with vLLM engines (raises an error below). + self._deterministic = bool(model_cfg.get("deterministic", False)) + if self._deterministic: + engine_type = model_cfg.get("engine_type", "native") + if "vllm" in engine_type.lower(): + raise ValueError( + "`deterministic` is not compatible with vLLM engines because vLLM uses custom " + "CUDA kernels (PagedAttention, FlashAttention) that do not support deterministic mode. " + f"Got engine_type='{engine_type}'. Use engine_type='native' for deterministic inference." + ) + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" + torch.manual_seed(0) + torch.cuda.manual_seed_all(0) + torch.backends.cuda.enable_flash_sdp(False) + torch.backends.cuda.enable_mem_efficient_sdp(False) + torch.use_deterministic_algorithms(True, warn_only=False) + else: + torch.backends.cuda.enable_flash_sdp(True) + torch.backends.cuda.enable_mem_efficient_sdp(True) + torch.use_deterministic_algorithms(False) + + self.model_cfg = model_cfg + + self.model_path = model_cfg.get("model_path") + if not self.model_path: + raise ValueError("`model_cfg.model_path` must be provided.") + + self.decode_audio = bool(model_cfg.get("decode_audio", True)) + + self.speaker_reference = model_cfg.get("speaker_reference") + self.speaker_name = model_cfg.get("speaker_name", None) + if self.decode_audio and not self.speaker_reference and not self.speaker_name: + raise ValueError("`model_cfg.speaker_reference` or `model_cfg.speaker_name` must be provided when decode_audio is enabled.") + + self.tts_system_prompt = model_cfg.get("tts_system_prompt", None) + logging.info(f"TTS system prompt: {self.tts_system_prompt}") + + self.dtype = str_to_dtype(model_cfg.get("compute_dtype", "bfloat16")) + + device = model_cfg.get("device") + device_id = model_cfg.get("device_id") + if device is None: + self.device = DEFAULT_DEVICE + else: + device_str = str(device) + if device_id is not None and device_str.startswith("cuda") and ":" not in device_str: + device_str = f"{device_str}:{device_id}" + self.device = torch.device(device_str) + + logging.info("=" * 70) + logging.info("INITIALIZING REALTIME STREAMING INFERENCE") + logging.info("=" * 70) + logging.info(f"Frame size: {FRAME_SIZE_SEC}s ({FRAME_SIZE_SAMPLES} samples @ {SAMPLE_RATE}Hz)") + logging.info(f"Device: {self.device}") + logging.info(f"Compute dtype: {self.dtype}") + logging.info(f"Decode audio: {self.decode_audio}") + logging.info(f"Engine type: {model_cfg.get('engine_type', 'native')}") + logging.info(f"Sampling - top_p: {model_cfg.get('top_p', 1.0)}, repetition_penalty: {model_cfg.get('repetition_penalty', 1.0)}, temperature: {model_cfg.get('temperature', 1.0)}") + logging.info(f"Precision (configured): matmul_precision={matmul_precision}, allow_tf32={allow_tf32}, deterministic={self._deterministic}") + logging.info(f"Precision (effective): float32_matmul_precision={torch.get_float32_matmul_precision()}, cudnn.allow_tf32={torch.backends.cudnn.allow_tf32}, cuda.matmul.allow_tf32={torch.backends.cuda.matmul.allow_tf32}") + logging.info("=" * 70) + + # Profiling: when True, a TimingSummary (extending NamedTimer with + # sync_cuda=True) is attached to each decode state, recording + # per-stage wall-clock times. Disabled by default to avoid + # unnecessary GPU stalls in production. + self._profile_timing = bool(model_cfg.get("profile_timing", False)) + + # Cached TTS helpers populated during initialization/warmup + self.first_context_subword_id = None + self.generation_config = None + self.first_tts_code_input = None + self.first_tts_past_key_values_input = None + + self.model = None + self.model_llm_interface = None + self.tokenizer = None + + # Engine configuration. + # engine_type is a user-facing config value that selects which combination + # of backends to use for the LLM and TTS components: + # "native" -> native_llm + native_eartts + # "vllm_llm" -> vllm_llm + native_eartts + # "vllm_eartts" -> native_llm + vllm_eartts + # "vllm_llm_vllm_eartts" -> vllm_llm + vllm_eartts + # The factory (create_model) uses the specific {backend}_{component} + # names: native_llm, native_eartts, vllm_llm, vllm_eartts. + self.engine_type = model_cfg.get("engine_type", "native") + self.use_vllm_llm = "vllm_llm" in self.engine_type.lower() + self.use_vllm_eartts = "vllm_eartts" in self.engine_type.lower() + self.vllm_llm_config = model_cfg.get("vllm_llm_config", None) + self.vllm_tts_config = model_cfg.get("vllm_tts_config", None) + self.request_id = "streaming_request_0" # For vLLM streaming + + # Sampling parameters + self.top_p = float(model_cfg.get("top_p", 1.0)) + self.repetition_penalty = float(model_cfg.get("repetition_penalty", 1.0)) + self.temperature = float(model_cfg.get("temperature", 1.0)) + + # LLM KV cache: when enabled, uses DynamicCache (standard) or + # HybridMambaAttentionDynamicCache (Nemotron) for incremental decoding. + # When disabled, falls back to full-history reprocessing each step. + self.use_llm_cache = bool(model_cfg.get("use_llm_cache", True)) + + # Perception cache configuration + self.use_perception_cache = bool(model_cfg.get("use_perception_cache", False)) + use_perception_cudagraph = bool(model_cfg.get("use_perception_cudagraph", False)) + if use_perception_cudagraph and not self.use_perception_cache: + raise ValueError( + "use_perception_cudagraph requires use_perception_cache to be enabled. " + "Please also set use_perception_cache=True." + ) + self.perception_cache_mgr: PerceptionCacheManager | None = None + self._use_perception_cudagraph = use_perception_cudagraph + + self._initialize_model() + + logging.info("NemotronVoicechatInferenceWrapper initialized successfully.") + + def _initialize_model(self): + """Initialize the NemotronVoiceChat model from an HF checkpoint.""" + logging.info("Initializing model structure...") + start_model_init = time.time() + + self.model = NemotronVoiceChat.from_pretrained( + self.model_path, + local_files_only=True, + ) + logging.info(f"NemotronVoiceChat initialized in {time.time() - start_model_init:.1f}s") + + # Delete unused native components BEFORE moving to GPU to save memory + if self.use_vllm_llm: + logging.info("Deleting native LLM (will use vLLM instead)...") + if hasattr(self.model.stt_model, 'llm') and self.model.stt_model.llm is not None: + for name, child in list(self.model.stt_model.llm.named_children()): + delattr(self.model.stt_model.llm, name) + del self.model.stt_model.llm + self.model.stt_model.llm = None + + if self.use_vllm_eartts: + logging.info("Deleting native TTS (will use vLLM instead)...") + del self.model.tts_model.tts_model + + if self.use_vllm_llm or self.use_vllm_eartts: + # Free memory from deleted components before GPU transfer and vLLM engine creation + gc.collect() + torch.cuda.empty_cache() + + # Setup model on device + self.model.to(self.device) + self.model.eval() + + # Convert some S2S components to the configured dtype + logging.info(f"Converting some S2S components to {self.dtype} (keeping perception & TTS in float32)...") + if self.model.stt_model.llm is not None: + self.model.stt_model.llm = self.model.stt_model.llm.to(self.dtype) + self.model.stt_model.lm_head = self.model.stt_model.lm_head.to(self.dtype) + self.model.stt_model.embed_tokens = self.model.stt_model.embed_tokens.to(self.dtype) + self.model.stt_model.asr_head = self.model.stt_model.asr_head.to(self.dtype) + self.model.stt_model.embed_asr_tokens = self.model.stt_model.embed_asr_tokens.to(self.dtype) + if self.model.stt_model.function_head is not None: + self.model.stt_model.function_head = self.model.stt_model.function_head.to(self.dtype) + logging.info("function_head converted to %s", self.dtype) + + self.tokenizer = self.model.stt_model.tokenizer + + # Allow overrides from wrapper config into the model config (e.g. logit boosts). + _BOOST_KEYS = ( + "inference_pad_boost", + "inference_bos_boost", + "inference_eos_boost", + "inference_user_pad_boost", + "inference_user_bos_boost", + "inference_user_eos_boost", + ) + for key in _BOOST_KEYS: + val = self.model_cfg.get(key, None) + if val is not None: + OmegaConf.update(self.model.stt_model.cfg, key, val) + boost_values = {k: self.model.stt_model.cfg.get(k, None) for k in _BOOST_KEYS} + logging.info(f"Inference logit boosts: {boost_values}") + + # Create LLM backend + if self.use_vllm_llm: + logging.info("Creating VLLMLLM backend...") + if self.vllm_llm_config is None: + raise ValueError("vllm_llm_config must be provided when engine_type contains 'vllm_llm'") + + # Set logit boosts as env vars BEFORE creating the vLLM engine, + # so they are inherited by the forked worker process. The modified + # nemotron_h.py reads VLLM_ASR_BOOST_ and + # VLLM_TEXT_BOOST_ in __init__. + stt = self.model.stt_model + asr_boost_map = { + "inference_user_pad_boost": stt.text_pad_id, + "inference_user_bos_boost": stt.text_bos_id, + "inference_user_eos_boost": stt.text_eos_id, + } + for cfg_key, token_id in asr_boost_map.items(): + val = self.model_cfg.get(cfg_key, None) + if val is not None and float(val) != 0.0: + env_key = f"VLLM_ASR_BOOST_{token_id}" + os.environ[env_key] = str(float(val)) + logging.info(f"Set env {env_key}={val} (from {cfg_key})") + + text_boost_map = { + "inference_pad_boost": stt.text_pad_id, + "inference_bos_boost": stt.text_bos_id, + "inference_eos_boost": stt.text_eos_id, + } + for cfg_key, token_id in text_boost_map.items(): + val = self.model_cfg.get(cfg_key, None) + if val is not None and float(val) != 0.0: + env_key = f"VLLM_TEXT_BOOST_{token_id}" + os.environ[env_key] = str(float(val)) + logging.info(f"Set env {env_key}={val} (from {cfg_key})") + + self.model_llm_interface = create_model( + model=self.model_path if self.use_vllm_llm else self.model, + engine_type="vllm_llm" if self.use_vllm_llm else "native_llm", + vllm_config=self.vllm_llm_config if self.use_vllm_llm else None, + top_p=self.top_p, + repetition_penalty=self.repetition_penalty, + temperature=self.temperature, + ) + logging.info(f"LLM backend: {type(self.model_llm_interface).__name__}") + + # Create TTS backend + self.model_eartts_interface = create_model( + model=self.model_path if self.use_vllm_eartts else self.model.tts_model, + engine_type="vllm_eartts" if self.use_vllm_eartts else "native_eartts", + vllm_config=self.vllm_tts_config if self.use_vllm_eartts else None, + ) + logging.info(f"TTS backend: {type(self.model_eartts_interface).__name__}") + + # torch.compile and subword cache (no-ops for vLLM, delegated to TTS backend) + if bool(self.model_cfg.get("use_tts_torch_compile", False)): + self.model_eartts_interface.compile() + self.model_eartts_interface.setup_subword_cache(self.model_cfg) + + # Get TTS info + if hasattr(self.model, 'tts_model'): + self.target_fps = self.model.tts_model.target_fps + self.target_sample_rate = self.model.tts_model.target_sample_rate + logging.info(f"TTS model initialized: target_fps={self.target_fps}, sample_rate={self.target_sample_rate}") + if self.decode_audio: + self._prepare_tts_initial_state() + else: + logging.warning("Warning: TTS model not found in the model") + + # Setup perception cache if enabled + if self.use_perception_cache: + self.perception_cache_mgr = PerceptionCacheManager( + model=self.model, + device=self.device, + dtype=self.dtype, + use_cudagraph=self._use_perception_cudagraph, + ) + if not self.perception_cache_mgr.setup(): + self.use_perception_cache = False + self.perception_cache_mgr = None + + def _get_bos_embedding(self): + """Get beginning of sequence embedding.""" + text_bos = torch.full((1,), fill_value=self.model.stt_model.text_pad_id, device=self.device) + input_embeds = self.model.stt_model.embed_tokens(text_bos) + return input_embeds.to(dtype=self.dtype) + + def _get_asr_bos_embedding(self) -> torch.Tensor: + """Get ASR BOS embedding for AR decoding.""" + text_bos = torch.full((1,), fill_value=self.model.stt_model.text_pad_id, device=self.device) + input_embeds = self.model.stt_model.embed_asr_tokens(text_bos) + return input_embeds.to(dtype=self.dtype) + + def _prepare_system_prompt_embeddings( + self, + system_prompt: str, + ) -> tuple[torch.Tensor | None, int]: + if not system_prompt or not system_prompt.strip(): + return None, 0 + + prompt_token_ids = self._build_prompt_token_ids(system_prompt) + prompt_tokens = torch.tensor(prompt_token_ids, dtype=torch.long, device=self.device).unsqueeze(0) + prompt_embedded = self.model.stt_model.embed_tokens(prompt_tokens).to(dtype=self.dtype) + prompt_len = prompt_tokens.shape[1] + + pad_id = self.model.stt_model.text_pad_id + pad_token = torch.full((1,), fill_value=pad_id, device=self.device, dtype=torch.long) + pad_emb = self.model.stt_model.embed_tokens(pad_token).to(dtype=self.dtype) + pad_asr_emb = self.model.stt_model.embed_asr_tokens(pad_token).to(dtype=self.dtype) + + has_fc = self.model.stt_model.function_head is not None + if prompt_len > 1: + prompt_embedded[:, 1:, :] += pad_emb + prompt_embedded[:, 1:, :] += pad_asr_emb + if has_fc: + prompt_embedded[:, 1:, :] += pad_emb + + bos_emb = self._get_bos_embedding() + asr_bos_emb = self._get_asr_bos_embedding() + prompt_embedded[:, 0, :] += bos_emb.squeeze(0) + prompt_embedded[:, 0, :] += asr_bos_emb.squeeze(0) + if has_fc: + prompt_embedded[:, 0, :] += pad_emb.squeeze(0) + + return prompt_embedded, prompt_len + + def _clone_cache(self, cache): + """Deep clone cache structures to ensure complete isolation between streams.""" + if cache is None: + return None + if isinstance(cache, torch.Tensor): + return cache.detach().clone() + if isinstance(cache, (list, tuple)): + return type(cache)(self._clone_cache(x) for x in cache) + if isinstance(cache, dict): + return {k: self._clone_cache(v) for k, v in cache.items()} + if hasattr(cache, '__dict__'): + return copy.deepcopy(cache) + return cache + + def _build_prompt_token_ids(self, system_prompt: str | None) -> list[int]: + if not system_prompt or not system_prompt.strip(): + return [] + return [self.tokenizer.bos_id] + self.tokenizer.text_to_ids(system_prompt) + [self.tokenizer.eos_id] + + def _init_token_buffers(self, max_len: int): + stt_model = self.model.stt_model + gen_text = torch.full((1, max_len), stt_model.text_pad_id, device=self.device, dtype=torch.long) + gen_asr_text = torch.full((1, max_len), stt_model.text_pad_id, device=self.device, dtype=torch.long) + gen_function_text = None + if getattr(stt_model, "function_head", None) is not None: + gen_function_text = torch.full((1, max_len), stt_model.text_pad_id, device=self.device, dtype=torch.long) + return gen_text, gen_asr_text, gen_function_text + + def _create_llm_cache(self): + if not self.use_llm_cache or self.use_vllm_llm: + return None + pretrained_llm = str(self.model.stt_model.cfg.get("pretrained_llm", "")) + if "Nemotron" in pretrained_llm: + return self.model.stt_model._create_nemotron_cache(batch_size=1) + return DynamicCache() + + def _create_codec_state(self, max_len: int): + if not self.decode_audio or not hasattr(self.model, "tts_model"): + return None, None + + codec_cache = CausalConv1dCache() + subword_mask = torch.ones((1, max_len), device=self.device, dtype=torch.bool) + return subword_mask, codec_cache + + def _prepare_tts_initial_state(self): + if not self.decode_audio: + return + if not hasattr(self.model, 'tts_model'): + return + + logging.info("Preparing TTS warmup state...") + + if self.speaker_name is not None: + logging.info(f"Using registered speaker name: {self.speaker_name}") + speaker_audio = None + speaker_audio_lens = None + else: + with fp32_precision(): + speaker_audio, speaker_sr = torchaudio.load(self.speaker_reference) + speaker_audio = resample(speaker_audio, speaker_sr, self.model.tts_model.target_sample_rate) + speaker_audio = speaker_audio.to(self.device) + speaker_audio_lens = torch.tensor([speaker_audio.size(1)], device=self.device).long() + + self.model.tts_model.set_init_inputs( + speaker_audio=speaker_audio, + speaker_audio_lens=speaker_audio_lens, + system_prompt=self.tts_system_prompt, + speaker_name=self.speaker_name, + ) + init_inputs = self.model.tts_model.get_init_inputs(B=1) + + self.generation_config = self.model.tts_model._get_generation_config(guidance_enabled=True) + init_inputs.update({"use_cache": True, "past_key_values": None, "guidance_enabled": True}) + + with torch.no_grad(): + if self.use_vllm_eartts: + self.tts_prompt_token_ids = init_inputs["subword_ids"].squeeze().cpu().numpy().tolist() + self.tts_init_inputs = init_inputs + outputs = self.model_eartts_interface.prefill_prompt( + init_inputs, + prompt_token_ids=getattr(self, 'tts_prompt_token_ids', None), + request_id="tts_warmup", + ) + if self.use_vllm_eartts: + # Abort warmup request so the engine is clean for actual streaming + self.model_eartts_interface.abort_request("tts_warmup") + + code = init_inputs["code"][:, -1:] + + self.first_context_subword_id = init_inputs["subword_ids"][:, -1].unsqueeze(-1) + self.first_tts_code_input = code.detach().clone() + self.first_tts_past_key_values_input = self._clone_cache(outputs.past_key_values) + + logging.info("TTS warmup state prepared") + + def create_decode_state(self, max_len: int) -> StreamingDecodeState: + gen_text, gen_asr_text, gen_function_text = self._init_token_buffers(max_len) + llm_cache = self._create_llm_cache() + subword_mask, tts_codec_cache = self._create_codec_state(max_len) + perception_cache = None + if self.use_perception_cache and self.perception_cache_mgr is not None: + perception_cache = self.perception_cache_mgr.get_initial_state(batch_size=1) + + tts_past_key_values = None + tts_code = None + if self.decode_audio and self.first_tts_code_input is not None: + tts_past_key_values = self._clone_cache(self.first_tts_past_key_values_input) + tts_code = self.first_tts_code_input.detach().clone() + + return StreamingDecodeState( + frame_idx=0, + gen_text=gen_text, + gen_asr_text=gen_asr_text, + gen_function_text=gen_function_text, + input_embeds_history=[], + llm_cache=llm_cache, + tts_past_key_values=tts_past_key_values, + tts_code=tts_code, + subword_mask=subword_mask, + perception_cache=perception_cache, + tts_codec_cache=tts_codec_cache, + llm_cache_position_offset=0, + timing=TimingSummary() if self._profile_timing else NullTimingSummary(), + ) + + def infer_one_step( + self, + audio_input: torch.Tensor, + num_frames_per_chunk: int, + state: StreamingDecodeState, + *, + request_id: str | None = None, + has_prompt: bool = False, + return_debug: bool = False, + sampling_params: dict[str, float] | None = None, + ) -> InferenceStepResult: + """Run one streaming inference step: perception -> LLM -> TTS -> audio decode. + + All mutable decode state (caches, gen_text, gen_asr_text, code, etc.) is + updated **in-place** on *state*. The returned :class:`InferenceStepResult` + carries only per-step outputs needed by the pipeline. + + Args: + audio_input (torch.Tensor): Raw audio tensor for this chunk, shape ``(1, samples)``. + num_frames_per_chunk (int): Number of 80 ms frames in this chunk. + state (StreamingDecodeState): Mutable decode state (KV caches, token workspaces, etc.). + request_id (str | None): Unique ID for this stream (used by vLLM engines). + has_prompt (bool): Whether the LLM KV cache already contains a prefilled + system prompt. Affects the first-frame embedding (PAD vs BOS). + return_debug (bool): If True, attach per-step debug info to the result. + sampling_params (dict[str, float] | None): Optional per-stream sampling overrides + (``top_p``, ``temperature``, ``repetition_penalty``). + Keys that are absent fall back to the pipeline-level defaults. + """ + effective_request_id = request_id or self.request_id + frame_idx = state.frame_idx + + state.timing.start("total_step") + use_llm_cache = state.llm_cache is not None + B = state.gen_text.shape[0] + + predicted_tokens = torch.empty((B, num_frames_per_chunk), dtype=state.gen_text.dtype, device=state.gen_text.device) + asr_predicted_tokens = torch.empty((B, num_frames_per_chunk), dtype=state.gen_text.dtype, device=state.gen_text.device) + function_predicted_tokens = None + if self.model.stt_model.function_head is not None: + function_predicted_tokens = torch.empty((B, num_frames_per_chunk), dtype=state.gen_text.dtype, device=state.gen_text.device) + + debug_logger = IntermediateResultLogger() if return_debug else NullIntermediateResultLogger() + + # --- Stage 1: Perception --- + state.timing.start("perception") + source_encoded, state.perception_cache = self._run_perception( + audio_input, frame_idx, num_frames_per_chunk, state.perception_cache, + ) + state.timing.stop("perception") + total_encoded_frames = source_encoded.shape[1] + if self.use_perception_cache and state.perception_cache is not None and state.perception_cache.is_initialized(): + # With cache: we get exactly num_frames_per_chunk output frames + base_frame_index = 0 + else: + # Without cache: use the second-to-last encoded frame as the + # "newest" because the model expects 10ms / 80ms / 80ms ... framing + # but we always feed 80ms chunks, so the final frame contains + # silence padding. + newest = total_encoded_frames - 2 + base_frame_index = max(newest - (num_frames_per_chunk - 1), 0) + + # --- Stage 2: Per-frame generation loop --- + new_input_embeds = [] + new_codes_for_decode = [] + for frame_offset in range(num_frames_per_chunk): + current_frame_idx = frame_idx + frame_offset + current_frame_index = min(base_frame_index + frame_offset, source_encoded.shape[1] - 1) + debug_logger.log_selected_frame_index(current_frame_index) + frame_embedding = source_encoded[:, current_frame_index:current_frame_index + 1, :] + + input_emb = self._build_input_embedding( + frame_embedding, current_frame_idx, state, has_prompt, + ) + debug_logger.log_input_embeds(input_emb) + + ans = self._run_llm_step( + input_emb, state, frame_offset, effective_request_id, + current_frame_idx, use_llm_cache, return_debug, new_input_embeds, + sampling_params=sampling_params, + ) + + if "text_logits" in ans: + debug_logger.log_text_logits(ans["text_logits"][:, -1]) + asr_logits = ans.get("asr_logits") + if asr_logits is not None: + debug_logger.log_asr_logits(asr_logits[:, -1]) + + state.gen_text[:, current_frame_idx] = ans["predicted_token"] + predicted_tokens[:, frame_offset] = ans["predicted_token"] + state.gen_asr_text[:, current_frame_idx] = ans["asr_predicted_token"] + asr_predicted_tokens[:, frame_offset] = ans["asr_predicted_token"] + + if "function_predicted_token" in ans: + if function_predicted_tokens is not None: + function_predicted_tokens[:, frame_offset] = ans["function_predicted_token"] + if state.gen_function_text is not None: + state.gen_function_text[:, current_frame_idx] = ans["function_predicted_token"] + + self._maybe_apply_forced_turn_taking(current_frame_idx, state.gen_text, state.gen_asr_text) + predicted_tokens[:, frame_offset] = state.gen_text[:, current_frame_idx] + + if self.decode_audio: + new_code = self._run_tts_step( + state, current_frame_idx, effective_request_id, + ) + new_codes_for_decode.append(new_code) + + # --- Stage 3: Audio decode --- + decoded_audio_new = self._decode_audio(new_codes_for_decode, state, frame_idx, num_frames_per_chunk) + + # --- Stage 4: Token -> string conversion --- + predicted_text_strs = self._tokens_to_strings(predicted_tokens) + asr_predicted_text_strs = self._tokens_to_strings(asr_predicted_tokens) + + logging.debug(f'frame {frame_idx}: USER asr: {asr_predicted_text_strs}') + logging.debug(f'frame {frame_idx}: AGENT txt: {predicted_text_strs}') + + # --- Update remaining state fields --- + if not use_llm_cache: + state.input_embeds_history = state.input_embeds_history + new_input_embeds + if use_llm_cache: + state.llm_cache_position_offset += num_frames_per_chunk + + state.timing.stop("total_step") + + debug = debug_logger.build_debug_dict(source_encoded, state.gen_text, state.gen_asr_text) + + return InferenceStepResult( + predicted_text_tokens=predicted_tokens, + asr_predicted_text_tokens=asr_predicted_tokens, + predicted_text_strs=predicted_text_strs, + asr_predicted_text_strs=asr_predicted_text_strs, + decoded_audio=decoded_audio_new, + function_predicted_text_tokens=function_predicted_tokens, + debug=debug, + ) + + # ------------------------------------------------------------------ + # infer_one_step sub-stages + # ------------------------------------------------------------------ + + def _build_input_embedding( + self, + frame_embedding: torch.Tensor, + current_frame_idx: int, + state: StreamingDecodeState, + has_prompt: bool, + ) -> torch.Tensor: + """Compose the LLM input embedding for a single frame. + + Combines the perception embedding (user channel) with the text / + ASR / function-call channel embeddings from the previous step. + At frame 0 this is either BOS (no prompt) or pad (after prompt). + + IMPORTANT: The arithmetic order here must match offline_inference + exactly (floating-point addition is not associative). For t > 0 + the text and ASR embeddings are summed first, then added to the + perception embedding in a single ``+=``. For t == 0 the sequential + ``+=`` pattern matches the offline path. + """ + stt = self.model.stt_model + emb = frame_embedding.clone() + emb *= stt.cfg.get("duplex_user_channel_weight", 1.0) + + has_fc = state.gen_function_text is not None + + if current_frame_idx == 0 and not has_prompt: + emb += self._get_bos_embedding() * stt.cfg.get("duplex_text_channel_weight", 1.0) + emb += self._get_asr_bos_embedding() * stt.cfg.get("duplex_asr_text_weight", 1.0) + if has_fc: + pad_token = torch.full((1,), fill_value=stt.text_pad_id, device=self.device, dtype=torch.long) + emb += stt.embed_tokens(pad_token).to(dtype=self.dtype) + + elif current_frame_idx == 0 and has_prompt: + pad_token = torch.full((1,), fill_value=stt.text_pad_id, device=self.device, dtype=torch.long) + emb += stt.embed_tokens(pad_token).to(dtype=self.dtype) + emb += stt.embed_asr_tokens(pad_token).to(dtype=self.dtype) + if has_fc: + emb += stt.embed_tokens(pad_token).to(dtype=self.dtype) + + else: + # Sum text + ASR first, then add once (must match offline operation order) + prev = current_frame_idx - 1 + last_token_emb = stt.embed_tokens( + state.gen_text[:, prev] + ) * stt.cfg.get("duplex_text_channel_weight", 1.0) + last_asr_token_emb = stt.embed_asr_tokens( + state.gen_asr_text[:, prev] + ) * stt.cfg.get("duplex_asr_text_weight", 1.0) + emb += last_token_emb + last_asr_token_emb + if has_fc: + emb += stt.embed_tokens(state.gen_function_text[:, prev]).to(dtype=self.dtype) + + return emb + + def _run_llm_step( + self, + input_emb: torch.Tensor, + state: StreamingDecodeState, + frame_offset: int, + request_id: str, + current_frame_idx: int, + use_llm_cache: bool, + return_debug: bool, + new_input_embeds: list, + sampling_params: dict[str, float] | None = None, + ) -> dict: + """Run one LLM forward pass (native cache, vLLM, or full-history). + + Updates ``state.llm_cache`` in-place for cached paths. For the + no-cache fallback, appends to *new_input_embeds* (list, mutated). + """ + state.timing.start("stt_model") + + if use_llm_cache or self.use_vllm_llm: + if self.use_vllm_llm: + ans = self.model_llm_interface( + input_emb, + request_id=request_id, + generated_tokens=state.gen_text, + current_step=current_frame_idx, + sampling_params=sampling_params, + ) + else: + cache_pos = torch.tensor( + [state.llm_cache_position_offset + frame_offset], device=self.device, + ) + ans = self.model_llm_interface( + input_emb, + cache=state.llm_cache, + cache_position=cache_pos, + generated_tokens=state.gen_text, + current_step=current_frame_idx, + return_logits=return_debug, + sampling_params=sampling_params, + ) + state.llm_cache = ans["cache"] + else: + new_input_embeds.append(input_emb) + full_input_embeds = torch.cat(state.input_embeds_history + new_input_embeds, dim=1) + ans = self.model_llm_interface( + full_input_embeds, + cache=None, + generated_tokens=state.gen_text, + current_step=current_frame_idx, + return_logits=return_debug, + sampling_params=sampling_params, + ) + + state.timing.stop("stt_model") + + return ans + + def _run_tts_step( + self, + state: StreamingDecodeState, + current_frame_idx: int, + request_id: str, + ) -> torch.Tensor: + """Run one TTS code-generation step. + + Mutates ``state.tts_code`` and ``state.tts_past_key_values`` in-place. + Returns the new code tensor (cloned) for later batch decoding. + """ + current_subword_id = state.gen_text[:, current_frame_idx].unsqueeze(-1) + + if current_frame_idx == 0: + if self.first_context_subword_id is None: + raise RuntimeError("first_context_subword_id is not initialized. Ensure TTS warmup ran successfully.") + prev_subword_id = self.first_context_subword_id + else: + prev_subword_id = state.gen_text[:, current_frame_idx - 1].unsqueeze(-1) + + current_subword_mask = state.subword_mask[:, current_frame_idx].unsqueeze(-1) + + if self.generation_config is None: + raise RuntimeError("generation_config is not initialized. Ensure TTS warmup ran successfully.") + + state.timing.start("tts_model") + if self.use_vllm_eartts: + tts_inputs = { + "code": state.tts_code, + "context_hidden_state": None, + "subword_ids": current_subword_id, + "subword_mask": current_subword_mask, + "past_key_values": state.tts_past_key_values, + "use_cache": True, + "guidance_enabled": True, + "generation_config": self.generation_config, + "ignore_eos_flag_stop": True, + } + else: + tts_inputs = { + "current_subword_id": current_subword_id, + "prev_subword_id": prev_subword_id, + "current_subword_mask": current_subword_mask, + "prev_audio_tokens": state.tts_code, + "past_key_values": state.tts_past_key_values, + "guidance_enabled": True, + "generation_config": self.generation_config, + "ignore_eos_flag_stop": True, + } + result = self.model_eartts_interface(tts_inputs, request_id=request_id) + state.tts_code = result.codes + state.tts_past_key_values = result.past_key_values + + state.timing.stop("tts_model") + + new_code = state.tts_code.clone() + + if self.model.cfg.get('inference_force_speech_silence_on_eos', None): + silence_codes = self.model.tts_model.codec_silence_tokens.view(1, 1, -1).expand(state.tts_code.shape) + state.tts_code = torch.where( + current_subword_id.unsqueeze(-1) == self.model.tts_model.text_eos_id, + silence_codes, + state.tts_code, + ) + + return new_code + + def _decode_audio( + self, + new_codes_for_decode: list[torch.Tensor], + state: StreamingDecodeState, + frame_idx: int, + num_frames_per_chunk: int, + ) -> torch.Tensor | None: + """Decode accumulated TTS codes into a waveform. + + Returns the decoded audio tensor or *None* when ``decode_audio`` + is disabled or no codes were produced. + """ + if not self.decode_audio or not new_codes_for_decode: + return None + + logging.debug(f"Decoding audio for {frame_idx}-th frame ({num_frames_per_chunk=})") + + state.timing.start("audio_codec") + with fp32_precision(), torch.no_grad(): + new_codes_tensor = torch.cat(new_codes_for_decode, dim=1) + if hasattr(self.model.tts_model, '_control_codes'): + from nemo.collections.speechlm2.models.duplex_ear_tts import replace_control_speech_codes + new_codes_tensor = replace_control_speech_codes( + new_codes_tensor, + self.model.tts_model._control_codes, + getattr(self.model.tts_model, 'codec_silence_tokens', None), + ) + new_code_len = torch.tensor( + [new_codes_tensor.shape[1]], dtype=torch.long, device=self.device, + ) + decoded_audio, _ = self.model.tts_model.audio_codec.decode( + new_codes_tensor, new_code_len, cache=state.tts_codec_cache, + ) + + state.timing.stop("audio_codec") + + return decoded_audio + + def _run_perception( + self, + audio_input: torch.Tensor, + frame_idx: int, + num_frames_per_chunk: int, + perception_cache: PerceptionCacheState | None, + ) -> tuple[torch.Tensor, PerceptionCacheState | None]: + """Run the perception encoder and return (source_encoded, updated_cache).""" + if self.use_perception_cache and perception_cache is not None and perception_cache.is_initialized(): + source_encoded, perception_cache = self.perception_cache_mgr.step( + audio_input=audio_input, + frame_idx=frame_idx, + num_frames_per_chunk=num_frames_per_chunk, + perception_cache=perception_cache, + ) + else: + buffer_len = torch.tensor([audio_input.shape[1]], dtype=torch.long, device=self.device) + source_encoded, _, _ = self.model.stt_model.perception( + input_signal=audio_input, + input_signal_length=buffer_len, + return_encoder_emb=True, + ) + + source_encoded = source_encoded.to(self.dtype) + return source_encoded, perception_cache + + def _tokens_to_strings(self, token_ids: torch.Tensor) -> list[str]: + """Convert a [B, T] tensor of token IDs to a list of strings. + + Uses ``_decode_tokens_with_specials`` so byte-level BPE is decoded + properly (e.g. ``âĢĻ`` -> ``'``) via HF ``convert_tokens_to_string``. + + Leading spaces are preserved in the output: in byte-level BPE, + word-initial tokens carry a space prefix that ``convert_tokens_to_string`` + keeps intact. So callers can concatenate successive chunk strings to + recover properly spaced text. A leading space means "new word"; no + leading space means the token continues the previous word. For + example, three chunks producing ``"Hi"``, ``" how can"``, + ``" I help"`` concatenate to ``"Hi how can I help"`` (not + ``"Hihow canI help"``). + + NOTE: multi-byte UTF-8 characters whose BPE tokens span two frames + will show as replacement chars (U+FFFD) because each frame is decoded + independently. + """ + result = [] + for tok_ids_b in token_ids: + toks = self.tokenizer.ids_to_tokens(tok_ids_b.tolist()) + result.append(_decode_tokens_with_specials(toks, self.tokenizer, keep_pad=False)) + return result + + def abort_request(self, request_id: str | None) -> bool: + """ + Abort an in-flight vLLM streaming request if the backend supports it. + """ + if not request_id: + return False + + success = False + + # Abort LLM if applicable + if self.use_vllm_llm: + abort_fn = getattr(self.model_llm_interface, "abort_request", None) + if callable(abort_fn): + try: + if abort_fn(request_id): + success = True + logging.info(f"Aborted LLM request {request_id} successfully.") + except Exception as exc: + logging.warning(f"Failed to abort LLM request {request_id}: {exc}") + + # Abort EarTTS if applicable + if self.use_vllm_eartts and self.model_eartts_interface is not None: + abort_fn = getattr(self.model_eartts_interface, "abort_request", None) + if callable(abort_fn): + try: + if abort_fn(request_id): + success = True + logging.info(f"Aborted EarTTS request {request_id} successfully.") + except Exception as exc: + logging.warning(f"Failed to abort EarTTS request {request_id}: {exc}") + + return success + + + def _maybe_apply_forced_turn_taking(self, t, gen_text, gen_asr): + """Apply forced turn-taking rules based on ASR channel tokens.""" + if not self.model_cfg.get("force_turn_taking", False): + return + + threshold = self.model_cfg.get("force_turn_taking_threshold", 40) + pad_window_steps = self.model_cfg.get("force_turn_taking_pad_window", 25) + + B = gen_text.size(0) + + for batch_idx in range(B): + lookback_start = max(0, t - threshold) + agent_text_window = gen_text[batch_idx, lookback_start:t] + current_asr_token = gen_asr[batch_idx, t] + + # ASR EOS or ~1 sec of pad tokens → insert agent BOS if not present in window + # Skip if we don't have enough tokens at the beginning + if t < pad_window_steps: + continue + + pad_lookback_start = t - pad_window_steps + asr_recent_tokens = gen_asr[batch_idx, pad_lookback_start:t] + has_pad_window = (asr_recent_tokens == self.model.stt_model.text_pad_id).all() if len(asr_recent_tokens) > 0 else False + + # Require that the pad window starts after a non-pad token + if has_pad_window and pad_lookback_start > 0: + token_before_window = gen_asr[batch_idx, pad_lookback_start - 1] + has_pad_window = (token_before_window != self.model.stt_model.text_pad_id) and (token_before_window != self.model.stt_model.text_bos_id) + elif has_pad_window and pad_lookback_start == 0: + # If the pad window starts at position 0, it doesn't meet the requirement + has_pad_window = False + + if has_pad_window: + if not (agent_text_window == self.model.stt_model.text_bos_id).any(): + gen_text[batch_idx, t] = self.model.stt_model.text_bos_id + logging.info(f"Forced turn-taking at frame {t}: inserted agent BOS (reason: pad window)") + + # ASR BOS → insert agent EOS if not present in window + elif current_asr_token == self.model.stt_model.text_bos_id: + if not (agent_text_window == self.model.stt_model.text_eos_id).any(): + gen_text[batch_idx, t] = self.model.stt_model.text_eos_id + logging.info(f"Forced turn-taking at frame {t}: inserted agent EOS (reason: user started speaking)") + diff --git a/nemo/collections/speechlm2/inference/model_wrappers/perception_cache.py b/nemo/collections/speechlm2/inference/model_wrappers/perception_cache.py new file mode 100644 index 000000000000..146521ed36ce --- /dev/null +++ b/nemo/collections/speechlm2/inference/model_wrappers/perception_cache.py @@ -0,0 +1,524 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Cache-aware perception encoder for streaming S2S inference. + +Provides incremental mel-spectrogram encoding with optional CUDA graph +acceleration, so that only new audio needs to be processed each step +instead of re-encoding the entire buffer. +""" + +import copy +from dataclasses import dataclass + +import torch +from omegaconf import OmegaConf + +from nemo.utils import logging + + +@dataclass +class PerceptionCacheState: + """Cache state for streaming perception inference. + + Holds the cache tensors for the ASR encoder used in the perception module. + This enables cache-aware streaming inference without needing the full audio buffer. + """ + cache_last_channel: torch.Tensor | None = None + cache_last_time: torch.Tensor | None = None + cache_last_channel_len: torch.Tensor | None = None + + def is_initialized(self) -> bool: + """Check if the cache has been initialized.""" + return None not in [self.cache_last_channel, self.cache_last_time, self.cache_last_channel_len] + + +@dataclass +class PerceptionCUDAGraphState: + """State for CUDA graph-accelerated perception encoder. + + Holds separate graphs for first chunk (different size) and subsequent chunks. + Also holds static buffers for inputs/outputs to enable graph replay. + """ + # CUDA graphs + graph_first: torch.cuda.CUDAGraph | None = None + graph_subsequent: torch.cuda.CUDAGraph | None = None + + # Static input buffers (for copying data before graph replay) + static_mel_first: torch.Tensor | None = None + static_mel_subsequent: torch.Tensor | None = None + static_mel_len_first: torch.Tensor | None = None + static_mel_len_subsequent: torch.Tensor | None = None + + # Static cache input buffers + static_cache_channel_in: torch.Tensor | None = None + static_cache_time_in: torch.Tensor | None = None + static_cache_channel_len_in: torch.Tensor | None = None + + # Static output buffers (results are written here during replay) + static_encoded_first: torch.Tensor | None = None + static_encoded_subsequent: torch.Tensor | None = None + static_encoded_len_first: torch.Tensor | None = None + static_encoded_len_subsequent: torch.Tensor | None = None + + # Static cache output buffers - SEPARATE for first and subsequent graphs + # (each graph writes to its own output tensors during replay) + static_cache_channel_out_first: torch.Tensor | None = None + static_cache_time_out_first: torch.Tensor | None = None + static_cache_channel_len_out_first: torch.Tensor | None = None + static_cache_channel_out_subsequent: torch.Tensor | None = None + static_cache_time_out_subsequent: torch.Tensor | None = None + static_cache_channel_len_out_subsequent: torch.Tensor | None = None + + def is_captured(self) -> bool: + """Check if graphs have been captured.""" + return self.graph_first is not None and self.graph_subsequent is not None + + +class PerceptionCacheManager: + """Manages cache-aware streaming perception encoding with optional CUDA graphs. + + This class encapsulates all perception cache setup, CUDA graph capture, + and the incremental encoding step. It is created by the inference wrapper + when ``use_perception_cache=True``. + """ + + def __init__(self, model, device: torch.device, dtype: torch.dtype, use_cudagraph: bool = False): + self.model = model + self.device = device + self.dtype = dtype + self.use_cudagraph = use_cudagraph + + self.streaming_cfg = None + self.preprocessor = None + self.subsampling_factor = None + self.input_features = None + self.sampling_frames = None + self.cudagraph_state: PerceptionCUDAGraphState | None = None + + def setup(self) -> bool: + """Setup cache-aware streaming for the perception encoder. + + Returns: + True if setup succeeded, False if the encoder doesn't support streaming. + """ + from nemo.collections.asr.parts.mixins.streaming import StreamingEncoder + + perception = self.model.stt_model.perception + encoder = perception.encoder + + if not isinstance(encoder, StreamingEncoder): + logging.warning("Perception encoder does not support streaming. Disabling perception cache.") + return False + + if encoder.streaming_cfg is None: + encoder.setup_streaming_params() + + self.streaming_cfg = encoder.streaming_cfg + + cfg = copy.deepcopy(perception.cfg) + OmegaConf.set_struct(cfg.preprocessor, False) + cfg.preprocessor.dither = 0.0 + cfg.preprocessor.pad_to = 0 + + self.preprocessor = perception.from_config_dict(cfg.preprocessor) + self.preprocessor.to(self.device) + + self.subsampling_factor = encoder.subsampling_factor + self.input_features = encoder._feat_in + + if hasattr(encoder, "pre_encode") and hasattr(encoder.pre_encode, "get_sampling_frames"): + self.sampling_frames = encoder.pre_encode.get_sampling_frames() + else: + self.sampling_frames = None + + logging.info(f"Perception cache setup complete:") + logging.info(f" Streaming config: chunk_size={self.streaming_cfg.chunk_size}, " + f"shift_size={self.streaming_cfg.shift_size}") + logging.info(f" Pre-encode cache size: {self.streaming_cfg.pre_encode_cache_size}") + logging.info(f" Subsampling factor: {self.subsampling_factor}") + + if self.use_cudagraph: + logging.info(f" Setting up CUDA graphs for perception encoder...") + self.capture_cudagraphs() + logging.info(f" CUDA graphs captured") + + return True + + def capture_cudagraphs(self): + """Capture CUDA graphs for perception encoder with both chunk sizes. + + Note: "chunk" in the streaming encoder config (chunk_size, shift_size, etc.) + follows NeMo's cache-aware streaming encoder API and is measured in + mel-spectrogram time-steps, not audio samples or seconds. + """ + encoder = self.model.stt_model.perception.encoder + perception = self.model.stt_model.perception + streaming_cfg = self.streaming_cfg + + if isinstance(streaming_cfg.chunk_size, list): + chunk_size_first = streaming_cfg.chunk_size[0] + chunk_size_subsequent = streaming_cfg.chunk_size[1] + else: + chunk_size_first = streaming_cfg.chunk_size + chunk_size_subsequent = streaming_cfg.chunk_size + + if isinstance(streaming_cfg.pre_encode_cache_size, list): + pre_encode_cache_first = streaming_cfg.pre_encode_cache_size[0] + pre_encode_cache_subsequent = streaming_cfg.pre_encode_cache_size[1] + else: + pre_encode_cache_first = streaming_cfg.pre_encode_cache_size + pre_encode_cache_subsequent = streaming_cfg.pre_encode_cache_size + + mel_len_first = chunk_size_first + pre_encode_cache_first + mel_len_subsequent = chunk_size_subsequent + pre_encode_cache_subsequent + + logging.info(f" CUDA graph mel lengths: first={mel_len_first}, subsequent={mel_len_subsequent}") + + cache_last_channel, cache_last_time, cache_last_channel_len = encoder.get_initial_cache_state( + batch_size=1 + ) + + state = PerceptionCUDAGraphState() + + state.static_mel_first = torch.zeros( + (1, self.input_features, mel_len_first), + dtype=torch.float32, device=self.device + ) + state.static_mel_subsequent = torch.zeros( + (1, self.input_features, mel_len_subsequent), + dtype=torch.float32, device=self.device + ) + state.static_mel_len_first = torch.tensor([mel_len_first], dtype=torch.long, device=self.device) + state.static_mel_len_subsequent = torch.tensor([mel_len_subsequent], dtype=torch.long, device=self.device) + + if cache_last_channel is not None: + state.static_cache_channel_in = cache_last_channel.clone() + if cache_last_time is not None: + state.static_cache_time_in = cache_last_time.clone() + if cache_last_channel_len is not None: + state.static_cache_channel_len_in = cache_last_channel_len.clone() + + logging.info(f" Warming up encoder for CUDA graph capture...") + for _ in range(3): + with torch.no_grad(): + _ = encoder.cache_aware_stream_step( + processed_signal=state.static_mel_first, + processed_signal_length=state.static_mel_len_first, + cache_last_channel=state.static_cache_channel_in.clone() if state.static_cache_channel_in is not None else None, + cache_last_time=state.static_cache_time_in.clone() if state.static_cache_time_in is not None else None, + cache_last_channel_len=state.static_cache_channel_len_in.clone() if state.static_cache_channel_len_in is not None else None, + keep_all_outputs=True, + drop_extra_pre_encoded=0, + ) + _ = encoder.cache_aware_stream_step( + processed_signal=state.static_mel_subsequent, + processed_signal_length=state.static_mel_len_subsequent, + cache_last_channel=state.static_cache_channel_in.clone() if state.static_cache_channel_in is not None else None, + cache_last_time=state.static_cache_time_in.clone() if state.static_cache_time_in is not None else None, + cache_last_channel_len=state.static_cache_channel_len_in.clone() if state.static_cache_channel_len_in is not None else None, + keep_all_outputs=True, + drop_extra_pre_encoded=streaming_cfg.drop_extra_pre_encoded, + ) + torch.cuda.synchronize() + + # Capture graph for FIRST chunk + logging.info(f" Capturing CUDA graph for first chunk (mel_len={mel_len_first})...") + state.graph_first = torch.cuda.CUDAGraph() + + if state.static_cache_channel_in is not None: + state.static_cache_channel_in.copy_(cache_last_channel) + if state.static_cache_time_in is not None: + state.static_cache_time_in.copy_(cache_last_time) + if state.static_cache_channel_len_in is not None: + state.static_cache_channel_len_in.copy_(cache_last_channel_len) + + with torch.cuda.graph(state.graph_first): + ( + encoded_first, + encoded_len_first, + cache_channel_out_first, + cache_time_out_first, + cache_channel_len_out_first, + ) = encoder.cache_aware_stream_step( + processed_signal=state.static_mel_first, + processed_signal_length=state.static_mel_len_first, + cache_last_channel=state.static_cache_channel_in, + cache_last_time=state.static_cache_time_in, + cache_last_channel_len=state.static_cache_channel_len_in, + keep_all_outputs=True, + drop_extra_pre_encoded=0, + ) + encoded_adapted_first, _ = perception.modality_adapter(audio_signal=encoded_first, length=encoded_len_first) + encoded_chunk_first = perception.proj(encoded_adapted_first.transpose(1, 2)) + + state.static_encoded_first = encoded_chunk_first + state.static_encoded_len_first = encoded_len_first + state.static_cache_channel_out_first = cache_channel_out_first + state.static_cache_time_out_first = cache_time_out_first + state.static_cache_channel_len_out_first = cache_channel_len_out_first + + # Capture graph for SUBSEQUENT chunks + logging.info(f" Capturing CUDA graph for subsequent chunks (mel_len={mel_len_subsequent})...") + state.graph_subsequent = torch.cuda.CUDAGraph() + + if state.static_cache_channel_in is not None: + state.static_cache_channel_in.copy_(cache_last_channel) + if state.static_cache_time_in is not None: + state.static_cache_time_in.copy_(cache_last_time) + if state.static_cache_channel_len_in is not None: + state.static_cache_channel_len_in.copy_(cache_last_channel_len) + + with torch.cuda.graph(state.graph_subsequent): + ( + encoded_subsequent, + encoded_len_subsequent, + cache_channel_out_subsequent, + cache_time_out_subsequent, + cache_channel_len_out_subsequent, + ) = encoder.cache_aware_stream_step( + processed_signal=state.static_mel_subsequent, + processed_signal_length=state.static_mel_len_subsequent, + cache_last_channel=state.static_cache_channel_in, + cache_last_time=state.static_cache_time_in, + cache_last_channel_len=state.static_cache_channel_len_in, + keep_all_outputs=True, + drop_extra_pre_encoded=streaming_cfg.drop_extra_pre_encoded, + ) + encoded_adapted_subsequent, _ = perception.modality_adapter(audio_signal=encoded_subsequent, length=encoded_len_subsequent) + encoded_chunk_subsequent = perception.proj(encoded_adapted_subsequent.transpose(1, 2)) + + state.static_encoded_subsequent = encoded_chunk_subsequent + state.static_encoded_len_subsequent = encoded_len_subsequent + state.static_cache_channel_out_subsequent = cache_channel_out_subsequent + state.static_cache_time_out_subsequent = cache_time_out_subsequent + state.static_cache_channel_len_out_subsequent = cache_channel_len_out_subsequent + + self.cudagraph_state = state + logging.info(f" CUDA graphs captured successfully") + + def get_initial_state(self, batch_size: int = 1) -> PerceptionCacheState: + """Get initial cache state for perception encoder.""" + encoder = self.model.stt_model.perception.encoder + cache_last_channel, cache_last_time, cache_last_channel_len = encoder.get_initial_cache_state( + batch_size=batch_size + ) + + return PerceptionCacheState( + cache_last_channel=cache_last_channel, + cache_last_time=cache_last_time, + cache_last_channel_len=cache_last_channel_len, + ) + + def step( + self, + audio_input: torch.Tensor, + frame_idx: int, + num_frames_per_chunk: int, + perception_cache: PerceptionCacheState, + ) -> tuple[torch.Tensor, PerceptionCacheState]: + """ + Perform cache-aware perception encoding for streaming inference. + + Note: "chunk" in this method (chunk_size, mel_chunk, etc.) follows NeMo's + cache-aware streaming encoder API and is measured in mel-spectrogram time-steps, + not audio samples or seconds. + + This method computes the full mel spectrogram from the audio buffer, then slices + it appropriately based on the frame index. It supports processing multiple + "base steps" in a single call, where each base step processes (lookahead + 1) frames. + + Processing logic per sub-step: + - First sub-step (sub_frame_idx == 0): take first chunk_size_first columns, + prepend zeros for pre_encode_cache + - Subsequent sub-steps (sub_frame_idx > 0): take chunk_size columns starting from + (shift_size_first + (step_number-1)*shift_size), prepend pre_encode_cache_size + columns from mel spec + + The method loops over sub-steps, running the encoder for each and concatenating + the outputs. This allows num_frames_per_chunk to be a multiple of (lookahead + 1). + + Args: + audio_input: Audio buffer tensor [B, T] (full buffer with all samples) + frame_idx: Current frame index in the stream + num_frames_per_chunk: Number of 80ms frames to process. Must be a multiple + of (lookahead + 1), i.e., encoder._cfg.att_context_size[1] + 1 + perception_cache: Current cache state containing encoder caches + + Returns: + Tuple of (encoded_output [B, T_out, D], updated_perception_cache) + where T_out = num_frames_per_chunk (one output frame per input frame) + """ + perception = self.model.stt_model.perception + encoder = perception.encoder + streaming_cfg = self.streaming_cfg + + audio_len = torch.tensor([audio_input.shape[1]], dtype=torch.long, device=self.device) + processed_signal, _ = self.preprocessor( + input_signal=audio_input, + length=audio_len, + ) + + if isinstance(streaming_cfg.chunk_size, list): + chunk_size_first = streaming_cfg.chunk_size[0] + chunk_size = streaming_cfg.chunk_size[1] + else: + chunk_size_first = streaming_cfg.chunk_size + chunk_size = streaming_cfg.chunk_size + + if isinstance(streaming_cfg.shift_size, list): + shift_size_first = streaming_cfg.shift_size[0] + shift_size = streaming_cfg.shift_size[1] + else: + shift_size_first = streaming_cfg.shift_size + shift_size = streaming_cfg.shift_size + + if isinstance(streaming_cfg.pre_encode_cache_size, list): + pre_encode_cache_size_first = streaming_cfg.pre_encode_cache_size[0] + pre_encode_cache_size = streaming_cfg.pre_encode_cache_size[1] + else: + pre_encode_cache_size_first = streaming_cfg.pre_encode_cache_size + pre_encode_cache_size = streaming_cfg.pre_encode_cache_size + + cache_last_channel = perception_cache.cache_last_channel + cache_last_time = perception_cache.cache_last_time + cache_last_channel_len = perception_cache.cache_last_channel_len + + base_step_size = encoder._cfg.att_context_size[1] + 1 + if num_frames_per_chunk % base_step_size != 0: + raise ValueError( + f"num_frames_per_chunk must be a multiple of (lookahead + 1) = {base_step_size}. " + f"Got num_frames_per_chunk={num_frames_per_chunk}" + ) + num_sub_steps = num_frames_per_chunk // base_step_size + + encoded_chunks = [] + + for sub_step in range(num_sub_steps): + sub_frame_idx = frame_idx + (sub_step * base_step_size) + is_first_sub_step = (sub_frame_idx == 0) + + if is_first_sub_step: + cur_chunk_size = chunk_size_first + cur_pre_encode_cache_size = pre_encode_cache_size_first + drop_extra_pre_encoded = 0 + + mel_chunk = processed_signal[:, :, :cur_chunk_size] + + if cur_pre_encode_cache_size > 0: + zeros_pad = torch.zeros( + (processed_signal.size(0), self.input_features, cur_pre_encode_cache_size), + device=self.device, + dtype=processed_signal.dtype, + ) + mel_chunk = torch.cat([zeros_pad, mel_chunk], dim=-1) + else: + cur_chunk_size = chunk_size + cur_pre_encode_cache_size = pre_encode_cache_size + drop_extra_pre_encoded = streaming_cfg.drop_extra_pre_encoded + + mel_T = processed_signal.shape[-1] + + step_number = sub_frame_idx // base_step_size + chunk_start = shift_size_first + (step_number - 1) * shift_size + chunk_end = chunk_start + cur_chunk_size + + offset = chunk_size - shift_size_first + if chunk_end > mel_T - offset: + sub_steps_remaining = num_sub_steps - 1 - sub_step + chunk_end = mel_T - offset - sub_steps_remaining * shift_size + chunk_start = chunk_end - cur_chunk_size + + main_chunk = processed_signal[:, :, chunk_start:chunk_end] + + cache_start = max(0, chunk_start - cur_pre_encode_cache_size) + cache_mel = processed_signal[:, :, cache_start:chunk_start] + + if cache_mel.shape[-1] < cur_pre_encode_cache_size: + zeros_pad = torch.zeros( + (cache_mel.size(0), cache_mel.size(1), cur_pre_encode_cache_size - cache_mel.shape[-1]), + device=self.device, + dtype=cache_mel.dtype, + ) + cache_mel = torch.cat([zeros_pad, cache_mel], dim=-1) + + mel_chunk = torch.cat([cache_mel, main_chunk], dim=-1) + + chunk_lengths = torch.tensor([mel_chunk.shape[-1]], dtype=torch.long, device=self.device) + + if self.use_cudagraph and self.cudagraph_state is not None and self.cudagraph_state.is_captured(): + graph_state = self.cudagraph_state + + if is_first_sub_step: + graph_state.static_mel_first.copy_(mel_chunk) + else: + graph_state.static_mel_subsequent.copy_(mel_chunk) + + if graph_state.static_cache_channel_in is not None and cache_last_channel is not None: + graph_state.static_cache_channel_in.copy_(cache_last_channel) + if graph_state.static_cache_time_in is not None and cache_last_time is not None: + graph_state.static_cache_time_in.copy_(cache_last_time) + if graph_state.static_cache_channel_len_in is not None and cache_last_channel_len is not None: + graph_state.static_cache_channel_len_in.copy_(cache_last_channel_len) + + if is_first_sub_step: + graph_state.graph_first.replay() + encoded_chunk = graph_state.static_encoded_first.clone() + cache_last_channel = graph_state.static_cache_channel_out_first.clone() if graph_state.static_cache_channel_out_first is not None else None + cache_last_time = graph_state.static_cache_time_out_first.clone() if graph_state.static_cache_time_out_first is not None else None + cache_last_channel_len = graph_state.static_cache_channel_len_out_first.clone() if graph_state.static_cache_channel_len_out_first is not None else None + else: + graph_state.graph_subsequent.replay() + encoded_chunk = graph_state.static_encoded_subsequent.clone() + cache_last_channel = graph_state.static_cache_channel_out_subsequent.clone() if graph_state.static_cache_channel_out_subsequent is not None else None + cache_last_time = graph_state.static_cache_time_out_subsequent.clone() if graph_state.static_cache_time_out_subsequent is not None else None + cache_last_channel_len = graph_state.static_cache_channel_len_out_subsequent.clone() if graph_state.static_cache_channel_len_out_subsequent is not None else None + + else: + ( + encoded, + encoded_len, + cache_last_channel, + cache_last_time, + cache_last_channel_len, + ) = encoder.cache_aware_stream_step( + processed_signal=mel_chunk, + processed_signal_length=chunk_lengths, + cache_last_channel=cache_last_channel, + cache_last_time=cache_last_time, + cache_last_channel_len=cache_last_channel_len, + keep_all_outputs=True, + drop_extra_pre_encoded=drop_extra_pre_encoded, + ) + + modality_adapter = perception.modality_adapter + encoded_adapted, _ = modality_adapter(audio_signal=encoded, length=encoded_len) + + encoded_chunk = perception.proj(encoded_adapted.transpose(1, 2)) + + encoded_chunks.append(encoded_chunk) + + if len(encoded_chunks) > 1: + encoded_chunk = torch.cat(encoded_chunks, dim=1) + else: + encoded_chunk = encoded_chunks[0] + + new_perception_cache = PerceptionCacheState( + cache_last_channel=cache_last_channel, + cache_last_time=cache_last_time, + cache_last_channel_len=cache_last_channel_len, + ) + + return encoded_chunk, new_perception_cache diff --git a/nemo/collections/speechlm2/inference/pipelines/__init__.py b/nemo/collections/speechlm2/inference/pipelines/__init__.py new file mode 100644 index 000000000000..9e3fb699d9f6 --- /dev/null +++ b/nemo/collections/speechlm2/inference/pipelines/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/speechlm2/inference/pipelines/s2s_pipeline_interface.py b/nemo/collections/speechlm2/inference/pipelines/s2s_pipeline_interface.py new file mode 100644 index 000000000000..8948739f205b --- /dev/null +++ b/nemo/collections/speechlm2/inference/pipelines/s2s_pipeline_interface.py @@ -0,0 +1,80 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Any + +from nemo.collections.speechlm2.inference.streaming.framing.s2s_request_options import S2SRequestOptions + + +class S2SPipelineInterface: + """Base class for all streaming S2S pipelines. + + This class is intentionally kept minimal and mirrors the behaviour of + ``BasePipeline`` that is used for streaming ASR pipelines. It + provides a small in-memory *state pool* that stores per-stream objects + (cache, running buffers, etc.) required by a concrete pipeline + implementation. Sub-classes are expected to implement + :py:meth:`create_state` to construct a fresh state object. + """ + + def __init__(self) -> None: + # Pool that holds per-stream state, keyed by ``stream_id`` + self._state_pool: dict[int, Any] = {} + + # ------------------------------------------------------------------ + # State helpers + # ------------------------------------------------------------------ + def get_state(self, stream_id: int): + """Return the state object for *stream_id* or *None* if it does not exist.""" + return self._state_pool.get(stream_id, None) + + def delete_state(self, stream_id: int) -> None: + """Delete the state associated with *stream_id* (noop if missing).""" + if stream_id in self._state_pool: + del self._state_pool[stream_id] + + def create_state(self, options: S2SRequestOptions | None = None): + """Create and return a *new*, *empty* state object. + + Args: + options: Per-stream request options (system prompt, sampling + overrides, etc.). Stored on the state so they can be + consulted throughout the stream's lifetime. + + Must be implemented by concrete pipelines. + """ + raise NotImplementedError("`create_state()` must be implemented in a subclass.") + + def get_or_create_state(self, stream_id: int, options: S2SRequestOptions | None = None): + """Return existing state for *stream_id* or create a new one via :py:meth:`create_state`.""" + if stream_id not in self._state_pool: + self._state_pool[stream_id] = self.create_state(options) + return self._state_pool[stream_id] + + # ------------------------------------------------------------------ + # Session helpers – identical to *BasePipeline* + # ------------------------------------------------------------------ + def reset_session(self) -> None: + """Clear the internal *state pool* – effectively resetting the pipeline.""" + self._state_pool.clear() + + def open_session(self) -> None: + """Alias for :py:meth:`reset_session` to start a fresh streaming session.""" + self.reset_session() + + def close_session(self) -> None: + """Alias for :py:meth:`reset_session` to end the current streaming session.""" + self.reset_session() diff --git a/nemo/collections/speechlm2/inference/pipelines/streaming_s2s_pipeline.py b/nemo/collections/speechlm2/inference/pipelines/streaming_s2s_pipeline.py new file mode 100644 index 000000000000..7fe64526694a --- /dev/null +++ b/nemo/collections/speechlm2/inference/pipelines/streaming_s2s_pipeline.py @@ -0,0 +1,839 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import os +import time +from dataclasses import dataclass + +import soundfile as sf +import torch +import librosa +from torch import Tensor +from omegaconf import DictConfig + +from nemo.collections.asr.inference.streaming.framing.request import Frame +from nemo.collections.asr.inference.utils.enums import RequestType +from nemo.collections.asr.inference.streaming.buffering.audio_bufferer import BatchedAudioBufferer +from nemo.collections.speechlm2.inference.utils.stepprogressbar import StepProgressBar +from nemo.collections.speechlm2.inference.pipelines.s2s_pipeline_interface import S2SPipelineInterface +from nemo.collections.speechlm2.inference.streaming.state.s2s_state import S2SStreamingState +from nemo.collections.speechlm2.inference.model_wrappers.decode_state import NullTimingSummary +from nemo.collections.speechlm2.inference.model_wrappers.nemotron_voicechat_inference_wrapper import NemotronVoicechatInferenceWrapper +from nemo.collections.speechlm2.parts.text_utils import tokens_to_str +from nemo.collections.speechlm2.inference.streaming.state.s2s_context_manager import S2SContextManager +from nemo.collections.speechlm2.inference.streaming.framing.s2s_request_options import S2SRequestOptions +from nemo.collections.speechlm2.inference.streaming.framing.silence_padded_frame_streamer import SilencePaddedContinuousBatchedFrameStreamer +from nemo.collections.speechlm2.inference.utils.pipeline_utils import PipelineOutput +from nemo.utils import logging + + +@dataclass +class GenerateStepOutput: + """Output of a single :meth:`StreamingS2SPipeline.generate_step` call + for one stream. + + Analogous to :class:`TranscribeStepOutput` in the ASR pipelines, + this carries the **incremental** (new-this-step) audio and text so + that callers don't have to diff against accumulated state. + + The underlying :class:`S2SStreamingState` still accumulates + everything for batch/offline use. + """ + + stream_id: int + audio: torch.Tensor + text: str = "" + asr_text: str = "" + + +class StreamingS2SPipeline(S2SPipelineInterface): + """ + Streaming S2S pipeline. + """ + + def __init__(self, cfg: DictConfig, s2s_model: NemotronVoicechatInferenceWrapper): + # ------------------------------------------------------------------ + # Model & device + # ------------------------------------------------------------------ + self.s2s_model = s2s_model + self.device = self.s2s_model.device + self.decode_audio = self.s2s_model.decode_audio + self.collect_debug = False + + # ------------------------------------------------------------------ + # Streaming configuration + # ------------------------------------------------------------------ + self.streaming_cfg = cfg.get("streaming", {}) + self.input_sample_rate = getattr(self.streaming_cfg, "input_sample_rate", 16000) + self.output_sample_rate = getattr(self.streaming_cfg, "output_sample_rate", 22050) + self.batch_size = getattr(self.streaming_cfg, "batch_size", 1) + self.max_len = getattr(self.streaming_cfg, "max_len", 200) + if self.batch_size != 1: + raise ValueError( + "StreamingS2SPipeline currently supports only single-stream inference " + "(streaming.batch_size must be 1)." + ) + + # ------------------------------------------------------------------ + # Chunk & buffer sizes + # Terminology: "frame" = 80ms audio unit, "chunk" = 1 or more frames + # A chunk is the amount of audio that is processed per inference step. + # ------------------------------------------------------------------ + self.chunk_size_in_secs = getattr(self.streaming_cfg, "chunk_size_in_secs", 0.08) + # Check if self.chunk_size_in_secs is a multiple of 0.08. + # Because of quirks of floating point arithmetic, the remainder could be either ~0 or ~0.08, + # so we check for both cases. + remainder = self.chunk_size_in_secs % 0.08 + if not (math.isclose(remainder, 0, abs_tol=1e-9) or math.isclose(remainder, 0.08, abs_tol=1e-9)): + raise ValueError(f"Chunk size must be a multiple of 0.08s, but got {self.chunk_size_in_secs}") + + self.num_frames_per_chunk = int(self.chunk_size_in_secs / 0.08) + + # Buffer size determines how much audio is passed to the perception encoder + # Default: 5.68 seconds (71 * 0.08). This is the minimum valid buffer size without the perception cache. + # i.e. att_context_size[0] + att_context_size[1] + 1 frames = 70+0+1 = 71 frames = 5.68 seconds + self.buffer_size_in_secs = getattr(self.streaming_cfg, "buffer_size_in_secs", 71 * 0.08) + + self.att_context_size = getattr(self.streaming_cfg, "att_context_size", [70,0]) + + # ------------------------------------------------------------------ + # bufferer – reused from ASR utilities + # ------------------------------------------------------------------ + self.bufferer = BatchedAudioBufferer( + sample_rate=self.input_sample_rate, + buffer_size_in_secs=self.buffer_size_in_secs, + ) + + # ------------------------------------------------------------------ + # System prompt & sampling defaults (from YAML s2s block) + # ------------------------------------------------------------------ + s2s_cfg = cfg.get("s2s", {}) + self.system_prompt: str | None = getattr(s2s_cfg, "system_prompt", None) + if self.system_prompt: + logging.info(f"System prompt configured: {self.system_prompt[:100]}{'...' if len(self.system_prompt) > 100 else ''}") + + self._default_top_p: float | None = getattr(s2s_cfg, "top_p", None) + self._default_temperature: float | None = getattr(s2s_cfg, "temperature", None) + self._default_repetition_penalty: float | None = getattr(s2s_cfg, "repetition_penalty", None) + + # Context manager + self.context_manager = S2SContextManager( + s2s_model=self.s2s_model, + num_slots=self.batch_size, + max_len=self.max_len, + ) + + # Output directory for generated files + self.output_dir = getattr(cfg, "output_dir", "./generated") + + # Parse and validate request type early, with a safe default + req_type_cfg = getattr(self.streaming_cfg, "request_type", "frame") + + # Parse and validate the request type; only 'frame' is supported for s2s. + self.request_type = RequestType.from_str(req_type_cfg) + if self.request_type is not RequestType.FRAME: + raise ValueError(f"Request type {self.request_type} is not supported for s2s.") + + self._stream_has_prompt: bool = False + + # ------------------------------------------------------------------ + # Input audio padding (silence appended after real audio) + # ------------------------------------------------------------------ + self.pad_audio_to_sec: float | None = cfg.get("pad_audio_to_sec", None) + self.pad_silence_ratio: float | None = cfg.get("pad_silence_ratio", None) + self.pad_audio_by_sec: float | None = cfg.get("pad_audio_by_sec", None) + if sum(x is not None for x in [self.pad_audio_to_sec, self.pad_silence_ratio, self.pad_audio_by_sec]) > 1: + raise ValueError("Set at most one of: pad_audio_to_sec, pad_silence_ratio, pad_audio_by_sec") + + super().__init__() + + # ------------------------------------------------------------------ + # State helpers + # ------------------------------------------------------------------ + def create_state(self, options: S2SRequestOptions | None = None) -> S2SStreamingState: + """Create new empty state with optional per-stream options.""" + dtype = getattr(self.s2s_model, "dtype", torch.float32) + return S2SStreamingState( + device=self.device, + dtype=dtype, + output_sample_rate=self.output_sample_rate, + options=options or S2SRequestOptions(), + ) + + def _init_state(self, stream_id: int, options: S2SRequestOptions | None = None) -> None: + """Initialize a new stream: resolve defaults, create state, create context, prefill. + + This is the S2S equivalent of ASR's ``init_state()`` in ``BasePipeline``. + Called automatically by :meth:`generate_step` when a frame has + ``is_first=True``. + + The method always runs stream initialization (state creation, + context-manager allocation, KV-cache prefill). If the triggering + frame also carries audio, :meth:`generate_step` will process it + immediately after this method returns. For latency-sensitive + deployments (real-time voice chat), callers should send the first + frame with **empty audio** so that prefill completes before the + user starts speaking — this prevents audio from queuing up during + the expensive prefill phase. + """ + # Normalize empty-string prompts to None so augment_with_defaults + # fills in the YAML default instead of treating "" as "set". + raw_opts = options or S2SRequestOptions() + if raw_opts.system_prompt is not None and not raw_opts.system_prompt.strip(): + raw_opts = S2SRequestOptions( + system_prompt=None, + top_p=raw_opts.top_p, + temperature=raw_opts.temperature, + repetition_penalty=raw_opts.repetition_penalty, + ) + opts = raw_opts.augment_with_defaults( + default_system_prompt=self.system_prompt, + default_top_p=self._default_top_p, + default_temperature=self._default_temperature, + default_repetition_penalty=self._default_repetition_penalty, + ) + self.get_or_create_state(stream_id, options=opts) + + self.context_manager = S2SContextManager( + s2s_model=self.s2s_model, + num_slots=self.batch_size, + max_len=self.max_len, + ) + + # Prefill can take hundreds of ms, or even tens of seconds if the + # prompt is long and the model is not warmed up. + prompt = opts.system_prompt + start_prefill = time.time() + with torch.no_grad(), torch.inference_mode(): + self._prefill_system_prompt(stream_id, prompt) + torch.cuda.synchronize() + logging.info(f"_init_state: stream_id={stream_id}, prefill={1000*(time.time()-start_prefill):.1f}ms") + + # Will tell generate_step_for_frames whether the KV cache already contains + # a system prompt, so it can choose the right first-frame embedding + # (PAD tokens if prefilled, BOS tokens if not). Consumed and + # cleared on the first audio frame. + self._stream_has_prompt = bool(prompt) + + + def generate_step_for_frames(self, frames: list[Frame], buffers: list[Tensor]) -> list[GenerateStepOutput]: + """Generate speech for audio Frames using a shared ContextManager. + + This is the S2S equivalent of ASR's ``transcribe_step_for_frames`` + in ``BasePipeline``. Like its ASR counterpart, it is never called + directly — :meth:`generate_step` (the public API, analogous to + ``transcribe_step``) handles stream init and then delegates here + for the actual audio processing. + + Stream initialization (state, context, prefill) is always handled + by :meth:`_init_state` *before* this method is called. + """ + if len(frames) == 0: + return [] + + stream_ids = [f.stream_id for f in frames] + eos_flags = [f.is_last for f in frames] + + logging.debug(f"stream_ids={stream_ids} eos_flags={eos_flags}") + + if len(frames) != 1: + raise NotImplementedError("NemotronVoicechatInferenceWrapper currently supports batch_size == 1") + + has_prompt = self._stream_has_prompt + self._stream_has_prompt = False + + request_id = self._request_id_for_stream(stream_ids[0]) + + context, _ = self.context_manager.get_context(stream_ids) + + audio_buffer = buffers[0] + if audio_buffer.dim() == 1: + audio_buffer = audio_buffer.unsqueeze(0) + audio_buffer = audio_buffer.to(self.s2s_model.device, dtype=self.s2s_model.dtype) + + # Sampling overrides were resolved by _init_state via augment_with_defaults + # and stored on state.options. Build the dict for infer_one_step. + pipeline_state = self.get_state(stream_ids[0]) + if pipeline_state is None: + raise RuntimeError( + f"No state initialized for stream {stream_ids[0]}. " + "Clients must send an is_first=True frame before streaming audio." + ) + sampling_params = { + k: getattr(pipeline_state.options, k) + for k in ("top_p", "temperature", "repetition_penalty") + if getattr(pipeline_state.options, k, None) is not None + } + + result = self.s2s_model.infer_one_step( + audio_input=audio_buffer, + num_frames_per_chunk=self.num_frames_per_chunk, + state=context, + request_id=request_id, + has_prompt=has_prompt, + return_debug=self.collect_debug, + sampling_params=sampling_params or None, + ) + + if self.collect_debug and result.debug is not None: + state = self.get_or_create_state(stream_ids[0]) + if not hasattr(state, "debug_steps"): + state.debug_steps = [] + state.debug_steps.append(result.debug) + + # Persist updated cache & clean finished streams + self.context_manager.update_context(stream_ids, result, self.num_frames_per_chunk) + + # Save token tensors and timing from the decode context before it is + # destroyed by reset_slots. + timing_by_stream: dict[int, object] = {} + for stream_id, eos_flag in zip(stream_ids, eos_flags): + if eos_flag: + ctx = self.context_manager.slot_contexts[ + self.context_manager.streamidx2slotidx[stream_id] + ] + if ctx is not None: + state = self.get_or_create_state(stream_id) + state.save_token_tensors(ctx.gen_text, ctx.gen_asr_text, ctx.frame_idx, + gen_function_text=ctx.gen_function_text) + timing_by_stream[stream_id] = ctx.timing + + self.context_manager.reset_slots(stream_ids, eos_flags) + + # Log summary and clean up finished streams + for stream_id, eos_flag in zip(stream_ids, eos_flags): + if eos_flag: + state = self.get_state(stream_id) + audio_sec = state.audio_buffer.shape[-1] / self.output_sample_rate if self.output_sample_rate > 0 else 0 + logging.info( + f"Stream {stream_id} finished: {state.final_total_frames} frames, " + f"{audio_sec:.1f}s audio, " + f"agent: {state.output_text_str!r}, user: {state.output_asr_text_str!r}" + ) + + # Compact pad-visible summary (· replaces ) + token_data = state.get_token_tensors() + if token_data is not None: + gen_text, gen_asr_text, total_frames, _ = token_data + tokenizer = self.s2s_model.tokenizer + pad_id = self.s2s_model.model.stt_model.text_pad_id + lengths = torch.tensor([total_frames], dtype=torch.long) + raw_agent = tokens_to_str(gen_text, lengths, tokenizer=tokenizer, pad_id=pad_id, keep_pad=True)[0] + raw_user = tokens_to_str(gen_asr_text, lengths, tokenizer=tokenizer, pad_id=pad_id, keep_pad=True)[0] + compact_agent = raw_agent.replace('', '·') + compact_user = raw_user.replace('', '·') + logging.info(f"Stream {stream_id} agent (with padding): {compact_agent}") + logging.info(f"Stream {stream_id} user (with padding): {compact_user}") + + # Timing summary (no-op when profile_timing is off) + timing_by_stream.get(stream_id, NullTimingSummary()).log_summary( + label=f"Stream {stream_id}", + chunk_ms=self.chunk_size_in_secs * 1000, + ) + + self.bufferer.rm_bufferer(stream_id) + self._abort_stream_request(stream_id) + # Note: We keep the state in _state_pool until finalization to save audio + # It will be cleaned up in close_session() + + # Split the batch-level InferenceStepResult into per-frame outputs. + # Two things happen for each frame: + # 1. The incremental audio/text is appended to S2SStreamingState + # (the pipeline-level accumulator that persists across steps — + # used by run() to build the final PipelineOutput). + # 2. A GenerateStepOutput is built with the same incremental data + # and returned to the caller (used by server integrations to + # stream partial results to clients without diffing state). + outputs: list[GenerateStepOutput] = [] + for idx, frame in enumerate(frames): + state = self.get_state(frame.stream_id) + audio = result.decoded_audio[idx:idx+1] if result.decoded_audio is not None else None + text = result.predicted_text_strs[idx] if result.predicted_text_strs else "" + asr_text = result.asr_predicted_text_strs[idx] if result.asr_predicted_text_strs else "" + + state.append_step_output(audio, text=text, asr_text=asr_text) + + outputs.append(GenerateStepOutput( + stream_id=frame.stream_id, + audio=audio if audio is not None else torch.empty(1, 0), + text=text, + asr_text=asr_text, + )) + return outputs + + _WARMUP_FALLBACK_PROMPT = "Mock system prompt for warmup." + + def warmup(self, system_prompt: str | None = None) -> None: + """Run a throwaway inference cycle to warm up the entire pipeline. + + The very first call through each stage incurs one-time overhead + (e.g. CUDA graph compilation, memory pool allocation, + DynamicCache initialization, torch.compile). Sending a silence + frame with ``is_first=True`` exercises the full path — prefill, + perception, LLM decode, TTS, and codec — so the first real + client request is fast. + + Args: + system_prompt: Prompt text to use for warmup. Falls back to + the YAML-configured ``self.system_prompt``, then to a + short fallback string so the LLM prefill path is always + exercised. + """ + prompt = system_prompt if system_prompt is not None else self.system_prompt + if not prompt: + prompt = self._WARMUP_FALLBACK_PROMPT + logging.info(f"No system prompt configured — using fallback prompt for warmup: \"{prompt}\"") + + warmup_stream_id = -1 + chunk_samples = int(self.chunk_size_in_secs * self.input_sample_rate) + + logging.info("Running pipeline warmup (prefill + one silence chunk)...") + t0 = time.time() + + warmup_frame = Frame( + samples=torch.zeros(chunk_samples), + stream_id=warmup_stream_id, + is_first=True, + is_last=True, + options=S2SRequestOptions(system_prompt=prompt), + ) + self.generate_step([warmup_frame]) + + # Tear down everything so the engine is clean for real traffic + self.reset_session() + self._stream_has_prompt = False + + logging.info(f"Pipeline warmup complete in {time.time() - t0:.3f}s") + + def generate_step(self, frames: list[Frame]) -> list[GenerateStepOutput]: + """Main streaming API — handles both init and audio processing. + + Mirrors ASR's ``transcribe_step``: on ``is_first`` frames, the + stream is initialized via :meth:`_init_state` (state creation, + context-manager allocation, KV-cache prefill). If the frame also + carries input audio, it is processed in the same call. If there + is no input audio (e.g. a server prefill-only request), the + method returns after init without running inference. + + Returns one :class:`GenerateStepOutput` per input frame carrying + the **incremental** output audio and text produced by this step. + The output audio tensor may be empty when no waveform is produced + (prefill-only frames with no input audio, or when + ``decode_audio=False``). + + For latency-sensitive deployments, send the ``is_first`` frame + with **empty audio** so that the expensive prefill completes + before the user starts speaking. For batch/offline usage the + first frame can carry real audio — init and first-chunk + processing simply happen back-to-back in one call. + """ + # Init phase — like ASR's `if request.is_first: self.init_state(...)` + # Known limitation: _init_state runs prefill synchronously, so with + # batch_size > 1 a long system prompt on one stream will block + # decoding on all other streams in the same batch. + for frame in frames: + if frame.is_first: + self._init_state(frame.stream_id, frame.options) + + # Decode phase. + # Although generate_step_for_frames currently enforces batch_size==1, + # this method already handles mixed batches (a mix of prefill-only + # and audio-carrying frames) so it is ready for batch_size > 1. + # We run inference only on the non-empty subset, then stitch the + # outputs back into a list aligned 1:1 with the original *frames*. + non_empty_frames = [f for f in frames if f.samples.numel() > 0] + if not non_empty_frames: + # All frames are prefill-only — nothing to decode. + return [ + GenerateStepOutput(stream_id=f.stream_id, audio=torch.empty(1, 0)) + for f in frames + ] + + buffers, left_paddings = self.bufferer.update(non_empty_frames) + # This is a workaround for the fact that the audio buffer does left + # padding, but the rest of the code requires no padding at all. + buffers = [b[lp:] for b, lp in zip(buffers, left_paddings)] + with torch.no_grad(), torch.inference_mode(): + step_outputs = self.generate_step_for_frames(non_empty_frames, buffers) + + # Fast path: every frame had audio, so step_outputs is already 1:1. + if len(non_empty_frames) == len(frames): + return step_outputs + + # Slow path (batch_size > 1 with a mix of prefill and audio frames): + # fill in empty outputs for the prefill-only streams. + output_by_stream: dict[int, GenerateStepOutput] = {o.stream_id: o for o in step_outputs} + return [ + output_by_stream.get( + f.stream_id, + GenerateStepOutput(stream_id=f.stream_id, audio=torch.empty(1, 0)), + ) + for f in frames + ] + + # ------------------------------------------------------------------ + # Finalization helpers + # ------------------------------------------------------------------ + def _finalize_and_save_finished_streams( + self, + frames: list[Frame], + audio_filepaths: list[str], + saved_paths_by_stream: dict[int, str], + ) -> None: + """Finalize any streams that ended in this batch and save their outputs.""" + for frame in frames: + if frame.is_last: + stream_id = frame.stream_id + state = self.get_or_create_state(stream_id) + + if hasattr(state, "finalize"): + state.finalize() + + in_path = audio_filepaths[stream_id] + base = os.path.splitext(os.path.basename(in_path))[0] + txt_dir = os.path.join(self.output_dir, "txt") + os.makedirs(txt_dir, exist_ok=True) + + out_path = None + if self.decode_audio: + # Squeeze (B=1,C=1) to 1D mono waveform for soundfile + generated_audio = state.audio_buffer + if generated_audio.dim() == 3 and generated_audio.size(0) == 1 and generated_audio.size(1) == 1: + generated_audio = generated_audio.squeeze(0).squeeze(0) + elif generated_audio.dim() == 2 and generated_audio.size(0) == 1: + generated_audio = generated_audio.squeeze(0) + generated_audio = generated_audio.to(torch.float32) + + wav_dir = os.path.join(self.output_dir, "wav") + stereo_dir = os.path.join(self.output_dir, "stereo") + os.makedirs(wav_dir, exist_ok=True) + os.makedirs(stereo_dir, exist_ok=True) + + out_path = os.path.join(wav_dir, f"{base}.wav") + if generated_audio.numel() > 0: + sf.write(out_path, generated_audio.detach().cpu().numpy(), self.output_sample_rate) + + # Save a stereo file with input (ch0) and output (ch1) + input_np, _ = librosa.load(in_path, sr=self.output_sample_rate, mono=True) + input_audio = torch.from_numpy(input_np).to(torch.float32) + gen_cpu = generated_audio.detach().cpu().to(input_audio.dtype) + + # Prepend silence to output channel to account for the one-chunk + # processing delay: the server can't produce output until it has + # received a full input chunk. + delay_samples = int(self.chunk_size_in_secs * self.output_sample_rate) + silence = torch.zeros(delay_samples, dtype=gen_cpu.dtype) + gen_cpu = torch.cat([silence, gen_cpu], dim=-1) + + # Pad the shorter channel so both have equal length + gen_len = int(gen_cpu.shape[-1]) + in_len = int(input_audio.shape[-1]) + max_len = max(gen_len, in_len) + if in_len < max_len: + input_audio = torch.cat([input_audio, torch.zeros(max_len - in_len, dtype=input_audio.dtype)], dim=-1) + if gen_len < max_len: + gen_cpu = torch.cat([gen_cpu, torch.zeros(max_len - gen_len, dtype=gen_cpu.dtype)], dim=-1) + stereo = torch.stack([input_audio, gen_cpu], dim=0).transpose(0, 1) + stereo_path = os.path.join(stereo_dir, f"{base}_input_output.wav") + sf.write(stereo_path, stereo.detach().cpu().numpy(), self.output_sample_rate) + + text_out = state.output_text_str + if isinstance(text_out, str): + try: + with open(os.path.join(txt_dir, f"{base}.txt"), "w", encoding="utf-8") as f: + f.write(text_out) + except OSError: + logging.warning(f"Failed to write text output for {base}") + + asr_text_out = state.output_asr_text_str + if isinstance(asr_text_out, str) and asr_text_out: + try: + with open(os.path.join(txt_dir, f"{base}_asr.txt"), "w", encoding="utf-8") as f: + f.write(asr_text_out) + except OSError: + logging.warning(f"Failed to write ASR text output for {base}") + + saved_paths_by_stream[stream_id] = out_path + # Keep state in _state_pool until _build_pipeline_output; + # it will be cleared on close_session(). + + + # ------------------------------------------------------------------ + # Session helpers (extend S2SPipelineInterface) + # ------------------------------------------------------------------ + + def reset_session(self) -> None: + """Reset feature buffer and ContextManager together.""" + for stream_id in list(self.context_manager.streamidx2slotidx.keys()): + self._abort_stream_request(stream_id) + self.bufferer.reset() + self.context_manager.reset() + + super().reset_session() # clears state pool + + # ------------------------------------------------------------------ + # Orchestrator – mirrors recognizers' *run* method + # ------------------------------------------------------------------ + def run( + self, + audio_filepaths: list[str], + options: list[S2SRequestOptions] | None = None, + progress_bar: StepProgressBar | None = None, + ) -> PipelineOutput: + """Stream all *audio_filepaths* through the pipeline and save outputs. + + Saves one generated ``.wav`` per input under ``self.output_dir`` and + returns their paths in ``PipelineOutput.texts``. + + Args: + audio_filepaths: Paths to input audio files. + options: Per-stream request options (system prompt, sampling, etc.). + progress_bar: Optional :class:`StepProgressBar` for per-step + progress with per-stream postfix. + """ + + if options is None: + options = [S2SRequestOptions() for _ in audio_filepaths] + + streamer = SilencePaddedContinuousBatchedFrameStreamer( + n_frames_per_stream=1, + frame_size_in_secs=self.chunk_size_in_secs, + sample_rate=self.input_sample_rate, + batch_size=self.batch_size, + pad_last_frame=True, + pad_to_sec=self.pad_audio_to_sec, + pad_by_sec=self.pad_audio_by_sec, + pad_ratio=self.pad_silence_ratio, + ) + streamer.set_audio_filepaths(audio_filepaths, options) + + os.makedirs(self.output_dir, exist_ok=True) + saved_paths_by_stream: dict[int, str] = {} + + self.open_session() + for frames in streamer: + # generate_step returns per-step GenerateStepOutput objects + # (useful for server integrations that stream partial results + # to clients). Here we rely on the accumulated state instead, + # which _finalize_and_save_finished_streams reads on is_last. + self.generate_step(frames) + self._finalize_and_save_finished_streams(frames, audio_filepaths, saved_paths_by_stream) + + if progress_bar is not None: + for f in frames: + progress_bar.step(f.stream_id) + + if progress_bar is not None: + progress_bar.finish() + + output = self._build_pipeline_output(audio_filepaths, saved_paths_by_stream) + self.close_session() + return output + + def _build_pipeline_output( + self, + audio_filepaths: list[str], + saved_paths_by_stream: dict[int, str], + ) -> PipelineOutput: + """Assemble final ``PipelineOutput`` from accumulated per-stream state. + + Must be called *before* ``close_session()`` since it reads from the + state pool. + """ + texts = [] + words = [] + asr_texts = [] + texts_with_timestamps = [] + asr_texts_with_timestamps = [] + raw_texts = [] + raw_asr_texts = [] + token_texts = [] + token_asr_texts = [] + token_function_texts = [] + token_lengths = [] + audio_paths = [] + + tokenizer = self.s2s_model.tokenizer + pad_id = self.s2s_model.model.stt_model.text_pad_id + + for idx in range(len(audio_filepaths)): + state = self.get_or_create_state(idx) + text_value = state.output_text_str or saved_paths_by_stream.get(idx, "") + texts.append(text_value) + audio_paths.append(saved_paths_by_stream.get(idx)) + words.append(list(state.output_words)) + asr_texts.append(state.output_asr_text_str) + + token_data = state.get_token_tensors() + if token_data is not None: + gen_text, gen_asr_text, total_frames, gen_function_text = token_data + token_texts.append(gen_text) + token_asr_texts.append(gen_asr_text) + token_function_texts.append(gen_function_text) + token_lengths.append(total_frames) + lengths = torch.tensor([total_frames], dtype=torch.long) + texts_with_timestamps.append( + tokens_to_str(gen_text, lengths, tokenizer=tokenizer, pad_id=pad_id, eval_text_turn_taking=True)[0] + ) + asr_texts_with_timestamps.append( + tokens_to_str(gen_asr_text, lengths, tokenizer=tokenizer, pad_id=pad_id, eval_text_turn_taking=True)[0] + ) + raw_texts.append( + tokens_to_str(gen_text, lengths, tokenizer=tokenizer, pad_id=pad_id, keep_pad=True)[0] + ) + raw_asr_texts.append( + tokens_to_str(gen_asr_text, lengths, tokenizer=tokenizer, pad_id=pad_id, keep_pad=True)[0] + ) + if gen_function_text is not None: + fc_text = tokens_to_str(gen_function_text, lengths, tokenizer=tokenizer, pad_id=pad_id, eval_text_turn_taking=False)[0] + fc_text_raw = tokens_to_str(gen_function_text, lengths, tokenizer=tokenizer, pad_id=pad_id, keep_pad=True)[0] + logging.info(f"Function calling channel: {fc_text}, fc_text_raw: {fc_text_raw}") + else: + token_texts.append(None) + token_asr_texts.append(None) + token_function_texts.append(None) + token_lengths.append(None) + texts_with_timestamps.append("") + asr_texts_with_timestamps.append("") + raw_texts.append("") + raw_asr_texts.append("") + + debug_data = [] + if self.collect_debug: + for idx in range(len(audio_filepaths)): + state = self.get_or_create_state(idx) + debug_data.append(getattr(state, "debug_steps", [])) + + return PipelineOutput( + texts=texts, + words=words, + asr_texts=asr_texts, + texts_with_timestamps=texts_with_timestamps, + asr_texts_with_timestamps=asr_texts_with_timestamps, + raw_texts=raw_texts, + raw_asr_texts=raw_asr_texts, + token_texts=token_texts, + token_asr_texts=token_asr_texts, + token_function_texts=token_function_texts, + token_lengths=token_lengths, + audio_filepaths=audio_paths, + debug_data=debug_data if debug_data else None, + ) + + def _prefill_system_prompt(self, stream_id: int, system_prompt: str | None = None) -> torch.Tensor | None: + """Prefill the system prompt for a new stream. + + This prepares the system prompt embeddings and processes them through + the LLM to update the KV cache before audio streaming begins. + Also prefills the TTS model with speaker embeddings when using vLLM EarTTS. + + Args: + stream_id: The stream identifier. + system_prompt: The system prompt text for this stream. If *None*, + TTS prefill still runs (for vLLM EarTTS) but no LLM prompt + is injected. + + Note on TTS prefill codes: + The TTS prefill generates output codes, but these should NOT be used + to initialize context.tts_code for inference. The batch approach uses + first_tts_code_input (INPUT codes from speaker reference) instead. + Using prefill OUTPUT codes causes audio quality issues (mumbling). + + Returns: + torch.Tensor | None: The TTS prefill output codes if vLLM EarTTS prefill + happened, None otherwise. These are returned for logging/debugging but + should NOT be used to update context.tts_code. + """ + request_id = self._request_id_for_stream(stream_id) + engine_type = getattr(self.s2s_model, "engine_type", "native") + tts_output_code = None + + # Prefill TTS with speaker embedding via model_eartts_interface + use_vllm_eartts = "vllm_eartts" in engine_type.lower() + if use_vllm_eartts: + eartts = self.s2s_model.model_eartts_interface + tts_init_inputs = getattr(self.s2s_model, "tts_init_inputs", None) + tts_prompt_token_ids = getattr(self.s2s_model, "tts_prompt_token_ids", None) + if tts_init_inputs is not None and tts_prompt_token_ids is not None: + logging.info(f"Prefilling TTS speaker embedding for stream {stream_id}...") + start_tts_prefill = time.time() + with torch.no_grad(): + tts_result = eartts.prefill_prompt( + tts_init_inputs, + prompt_token_ids=tts_prompt_token_ids, + request_id=request_id, + ) + # Capture the generated codes to sync context with vLLM state + if hasattr(tts_result, 'codes') and tts_result.codes is not None: + tts_output_code = tts_result.codes.detach().clone() + logging.debug(f"TTS prefill generated codes shape: {tts_output_code.shape}") + logging.info(f"TTS speaker embedding prefilled in {time.time() - start_tts_prefill:.3f}s") + else: + logging.warning("TTS init inputs not available, skipping TTS prefill") + + if not system_prompt: + return tts_output_code + + # Prefill LLM with system prompt via model_llm_interface + logging.info(f"Prefilling system prompt for stream {stream_id}...") + start_get_prompt_embeddings = time.time() + prompt_embedded, prompt_len = self.s2s_model._prepare_system_prompt_embeddings(system_prompt) + logging.debug(f"Time taken to get prompt embeddings: {time.time() - start_get_prompt_embeddings:.3f}s") + + if prompt_embedded is None: + logging.warning("System prompt embedding returned None, skipping prefill") + return tts_output_code + + use_vllm_llm = "vllm_llm" in engine_type.lower() + llm = self.s2s_model.model_llm_interface + + if use_vllm_llm: + logging.info(f"Prefilling {prompt_len} prompt embeddings for vLLM LLM...") + start_prefill = time.time() + with torch.no_grad(): + llm.prefill_prompt(prompt_embedded, request_id=request_id) + logging.info(f"System prompt prefilled ({prompt_len} tokens) in {time.time() - start_prefill:.3f}s") + else: + context, _ = self.context_manager.get_context([stream_id]) + if context.llm_cache is not None: + # Native cache mode: process prompt through LLM to warm up KV cache + with torch.no_grad(): + cache_pos = torch.arange(prompt_len, device=self.s2s_model.device) + ans = llm.prefill_prompt( + prompt_embedded, + cache=context.llm_cache, + cache_position=cache_pos, + ) + context.llm_cache = ans.get("cache", context.llm_cache) + context.llm_cache_position_offset = prompt_len + logging.info(f"System prompt processed, cache updated ({prompt_len} tokens, offset={prompt_len})") + else: + for t in range(prompt_len): + context.input_embeds_history.append(prompt_embedded[:, t:t+1, :]) + logging.info(f"Added {prompt_len} prompt embeddings to input_embeds_history") + + return tts_output_code + + def _request_id_for_stream(self, stream_id: int) -> str: + return str(stream_id) + + def _abort_stream_request(self, stream_id: int) -> None: + request_id = self._request_id_for_stream(stream_id) + abort_fn = getattr(self.s2s_model, "abort_request", None) + if callable(abort_fn): + try: + abort_fn(request_id) + except Exception as exc: + logging.warning(f"Failed to abort request {request_id} for stream {stream_id}: {exc}") diff --git a/nemo/collections/speechlm2/inference/streaming/__init__.py b/nemo/collections/speechlm2/inference/streaming/__init__.py new file mode 100644 index 000000000000..9e3fb699d9f6 --- /dev/null +++ b/nemo/collections/speechlm2/inference/streaming/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/speechlm2/inference/streaming/framing/__init__.py b/nemo/collections/speechlm2/inference/streaming/framing/__init__.py new file mode 100644 index 000000000000..9e3fb699d9f6 --- /dev/null +++ b/nemo/collections/speechlm2/inference/streaming/framing/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/speechlm2/inference/streaming/framing/s2s_request_options.py b/nemo/collections/speechlm2/inference/streaming/framing/s2s_request_options.py new file mode 100644 index 000000000000..999985f02610 --- /dev/null +++ b/nemo/collections/speechlm2/inference/streaming/framing/s2s_request_options.py @@ -0,0 +1,72 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + + +@dataclass(frozen=True, slots=True) +class S2SRequestOptions: + """Immutable per-stream options for S2S inference. + + Attached to the first ``Frame`` of each stream via the ``options`` + field so that the pipeline can read per-stream configuration at the + start of every new audio stream. Frozen so that options cannot be + accidentally modified after the stream is initialised. + + All fields default to ``None``, which means "use the pipeline-level + default". Call :meth:`augment_with_defaults` to fill ``None`` fields + with pipeline-level values — this mirrors the + ``ASRRequestOptions.augment_with_defaults()`` pattern used by the + ASR streaming pipelines. + """ + + system_prompt: str | None = None + + top_p: float | None = None # (0, 1] + temperature: float | None = None # >= 0 + repetition_penalty: float | None = None # > 0 + + def __post_init__(self) -> None: + if self.top_p is not None and not (0.0 < self.top_p <= 1.0): + raise ValueError(f"top_p must be in (0, 1], got {self.top_p}") + if self.temperature is not None and self.temperature < 0.0: + raise ValueError(f"temperature must be >= 0, got {self.temperature}") + if self.repetition_penalty is not None and self.repetition_penalty <= 0.0: + raise ValueError(f"repetition_penalty must be > 0, got {self.repetition_penalty}") + + @staticmethod + def _with_default(value: Any, default: Any) -> Any: + """Return *value* when it is not ``None``, otherwise *default*.""" + return default if value is None else value + + def augment_with_defaults( + self, + default_system_prompt: str | None = None, + default_top_p: float | None = None, + default_temperature: float | None = None, + default_repetition_penalty: float | None = None, + ) -> S2SRequestOptions: + """Return a new options instance with ``None`` fields filled from defaults. + + This is the S2S equivalent of ``ASRRequestOptions.augment_with_defaults``. + """ + return S2SRequestOptions( + system_prompt=self._with_default(self.system_prompt, default_system_prompt), + top_p=self._with_default(self.top_p, default_top_p), + temperature=self._with_default(self.temperature, default_temperature), + repetition_penalty=self._with_default(self.repetition_penalty, default_repetition_penalty), + ) diff --git a/nemo/collections/speechlm2/inference/streaming/framing/silence_padded_frame_streamer.py b/nemo/collections/speechlm2/inference/streaming/framing/silence_padded_frame_streamer.py new file mode 100644 index 000000000000..edca88bfd31f --- /dev/null +++ b/nemo/collections/speechlm2/inference/streaming/framing/silence_padded_frame_streamer.py @@ -0,0 +1,75 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo.collections.asr.inference.streaming.framing.mono_stream import MonoStream +from nemo.collections.asr.inference.streaming.framing.multi_stream import ContinuousBatchedFrameStreamer +from nemo.collections.speechlm2.inference.streaming.framing.silence_padded_stream import SilencePaddedStream + + +class SilencePaddedContinuousBatchedFrameStreamer(ContinuousBatchedFrameStreamer): + """``ContinuousBatchedFrameStreamer`` that optionally wraps each + ``MonoStream`` in a :class:`SilencePaddedStream` so extra silence + frames are yielded transparently at the end of each audio file. + + When no padding is configured the behaviour is identical to the base + class. + """ + + def __init__( + self, + *, + pad_to_sec: float | None = None, + pad_by_sec: float | None = None, + pad_ratio: float | None = None, + **kwargs, + ): + super().__init__(**kwargs) + self.pad_to_sec = pad_to_sec + self.pad_by_sec = pad_by_sec + self.pad_ratio = pad_ratio + + @property + def _needs_padding(self) -> bool: + return any(x is not None for x in (self.pad_to_sec, self.pad_by_sec, self.pad_ratio)) + + def add_stream(self) -> None: + if self.stream_id >= self.n_audio_files: + return + + inner = MonoStream( + self.sample_rate, + self.frame_size_in_secs, + stream_id=self.stream_id, + pad_last_frame=self.pad_last_frame, + ) + + if self._needs_padding: + stream = SilencePaddedStream( + inner, + chunk_size_in_secs=self.frame_size_in_secs, + pad_to_sec=self.pad_to_sec, + pad_by_sec=self.pad_by_sec, + pad_ratio=self.pad_ratio, + ) + else: + stream = inner + + audio_filepath = self.audio_filepaths[self.stream_id] + self.sid2filepath[self.stream_id] = audio_filepath + self.elapsed_durations[self.stream_id] = 0.0 + stream.load_audio(audio_filepath, self.options[self.stream_id]) + + self.multi_streamer.add_stream(stream, stream_id=self.stream_id) + self.stream_id += 1 + self.update_progress_bar() diff --git a/nemo/collections/speechlm2/inference/streaming/framing/silence_padded_stream.py b/nemo/collections/speechlm2/inference/streaming/framing/silence_padded_stream.py new file mode 100644 index 000000000000..e5f6d3511efa --- /dev/null +++ b/nemo/collections/speechlm2/inference/streaming/framing/silence_padded_stream.py @@ -0,0 +1,97 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from nemo.collections.asr.inference.streaming.framing.mono_stream import MonoStream +from nemo.collections.asr.inference.streaming.framing.request import Frame +from nemo.collections.asr.inference.streaming.framing.stream import Stream + + +class SilencePaddedStream(Stream): + """Wraps a ``MonoStream`` and appends silence frames after the real audio + to reach a target duration. + + The pipeline's ``run()`` loop sees a single, longer stream — no frame + mutation or side-channel silence injection is needed. ``MultiStream`` + keeps the stream alive until the final silence frame sets ``is_last=True``. + """ + + def __init__( + self, + inner: MonoStream, + chunk_size_in_secs: float, + pad_to_sec: float | None = None, + pad_by_sec: float | None = None, + pad_ratio: float | None = None, + ): + super().__init__(inner.stream_id) + self.inner = inner + self.chunk_size_in_secs = chunk_size_in_secs + self.pad_to_sec = pad_to_sec + self.pad_by_sec = pad_by_sec + self.pad_ratio = pad_ratio + self._inner_exhausted = False + self._silence_frames_remaining = 0 + + def load_audio(self, audio, options=None): + self.inner.load_audio(audio, options) + audio_secs = self.inner.n_samples / self.inner.rate + remaining = self._padding_secs(audio_secs) + self._silence_frames_remaining = ( + max(1, round(remaining / self.chunk_size_in_secs)) if remaining > 0 else 0 + ) + + def _padding_secs(self, elapsed: float) -> float: + if self.pad_to_sec is not None: + return max(0.0, self.pad_to_sec - elapsed) + if self.pad_ratio is not None: + return elapsed * self.pad_ratio + if self.pad_by_sec is not None: + return self.pad_by_sec + return 0.0 + + def __iter__(self): + self.inner.__iter__() + self._inner_exhausted = False + return self + + def __next__(self) -> list[Frame]: + if not self._inner_exhausted: + frames = next(self.inner) + frame = frames[0] + if frame.is_last and self._silence_frames_remaining > 0: + modified = Frame( + samples=frame.samples, + stream_id=frame.stream_id, + is_first=frame.is_first, + is_last=False, + length=frame.length, + options=frame.options, + ) + self._inner_exhausted = True + return [modified] + return frames + + if self._silence_frames_remaining > 0: + self._silence_frames_remaining -= 1 + return [Frame( + samples=torch.zeros(self.inner.frame_size), + stream_id=self.stream_id, + is_first=False, + is_last=(self._silence_frames_remaining == 0), + length=self.inner.frame_size, + )] + + raise StopIteration diff --git a/nemo/collections/speechlm2/inference/streaming/state/__init__.py b/nemo/collections/speechlm2/inference/streaming/state/__init__.py new file mode 100644 index 000000000000..9e3fb699d9f6 --- /dev/null +++ b/nemo/collections/speechlm2/inference/streaming/state/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/speechlm2/inference/streaming/state/s2s_context_manager.py b/nemo/collections/speechlm2/inference/streaming/state/s2s_context_manager.py new file mode 100644 index 000000000000..4bb0c6dc1e0b --- /dev/null +++ b/nemo/collections/speechlm2/inference/streaming/state/s2s_context_manager.py @@ -0,0 +1,152 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from queue import Queue +import torch +from nemo.utils import logging + +from nemo.collections.speechlm2.inference.model_wrappers.decode_state import ( + InferenceStepResult, + StreamingDecodeState, +) + + +class S2SContextManager: + + def __init__( + self, + s2s_model, + num_slots: int, + max_len: int, + ): + self.s2s_model = s2s_model + self.num_slots = num_slots + + self.max_len = max_len + self.device = getattr(self.s2s_model, "device", torch.device("cpu")) + self.dtype = getattr(self.s2s_model, "dtype", torch.float32) + + self.reset() + + def reset(self) -> None: + """Reset all bookkeeping for a new streaming session.""" + self.streamidx2slotidx: dict[int, int] = {} + self.slotidx2streamidx: dict[int, int] = {} + self.free_slots = Queue(self.num_slots) + for i in range(self.num_slots): + self.free_slots.put(i) + self.slot_contexts: list[StreamingDecodeState | None] = [None] * self.num_slots + + def _create_context(self) -> StreamingDecodeState: + """Allocate a fresh context backed by the realtime inference model.""" + if not hasattr(self.s2s_model, "create_decode_state"): + raise RuntimeError("s2s_model must provide create_decode_state(max_len)") + return self.s2s_model.create_decode_state(self.max_len) + + def _ensure_slot(self, stream_id: int) -> int: + if stream_id not in self.streamidx2slotidx: + if self.free_slots.empty(): + # Emergency cleanup: force-release all slots for a fresh start + # This handles cases where previous streams didn't end properly + # (e.g., exceptions, client disconnects, missing is_last=True) + logging.warning(f"No free slots available - forcing cleanup of all {self.num_slots} slots") + orphaned_streams = list(self.slotidx2streamidx.values()) + if orphaned_streams: + logging.warning(f"Orphaned streams being cleaned up: {orphaned_streams}") + for slot_idx in range(self.num_slots): + self.reset_slot(slot_idx) + slot_idx = self.free_slots.get() + # Ensure the slot is completely clean before assigning to new stream + if self.slot_contexts[slot_idx] is not None: + logging.warning(f"Slot {slot_idx} was not properly cleaned. Forcing cleanup.") + self.slot_contexts[slot_idx] = None + self.streamidx2slotidx[stream_id] = slot_idx + self.slotidx2streamidx[slot_idx] = stream_id + return self.streamidx2slotidx[stream_id] + + def reset_slot(self, slot_idx: int) -> None: + """Release a slot back to the pool.""" + if slot_idx < 0 or slot_idx >= self.num_slots: + return + # Set to None to break reference and allow garbage collection + self.slot_contexts[slot_idx] = None + stream_id = self.slotidx2streamidx.get(slot_idx) + if stream_id is not None: + del self.slotidx2streamidx[slot_idx] + del self.streamidx2slotidx[stream_id] + self.free_slots.put(slot_idx) + + def update_context( + self, + stream_ids: list[int], + step_result: InferenceStepResult, + num_frames: int, + ) -> None: + """Advance frame counter and set subword mask after an inference step. + + All cache and tensor mutations (dynamic_cache, past_key_values, code, + perception_cache, codec_cache, gen_text, gen_asr_text, etc.) are + already applied in-place on the ``StreamingDecodeState`` by + ``infer_one_step``. This method only bumps ``frame_idx`` and marks + the subword mask for the newly generated frames. + """ + if len(stream_ids) == 0: + return + if len(stream_ids) != 1: + raise NotImplementedError("update_context currently supports batch_size == 1") + + stream_id = stream_ids[0] + slot_idx = self.streamidx2slotidx.get(stream_id) + if slot_idx is None: + raise RuntimeError(f"Stream {stream_id} is not registered in the context manager") + + context = self.slot_contexts[slot_idx] + if context is None: + context = self._create_context() + self.slot_contexts[slot_idx] = context + + start_idx = context.frame_idx + end_idx = start_idx + num_frames + if end_idx > context.gen_text.shape[1]: + raise RuntimeError( + "Context maximum length exceeded. Consider increasing `streaming.max_len` in the configuration." + ) + + context.frame_idx = end_idx + + if context.subword_mask is not None: + context.subword_mask[:, start_idx:end_idx] = True + + def reset_slots(self, stream_ids: list[int], eos_flags: list[bool]) -> None: + """Release contexts for streams that signalled end-of-stream.""" + if len(stream_ids) != len(eos_flags): + raise ValueError("stream_ids and eos_flags must have the same length") + for stream_id, eos_flag in zip(stream_ids, eos_flags): + if eos_flag and stream_id in self.streamidx2slotidx: + self.reset_slot(self.streamidx2slotidx[stream_id]) + + def get_context(self, stream_ids: list[int]) -> tuple[StreamingDecodeState, dict[int, int]]: + """Return the cached context associated with the provided stream ids.""" + if len(stream_ids) == 0: + return self._create_context(), {} + if len(stream_ids) != 1: + raise NotImplementedError("get_context currently supports batch_size == 1") + + stream_id = stream_ids[0] + slot_idx = self._ensure_slot(stream_id) + + if self.slot_contexts[slot_idx] is None: + self.slot_contexts[slot_idx] = self._create_context() + + return self.slot_contexts[slot_idx], {slot_idx: 0} diff --git a/nemo/collections/speechlm2/inference/streaming/state/s2s_state.py b/nemo/collections/speechlm2/inference/streaming/state/s2s_state.py new file mode 100644 index 000000000000..c9c300aa3db2 --- /dev/null +++ b/nemo/collections/speechlm2/inference/streaming/state/s2s_state.py @@ -0,0 +1,136 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass, field + +import torch +from nemo.collections.asr.inference.utils.text_segment import Word +from nemo.collections.speechlm2.inference.streaming.framing.s2s_request_options import S2SRequestOptions + + +@dataclass +class S2SStreamingState: + """Pipeline-level output accumulator for a single S2S stream. + + Collects generated audio samples, text strings, and word timings + produced by each inference step. Kept alive by the pipeline's + ``_state_pool`` until ``close_session()`` so the final + :class:`PipelineOutput` can be assembled. + + This is *not* the model-level decode state (KV caches, token + workspaces) -- that is :class:`StreamingDecodeState` in + ``model_wrappers/decode_state.py``. + """ + + # Required init metadata + device: torch.device + dtype: torch.dtype + output_sample_rate: int + + # Per-stream request options (system prompt, sampling overrides, etc.) + options: S2SRequestOptions = field(default_factory=S2SRequestOptions) + + # Growing audio buffer — shape (1, T), appended each step + audio_buffer: torch.Tensor = field(init=False) + + # Accumulated agent response text (built incrementally per step) + output_text_str: str = "" + # Accumulated ASR (user) text + output_asr_text_str: str = "" + # Word-level timings for the agent response + output_words: list[Word] = field(default_factory=list) + + # Snapshots of full token-ID tensors, saved from StreamingDecodeState + # before the decode context is destroyed at end-of-stream. + # Used for post-hoc tokens_to_str conversion. + final_gen_text: torch.Tensor | None = None + final_gen_asr_text: torch.Tensor | None = None + final_gen_function_text: torch.Tensor | None = None + final_total_frames: int = 0 + + def __post_init__(self) -> None: + # Depends on self.device and self.dtype, so can't be a field default. + self.audio_buffer = torch.empty((1, 0), device=self.device, dtype=self.dtype) + + def reset(self) -> None: + """Reset all accumulated outputs to initial state.""" + self.audio_buffer = torch.empty((1, 0), device=self.device, dtype=self.dtype) + self.output_text_str = "" + self.output_asr_text_str = "" + self.output_words.clear() + self.final_gen_text = None + self.final_gen_asr_text = None + self.final_gen_function_text = None + self.final_total_frames = 0 + + def append_step_output( + self, + audio: torch.Tensor, + text: str | None = None, + asr_text: str | None = None, + ) -> None: + """Append generated audio and optional text from one inference step.""" + if audio is None: + return + if not isinstance(audio, torch.Tensor): + raise TypeError("audio must be a torch.Tensor") + + append_tensor = audio + if append_tensor.dim() > 1: + append_tensor = append_tensor.reshape(1, -1) + elif append_tensor.dim() == 1: + append_tensor = append_tensor.unsqueeze(0) + prior_samples = int(self.audio_buffer.shape[-1]) + appended_samples = int(append_tensor.shape[-1]) + self.audio_buffer = torch.cat( + [self.audio_buffer, append_tensor.to(self.device, dtype=self.dtype)], dim=-1 + ) + + if isinstance(text, str) and text: + self.output_text_str += text + if appended_samples > 0 and self.output_sample_rate > 0: + start_t = float(prior_samples) / float(self.output_sample_rate) + end_t = float(prior_samples + appended_samples) / float(self.output_sample_rate) + self.output_words.append(Word(text=text, start=start_t, end=end_t, conf=1.0)) + + if isinstance(asr_text, str) and asr_text: + self.output_asr_text_str += asr_text + + def save_token_tensors( + self, + gen_text: torch.Tensor, + gen_asr_text: torch.Tensor, + total_frames: int, + gen_function_text: torch.Tensor | None = None, + ) -> None: + """Snapshot the full token-ID tensors from the decode context before it is destroyed.""" + self.final_gen_text = gen_text[:, :total_frames].clone().cpu() + self.final_gen_asr_text = gen_asr_text[:, :total_frames].clone().cpu() + self.final_total_frames = total_frames + self.final_gen_function_text = ( + gen_function_text[:, :total_frames].clone().cpu() + if gen_function_text is not None else None + ) + + def get_token_tensors(self) -> tuple | None: + """Return (gen_text, gen_asr_text, total_frames, gen_function_text) or None if not saved.""" + if self.final_gen_text is None: + return None + return self.final_gen_text, self.final_gen_asr_text, self.final_total_frames, self.final_gen_function_text + + def clear_audio_buffer(self) -> None: + """Clear the audio buffer (e.g. after sending audio to a client).""" + self.audio_buffer = torch.empty((1, 0), device=self.device, dtype=self.dtype) diff --git a/nemo/collections/speechlm2/inference/utils/__init__.py b/nemo/collections/speechlm2/inference/utils/__init__.py new file mode 100644 index 000000000000..9e3fb699d9f6 --- /dev/null +++ b/nemo/collections/speechlm2/inference/utils/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/speechlm2/inference/utils/audio_data.py b/nemo/collections/speechlm2/inference/utils/audio_data.py new file mode 100644 index 000000000000..2135e97dfd17 --- /dev/null +++ b/nemo/collections/speechlm2/inference/utils/audio_data.py @@ -0,0 +1,176 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Audio data loading and output serialization for S2S inference scripts.""" + +import json +import os + +import soundfile as sf + +from nemo.collections.common.parts.preprocessing.manifest import get_full_path +from nemo.collections.speechlm2.inference.streaming.framing.s2s_request_options import S2SRequestOptions +from nemo.collections.speechlm2.inference.utils.pipeline_utils import PipelineOutput + + +def prepare_audio_data( + audio_file: str, + default_system_prompt: str | None = None, + sort_by_duration: bool = True, +) -> tuple[list[str], list[S2SRequestOptions], list[str | None]]: + """Load audio filepaths and per-stream options from a folder, single file, or manifest. + + When the input is a JSON manifest, each line may contain:: + + {"audio_filepath": "clip.wav", "text": "...", "system_prompt": "..."} + + If ``system_prompt`` is absent on a line, *default_system_prompt* is used. + + Returns: + ``(filepaths, options, ground_truths)`` -- parallel lists of audio paths, + per-stream request options, and ground-truth texts (``None`` when unavailable). + """ + audio_file = audio_file.strip() + if not os.path.isabs(audio_file): + audio_file = os.path.abspath(audio_file) + + options: list[S2SRequestOptions] = [] + ground_truths: list[str | None] = [] + + if os.path.isdir(audio_file): + filepaths = [os.path.join(audio_file, x) for x in os.listdir(audio_file) if x.endswith(".wav")] + options = [S2SRequestOptions(system_prompt=default_system_prompt) for _ in filepaths] + ground_truths = [None] * len(filepaths) + elif audio_file.endswith(".wav"): + filepaths = [audio_file] + options = [S2SRequestOptions(system_prompt=default_system_prompt)] + ground_truths = [None] + elif audio_file.endswith((".json", ".jsonl")): + samples = [] + with open(audio_file, 'r') as f: + for line in f.readlines(): + if line.strip(): + samples.append(json.loads(line)) + filepaths = [get_full_path(entry["audio_filepath"], audio_file) for entry in samples] + options = [ + S2SRequestOptions( + system_prompt=entry.get("system_prompt", default_system_prompt), + ) + for entry in samples + ] + ground_truths = [entry.get("text", None) for entry in samples] + else: + raise ValueError(f"audio_file `{audio_file}` needs to be a folder, audio file, or manifest file") + + if sort_by_duration: + durations = [sf.SoundFile(fp).frames for fp in filepaths] + order = sorted(range(len(filepaths)), key=lambda i: durations[i]) + filepaths = [filepaths[i] for i in order] + options = [options[i] for i in order] + ground_truths = [ground_truths[i] for i in order] + + return filepaths, options, ground_truths + + +def calculate_durations_incl_padding( + audio_filepaths: list[str], + pad_audio_to_sec: float | None = None, + pad_silence_ratio: float | None = None, + pad_audio_by_sec: float | None = None, +) -> list[float]: + """Return per-file durations in seconds, accounting for silence padding. + + At most one padding argument may be set; when none are set this + returns the raw audio durations. + """ + if sum(x is not None for x in [pad_audio_to_sec, pad_silence_ratio, pad_audio_by_sec]) > 1: + raise ValueError("Set at most one of: pad_audio_to_sec, pad_silence_ratio, pad_audio_by_sec") + durations = [] + for fp in audio_filepaths: + sound = sf.SoundFile(fp) + dur = sound.frames / sound.samplerate + if pad_audio_to_sec is not None: + dur = max(dur, pad_audio_to_sec) + elif pad_silence_ratio is not None: + dur *= (1 + pad_silence_ratio) + elif pad_audio_by_sec is not None: + dur += pad_audio_by_sec + durations.append(dur) + return durations + + +def dump_output( + audio_filepaths: list[str], + output: PipelineOutput, + output_dir: str, + options: list[S2SRequestOptions], + ground_truths: list[str | None], +) -> None: + """Dump inference results to output_processed.json, output_raw.json, and per-file CTM. + + ``output_processed.json`` uses the canonical S2S processed-output schema + (timestamps in pred_text via ``<|t|>`` / ``<$t$>``). + + ``output_raw.json`` preserves all tokens including ```` (pad tokens). + """ + output_processed_path = os.path.join(output_dir, "output_processed.json") + output_raw_path = os.path.join(output_dir, "output_raw.json") + output_ctm_dir = os.path.join(output_dir, "ctm") + + os.makedirs(output_ctm_dir, exist_ok=True) + + asr_texts_ts = output.asr_texts_with_timestamps or [None] * len(audio_filepaths) + texts_ts = output.texts_with_timestamps or [""] * len(audio_filepaths) + raw_texts = output.raw_texts or [""] * len(audio_filepaths) + raw_asr_texts = output.raw_asr_texts or [""] * len(audio_filepaths) + + with open(output_processed_path, 'w') as f_proc, open(output_raw_path, 'w') as f_raw: + for audio_filepath, words, opts, gt, pred_text_ts, pred_src_text_ts, pred_text_raw, pred_src_text_raw in zip( + audio_filepaths, output.words, options, ground_truths, + texts_ts, asr_texts_ts, raw_texts, raw_asr_texts, + ): + stem = os.path.splitext(os.path.basename(audio_filepath))[0] + ctm_filepath = os.path.abspath(os.path.join(output_ctm_dir, f"{stem}.ctm")) + with open(ctm_filepath, 'w') as ctm_fout: + for word in words: + ctm_line = f"A {round(word.start, 2)} {round(word.duration, 2)} {word.text} {word.conf}" + ctm_fout.write(f"{stem} {ctm_line}\n") + + pred_audio_path = os.path.join(output_dir, "wav", f"{stem}.wav") + + record_processed = { + "id": stem, + "target_text": "", + "pred_audio": pred_audio_path, + "src_text": gt or "", + "pred_src_text": pred_src_text_ts or "", + "pred_text": pred_text_ts or "", + "system_prompt": opts.system_prompt or "", + } + json.dump(record_processed, f_proc, ensure_ascii=False) + f_proc.write('\n') + f_proc.flush() + + record_raw = { + "id": stem, + "target_text": "", + "pred_audio": pred_audio_path, + "src_text": gt or "", + "pred_src_text": pred_src_text_raw or "", + "pred_text": pred_text_raw or "", + "system_prompt": opts.system_prompt or "", + } + json.dump(record_raw, f_raw, ensure_ascii=False) + f_raw.write('\n') + f_raw.flush() diff --git a/nemo/collections/speechlm2/inference/utils/pipeline_utils.py b/nemo/collections/speechlm2/inference/utils/pipeline_utils.py new file mode 100644 index 000000000000..0425fe1cdd6b --- /dev/null +++ b/nemo/collections/speechlm2/inference/utils/pipeline_utils.py @@ -0,0 +1,79 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +import torch +from whisper_normalizer.english import EnglishTextNormalizer + +from nemo.collections.asr.inference.utils.text_segment import Word + +_whisper_normalizer = EnglishTextNormalizer() + + +def clean_pred_text(text: str) -> str: + """Clean prediction text for fair WER comparison. + + First strips model-specific tokens (turn markers, timestamps, pad tokens) + that the Whisper normalizer doesn't know about, then applies + ``EnglishTextNormalizer`` — the same normalizer used by the offline eval + metrics in ``speechlm2.parts.metrics.wer``. + """ + if not text: + return "" + # Strip model-specific tokens + text = text.lstrip('^') + text = re.sub(r'', '', text) + text = re.sub(r'<\$[\d.]+\$>', '', text) + text = re.sub(r'<\|[\d.]+\|>', '', text) + text = re.sub(r'', '', text) + # Normalize with Whisper's EnglishTextNormalizer (same as offline eval) + return _whisper_normalizer(text) + + +class PipelineOutput: + """ + Class to store the output of the S2S pipeline. + """ + + def __init__( + self, + texts: list[str] | None = None, + words: list[list[Word]] | None = None, + asr_texts: list[str] | None = None, + texts_with_timestamps: list[str] | None = None, + asr_texts_with_timestamps: list[str] | None = None, + raw_texts: list[str] | None = None, + raw_asr_texts: list[str] | None = None, + token_texts: list[torch.Tensor | None] | None = None, + token_asr_texts: list[torch.Tensor | None] | None = None, + token_function_texts: list[torch.Tensor | None] | None = None, + token_lengths: list[int | None] | None = None, + audio_filepaths: list[str | None] | None = None, + debug_data: list[list] | None = None, + ): + if texts is None and words is None: + raise ValueError("At least one of the 'texts' or 'words' should be provided.") + self.texts = texts + self.words = words + self.asr_texts = asr_texts + self.texts_with_timestamps = texts_with_timestamps + self.asr_texts_with_timestamps = asr_texts_with_timestamps + self.raw_texts = raw_texts + self.raw_asr_texts = raw_asr_texts + self.token_texts = token_texts + self.token_asr_texts = token_asr_texts + self.token_function_texts = token_function_texts + self.token_lengths = token_lengths + self.audio_filepaths = audio_filepaths + self.debug_data = debug_data diff --git a/nemo/collections/speechlm2/inference/utils/stepprogressbar.py b/nemo/collections/speechlm2/inference/utils/stepprogressbar.py new file mode 100644 index 000000000000..a21b25fc8039 --- /dev/null +++ b/nemo/collections/speechlm2/inference/utils/stepprogressbar.py @@ -0,0 +1,73 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Per-inference-step progress bar for S2S streaming pipelines.""" + +from __future__ import annotations + +import math + +from tqdm import tqdm + +from nemo.collections.speechlm2.inference.utils.audio_data import calculate_durations_incl_padding + + +class StepProgressBar: + """Tracks per-step inference progress across one or more streams. + + Each call to :meth:`step` advances the bar by one and updates the + per-stream postfix (e.g. ``stream 2: 45/127``). + + Create via :meth:`from_audio_filepaths`. + """ + + def __init__(self, total_steps: int, steps_per_stream: dict[int, int] | None = None): + self._bar = tqdm(total=total_steps, desc="Inference", unit="step", dynamic_ncols=True) + self._steps_per_stream = steps_per_stream or {} + self._stream_progress: dict[int, int] = {} + + def step(self, stream_id: int) -> None: + """Record one inference step for *stream_id* and advance the bar.""" + self._stream_progress[stream_id] = self._stream_progress.get(stream_id, 0) + 1 + stream_total = self._steps_per_stream.get(stream_id) + if stream_total is not None: + self._bar.set_postfix_str( + f"stream {stream_id}: {self._stream_progress[stream_id]}/{stream_total}", + refresh=False, + ) + self._bar.update(1) + + def finish(self) -> None: + """Close the underlying tqdm bar.""" + self._bar.close() + + @classmethod + def from_audio_filepaths( + cls, + audio_filepaths: list[str], + chunk_size_in_secs: float, + pad_audio_to_sec: float | None = None, + pad_silence_ratio: float | None = None, + pad_audio_by_sec: float | None = None, + ) -> StepProgressBar: + durations = calculate_durations_incl_padding( + audio_filepaths, pad_audio_to_sec, pad_silence_ratio, pad_audio_by_sec, + ) + steps_per_stream = { + idx: math.ceil(dur / chunk_size_in_secs) for idx, dur in enumerate(durations) + } + return cls( + total_steps=sum(steps_per_stream.values()), + steps_per_stream=steps_per_stream, + ) diff --git a/nemo/collections/speechlm2/inference/vllm/__init__.py b/nemo/collections/speechlm2/inference/vllm/__init__.py new file mode 100644 index 000000000000..9e3fb699d9f6 --- /dev/null +++ b/nemo/collections/speechlm2/inference/vllm/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/speechlm2/inference/vllm/scripts/__init__.py b/nemo/collections/speechlm2/inference/vllm/scripts/__init__.py new file mode 100644 index 000000000000..9e3fb699d9f6 --- /dev/null +++ b/nemo/collections/speechlm2/inference/vllm/scripts/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/speechlm2/inference/vllm/scripts/convert_eartts_checkpoint.py b/nemo/collections/speechlm2/inference/vllm/scripts/convert_eartts_checkpoint.py new file mode 100644 index 000000000000..620f6bf20d66 --- /dev/null +++ b/nemo/collections/speechlm2/inference/vllm/scripts/convert_eartts_checkpoint.py @@ -0,0 +1,278 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import json +import argparse + +import torch +from omegaconf import OmegaConf, DictConfig +from safetensors.torch import save_file, load_file +from transformers import AutoConfig +from nemo.utils import logging + +from nemo.collections.speechlm2.models.duplex_ear_tts import DuplexEARTTS + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--config", type=str, required=True) + parser.add_argument("--model", type=str, required=True) + parser.add_argument("--outdir", type=str, required=True) + return parser.parse_args() + + +def convert(outdir, config, model_path): + os.makedirs(outdir, exist_ok=True) + + # load config + with open(config, "r") as f: + full_config = json.load(f) + config_dict = full_config["model"]["speech_generation"] + cfg = DictConfig(config_dict) + # config modification that is needed to run inference + cfg.model.tts_config.use_unshifthed_prompt = True + cfg.data.add_audio_prompt_after_description = True + cfg.model.tts_config.use_unshifthed_prompt = True + cfg.model.subword_mask_exactly_as_eartts = False + cfg.model.context_hidden_mask_exactly_as_eartts = False + cfg.model.tts_config.disable_eos_prediction = True + cfg.model.inference_force_speech_silence_on_eos = True + cfg.model.use_word_sep_tokenizer = False + cfg.model.num_delay_speech_tokens = 0 + cfg.data.source_sample_rate = 22050 + cfg.data.target_sample_rate = 22050 + cfg.model.pretrained_model = None + + # Resolve tokenizer name from the STT config (same tokenizer is shared by TTS). + _pretrained_tokenizer_name = ( + full_config.get("model", {}).get("stt", {}).get("model", {}).get("pretrained_llm", None) + ) + if _pretrained_tokenizer_name is None: + raise ValueError( + "Cannot determine tokenizer: 'pretrained_llm' not found in " + "config.json -> model -> stt -> model. Check the checkpoint." + ) + + model = DuplexEARTTS(OmegaConf.to_container(cfg, resolve=True)).eval() + # get subword encoder vocabs and config + subword_id_to_char_ids = model.tts_model.embed_subword.subword_id_to_char_ids + char_vocab = model.tts_model.embed_subword.char_vocab + # create weights for the embedding layers that convert subword ids to char ids + vocab_size = len(subword_id_to_char_ids) + max_char_len = max(len(char_ids) for char_ids in subword_id_to_char_ids.values()) + hidden_size = cfg.model.tts_config.backbone_config.hidden_size + + # load checkpoint (support both safetensors and pytorch formats) + weights = load_file(model_path) + # select tts model weights, strip off one nested layer + weights = {k[len("tts_model."):]: v for k, v in weights.items() if "tts_model." in k} + + # duplicate weights for rvq embeddings and embed code + rvq_embs_weight = weights["tts_model.rvq_embs"].clone() # 31 x codebook_size x latent_size + rvq_embs_weight_pad = torch.nn.functional.pad(rvq_embs_weight, [0, 0, 0, 1]) # 31 x (codebook_size + 1) x latent_size + embed_code_weight = weights["tts_model.embed_code.weight"].clone() # latent_size x hidden_size + + # ====================== + # embedding module weights + bos_emb = weights["tts_model.bos_emb"] + null_emb = weights["tts_model.null_emb"] + embed_subwords_weight = torch.zeros( + (vocab_size, max_char_len), dtype=bos_emb.dtype, device=bos_emb.device + ) + embed_subwords_mask_weight = torch.zeros( + (vocab_size, max_char_len), dtype=bos_emb.dtype, device=bos_emb.device + ) + for subword_id_str, char_ids_lst in subword_id_to_char_ids.items(): + subword_id = int(subword_id_str) + char_ids = torch.tensor( + char_ids_lst, dtype=bos_emb.dtype, device=bos_emb.device + ) + embed_subwords_weight[subword_id, : len(char_ids)] = char_ids + embed_subwords_mask_weight[subword_id, : len(char_ids)] = 1 + + # create weights for the embedding model that runs outside of the eartts + embedding_module_weights = {} + embedding_module_weights["bos_emb"] = bos_emb + embedding_module_weights["null_emb"] = null_emb + + # embedding transformer has a lot of weights + for key, weight in weights.items(): + if "tts_model.embed_subword" in key: + key = key[len("tts_model.") :] + # bos_eos_emb and subword_flag_emb are moved outside embed_subword + if key.startswith("embed_subword.bos_eos_emb.") or key.startswith("embed_subword.subword_flag_emb."): + key = key[len("embed_subword."):] + embedding_module_weights[key] = weight + for key, weight in weights.items(): + if "tts_model.gated_fusion_audio_text" in key: + key = key[len("tts_model.") :] + embedding_module_weights[key] = weight + if "tts_model.audio_prompt_projection_W" in weights: + embedding_module_weights["audio_prompt_projection_W"] = weights["tts_model.audio_prompt_projection_W"] + embedding_module_weights["embed_subword.embed_subwords.weight"] = ( + embed_subwords_weight + ) + embedding_module_weights["embed_subword.embed_subwords_mask.weight"] = ( + embed_subwords_mask_weight + ) + for i in range(rvq_embs_weight_pad.shape[0]): + embedding_module_weights[f"rvq_embs.{i}.weight"] = rvq_embs_weight_pad[i] + embedding_module_weights["embed_code.weight"] = embed_code_weight + embedding_module_weights = { + f"total_emb.{k}": v for k, v in embedding_module_weights.items() + } + + # ====================== + # gemma backbone weights + backbone_module_weights = {k[len("tts_model."):]: v for k, v in weights.items() if k.startswith("tts_model.backbone.")} + backbone_module_weights["backbone.embed_tokens.weight"] = torch.randn(1, hidden_size, dtype=bos_emb.dtype, device=bos_emb.device) + + # ====================== + # sampler weights + used_keys = ["rvq_embs", "embed_code", "mog_head"] + sampler_weights = {k[len("tts_model."):]: v for k, v in weights.items() if any(k.startswith(f"tts_model.{key}") for key in used_keys)} + sampler_weights = {"sampler." + k: v for k, v in sampler_weights.items()} + + # combine embedding module and backbone module weights + weights = {**embedding_module_weights, **backbone_module_weights, **sampler_weights} + weights = {"model." + k: v for k, v in weights.items()} + + # save weights + safetensors_path = os.path.join(outdir, "model.safetensors") + save_file(weights, safetensors_path) + logging.info("Saved weights for vllm model") + weight_map = {name: "model.safetensors" for name in weights.keys()} + index = { + "metadata": { + "total_size": sum(w.numel() * w.element_size() for w in weights.values()) + }, + "weight_map": weight_map, + } + index_path = os.path.join(outdir, "model.safetensors.index.json") + with open(index_path, "w") as f: + json.dump(index, f, indent=2) + logging.info("Saved model index") + + # save config.json + flat_config = {"architectures": ["EarTTSForCausalLM"], "model_type": "eartts"} + # not using vocab size of the backbone model + flat_config["vocab_size"] = 1 + + # Parse backbone config exactly as NeMo does to get all defaults from transformers + backbone_type = cfg.model.tts_config.get("backbone_type", None) + backbone_config_dict = OmegaConf.to_container( + cfg.model.tts_config.backbone_config, resolve=True + ) if cfg.model.tts_config.get("backbone_config") else {} + + # Create AutoConfig the same way NeMo does - this fills in all defaults + parsed_backbone_config = AutoConfig.for_model(backbone_type, **backbone_config_dict) + + # Store the backbone type for vllm to use + flat_config["backbone_type"] = backbone_type + + # Forward all backbone configs from the parsed AutoConfig (includes defaults) + for key in [ + "hidden_size", + "intermediate_size", + "num_hidden_layers", + "num_attention_heads", + "num_key_value_heads", + "head_dim", + "max_position_embeddings", + "rope_theta", + "rope_local_base_freq", + "sliding_window", + "layer_types", + ]: + if hasattr(parsed_backbone_config, key): + value = getattr(parsed_backbone_config, key) + # convert to list if it's a tuple or other iterable (except str) + if hasattr(value, '__iter__') and not isinstance(value, (str, dict)): + value = list(value) + flat_config[key] = value + # forward overall configs + for key in ["latent_size", "codebook_size", "num_quantizers", "exponent"]: + flat_config[key] = cfg.model.tts_config[key] + # forward mog head configs + for key in ["num_layers", "low_rank", "num_predictions", "min_log_std", "eps"]: + flat_config[f"mog_{key}"] = cfg.model.tts_config.mog_head_config[key] + + # forward inference configs (with name mapping for vLLM model) + # num_iter is hardcoded to 8 in native model's _get_generation_config + flat_config["num_iter"] = 8 + flat_config["noise_scale"] = cfg.model.get("inference_noise_scale", 0.8) + flat_config["top_p_or_k"] = cfg.model.get("inference_top_p_or_k", 0.8) + flat_config["guidance_scale"] = cfg.model.get("inference_guidance_scale", 0.5) + + # configuration of the embedding module + flat_config["emb_backbone_config"] = OmegaConf.to_container( + cfg.model.tts_config.cas_config.backbone_config, resolve=True + ) + flat_config["emb_backbone_type"] = cfg.model.tts_config.cas_config.backbone_type + flat_config["emb_vocab_size"] = vocab_size + flat_config["emb_char_vocab_size"] = len(char_vocab) + flat_config["max_char_len"] = max_char_len + + # configuration of flag embeddings + flat_config["pretrained_tokenizer_name"] = _pretrained_tokenizer_name + flat_config["use_subword_flag_emb"] = cfg.model.tts_config.use_subword_flag_emb + flat_config["use_bos_eos_emb"] = cfg.model.tts_config.use_bos_eos_emb + flat_config["use_gated_fusion_for_text_audio"] = cfg.model.tts_config.use_gated_fusion_for_text_audio + flat_config["use_audio_prompt_frozen_projection"] = cfg.model.tts_config.use_audio_prompt_frozen_projection + # hardcode enabling guidance so emb is created and application + # of cfg is captured into a cuda graph + flat_config["enable_guidance"] = True + + # configuring custom inputs/outputs + flat_config["custom_input_specs"] = [ + { + "name": "acoustic_tokens", + "dim": flat_config["num_quantizers"], + "dtype": "int32", + }, + {"name": "text_tokens", "dtype": "int32"}, + {"name": "text_mask"}, + {"name": "bos_mask"}, + {"name": "speaker_latent", "dim": flat_config["hidden_size"]}, + ] + flat_config["custom_outputs"] = ["acoustic_tokens"] + + with open(os.path.join(outdir, "config.json"), "w") as f: + json.dump(flat_config, f, indent=2) + logging.info("Saved vllm config") + + # Extract and save pre-computed speaker latents (audio_prompt_latents.*) + # from the NeMo checkpoint so they can be used at inference time. + speaker_latents_dir = os.path.join(outdir, "speaker_latents") + all_weights = load_file(model_path) + found_latents = False + for key, tensor in all_weights.items(): + if "audio_prompt_latents." in key: + speaker_name = key.split("audio_prompt_latents.")[-1] + os.makedirs(speaker_latents_dir, exist_ok=True) + latent_path = os.path.join(speaker_latents_dir, f"{speaker_name}.pt") + torch.save(tensor, latent_path) + logging.info(f"Saved speaker latent '{speaker_name}' to {latent_path} (shape={tensor.shape})") + found_latents = True + if not found_latents: + logging.warning( + "No audio_prompt_latents found in checkpoint. " + "speaker_name will not work unless latents are added." + ) + + +if __name__ == "__main__": + args = parse_args() + convert(args.outdir, args.config, args.model) diff --git a/nemo/collections/speechlm2/inference/vllm/scripts/convert_nemotronllm_checkpoint.py b/nemo/collections/speechlm2/inference/vllm/scripts/convert_nemotronllm_checkpoint.py new file mode 100644 index 000000000000..e9f1c53c12e6 --- /dev/null +++ b/nemo/collections/speechlm2/inference/vllm/scripts/convert_nemotronllm_checkpoint.py @@ -0,0 +1,259 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Convert NeMo STT checkpoint to HuggingFace format compatible with vLLM. + +This script extracts weights from a NeMo checkpoint that has the structure: +- stt_model.llm.layers.* +- stt_model.lm_head.* +- stt_model.asr_head.* +- stt_model.embed_asr_tokens.* +- stt_model.embed_tokens.* + +And converts them to HuggingFace format that can be loaded by vLLM with the +custom WeightsMapper defined in nemotron_h.py. +""" + +import argparse +import json +import os +from pathlib import Path +import torch +from safetensors.torch import load_file, save_file +from transformers import AutoConfig, AutoTokenizer +from nemo.utils import logging + + +def load_checkpoint(checkpoint_path: str) -> dict[str, torch.Tensor]: + """ + Load checkpoint from safetensors or PyTorch format. + + Args: + checkpoint_path: Path to the checkpoint file (.safetensors or .pt/.pth) + + Returns: + Dictionary of tensor names to tensors + """ + if os.path.isdir(checkpoint_path): + checkpoint_path = os.path.join(checkpoint_path, "model.safetensors") + + if checkpoint_path.endswith('.safetensors'): + logging.info(f"Loading safetensors from {checkpoint_path}") + return load_file(checkpoint_path) + else: + logging.info(f"Loading PyTorch checkpoint from {checkpoint_path}") + ckpt = torch.load(checkpoint_path, map_location='cpu') + # Handle different checkpoint formats + if 'state_dict' in ckpt: + return ckpt['state_dict'] + elif 'model' in ckpt: + return ckpt['model'] + else: + return ckpt + + +def filter_tensors( + state_dict: dict[str, torch.Tensor], + prefixes_to_keep: list[str] +) -> dict[str, torch.Tensor]: + """ + Filter tensors to keep only those with specified prefixes. + + Args: + state_dict: Full state dictionary + prefixes_to_keep: List of prefixes to keep (e.g., ["stt_model.llm", "stt_model.asr_head"]) + + Returns: + Filtered state dictionary + """ + filtered_dict = {} + for name, tensor in state_dict.items(): + if any(name.startswith(prefix) for prefix in prefixes_to_keep): + filtered_dict[name] = tensor + logging.debug(f"Keeping: {name} with shape {tensor.shape}") + else: + logging.debug(f"Skipping: {name}") + + logging.info(f"Total tensors kept: {len(filtered_dict)}") + return filtered_dict + + +def convert_nemo_to_hf_format( + checkpoint_path: str, + output_dir: str, + config_path: str | None = None, + pretrained_llm: str | None = None, + tensors_to_keep: list[str] | None = None, + dtype: str = "float32", +) -> None: + """ + Convert NeMo STT checkpoint to HuggingFace format. + + Args: + checkpoint_path: Path to the NeMo checkpoint (.safetensors or .pt) + output_dir: Directory to save the converted checkpoint + config_path: Path to config.json (if None, will look in same dir as checkpoint) + pretrained_llm: HuggingFace model name to get base config from + tensors_to_keep: List of tensor prefixes to keep (default: all stt_model.* tensors) + dtype: Data type for tensors ("float32", "float16", "bfloat16") + """ + # Default prefixes to keep + if tensors_to_keep is None: + tensors_to_keep = [ + "stt_model.llm", + "stt_model.lm_head", + "stt_model.asr_head", + "stt_model.embed_asr_tokens", + "stt_model.embed_tokens", + ] + + # Load config to get pretrained_llm if not provided + if config_path is None: + ckpt_dir = checkpoint_path if os.path.isdir(checkpoint_path) else os.path.dirname(checkpoint_path) + config_path = os.path.join(ckpt_dir, "config.json") + + if os.path.exists(config_path): + logging.info(f"Loading config from {config_path}") + with open(config_path, "r") as f: + config = json.load(f) + + try: + pretrained_llm = config["model"]["stt"]["model"]["pretrained_llm"] + logging.info(f"Found pretrained_llm in config: {pretrained_llm}") + except KeyError: + if pretrained_llm is None: + raise ValueError( + "Could not find pretrained_llm in config and none provided via argument" + ) + else: + if pretrained_llm is None: + raise ValueError( + f"Config file not found at {config_path} and pretrained_llm not provided" + ) + + # Create output directory + output_path = Path(output_dir) + output_path.mkdir(parents=True, exist_ok=True) + + # Load base config from pretrained model + logging.info(f"Loading base config from {pretrained_llm}") + base_config = AutoConfig.from_pretrained(pretrained_llm, trust_remote_code=True) + + # Modify config for custom inputs/outputs + custom_config = { + "custom_input_specs": [ + { + "name": "combined_embeds", + "dtype": dtype, + "dim": base_config.hidden_size + } + ], + "custom_outputs": ["text_logits", "asr_tokens", "asr_logits"] + } + base_config.update(custom_config) + + + # Load tokenizer from pretrained model + logging.info(f"Loading tokenizer from {pretrained_llm}") + tokenizer = AutoTokenizer.from_pretrained(pretrained_llm, trust_remote_code=True) + + # Load checkpoint + logging.info(f"Loading checkpoint from {checkpoint_path}") + state_dict = load_checkpoint(checkpoint_path) + + # Filter tensors + logging.info(f"Filtering tensors to keep prefixes: {tensors_to_keep}") + filtered_state_dict = filter_tensors(state_dict, tensors_to_keep) + + if len(filtered_state_dict) == 0: + raise ValueError( + f"No tensors found with prefixes {tensors_to_keep}. " + f"Available prefixes: {set(k.split('.')[0] for k in state_dict.keys())}" + ) + + # Save tensors + output_model_path = output_path / "model.safetensors" + logging.info(f"Saving tensors to {output_model_path}") + save_file(filtered_state_dict, str(output_model_path)) + + # Save config + output_config_path = output_path / "config.json" + logging.info(f"Saving config to {output_config_path}") + base_config.save_pretrained(str(output_path)) + + # Save tokenizer + logging.info(f"Saving tokenizer to {output_path}") + tokenizer.save_pretrained(str(output_path)) + + logging.info(f"Conversion completed successfully! Output saved to: {output_path}") + + +def main(): + parser = argparse.ArgumentParser( + description="Convert NeMo STT checkpoint to HuggingFace format for vLLM" + ) + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to NeMo checkpoint file (.safetensors or .pt/.pth)", + ) + parser.add_argument( + "--output-dir", + type=str, + required=True, + help="Directory to save converted checkpoint", + ) + parser.add_argument( + "--config", + type=str, + default=None, + help="Path to config.json (default: same directory as checkpoint)", + ) + parser.add_argument( + "--pretrained-llm", + type=str, + default=None, + help="HuggingFace model name to use as base (default: read from config)", + ) + parser.add_argument( + "--tensors-to-keep", + type=str, + nargs="+", + default=None, + help="Tensor prefixes to keep (default: all stt_model.* backbone llm related tensors)", + ) + parser.add_argument( + "--dtype", + type=str, + default="float32", + choices=["float32", "float16", "bfloat16", "fp32", "fp16", "bf16"], + help="Target dtype for tensors (default: float32)", + ) + + args = parser.parse_args() + + convert_nemo_to_hf_format( + checkpoint_path=args.checkpoint, + output_dir=args.output_dir, + config_path=args.config, + pretrained_llm=args.pretrained_llm, + tensors_to_keep=args.tensors_to_keep, + dtype=args.dtype, + ) + + +if __name__ == "__main__": + main() diff --git a/nemo/collections/speechlm2/inference/vllm/streaming_llm_engine.py b/nemo/collections/speechlm2/inference/vllm/streaming_llm_engine.py new file mode 100644 index 000000000000..3c570953acf1 --- /dev/null +++ b/nemo/collections/speechlm2/inference/vllm/streaming_llm_engine.py @@ -0,0 +1,493 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Async vLLM engine wrapper for models that use custom input tensors. + +Wraps vLLM's ``AsyncLLM`` for streaming step-by-step inference with +``custom_inputs`` (used by both DuplexSTT/LLM and EarTTS backends). +``engine_kind`` selects EarTTS-only runtime settings (guidance scale, attention +backend env); it does not imply an inheritance relationship between TTS and LLM. +""" + +import os +import json +import torch +from typing import Any, AsyncGenerator, Literal +from dataclasses import dataclass +from enum import Enum + +from vllm.v1.engine.async_llm import AsyncLLM +from vllm import SamplingParams +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.attention.selector import _cached_get_attn_backend + +from nemo.utils import logging + +class StreamStatus(Enum): + IDLE = "idle" + ACTIVE = "active" + ABORTED = "aborted" + FINISHED = "finished" + + +@dataclass +class GenerationResult: + token_id: int + is_finished: bool + custom_outputs: dict[str, torch.Tensor] | None = None + finish_reason: str | None = None + total_tokens: int = 0 + + +@dataclass +class RequestState: + """State for a single generation request.""" + request_id: str + status: StreamStatus + generated_tokens: list + generation_iterator: AsyncGenerator | None = None + + +class CustomInputAsyncVLLMEngine: + """ + Wrapper for vLLM ``AsyncLLM`` with custom input tensor specifications. + + KV cache is managed entirely inside the AsyncLLM engine -- callers do not + allocate, pass, or update cache objects. The engine uses PagedAttention to + manage GPU memory automatically. Per-request state (generated tokens, + generation iterator, status) is tracked in ``self.requests`` dicts. + + Provides: + - Start/stop streaming with custom embedding inputs + - Generate one token at a time via ``generate_next_token`` + - Abort / restart individual requests by ``request_id`` + """ + + def __init__( + self, + model_path: str, + max_model_len: int = 10240, + max_num_batched_tokens: int = 768, + gpu_memory_utilization: float = 0.8, + trust_remote_code: bool = True, + dtype: str = "bfloat16", + skip_tokenizer_init: bool = False, + engine_kind: Literal["llm", "eartts"] = "llm", + **sampling_kwargs + ): + """ + Initialize the async vLLM engine wrapper. + + Args: + model_path: Path to the vLLM-compatible model checkpoint (required) + max_model_len: Maximum sequence length (default: 10240) + max_num_batched_tokens: Maximum tokens processed per forward pass. + Controls prefill chunk size and max concurrent decode streams. + Default: 768. + gpu_memory_utilization: GPU memory utilization ratio (default: 0.8) + trust_remote_code: Whether to trust remote code (default: True) + dtype: Data type for embeddings (default: "bfloat16") + engine_kind: ``"llm"`` for DuplexSTT-style models; ``"eartts"`` applies + EarTTS-specific ``guidance_scale`` and attention/runtime env during init. + **sampling_kwargs: Additional vLLM sampling parameters. + Note: vLLM is configured for greedy decoding internally + (temperature=0, ignore_eos=True). Actual text sampling + (top-p, repetition penalty) is applied post-hoc by + ModelInterface._sample_text_token, not by vLLM. + """ + self.model_path = model_path + self.max_model_len = max_model_len + self.max_num_batched_tokens = max_num_batched_tokens + self.gpu_memory_utilization = gpu_memory_utilization + self.trust_remote_code = trust_remote_code + self.dtype = dtype + self.skip_tokenizer_init = skip_tokenizer_init + self.engine_kind = engine_kind + + # Engine state + self.engine: AsyncLLM | None = None + + # Request state tracking - supports multiple concurrent requests + self.requests: dict[str, RequestState] = {} + + # vLLM sampling is disabled (skip_sampling=True) because we handle + # token selection ourselves for both the LLM and EarTTS paths. + # For LLM: text sampling (top-p, repetition penalty) is applied + # post-hoc by ModelInterface._sample_text_token on the returned logits. + # For EarTTS: acoustic tokens are produced by the model's own forward + # pass (RVQ codebook prediction), not by vLLM's sampler. + default_sampling = { + "max_tokens": 100000, # Set very high to prevent stopping - use abort to stop explicitly + "skip_sampling": True, + "ignore_eos": True, + } + default_sampling.update(sampling_kwargs) + self.sampling_params = SamplingParams(**default_sampling) + + if self.engine_kind == "eartts": + guidance_scale = self._read_guidance_scale_from_config() + self.sampling_params.guidance_scale = guidance_scale + logging.info( + f"CustomInputAsyncVLLMEngine initialized (engine_kind=eartts, " + f"guidance_scale={guidance_scale}, model={model_path})" + ) + else: + logging.info( + f"CustomInputAsyncVLLMEngine initialized (engine_kind=llm, model={model_path})" + ) + + def _read_guidance_scale_from_config(self) -> float: + """Read guidance_scale from the converted vLLM model's config.json.""" + config_path = os.path.join(self.model_path, "config.json") + if os.path.isfile(config_path): + with open(config_path, "r") as f: + cfg = json.load(f) + value = cfg.get("guidance_scale", None) + if value is not None: + logging.info(f"Read guidance_scale={value} from {config_path}") + return float(value) + logging.warning( + f"guidance_scale not found in {config_path}, using default 0.5." + ) + return 0.5 + + async def initialize(self): + """Initialize the vLLM engine with custom input specifications.""" + if self.engine is not None: + logging.info("Engine already initialized!") + return + + logging.info("Initializing vLLM engine...") + + eartts_env = self.engine_kind == "eartts" + if eartts_env: + # Force TRITON_ATTN backend for EarTTS + os.environ["VLLM_ATTENTION_BACKEND"] = "TRITON_ATTN" + # TF32 matmul precision to match TTS training ("medium"). + # torch.set_float32_matmul_precision is process-local and does NOT + # propagate to vLLM's spawned worker processes; this CUDA-level env + # var is inherited by child processes. + os.environ["NVIDIA_TF32_OVERRIDE"] = "1" + _cached_get_attn_backend.cache_clear() + + try: + engine_args = AsyncEngineArgs( + model=self.model_path, + max_model_len=self.max_model_len, + max_num_batched_tokens=self.max_num_batched_tokens, + gpu_memory_utilization=self.gpu_memory_utilization, + trust_remote_code=self.trust_remote_code, + mamba_ssm_cache_dtype="float32", + dtype=self.dtype, + skip_tokenizer_init=self.skip_tokenizer_init, + enable_prefix_caching=False + ) + + # Custom input/output specs are defined in the model's config file + vllm_config = engine_args.create_engine_config() + self.custom_input_specs = vllm_config.model_config.custom_input_specs + + # Initialize the engine + self.engine = AsyncLLM.from_vllm_config(vllm_config) + + logging.info("Engine initialized with custom input specs:") + for spec in self.custom_input_specs: + logging.info(f" - {spec}") + finally: + if eartts_env: + os.environ.pop("VLLM_ATTENTION_BACKEND", None) + os.environ.pop("NVIDIA_TF32_OVERRIDE", None) + _cached_get_attn_backend.cache_clear() + + def _get_safe_prompt_tokens(self, length: int = 10) -> list[int]: + """Generate safe prompt tokens that won't cause immediate EOS.""" + if self.engine and hasattr(self.engine, 'tokenizer') and self.engine.tokenizer: + bos_id = getattr(self.engine.tokenizer, 'bos_token_id', 1) + else: + bos_id = 1 + + # Mix of BOS + safe alphanumeric tokens + safe_tokens = [bos_id] + list(range(50, 59)) # tokens 50-58 are usually safe + return (safe_tokens * ((length // len(safe_tokens)) + 1))[:length] + + async def start_generation( + self, + request_id: str = "speech_stream" + ) -> bool: + """ + Start a new streaming generation session. + + Args: + request_id: Unique identifier for this generation request + + Returns: + bool: True if generation started successfully + """ + if self.engine is None: + raise RuntimeError("Engine not initialized! Call initialize() first.") + + # Check if request already exists + if request_id in self.requests: + existing_state = self.requests[request_id] + if existing_state.status == StreamStatus.ACTIVE: + logging.warning(f"Request {request_id} is already active. Aborting it first.") + await self.abort_generation(request_id) + + logging.info(f"Starting generation session with request_id: {request_id}") + + # Create new request state + self.requests[request_id] = RequestState( + request_id=request_id, + status=StreamStatus.ACTIVE, + generated_tokens=[], + generation_iterator=None # Will be created on first generate_next_token call + ) + return True + + async def generate_next_token(self, input_tensors: list[torch.Tensor], + prompt_token_ids: list[int] | None = None, + request_id: str = "speech_stream") -> GenerationResult | None: + """ + Generate the next token using the provided input embedding. + + Args: + input_tensors: List of tensors for generating the next token + prompt_token_ids: Optional list of token IDs for the system prompt + request_id: Unique identifier for this generation request + + Returns: + GenerationResult or None if generation is finished/aborted + """ + if request_id not in self.requests: + raise RuntimeError(f"Request {request_id} not found. Call start_generation() first.") + + request_state = self.requests[request_id] + + if request_state.status != StreamStatus.ACTIVE: + logging.warning(f"Generation not active for request {request_id} (status: {request_state.status})") + return None + + assert len(input_tensors) == len(self.custom_input_specs), f"Expected {len(self.custom_input_specs)} input tensors, got {len(input_tensors)}" + + if self.engine is None: + raise RuntimeError("Engine not initialized") + + custom_inputs = {} + max_length = 1 + for i, spec in enumerate(self.custom_input_specs): + input_dtype = spec.dtype + if input_dtype is None: + input_dtype = "float32" # Default dtype + if spec.dim is not None and spec.dim != input_tensors[i].shape[-1]: + raise ValueError(f"Input tensor dimension mismatch for {spec.name}: expected {spec.dim}, got {input_tensors[i].shape[-1]}") + custom_inputs[spec.name] = input_tensors[i].to(dtype=getattr(torch, input_dtype)).cpu() + max_length = max(max_length, input_tensors[i].shape[0]) + + try: + # If this is the first call, initialize the generation + if request_state.generation_iterator is None: + # Create initial inputs with a single safe prompt token + # this will not be used for generation, just to initialize the model state + prompt_tokens = prompt_token_ids if prompt_token_ids is not None else self._get_safe_prompt_tokens(max_length) + assert len(prompt_tokens) == max_length, f"Prompt tokens length {len(prompt_tokens)} does not match input length {max_length}" + inputs = { + "prompt_token_ids": prompt_tokens, + "custom_inputs": custom_inputs + } + + logging.info(f"Initializing generation for request {request_id} with first embedding") + + # Start generation + request_state.generation_iterator = self.engine.generate( + inputs, + self.sampling_params, + request_id=request_id + ) + + # If this is not the first call, append the current embedding for processing + elif request_state.generation_iterator is not None and len(request_state.generated_tokens) > 0: + try: + await self.engine.append_request( + request_id=request_id, + custom_inputs=custom_inputs + ) + except ValueError as e: + if "not found" in str(e): + logging.warning(f"Request {request_id} was removed from vLLM engine. Marking as finished.") + request_state.status = StreamStatus.FINISHED + return None + else: + raise RuntimeError(f"Error appending to request {request_id}: {e}") + + # Get next output from the generation + output = await request_state.generation_iterator.__anext__() + + # Extract new tokens + current_tokens = output.outputs[0].token_ids + if len(current_tokens) > len(request_state.generated_tokens): + new_tokens = current_tokens[len(request_state.generated_tokens):] + assert len(new_tokens) == 1, f"Expected exactly one new token, got {len(new_tokens)}" + new_tokens = current_tokens[-1:] + request_state.generated_tokens.extend(new_tokens) + + # Get the latest token + latest_token = new_tokens[-1] + + # Check if finished + if output.finished: + request_state.status = StreamStatus.FINISHED + finish_reason = output.outputs[0].finish_reason + logging.warning(f"Request {request_id} finished after {len(request_state.generated_tokens)} tokens. Reason: {finish_reason}") + return GenerationResult( + token_id=latest_token, + custom_outputs=output.outputs[0].custom_outputs if hasattr(output.outputs[0], 'custom_outputs') else None, + is_finished=True, + finish_reason=finish_reason, + total_tokens=len(request_state.generated_tokens) + ) + else: + return GenerationResult( + token_id=latest_token, + custom_outputs=output.outputs[0].custom_outputs if hasattr(output.outputs[0], 'custom_outputs') else None, + is_finished=False, + total_tokens=len(request_state.generated_tokens) + ) + else: + # No new tokens generated + if output.finished: + logging.warning("No new tokens but finished!") + logging.warning(output.outputs[0].finish_reason) + request_state.status = StreamStatus.FINISHED + return None + + except StopAsyncIteration: + # Generation ended + request_state.status = StreamStatus.FINISHED + return None + except Exception as e: + logging.error(f"Error in generate_next_token for request {request_id}: {e}") + request_state.status = StreamStatus.FINISHED + return None + + async def abort_generation(self, request_id: str = "speech_stream") -> bool: + """ + Abort a specific generation request. + + Args: + request_id: Unique identifier for the generation request to abort + + Returns: + bool: True if abort was successful + """ + if request_id not in self.requests: + logging.warning(f"Request {request_id} not found") + return False + + request_state = self.requests[request_id] + + if request_state.status != StreamStatus.ACTIVE: + logging.info(f"Request {request_id} is {request_state.status.value}, cleaning up state") + # Just remove the state, no need to abort + del self.requests[request_id] + return True + + try: + await self.engine.abort(request_id) + request_state.status = StreamStatus.ABORTED + del self.requests[request_id] + logging.info(f"Aborted generation for request: {request_id}") + return True + except Exception as e: + logging.error(f"Error aborting generation for request {request_id}: {e}") + return False + + async def shutdown(self): + """Shutdown the engine and cleanup resources.""" + # Abort all active requests + for request_id, request_state in list(self.requests.items()): + if request_state.status == StreamStatus.ACTIVE: + await self.abort_generation(request_id) + + if self.engine is not None: + logging.info("Shutting down engine...") + self.engine.shutdown() + self.engine = None + logging.info("Engine shutdown complete.") + + # Clear all request states + self.requests.clear() + + def get_status(self, request_id: str | None = None) -> dict[str, Any]: + """Get current status information. + + Args: + request_id: If provided, return status for specific request. + If None, return status for all requests. + + Returns: + Status information dictionary + """ + if request_id is not None: + if request_id not in self.requests: + return {"error": f"Request {request_id} not found"} + + request_state = self.requests[request_id] + return { + "request_id": request_id, + "status": request_state.status.value, + "tokens_generated": len(request_state.generated_tokens), + "latest_tokens": request_state.generated_tokens[-5:] if request_state.generated_tokens else [] + } + else: + # Return summary of all requests + return { + "total_requests": len(self.requests), + "requests": { + rid: { + "status": state.status.value, + "tokens_generated": len(state.generated_tokens) + } + for rid, state in self.requests.items() + } + } + + async def __aenter__(self): + """Async context manager entry.""" + await self.initialize() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Async context manager exit.""" + await self.shutdown() + + +def create_engine(engine_type: str = "llm", **kwargs) -> CustomInputAsyncVLLMEngine: + """ + Factory function to create a CustomInputAsyncVLLMEngine instance. + + Args: + engine_type: ``"llm"`` or ``"eartts"`` (maps to ``engine_kind``). + **kwargs: Passed to the engine (model_path, max_model_len, etc.). + Returns: + A configured ``CustomInputAsyncVLLMEngine``. + """ + + if engine_type == "eartts": + return CustomInputAsyncVLLMEngine(engine_kind="eartts", **kwargs) + elif engine_type == "llm": + return CustomInputAsyncVLLMEngine(engine_kind="llm", **kwargs) + else: + raise ValueError(f"Unsupported engine_type: {engine_type}") + diff --git a/nemo/collections/speechlm2/models/duplex_ear_tts.py b/nemo/collections/speechlm2/models/duplex_ear_tts.py index e1e982f71a63..4a9aa1f49830 100644 --- a/nemo/collections/speechlm2/models/duplex_ear_tts.py +++ b/nemo/collections/speechlm2/models/duplex_ear_tts.py @@ -111,8 +111,13 @@ def __init__(self, cfg: dict) -> None: # compute samples per frame self.source_samples_per_frame = int(self.source_sample_rate * cfg.data.frame_length) - # get codec silence tokens - codec_silence_tokens = self.get_codec_silence_frame() + # get codec silence tokens (skip when codec has random weights — the + # buffer will be overwritten from the checkpoint) + if self.cfg.get('pretrained_codec_model', None) is not None: + codec_silence_tokens = self.get_codec_silence_frame() + else: + num_q = self.tts_model.config.num_quantizers + codec_silence_tokens = torch.zeros(num_q, dtype=torch.long) self.register_buffer("codec_silence_tokens", codec_silence_tokens) # cached for quicker audio decoding diff --git a/nemo/collections/speechlm2/models/duplex_stt_model.py b/nemo/collections/speechlm2/models/duplex_stt_model.py index d15becb2cbc8..4d7c322036eb 100644 --- a/nemo/collections/speechlm2/models/duplex_stt_model.py +++ b/nemo/collections/speechlm2/models/duplex_stt_model.py @@ -14,7 +14,8 @@ import copy import os import re - +import warnings +from pathlib import Path import torch from lightning import LightningModule from omegaconf import DictConfig @@ -45,6 +46,7 @@ from nemo.collections.speechlm2.parts.pretrained import ( load_pretrained_hf, maybe_load_pretrained_models, + resolve_pretrained_config, set_model_dict_for_partial_init, setup_speech_encoder, ) @@ -59,7 +61,7 @@ def maybe_rename_llm_kwargs_for_nemotron(kwargs: dict, model_cfg) -> dict: return kwargs cache = kwargs.pop("past_key_values") if cache is not None: - cache_key = model_cfg.get("cache_key", "past_key_values") + cache_key = model_cfg.get("cache_key", "cache_params") kwargs[cache_key] = cache return kwargs @@ -79,16 +81,19 @@ def __init__(self, cfg: dict) -> None: self.predict_user_text = self.cfg.get("predict_user_text", False) + pretrained_weights, tokenizer_path = resolve_pretrained_config(self.cfg) + # Load LLM first llm = load_pretrained_hf( self.cfg.pretrained_llm, - pretrained_weights=self.cfg.pretrained_weights, + pretrained_weights=pretrained_weights, trust_remote_code=self.cfg.get("trust_remote_code", True), + use_meta_device=self.cfg.get("use_meta_device", False), ).train() # Initialize tokenizer with optional special tokens from config self.tokenizer = AutoTokenizer( - self.cfg.pretrained_llm, + tokenizer_path, use_fast=True, bos_token=self.cfg.get("bos_token", None), eos_token=self.cfg.get("eos_token", None), @@ -109,10 +114,17 @@ def __init__(self, cfg: dict) -> None: self.asr_head = copy.deepcopy(self.lm_head) self.embed_asr_tokens = copy.deepcopy(self.embed_tokens) + self.use_function_head = self.cfg.get("use_function_head", False) + if self.use_function_head: + self.function_head = copy.deepcopy(self.lm_head) + logging.info("[Function Calling] Initialized function_head (deep copy of lm_head)") + else: + self.function_head = None + maybe_install_lora(self) # Load the pretrained streaming ASR model - setup_speech_encoder(self, pretrained_weights=self.cfg.pretrained_weights) + setup_speech_encoder(self, pretrained_weights=pretrained_weights) maybe_load_pretrained_models(self) @@ -122,6 +134,24 @@ def __init__(self, cfg: dict) -> None: # Initialize streaming inference engine self.streaming_inference = DuplexSTTStreamingInference(self) + def save_pretrained( + self, + save_directory: str | Path, + **kwargs, + ) -> str | None: + """Save model and also export LLM artifacts (config + tokenizer) for offline inference.""" + result = super().save_pretrained(save_directory, **kwargs) + + try: + llm_dir = Path(save_directory) / "llm_artifacts" + llm_dir.mkdir(parents=True, exist_ok=True) + self.tokenizer.tokenizer.save_pretrained(str(llm_dir)) + logging.info(f"Saved LLM tokenizer to {llm_dir}") + except Exception as e: + warnings.warn(f"Failed to save LLM tokenizer: {e}. Inference will fall back to downloading from HF.") + + return result + @property def text_vocab_size(self): """Return the size of the text tokenizer.""" @@ -167,12 +197,15 @@ def forward( self, input_embeds: Tensor, cache=None, + cache_position=None, ) -> dict[str, Tensor]: """ Text prediction only (audio_loss_weight=0). """ kwargs = dict(inputs_embeds=input_embeds, past_key_values=cache, use_cache=cache is not None, return_dict=True) kwargs = maybe_rename_llm_kwargs_for_nemotron(kwargs, self.cfg) + if cache_position is not None: + kwargs["cache_position"] = cache_position out = self.llm(**kwargs) B, T = input_embeds.shape[:2] @@ -191,9 +224,22 @@ def forward( if self.cfg.get("inference_eos_boost", None): text_logits[:, :, self.text_eos_id] += self.cfg.inference_eos_boost + if self.predict_user_text: + if self.cfg.get("inference_user_pad_boost", None): + asr_logits[:, :, self.text_pad_id] += self.cfg.inference_user_pad_boost + if self.cfg.get("inference_user_bos_boost", None): + asr_logits[:, :, self.text_bos_id] += self.cfg.inference_user_bos_boost + if self.cfg.get("inference_user_eos_boost", None): + asr_logits[:, :, self.text_eos_id] += self.cfg.inference_user_eos_boost + ans = {"text_logits": text_logits} if self.predict_user_text: ans["asr_logits"] = asr_logits + if self.function_head is not None: + function_in = out['last_hidden_state'] + if function_in.dtype != self.function_head.weight.dtype: + function_in = function_in.to(self.function_head.weight.dtype) + ans["function_logits"] = self.function_head(function_in) if cache is not None: if 'Nemotron' in self.cfg.pretrained_llm: @@ -728,3 +774,148 @@ def load_state_dict(self, state_dict, strict: bool = True): logging.info("Error loading model state_dict !! Retrying with partial initialization!") model_dict = set_model_dict_for_partial_init(state_dict, self.state_dict()) return super().load_state_dict(model_dict, strict=False) + + def _create_nemotron_cache(self, batch_size: int = 1): + """Create a HybridMambaAttentionDynamicCache for Nemotron hybrid Mamba2/Attention models.""" + import importlib + cache_cls = None + llm = self.llm + if hasattr(llm, '_orig_mod'): + llm = llm._orig_mod + model_module = type(llm).__module__ + if model_module: + mod = importlib.import_module(model_module) + cache_cls = getattr(mod, "HybridMambaAttentionDynamicCache", None) + if cache_cls is None: + raise RuntimeError( + "Could not find HybridMambaAttentionDynamicCache in the Nemotron model's module. " + "Ensure the model was loaded with trust_remote_code=True." + ) + + # Newer transformers defines key_cache/value_cache as read-only + # properties on DynamicCache, but the Nemotron __init__ tries to set + # them as regular attributes. Create a patched subclass that shadows + # the properties so the assignment succeeds. + needs_patch = any( + isinstance(cls.__dict__.get(attr), property) + for cls in cache_cls.__mro__ + for attr in ('key_cache', 'value_cache') + ) + if needs_patch: + patched = type(cache_cls.__name__, (cache_cls,), { + 'key_cache': None, + 'value_cache': None, + }) + else: + patched = cache_cls + + config = llm.config + cache = patched( + config=config, + batch_size=batch_size, + dtype=llm.dtype if hasattr(llm, 'dtype') else torch.float32, + device=self.device, + ) + if not hasattr(cache, 'conv_kernel_size'): + cache.conv_kernel_size = config.conv_kernel + + intermediate_size = config.mamba_num_heads * config.mamba_head_dim + conv_dim = intermediate_size + 2 * config.n_groups * config.ssm_state_size + if conv_dim != intermediate_size: + conv_kernel = config.conv_kernel + dtype = llm.dtype if hasattr(llm, 'dtype') else torch.float32 + for i, pattern in enumerate(config.hybrid_override_pattern): + if pattern == "M" and i < len(cache.conv_states): + cache.conv_states[i] = torch.zeros( + batch_size, conv_dim, conv_kernel, device=self.device, dtype=dtype, + ) + + self._patch_nemotron_cache_bugs(cache) + self._patch_nemotron_block_forward(llm) + return cache + + @staticmethod + def _patch_nemotron_block_forward(llm): + """Patch NemotronHBlock.forward to pass cache and mask to attention layers. + + The upstream HF code only passes cache_position to the attention mixer, + omitting past_key_value and attention_mask. Without this fix, attention + layers never see the KV cache and only attend to the current token. + """ + import types + + if hasattr(llm, '_orig_mod'): + llm = llm._orig_mod + + layers = getattr(llm, 'layers', None) + if layers is None: + return + + def _patched_block_forward( + self, + hidden_states, + cache_params=None, + cache_position=None, + attention_mask=None, + ): + with torch.cuda.stream(torch.cuda.default_stream(hidden_states.device)): + residual = hidden_states + hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype)) + if self.residual_in_fp32: + residual = residual.to(torch.float32) + + if self.block_type == "mamba": + hidden_states = self.mixer( + hidden_states, cache_params=cache_params, cache_position=cache_position + ) + elif self.block_type == "attention": + hidden_states = self.mixer( + hidden_states, + past_key_value=cache_params, + attention_mask=attention_mask, + cache_position=cache_position, + ) + hidden_states = hidden_states[0] + elif self.block_type == "mlp": + hidden_states = self.mixer(hidden_states) + else: + raise ValueError(f"Invalid block_type: {self.block_type}") + + hidden_states = residual + hidden_states + return hidden_states + + patched_count = 0 + for layer in layers: + block_type = getattr(layer, 'block_type', None) + if block_type == "attention": + layer.forward = types.MethodType(_patched_block_forward, layer) + patched_count += 1 + + if patched_count > 0: + logging.info(f"Patched {patched_count} NemotronHBlock attention layers to pass KV cache") + + @staticmethod + def _patch_nemotron_cache_bugs(cache): + """Patch bugs in HybridMambaAttentionDynamicCache from the HF model code. + + The upstream code references self.conv_states.device and self.ssm_states.device, + but these are Python lists, not tensors. We monkey-patch the affected methods. + """ + import types + + def update_conv_state(self, layer_idx, new_conv_state, cache_init=False): + if cache_init: + self.conv_states[layer_idx] = new_conv_state.to(self.conv_states[layer_idx].device) + else: + self.conv_states[layer_idx] = self.conv_states[layer_idx].roll(shifts=-1, dims=-1) + self.conv_states[layer_idx][:, :, -1] = new_conv_state[:, 0, :].to( + self.conv_states[layer_idx].device + ) + return self.conv_states[layer_idx] + + def update_ssm_state(self, layer_idx, new_ssm_state): + self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states[layer_idx].device) + return self.ssm_states[layer_idx] + + cache.update_conv_state = types.MethodType(update_conv_state, cache) + cache.update_ssm_state = types.MethodType(update_ssm_state, cache) diff --git a/nemo/collections/speechlm2/models/nemotron_voicechat.py b/nemo/collections/speechlm2/models/nemotron_voicechat.py index 8c849a5c1548..760e44b6df42 100644 --- a/nemo/collections/speechlm2/models/nemotron_voicechat.py +++ b/nemo/collections/speechlm2/models/nemotron_voicechat.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import gc +import json import os +import warnings from pathlib import Path -from typing import Optional, Union - import torch from huggingface_hub import CONFIG_NAME from lightning import LightningModule @@ -120,8 +120,10 @@ def __init__(self, cfg: dict) -> None: # Load Duplex TTS model self.tts_model = DuplexEARTTS(OmegaConf.to_container(self.cfg.speech_generation, resolve=True)) - # reset silence tokens to avoid inference issues - self.tts_model.codec_silence_tokens = self.tts_model.get_codec_silence_frame() + # reset silence tokens to avoid inference issues (skip when codec + # has random weights — the buffer will be loaded from checkpoint) + if self.tts_model.cfg.get('pretrained_codec_model', None) is not None: + self.tts_model.codec_silence_tokens = self.tts_model.get_codec_silence_frame() self.target_fps = self.tts_model.target_fps # compute source fps self.source_fps = self.source_sample_rate / ( @@ -131,6 +133,39 @@ def __init__(self, cfg: dict) -> None: self._use_fsdp = False self._use_tp = False + def save_pretrained( + self, + save_directory: str | Path, + **kwargs, + ) -> str | None: + """Save model and export LLM artifacts (tokenizer + perception config) for offline inference.""" + result = super().save_pretrained(save_directory, **kwargs) + + # Save tokenizer for offline loading + try: + llm_dir = Path(save_directory) / "llm_artifacts" + llm_dir.mkdir(parents=True, exist_ok=True) + self.stt_model.tokenizer.tokenizer.save_pretrained(str(llm_dir)) + logging.info(f"Saved LLM tokenizer to {llm_dir}") + except Exception as e: + warnings.warn(f"Failed to save LLM tokenizer: {e}. Inference will fall back to downloading from HF.") + + # Save full perception config at the top level of config.json so that + # resolve_pretrained_config() can skip pretrained ASR/LLM downloads. + try: + config_path = Path(save_directory) / "config.json" + if config_path.exists(): + with open(config_path) as f: + config = json.load(f) + config["perception"] = OmegaConf.to_container(self.stt_model.cfg.perception, resolve=True) + with open(config_path, "w") as f: + json.dump(config, f, indent=2) + logging.info(f"Saved perception config to {config_path}") + except Exception as e: + warnings.warn(f"Failed to save perception config: {e}") + + return result + def init_from_model_from_ckpt(self, checkpoint_path): if checkpoint_path is not None: checkpoint_state = torch.load(checkpoint_path, weights_only=False, map_location='cpu')['state_dict'] @@ -144,13 +179,13 @@ def _from_pretrained( cls, *, model_id: str, - revision: Optional[str], - cache_dir: Optional[Union[str, Path]], + revision: str | None, + cache_dir: str | Path | None, force_download: bool, - proxies: Optional[dict], - resume_download: Optional[bool], + proxies: dict | None, + resume_download: bool | None, local_files_only: bool, - token: Union[str, bool, None], + token: str | bool | None, map_location: str = "cpu", strict: bool = False, **model_kwargs, @@ -183,6 +218,25 @@ def _from_pretrained( # Skip loading child module weights natively model_kwargs['cfg']['pretrained_weights'] = False + # Propagate pretrained_weights=False into nested configs so child + # modules skip downloading pretrained ASR, LLM, and codec models. + cfg = model_kwargs['cfg'] + try: + stt_model_cfg = cfg['model']['stt']['model'] + stt_model_cfg['pretrained_weights'] = False + stt_model_cfg['use_meta_device'] = True + if 'perception' in cfg: + stt_model_cfg['perception'] = cfg['perception'] + logging.info("Injected saved perception config into STT model config") + except (KeyError, TypeError): + logging.warning("Could not propagate pretrained_weights=False into nested STT config") + try: + tts_model_cfg = cfg['model']['speech_generation']['model'] + tts_model_cfg['pretrained_model'] = None + tts_model_cfg['pretrained_codec_model'] = None + except (KeyError, TypeError): + logging.warning("Could not nullify pretrained TTS/codec paths in nested TTS config") + # Instantiate the empty model skeleton model = cls(model_kwargs['cfg']) @@ -204,12 +258,32 @@ def _from_pretrained( if resolved_weights_file is None: raise RuntimeError(f"Missing model.safetensors file for {model_id=}") - # Stream the weights safely using your custom memory-efficient loader! + # Stream the weights from safetensors ckpt_dir = os.path.dirname(resolved_weights_file) model.init_from_safetensors_ckpt(ckpt_dir) return model + def _replace_tensor(self, full_key, value): + """Replace a parameter or buffer on its parent module. + + Meta-device tensors (from torch.device('meta')) have no storage, so the + usual ``target.data.copy_(tensor)`` raises an error. Instead we walk + the module tree to find the parent and swap the entry directly in + ``module._parameters`` or ``module._buffers``. + """ + parts = full_key.split(".") + module = self + for part in parts[:-1]: + module = getattr(module, part) + name = parts[-1] + if name in module._parameters: + module._parameters[name] = torch.nn.Parameter( + value, requires_grad=module._parameters[name].requires_grad + ) + elif name in module._buffers: + module._buffers[name] = value + def init_from_safetensors_ckpt(self, ckpt_path, prefix=""): """ Memory-efficient streaming safetensors loader with dynamic @@ -251,6 +325,8 @@ def init_from_safetensors_ckpt(self, ckpt_path, prefix=""): if target.shape != tensor.shape: logging.warning(f"Shape mismatch for {key}: " f"model {target.shape} vs ckpt {tensor.shape}") + elif target.is_meta: + self._replace_tensor(prefix + key, tensor) else: target.data.copy_(tensor) @@ -263,6 +339,8 @@ def init_from_safetensors_ckpt(self, ckpt_path, prefix=""): logging.warning( f"Buffer shape mismatch for {key}: " f"model {target.shape} vs ckpt {tensor.shape}" ) + elif target.is_meta: + self._replace_tensor(prefix + key, tensor) else: target.data.copy_(tensor) @@ -281,6 +359,35 @@ def init_from_safetensors_ckpt(self, ckpt_path, prefix=""): if missing_keys: logging.warning(f"{len(missing_keys)} keys in checkpoint not found in model") + # Fail if any *parameters* are still on meta device — those genuinely + # need weights from the checkpoint and their absence is an error. + meta_params = [n for n, p in self.named_parameters() if p.is_meta] + if meta_params: + raise RuntimeError( + f"{len(meta_params)} parameters still on meta device after checkpoint load " + f"(missing from checkpoint): {meta_params[:20]}" + ) + + # Buffers on meta device are typically non-persistent computed values + # (e.g. rotary_emb.inv_freq registered with persistent=False) that + # save_pretrained / safetensors intentionally omit. Reinitialize + # them by moving the owning module to CPU then back, which triggers + # the buffer's factory function. + meta_buffers = [(n, b) for n, b in self.named_buffers() if b.is_meta] + if meta_buffers: + logging.info( + f"Reinitializing {len(meta_buffers)} non-persistent meta buffer(s): " + f"{[n for n, _ in meta_buffers[:10]]}" + ) + for buf_name, buf in meta_buffers: + parts = buf_name.split(".") + module = self + for part in parts[:-1]: + module = getattr(module, part) + module.to_empty(device="cpu") + module.to(device="cpu") + logging.info("Meta buffers reinitialised on CPU") + gc.collect() def training_step(self, batch: dict, batch_idx: int): @@ -437,6 +544,7 @@ def offline_inference( incremental_audio_decoding: bool = False, generation_config: dict = None, guidance_enabled: bool = True, + return_logits: bool = False, ) -> dict[str, torch.Tensor]: """ Runs full offline duplex speech-to-speech inference. @@ -485,6 +593,12 @@ def offline_inference( guidance_enabled (bool, optional): Enables classifier-free guidance. + return_logits (bool, optional): + When True, collect per-step text and ASR logits and + include them in the returned dict as ``"text_logits"`` + (B, T, V_text) and ``"asr_logits"`` (B, T, V_asr). + Useful for parity testing against incremental inference. + Returns: dict[str, torch.Tensor]: @@ -508,6 +622,12 @@ def offline_inference( Tensor (B,) — waveform lengths in samples (if decode_audio=True). + • "text_logits" (only when return_logits=True): + Tensor (B, T, V_text) — per-step text head logits. + + • "asr_logits" (only when return_logits=True): + Tensor (B, T, V_asr) — per-step ASR head logits. + Notes: • Uses streaming inference backend of DuplexSTTModel. • Uses autoregressive codec generation from DuplexEARTTS. @@ -525,6 +645,10 @@ def offline_inference( B = inference_state["B"] T = inference_state["T"] + if return_logits: + _text_logits = [ans["text_logits"][:, -1].detach()] + _asr_logits = [ans["asr_logits"][:, -1].detach()] if "asr_logits" in ans else [] + # if speaker_name is provided uses it, if not uses the speaker_audio provided, if speaker_audio is None load it from inference_speaker_reference if speaker_audio is None: speaker_name = self.cfg.get("inference_speaker_name", None) @@ -572,7 +696,12 @@ def offline_inference( # Autoregressive loop for t in range(1, T): # do one step inference on Duplex STT model - _ = self.stt_model.streaming_inference._step_inference(t, inference_state, ans) + ans = self.stt_model.streaming_inference._step_inference(t, inference_state, ans) + + if return_logits: + _text_logits.append(ans["text_logits"][:, -1].detach()) + if "asr_logits" in ans: + _asr_logits.append(ans["asr_logits"][:, -1].detach()) # do one step inference on Duplex TTS model # current subword id is always seem @@ -609,7 +738,7 @@ def offline_inference( audio_pred = torch.cat([audio_pred, audio_pred_i], dim=1) audio_pred_len += audio_pred_i_len - logging.info(f"Autoregressive inference step: {t} of {T} !") + logging.debug(f"Autoregressive inference step: {t} of {T} !") # Trim back to local length if padded if self._use_fsdp and T > inference_state["T_local"]: @@ -629,6 +758,11 @@ def offline_inference( ans["audio"] = audio_pred.squeeze(1) ans["audio_len"] = audio_pred_len + if return_logits: + ans["text_logits"] = torch.stack(_text_logits, dim=1) + if _asr_logits: + ans["asr_logits"] = torch.stack(_asr_logits, dim=1) + return ans def load_state_dict(self, state_dict, strict: bool = True): diff --git a/nemo/collections/speechlm2/modules/ear_tts_model.py b/nemo/collections/speechlm2/modules/ear_tts_model.py index 879416915cc7..a32ca5740811 100644 --- a/nemo/collections/speechlm2/modules/ear_tts_model.py +++ b/nemo/collections/speechlm2/modules/ear_tts_model.py @@ -903,6 +903,8 @@ def __init__( if self.use_bos_eos_emb: self.bos_eos_emb = BOSEOSEmbedding(tokenizer, self.hidden_size) + self.use_tts_subword_cache = False + def prepare_inputs(self, subword_ids: Tensor, padding_mask: Tensor) -> tuple[Tensor, Tensor]: """ Converts a batch of subword IDs into a padded batch of character IDs. @@ -937,6 +939,10 @@ def forward(self, subword_ids: Tensor, subword_mask: Tensor | None = None) -> Te """ Performs the forward pass to get character-aware subword embeddings. + When use_tts_subword_cache is True and the module is in eval mode, a + per-subword-ID cache skips the expensive char encoding + backbone + + pooling path for previously seen tokens. + Args: subword_ids (Tensor): A tensor of subword IDs. Shape: `[batch, seq_len]`. subword_mask (Tensor | None): A boolean mask for padding. Defaults to None. @@ -947,6 +953,19 @@ def forward(self, subword_ids: Tensor, subword_mask: Tensor | None = None) -> Te if subword_mask is None: subword_mask = torch.ones_like(subword_ids, dtype=torch.bool) + # Inference cache: return cached embeddings if all valid IDs have been seen + if not self.training and self.use_tts_subword_cache: + if not hasattr(self, '_inference_cache'): + self._inference_cache = {} + valid_ids = torch.masked_select(subword_ids, subword_mask).tolist() + if all(sid in self._inference_cache for sid in valid_ids): + cached = torch.stack([self._inference_cache[sid] for sid in valid_ids]) + out = torch.zeros( + subword_ids.shape + (cached.size(-1),), device=subword_ids.device, dtype=cached.dtype + ) + out[subword_mask] = cached + return out + # 1. Convert subword IDs to character IDs char_ids, char_lengths = self.prepare_inputs(subword_ids, subword_mask) @@ -978,6 +997,13 @@ def forward(self, subword_ids: Tensor, subword_mask: Tensor | None = None) -> Te if self.use_bos_eos_emb: subword_embeds = self.bos_eos_emb(subword_embeds, subword_ids) + # Cache results for future lookups + if not self.training and self.use_tts_subword_cache: + valid_ids = torch.masked_select(subword_ids, subword_mask).tolist() + valid_embeds = subword_embeds[subword_mask].detach() + for idx, sid in enumerate(valid_ids): + self._inference_cache[sid] = valid_embeds[idx] + return subword_embeds @@ -1146,15 +1172,13 @@ def depthsum_embedding(self, code: Tensor) -> Tensor: ret: [b, t, h] """ b, t, d = code.size() - _, v, h = self.rvq_embs.size() - device = code.device - - ret = torch.zeros((b, t, h), device=device) embs = F.pad(self.rvq_embs, [0, 0, 0, 1]) - for i in range(d): - emb = embs[i] - ret = ret + F.embedding(code[..., i], emb) - return ret + v_padded = embs.shape[1] + offsets = torch.arange(d, device=code.device).view(1, 1, d) * v_padded + flat_indices = (code + offsets).reshape(b * t * d) + flat_embs = embs.reshape(d * v_padded, -1) + gathered = F.embedding(flat_indices, flat_embs).reshape(b, t, d, -1) + return gathered.sum(dim=2) def prepare_training_inputs(self, code: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: """Prepares masked and dropped-out versions of the code for training.""" @@ -1638,11 +1662,13 @@ def generate_step( num_maskings = torch.ceil(masking_rates * self.config.num_quantizers).long() ks = num_maskings - F.pad(num_maskings[1:], [0, 0, 0, 1]) + ks_list = ks.squeeze(-1).tolist() # 4. Iteratively unmask the continuous part of the code cnt = 0 - for i, k in enumerate(ks): - if torch.all(k == 0): + for i, k_val in enumerate(ks_list): + k_val = int(k_val) + if k_val == 0: continue # Prepare input for the MoG head @@ -1652,7 +1678,7 @@ def generate_step( mog_input_embeds = self.embed_code(self.depthsum_embedding(code)) if self.config.random_target_masking: - mog_input_embeds += self.embed_target_mask(cnt + k - 1) + mog_input_embeds += self.embed_target_mask(cnt + k_val - 1) if guidance_scale_i > 0.0: mog_input_embeds = torch.cat( [mog_input_embeds + hidden_states, mog_input_embeds + uncond_hidden_states], 0 @@ -1666,8 +1692,8 @@ def generate_step( top_p_or_k=top_p_or_k_i, ) z = mog_mu + torch.exp(mog_logs) * torch.randn_like(mog_mu) * noise_scale_i - code = depthsum_encoding_step(self.rvq_embs, z, code, cnt, k[0].item()) - cnt += k[0].item() + code = depthsum_encoding_step(self.rvq_embs, z, code, cnt, k_val) + cnt += k_val return code, lm_logits, eos_flag def load_state_dict(self, state_dict, strict: bool = True): diff --git a/nemo/collections/speechlm2/parts/pretrained.py b/nemo/collections/speechlm2/parts/pretrained.py index bc158ce24c42..626ac25bebdd 100644 --- a/nemo/collections/speechlm2/parts/pretrained.py +++ b/nemo/collections/speechlm2/parts/pretrained.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import json from contextlib import contextmanager from pathlib import Path from typing import Dict @@ -42,7 +43,11 @@ def load_pretrained_nemo(cls, model_path_or_name: str): def load_pretrained_hf( - model_path_or_name: str, pretrained_weights: bool = True, dtype=torch.float32, trust_remote_code: bool = False + model_path_or_name: str, + pretrained_weights: bool = True, + dtype=torch.float32, + trust_remote_code: bool = False, + use_meta_device: bool = False, ): """ Load pretrained HuggingFace AutoModelForCausalLM. @@ -55,6 +60,8 @@ def load_pretrained_hf( pretrained_weights: Whether to load pretrained weights (True) or random init (False) dtype: Data type for the model trust_remote_code: Whether to trust remote code when loading model (needed for some models like Nemotron) + use_meta_device: If True, create the model on the meta device (no memory allocation). + The caller must handle materializing meta tensors from a checkpoint. """ if pretrained_weights: return AutoModelForCausalLM.from_pretrained( @@ -62,9 +69,47 @@ def load_pretrained_hf( ) else: config = AutoConfig.from_pretrained(model_path_or_name, trust_remote_code=trust_remote_code) + if use_meta_device: + with torch.device('meta'): + return AutoModelForCausalLM.from_config(config, torch_dtype=dtype, trust_remote_code=trust_remote_code) return AutoModelForCausalLM.from_config(config, torch_dtype=dtype, trust_remote_code=trust_remote_code) +def resolve_pretrained_config(cfg): + """Resolve pretrained config when pretrained_s2s_model points to an HF checkpoint. + + When the HF checkpoint contains a config.json with perception config, this function: + - Sets pretrained_weights to False (weights will be loaded from pretrained_s2s_model) + - Loads the perception config from the HF checkpoint into cfg + - Resolves the tokenizer path to local llm_artifacts if available + + Args: + cfg: DictConfig with model configuration (modified in-place for perception config). + + Returns: + Tuple of (pretrained_weights, tokenizer_path). + """ + tokenizer_path = cfg.pretrained_llm + pretrained_weights = cfg.pretrained_weights + pretrained_s2s = cfg.get("pretrained_s2s_model", None) + if pretrained_s2s is not None: + hf_config_path = Path(pretrained_s2s) / "config.json" + if hf_config_path.exists(): + with open(hf_config_path) as f: + hf_config = json.load(f) + if "perception" in hf_config: + pretrained_weights = False + with open_dict(cfg): + cfg.perception = hf_config["perception"] + logging.info(f"Loaded perception config from {hf_config_path}, skipping pretrained downloads") + # Use local tokenizer if available + llm_artifacts_dir = Path(pretrained_s2s) / "llm_artifacts" + if llm_artifacts_dir.is_dir(): + tokenizer_path = str(llm_artifacts_dir) + logging.info(f"Using local tokenizer from {llm_artifacts_dir}") + return pretrained_weights, tokenizer_path + + @contextmanager def move_embedding(model): """Temporarily restores the embedding layer into HF LLM. Supports LoRA models.""" diff --git a/nemo/collections/speechlm2/parts/text_utils.py b/nemo/collections/speechlm2/parts/text_utils.py index 8c2a36facf9b..b5805ea84627 100644 --- a/nemo/collections/speechlm2/parts/text_utils.py +++ b/nemo/collections/speechlm2/parts/text_utils.py @@ -16,6 +16,62 @@ from nemo.collections.common.tokenizers import AutoTokenizer +def _decode_tokens_with_specials( + token_strings: list[str], + tokenizer, + pad_token_str: str = '', + keep_pad: bool = False, +) -> str: + """Decode token strings with proper byte-level BPE handling. + + Groups consecutive non-special tokens and decodes each group via + ``tokenizer.tokens_to_text()`` (HF ``convert_tokens_to_string``), which + properly reverses byte-level BPE encoding (e.g. ``âĢĻ`` -> ``'``). + Special tokens (BOS, EOS, PAD) are never passed to + ``convert_tokens_to_string``. BOS/EOS are always kept as literal + strings so that turn boundaries are visible. PAD tokens are kept + only when *keep_pad* is True. + + Args: + token_strings: Raw token strings from ``tokenizer.ids_to_tokens()``. + tokenizer: Tokenizer with ``tokens_to_text``, ``bos_token``, and + ``eos_token`` attributes (NeMo ``AutoTokenizer`` or similar). + pad_token_str: String representation of the pad token. + keep_pad: If True, preserve all special tokens as literal strings + in the output. If False, strip them. + """ + bos = getattr(tokenizer, 'bos_token', None) + eos = getattr(tokenizer, 'eos_token', None) + + # All tokens that must not go through convert_tokens_to_string. + special_tokens = {pad_token_str} + if bos: + special_tokens.add(bos) + if eos: + special_tokens.add(eos) + + result_parts: list[str] = [] + segment: list[str] = [] + + for tok in token_strings: + if tok in special_tokens: + if segment: + result_parts.append(tokenizer.tokens_to_text(segment)) + segment = [] + if tok == pad_token_str: + if keep_pad: + result_parts.append(tok) + else: + result_parts.append(tok) + else: + segment.append(tok) + + if segment: + result_parts.append(tokenizer.tokens_to_text(segment)) + + return ''.join(result_parts) + + def tokens_to_str( tokens: torch.Tensor, lengths: torch.Tensor, @@ -23,9 +79,10 @@ def tokens_to_str( pad_id: int, eval_text_turn_taking: bool = False, show_eot_timestamps: bool = False, + keep_pad: bool = False, ) -> list[str]: """ - Convert token IDs to text strings, filtering out special tokens. + Convert token IDs to text strings with proper byte-level BPE decoding. Args: tokens: Token IDs tensor (B, T) @@ -34,14 +91,18 @@ def tokens_to_str( pad_id: Pad token ID to filter out eval_text_turn_taking: If True, insert timestamps at bos/eos positions show_eot_timestamps: If True, also insert timestamps at end-of-text (first pad after BOS) + keep_pad: If True, preserve all special tokens (including pad) as literal + strings in the output. Useful for "raw" output that shows the full + token stream. If False (default), special tokens are stripped. Returns: List of decoded text strings """ + pad_token_str = tokenizer.ids_to_tokens([pad_id])[0] ans = [] - # Helper function to filter special tokens from token IDs - # This filtering is applied regardless of eval_text_turn_taking mode + # Helper function to filter special tokens from token IDs. + # This filtering is applied regardless of eval_text_turn_taking mode. def filter_special_tokens(token_ids): # Filter out pad token_ids = token_ids[token_ids != pad_id] @@ -102,8 +163,9 @@ def filter_special_tokens(token_ids): out_str.append(tokenizer.ids_to_text(remaining_ids)) ans.append(" ".join(out_str)) else: - # For non-turn-taking mode: filter out ALL special tokens, return only pure text hyp_ids = hyp_ids[:hyp_len] - hyp_ids = filter_special_tokens(hyp_ids) - ans.append(tokenizer.ids_to_text(hyp_ids)) + toks = tokenizer.ids_to_tokens(hyp_ids.tolist()) + ans.append( + _decode_tokens_with_specials(toks, tokenizer, pad_token_str=pad_token_str, keep_pad=keep_pad) + ) return ans diff --git a/tests/collections/speechlm2/conftest.py b/tests/collections/speechlm2/conftest.py new file mode 100644 index 000000000000..d37130392370 --- /dev/null +++ b/tests/collections/speechlm2/conftest.py @@ -0,0 +1,280 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared fixtures for speechlm2 tests.""" + +from __future__ import annotations + +import json +import os + +# nemotron_voicechat_pipeline_{parity,nocrash} tests set +# torch.use_deterministic_algorithms(True), which requires CuBLAS to have a +# deterministic workspace. CuBLAS reads this env var only once — at +# initialization (first CUDA matmul in the process) — so it must be set here, +# before any fixture or test triggers CUDA work. The setting is harmless for +# non-deterministic tests: it only reserves 32 KB of extra GPU workspace and +# has no effect unless deterministic mode is active. +os.environ.setdefault("CUBLAS_WORKSPACE_CONFIG", ":4096:8") + +import numpy as np +import pytest +import soundfile as sf +import torch + +from nemo.collections.speechlm2.models import NemotronVoiceChat + +_pretrained_llm = "TinyLlama/TinyLlama_v1.1" +if os.path.exists("/home/TestData/speechlm/pretrained_models"): + _pretrained_llm = "/home/TestData/speechlm/pretrained_models/TinyLlama--TinyLlama_v1.1" + + +def _tiny_voicechat_config(*, predict_user_text: bool = True, streaming_encoder: bool = False) -> dict: + """Return a minimal NemotronVoiceChat config with random weights. + + Args: + predict_user_text: Enable ASR head for user text prediction. + streaming_encoder: When True, configure the conformer encoder for + cache-aware streaming (causal convolutions, chunked_limited + attention) matching the real checkpoint. When False, use + default (non-causal) settings suitable for offline tests. + """ + encoder_cfg: dict = { + "_target_": "nemo.collections.asr.modules.ConformerEncoder", + "feat_in": 80, + "d_model": 512, + "n_heads": 8, + "n_layers": 1, + "subsampling_factor": 8, + } + if streaming_encoder: + encoder_cfg.update({ + "subsampling": "dw_striding", + "causal_downsampling": True, + "att_context_size": [70, 0], + "att_context_style": "chunked_limited", + "conv_kernel_size": 9, + "conv_context_size": "causal", + }) + + return { + "model": { + "scoring_asr": "stt_en_fastconformer_transducer_large", + "stt": { + "model": { + "pretrained_llm": _pretrained_llm, + "pretrained_weights": False, + "predict_user_text": predict_user_text, + "audio_loss_weight": 1, + "text_loss_weight": 3, + "source_sample_rate": 16000, + "validation_save_path": "/tmp/test_duplex_stt_logs", + "perception": { + "_target_": "nemo.collections.speechlm2.modules.perception.AudioPerceptionModule", + "preprocessor": { + "_target_": "nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor", + "features": 80, + }, + "encoder": encoder_cfg, + "modality_adapter": { + "_target_": "nemo.collections.speechlm2.modules.perception.IdentityConnector", + "d_model": 512, + }, + "output_dim": 2048, + }, + "optimizer": {"_target_": "torch.optim.AdamW"}, + }, + "data": {"source_sample_rate": 16000}, + "exp_manager": {"explicit_log_dir": "/tmp/test_duplex_stt_logs"}, + }, + "speech_generation": { + "model": { + "pretrained_lm_name": _pretrained_llm, + "pretrained_ae_dir": None, + "pretrained_tts_model": None, + "scoring_asr": "stt_en_fastconformer_transducer_large", + "freeze_params": [r"^audio_codec\..+$", r"^embed_tokens\..+$"], + "bos_token": "", + "eos_token": "", + "pad_token": "", + "audio_codec_run_dtype": "float32", + "prevent_freeze_params": [], + "audio_save_path": "", + "inference_guidance_scale": 0.5, + "inference_noise_scale": 0.8, + "inference_top_p_or_k": 0.8, + "inference_guidance_enabled": False, + "subword_mask_exactly_as_eartts": False, + "context_hidden_mask_exactly_as_eartts": False, + "optimizer": { + "_target_": "torch.optim.AdamW", + "lr": 4e-5, + "betas": [0.9, 0.98], + "weight_decay": 0, + "foreach": True, + }, + "lr_scheduler": { + "_target_": "nemo.core.optim.lr_scheduler.InverseSquareRootAnnealing", + "warmup_steps": 2500, + "min_lr": 1e-6, + "max_steps": 100_000_000, + }, + "codec_config": { + "latent_size": 512, + "n_fft": 16, + "hop_length": 4, + "base_hidden_size": 384, + "channel_mult": [1, 2, 4], + "rates": [7, 7, 9], + "num_blocks": 3, + "kernel_size": 7, + "groups": 1, + "codebook_size": 1024, + "num_quantizers": 31, + "wav_to_token_ratio": 1764, + }, + "tts_config": { + "use_gated_fusion_for_text_audio": True, + "disable_eos_prediction": True, + "use_bos_eos_emb": True, + "use_subword_flag_emb": True, + "num_delay_speech_tokens": 2, + "backbone_type": "gemma3_text", + "backbone_model_class": None, + "backbone_config_class": None, + "backbone_config": { + "hidden_size": 1152, + "intermediate_size": 4608, + "num_hidden_layers": 1, + "num_attention_heads": 16, + "num_key_value_heads": 16, + "head_dim": 72, + "attention_dropout": 0.1, + "use_cache": False, + }, + "latent_size": 512, + "codebook_size": 1024, + "num_quantizers": 31, + "context_hidden_size": None, + "cas_config": { + "backbone_type": "t5gemma", + "backbone_model_class": None, + "backbone_config_class": None, + "backbone_config": { + "is_encoder_decoder": False, + "encoder": { + "hidden_size": 1152, + "intermediate_size": 4608, + "num_hidden_layers": 1, + "num_attention_heads": 16, + "num_key_value_heads": 16, + "head_dim": 72, + "use_cache": False, + "attention_dropout": 0.1, + }, + }, + }, + "mog_head_config": { + "intermediate_size": 4608, + "num_layers": 3, + "low_rank": 64, + "num_predictions": 1024, + "min_log_std": -4.0, + "eps": 1e-6, + }, + "p_uncond": 0.1, + "label_smoothing": 0.01, + "max_training_rate": 0.8, + "quantizer_dropout": 0.5, + "random_target_masking": False, + "exponent": 3.0, + }, + }, + "data": { + "add_text_bos_and_eos_in_each_turn": True, + "add_audio_prompt": True, + "audio_prompt_duration": 3.0, + "frame_length": 0.08, + "source_sample_rate": 16000, + "target_sample_rate": 22050, + }, + "exp_manager": {"explicit_log_dir": "/tmp/test_duplex_stt_logs"}, + }, + }, + "data": { + "frame_length": 0.08, + "source_sample_rate": 16000, + "target_sample_rate": 22050, + "input_roles": ["user", "User"], + "output_roles": ["agent", "Assistant", "assistant", "Agent"], + }, + "exp_manager": {"explicit_log_dir": "/tmp/test_parity_logs"}, + } + + +@pytest.fixture(scope="session") +def tiny_model_artifacts(tmp_path_factory): + """Build a tiny NemotronVoiceChat with random weights, write test audio files. + + Session-scoped so the model is built only once across all test modules. + The fixture returns only file paths (immutable), so sharing is safe. + + Returns ``(model_dir, audio_path, speaker_ref_path)``. + """ + base = tmp_path_factory.mktemp("tiny_model") + + audio_path = str(base / "test_audio.wav") + sf.write(audio_path, np.random.RandomState(42).randn(3 * 16000).astype(np.float32), 16000) + + speaker_ref_path = str(base / "speaker_ref.wav") + sf.write(speaker_ref_path, np.random.RandomState(99).randn(22050).astype(np.float32), 22050) + + cfg = _tiny_voicechat_config(streaming_encoder=True) + model = NemotronVoiceChat(cfg) + model.to("cuda") + model.eval() + + model_dir = str(base / "model") + model.save_pretrained(model_dir) + + # save_pretrained writes the tokenizer to llm_artifacts/, but config.json + # still references the HF hub name (e.g. "TinyLlama/TinyLlama_v1.1"). + # Save the LLM model config alongside the tokenizer so llm_artifacts/ + # is a complete local model reference, then rewrite config.json to point + # at it. This avoids HuggingFace network requests on every from_pretrained. + llm_artifacts = os.path.join(model_dir, "llm_artifacts") + model.stt_model.llm.config.save_pretrained(llm_artifacts) + cfg["model"]["stt"]["model"]["pretrained_llm"] = llm_artifacts + cfg["model"]["speech_generation"]["model"]["pretrained_lm_name"] = llm_artifacts + with open(os.path.join(model_dir, "config.json"), "w") as f: + json.dump(cfg, f) + + del model + torch.cuda.empty_cache() + + return model_dir, audio_path, speaker_ref_path + + +@pytest.fixture(scope="session") +def tiny_voicechat_model(): + """Build a tiny NemotronVoiceChat model (predict_user_text=False). + + Used by ``test_nemotron_voicechat.py`` for validation and offline + generation tests. + """ + cfg = _tiny_voicechat_config(predict_user_text=False) + model = NemotronVoiceChat(cfg) + if torch.cuda.is_available(): + model.to("cuda") + return model diff --git a/tests/collections/speechlm2/test_nemotron_voicechat.py b/tests/collections/speechlm2/test_nemotron_voicechat.py index 4f2b431e4f25..5afd176f08fc 100644 --- a/tests/collections/speechlm2/test_nemotron_voicechat.py +++ b/tests/collections/speechlm2/test_nemotron_voicechat.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os - import pytest import torch from lhotse import CutSet, SupervisionSegment @@ -21,218 +19,11 @@ from nemo.collections.common.data.utils import move_data_to_device from nemo.collections.speechlm2 import DuplexSTTDataset -from nemo.collections.speechlm2.models import NemotronVoiceChat - -if torch.cuda.is_available(): - torch.set_default_device('cuda') - - -pretrained_llm = "TinyLlama/TinyLlama_v1.1" -if os.path.exists("/home/TestData/speechlm/pretrained_models"): - pretrained_llm = "/home/TestData/speechlm/pretrained_models/TinyLlama--TinyLlama_v1.1" - -# STT sampling rate -source_sample_rate = 16000 -# TTS sampling rate -target_sample_rate = 22050 - - -def create_model( - predict_user_text=False, - force_use_noise_augmentation=False, - old_noise_prob=0.0, - old_noise_min_snr=0.0, - old_noise_max_snr=0.0, -): - """Helper function to create a model with configurable settings.""" - test_stt_cfg = { - "model": { - "pretrained_llm": pretrained_llm, - "pretrained_weights": False, - "audio_loss_weight": 1, - "text_loss_weight": 3, - "source_sample_rate": source_sample_rate, - "validation_save_path": "/tmp/test_duplex_stt_logs", - "perception": { - "_target_": "nemo.collections.speechlm2.modules.perception.AudioPerceptionModule", - "preprocessor": { - "_target_": "nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor", - "features": 80, - }, - "encoder": { - "_target_": "nemo.collections.asr.modules.ConformerEncoder", - "feat_in": 80, - "d_model": 512, - "n_heads": 8, - "n_layers": 1, - "subsampling_factor": 8, - }, - "modality_adapter": { - "_target_": "nemo.collections.speechlm2.modules.perception.IdentityConnector", - "d_model": 512, - }, - "output_dim": 2048, - }, - "predict_user_text": predict_user_text, - "force_use_noise_augmentation": force_use_noise_augmentation, - "old_noise_prob": old_noise_prob, - "old_noise_min_snr": old_noise_min_snr, - "old_noise_max_snr": old_noise_max_snr, - "optimizer": {"_target_": "torch.optim.AdamW"}, - }, - "data": { - "source_sample_rate": 16000, - }, - "exp_manager": { - "explicit_log_dir": "/tmp/test_duplex_stt_logs", - }, - } - - test_tts_config = { - "model": { - "pretrained_lm_name": pretrained_llm, - "pretrained_ae_dir": None, - "pretrained_tts_model": None, - "scoring_asr": "stt_en_fastconformer_transducer_large", - "freeze_params": [ - r"^audio_codec\..+$", # Keep audio codec frozen as it only provides supervision for training. - r"^embed_tokens\..+$", # Keep embed_tokens frozen as done in eartts - ], - "bos_token": "", - "eos_token": "", - "pad_token": "", - "audio_codec_run_dtype": "float32", - "prevent_freeze_params": [], - "audio_save_path": "", - "inference_guidance_scale": 0.5, - "inference_noise_scale": 0.8, - "inference_top_p_or_k": 0.8, - "inference_guidance_enabled": False, - "subword_mask_exactly_as_eartts": False, - "context_hidden_mask_exactly_as_eartts": False, - "optimizer": { - "_target_": "torch.optim.AdamW", - "lr": 4e-5, - "betas": [0.9, 0.98], - "weight_decay": 0, - "foreach": True, - }, - "lr_scheduler": { - "_target_": "nemo.core.optim.lr_scheduler.InverseSquareRootAnnealing", - "warmup_steps": 2500, - "min_lr": 1e-6, - "max_steps": 100_000_000, - }, - "codec_config": { - "latent_size": 512, - "n_fft": 16, - "hop_length": 4, - "base_hidden_size": 384, - "channel_mult": [1, 2, 4], - "rates": [7, 7, 9], - "num_blocks": 3, - "kernel_size": 7, - "groups": 1, - "codebook_size": 1024, - "num_quantizers": 31, - "wav_to_token_ratio": 1764, - }, - "tts_config": { - "use_gated_fusion_for_text_audio": True, - "disable_eos_prediction": True, - "use_bos_eos_emb": True, - "use_subword_flag_emb": True, - "num_delay_speech_tokens": 2, - "backbone_type": "gemma3_text", - "backbone_model_class": None, - "backbone_config_class": None, - "backbone_config": { - "hidden_size": 1152, - "intermediate_size": 4608, - "num_hidden_layers": 1, - "num_attention_heads": 16, - "num_key_value_heads": 16, - "head_dim": 72, - "attention_dropout": 0.1, - "use_cache": False, - }, - "latent_size": 512, - "codebook_size": 1024, - "num_quantizers": 31, - "context_hidden_size": None, - "cas_config": { - "backbone_type": "t5gemma", - "backbone_model_class": None, - "backbone_config_class": None, - "backbone_config": { - "is_encoder_decoder": False, - "encoder": { - "hidden_size": 1152, - "intermediate_size": 4608, - "num_hidden_layers": 1, - "num_attention_heads": 16, - "num_key_value_heads": 16, - "head_dim": 72, - "use_cache": False, - "attention_dropout": 0.1, - }, - }, - }, - "mog_head_config": { - "intermediate_size": 4608, - "num_layers": 3, - "low_rank": 64, - "num_predictions": 1024, - "min_log_std": -4.0, - "eps": 1e-6, - }, - "p_uncond": 0.1, - "label_smoothing": 0.01, - "max_training_rate": 0.8, - "quantizer_dropout": 0.5, - "random_target_masking": False, - "exponent": 3.0, - }, - }, - "data": { - "add_text_bos_and_eos_in_each_turn": True, - "add_audio_prompt": True, - "audio_prompt_duration": 3.0, - "frame_length": 0.08, - "source_sample_rate": source_sample_rate, - "target_sample_rate": target_sample_rate, - }, - "exp_manager": { - "explicit_log_dir": "/tmp/test_duplex_stt_logs", - }, - } - - test_config = { - "model": { - "scoring_asr": "stt_en_fastconformer_transducer_large", - "stt": test_stt_cfg, - "speech_generation": test_tts_config, - }, - "data": { - "frame_length": 0.08, - "source_sample_rate": source_sample_rate, - "target_sample_rate": target_sample_rate, - "input_roles": ["user", "User"], - "output_roles": ["agent", "Assistant", "assistant", "Agent"], - }, - "exp_manager": { - "explicit_log_dir": "/tmp/test_nemotron_voicechat_logs", - }, - } - model = NemotronVoiceChat(test_config) - if torch.cuda.is_available(): - model.to("cuda") - return model @pytest.fixture(scope="session") -def model(): - return create_model(predict_user_text=False) +def model(tiny_voicechat_model): + return tiny_voicechat_model @pytest.fixture(scope="session") @@ -240,7 +31,7 @@ def dataset(model): return DuplexSTTDataset( model.stt_model.tokenizer, frame_length=0.08, - source_sample_rate=source_sample_rate, + source_sample_rate=16000, input_roles=["user"], output_roles=["assistant"], ) diff --git a/tests/collections/speechlm2/test_nemotron_voicechat_pipeline_nocrash.py b/tests/collections/speechlm2/test_nemotron_voicechat_pipeline_nocrash.py new file mode 100644 index 000000000000..8d9a7d1db068 --- /dev/null +++ b/tests/collections/speechlm2/test_nemotron_voicechat_pipeline_nocrash.py @@ -0,0 +1,185 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""No-crash pipeline tests for NemotronVoiceChat streaming inference. + +Exercises ``StreamingS2SPipeline.run()`` with a tiny random-weight model +under various config combinations. Each test verifies only that the +pipeline completes without raising — no output quality checks. + +Run from the NeMo repo root:: + + CUDA_VISIBLE_DEVICES=0 pytest tests/collections/speechlm2/test_nemotron_voicechat_pipeline_nocrash.py -v -s +""" + +from __future__ import annotations + +import os +import tempfile +from typing import Any + +import pytest +import torch +from omegaconf import OmegaConf + +from nemo.collections.speechlm2.inference.factory.s2s_pipeline_builder import S2SPipelineBuilder +from nemo.collections.speechlm2.inference.pipelines.streaming_s2s_pipeline import StreamingS2SPipeline +from nemo.collections.speechlm2.inference.utils.stepprogressbar import StepProgressBar + +_CONF_YAML = os.path.join( + os.path.dirname(__file__), + "../../../examples/speechlm2/nemo_inference_pipelines/conf/s2s_streaming.yaml", +) +_MOCK_SYSTEM_PROMPT = "This is a mock prompt for the test" + +# Safe defaults so tests run quickly with the tiny model. +# Individual tests override specific keys via OmegaConf.merge. +_TEST_DEFAULTS = { + "s2s": { + "engine_type": "native", + "compute_dtype": "float32", + "deterministic": False, + "decode_audio": False, + "use_perception_cache": False, + "use_perception_cudagraph": False, + "use_llm_cache": False, + "system_prompt": None, + "top_p": 1.0, + "repetition_penalty": 1.0, + "temperature": 1.0, + }, + "streaming": { + "chunk_size_in_secs": 0.08, + "buffer_size_in_secs": 71 * 0.08, + }, +} + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _build_no_crash_pipeline( + model_path: str, + audio_path: str, + output_dir: str, + *overrides: dict[str, Any], +) -> StreamingS2SPipeline: + """Build a :class:`StreamingS2SPipeline` for no-crash testing. + + Loads the YAML base config, applies ``_TEST_DEFAULTS``, then merges + each dict in *overrides* on top (in order). Nested dicts are + recursively merged by OmegaConf, so ``{"s2s": {"top_p": 0.9}}`` + overrides only that key while keeping all other s2s defaults. + """ + cfg = OmegaConf.load(_CONF_YAML) + cfg = OmegaConf.merge( + cfg, + _TEST_DEFAULTS, + {"audio_file": audio_path, "output_dir": output_dir, "s2s": {"model_path": model_path}}, + ) + for overrides in overrides: + if overrides: + cfg = OmegaConf.merge(cfg, overrides) + return S2SPipelineBuilder.build_pipeline(cfg) + + +# --------------------------------------------------------------------------- +# Parametrized configs — each entry is a single overrides dict +# --------------------------------------------------------------------------- + +# Text-only configs (decode_audio=False): minimal STT-path smoke checks. +_TEXT_CONFIGS = [ + pytest.param({}, id="baseline"), + pytest.param( + {"s2s": {"use_llm_cache": True, "use_perception_cache": True}}, + id="both_caches", + ), + pytest.param({"pad_audio_by_sec": 2}, id="pad_by_sec"), +] + +# Audio configs (decode_audio=True): exercises the full STT + TTS pipeline. +_AUDIO_CONFIGS = [ + pytest.param({}, id="baseline"), + pytest.param( + {"s2s": {"use_llm_cache": True, "use_perception_cache": True, "system_prompt": _MOCK_SYSTEM_PROMPT}, + "streaming": {"chunk_size_in_secs": 0.24}, + "pad_audio_to_sec": 5}, + id="both_caches_prompt_multiframe_pad_to_sec", + ), + pytest.param( + {"s2s": {"use_llm_cache": True, "top_p": 0.9, "temperature": 0.7, "repetition_penalty": 1.1}, + "pad_silence_ratio": 0.5}, + id="sampling_pad_silence_ratio", + ), + pytest.param( + {"s2s": {"use_tts_subword_cache": True, "use_tts_torch_compile": True}, + "pad_audio_by_sec": 2}, + id="tts_optimizations_pad_by_sec", + ), + pytest.param( + {"s2s": {"deterministic": True, "temperature": 0.0}}, + id="deterministic", + ), + pytest.param( + {"s2s": {"profile_timing": True}}, + id="profile_timing", + ), +] + +# --------------------------------------------------------------------------- +# Tests (tiny_model_artifacts fixture is provided by conftest.py) +# --------------------------------------------------------------------------- + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU") +@pytest.mark.parametrize("overrides", _TEXT_CONFIGS) +def test_pipeline_no_crash_tiny_model(tiny_model_artifacts, overrides): + """Run the streaming pipeline with various configs and verify it doesn't crash.""" + model_dir, audio_path, _ = tiny_model_artifacts + output_dir = tempfile.mkdtemp(prefix="no-crash-text-") + + pipeline = _build_no_crash_pipeline(model_dir, audio_path, output_dir, overrides) + progress_bar = StepProgressBar.from_audio_filepaths( + [audio_path], + chunk_size_in_secs=pipeline.chunk_size_in_secs, + pad_audio_to_sec=pipeline.pad_audio_to_sec, + pad_silence_ratio=pipeline.pad_silence_ratio, + pad_audio_by_sec=pipeline.pad_audio_by_sec, + ) + result = pipeline.run([audio_path], progress_bar=progress_bar) + assert result is not None + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU") +@pytest.mark.parametrize("overrides", _AUDIO_CONFIGS) +def test_pipeline_no_crash_tiny_model_decode_audio(tiny_model_artifacts, overrides): + """Run the streaming pipeline with decode_audio=True and verify it doesn't crash.""" + model_dir, audio_path, speaker_ref_path = tiny_model_artifacts + output_dir = tempfile.mkdtemp(prefix="no-crash-audio-") + + pipeline = _build_no_crash_pipeline( + model_dir, audio_path, output_dir, + {"s2s": {"decode_audio": True, "speaker_reference": speaker_ref_path}}, + overrides, + ) + progress_bar = StepProgressBar.from_audio_filepaths( + [audio_path], + chunk_size_in_secs=pipeline.chunk_size_in_secs, + pad_audio_to_sec=pipeline.pad_audio_to_sec, + pad_silence_ratio=pipeline.pad_silence_ratio, + pad_audio_by_sec=pipeline.pad_audio_by_sec, + ) + result = pipeline.run([audio_path], progress_bar=progress_bar) + assert result is not None diff --git a/tests/collections/speechlm2/test_nemotron_voicechat_pipeline_parity.py b/tests/collections/speechlm2/test_nemotron_voicechat_pipeline_parity.py new file mode 100644 index 000000000000..217af74f5a01 --- /dev/null +++ b/tests/collections/speechlm2/test_nemotron_voicechat_pipeline_parity.py @@ -0,0 +1,343 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Offline vs. incremental inference parity tests for NemotronVoiceChat. + +``test_parity_tiny_model`` checks offline-vs-incremental parity using a +tiny model with random weights (no checkpoint needed, requires only a GPU). + +``test_parity_real_checkpoint`` does the same on a real exported checkpoint +and is skipped unless ``PARITY_CHECKPOINT_PATH`` is set:: + + PARITY_CHECKPOINT_PATH=/path/to/exported/checkpoint + PARITY_AUDIO_PATH=/path/to/test.wav # optional, defaults to force_align_test.mp3 + PARITY_SPEAKER_NAME= # optional + +Run from the NeMo repo root (use ``-s`` to see live progress):: + + # unit tests only + CUDA_VISIBLE_DEVICES=0 pytest tests/collections/speechlm2/test_nemotron_voicechat_pipeline_parity.py -v -s + + # include integration test + PARITY_CHECKPOINT_PATH=... \\ + CUDA_VISIBLE_DEVICES=0 pytest tests/collections/speechlm2/test_nemotron_voicechat_pipeline_parity.py -v -s +""" + +from __future__ import annotations + +import math +import os +import tempfile +import time +from typing import Any + +import pytest +import torch +from omegaconf import OmegaConf + +from nemo.utils import logging + +from nemo.collections.speechlm2.inference.factory.s2s_pipeline_builder import S2SPipelineBuilder +from nemo.collections.speechlm2.inference.model_wrappers.nemotron_voicechat_inference_wrapper import ( + FRAME_SIZE_SAMPLES, + SAMPLE_RATE, +) +from nemo.collections.speechlm2.inference.pipelines.streaming_s2s_pipeline import StreamingS2SPipeline +from nemo.collections.speechlm2.inference.streaming.framing.s2s_request_options import S2SRequestOptions + +_CONF_YAML = os.path.join( + os.path.dirname(__file__), + "../../../examples/speechlm2/nemo_inference_pipelines/conf/s2s_streaming.yaml", +) +_FORCE_ALIGN_AUDIO = os.path.join( + os.path.dirname(__file__), + "test_data", + "force_align_test.mp3", +) +_MOCK_SYSTEM_PROMPT = "This is a mock prompt for the test" + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _compare_tensors( + a: torch.Tensor | None, + b: torch.Tensor | None, +) -> dict[str, Any]: + """Prefix-aware comparison of two tensors (tokens or logits, any shape).""" + if a is None or b is None: + return {"match": None, "note": "one or both tensors missing"} + a, b = a.detach().cpu().float(), b.detach().cpu().float() + T = min(a.shape[1], b.shape[1]) + if T == 0: + return {"prefix_len": 0, "match": True} + ap, bp = a[:, :T], b[:, :T] + diff = (ap - bp).abs() + match = bool(diff.max() == 0) + result: dict[str, Any] = {"prefix_len": T, "match": match, "max_abs_diff": float(diff.max())} + if not match: + reduce = tuple(i for i in range(diff.dim()) if i != 1) + per_step = diff.amax(dim=reduce) if reduce else diff.squeeze() + nonzero = (per_step > 0).nonzero(as_tuple=False) + if nonzero.numel(): + result["first_diff_step"] = int(nonzero[0].item()) + return result + + +def _merge_incremental_debug_steps(steps: list[dict[str, Any]]) -> dict[str, Any]: + """Merge per-step debug dicts from the pipeline into a single dict.""" + if not steps: + return {} + all_text_logits = [s["text_logits"] for s in steps if s.get("text_logits") is not None] + all_asr_logits = [s["asr_logits"] for s in steps if s.get("asr_logits") is not None] + return { + "text_logits": torch.cat(all_text_logits, dim=1) if all_text_logits else None, + "asr_logits": torch.cat(all_asr_logits, dim=1) if all_asr_logits else None, + } + + +def _load_and_pad_audio( + audio_path: str, device: torch.device, dtype: torch.dtype, +) -> tuple[torch.Tensor, torch.Tensor]: + """Load audio, zero-pad to a whole number of 80 ms frames, return ``(audio, lens)``.""" + import librosa + + audio_np, _ = librosa.load(audio_path, sr=SAMPLE_RATE) + padded_len = math.ceil(len(audio_np) / FRAME_SIZE_SAMPLES) * FRAME_SIZE_SAMPLES + audio = torch.nn.functional.pad( + torch.tensor(audio_np, device=device, dtype=dtype).unsqueeze(0), + (0, max(0, padded_len - len(audio_np))), + ) + return audio, torch.tensor([audio.shape[1]], device=device, dtype=torch.long) + + +def run_parity_check( + pipeline: StreamingS2SPipeline, + audio_path: str, + *, + system_prompt: str | None = None, +) -> dict[str, Any]: + """Run offline and incremental inference on the same audio, return comparison. + + Only STT-level tokens and logits are compared; TTS is irrelevant for + the core parity invariant. + """ + wrapper = pipeline.s2s_model + audio, audio_lens = _load_and_pad_audio(audio_path, wrapper.device, wrapper.dtype) + + prompt_tokens = prompt_token_lens = None + if system_prompt: + tok = wrapper.tokenizer + ids = [tok.bos_id] + tok.text_to_ids(system_prompt) + [tok.eos_id] + prompt_tokens = torch.tensor(ids, device=wrapper.device, dtype=torch.long).unsqueeze(0) + prompt_token_lens = torch.tensor([len(ids)], device=wrapper.device, dtype=torch.long) + + if wrapper.speaker_name is not None: + OmegaConf.update(wrapper.model.cfg, "inference_speaker_name", wrapper.speaker_name, force_add=True) + elif wrapper.speaker_reference: + OmegaConf.update(wrapper.model.cfg, "inference_speaker_reference", wrapper.speaker_reference, force_add=True) + speaker_kw: dict[str, Any] = {} + if not wrapper.model.cfg.get("inference_speaker_name") and not wrapper.model.cfg.get("inference_speaker_reference"): + speaker_kw["speaker_audio"] = torch.randn(1, 22050, device=wrapper.device) + speaker_kw["speaker_audio_lens"] = torch.tensor([22050], device=wrapper.device, dtype=torch.long) + + # -- Offline -- + logging.info("Running offline_inference ...") + t0 = time.time() + offline = wrapper.model.offline_inference( + input_signal=audio, + input_signal_lens=audio_lens, + prompt_tokens=prompt_tokens, + prompt_token_lens=prompt_token_lens, + decode_audio=False, + return_logits=True, + **speaker_kw, + ) + logging.info(f" offline done in {time.time() - t0:.2f}s") + + # -- Incremental -- + logging.info("Running incremental inference (pipeline.run) ...") + t0 = time.time() + pipeline.collect_debug = True + pipeline_output = pipeline.run( + [audio_path], + options=[S2SRequestOptions(system_prompt=system_prompt)], + ) + logging.info(f" incremental done in {time.time() - t0:.2f}s") + + inc_tokens = pipeline_output.token_texts[0] if pipeline_output.token_texts else None + inc_asr_tokens = pipeline_output.token_asr_texts[0] if pipeline_output.token_asr_texts else None + inc_debug = _merge_incremental_debug_steps( + pipeline_output.debug_data[0] if pipeline_output.debug_data and pipeline_output.debug_data[0] else [] + ) + + # -- Compare -- + # offline_inference returns logits for ALL positions (including prompt), + # while the incremental path only produces logits for audio positions. + # Trim the prompt prefix from offline logits so the two are aligned. + prompt_len = prompt_tokens.shape[1] if prompt_tokens is not None else 0 + + report: dict[str, Any] = { + "token_comparison": _compare_tensors(offline.get("tokens_text"), inc_tokens), + "asr_token_comparison": _compare_tensors(offline.get("tokens_text_src"), inc_asr_tokens), + } + for key, off_key, inc_key in [ + ("text_logit_comparison", "text_logits", "text_logits"), + ("asr_logit_comparison", "asr_logits", "asr_logits"), + ]: + off_t, inc_t = offline.get(off_key), inc_debug.get(inc_key) + if off_t is not None and prompt_len > 0: + off_t = off_t[:, prompt_len:] + if off_t is not None and inc_t is not None: + report[key] = _compare_tensors(off_t, inc_t) + + return report + + +def assert_parity( + report: dict[str, Any], + *, + strict: bool = True, + atol: float = 0.0, +) -> None: + """Raise ``AssertionError`` if parity checks in *report* fail.""" + failures: list[str] = [] + for key in ("token_comparison", "asr_token_comparison"): + c = report.get(key, {}) + if c.get("match") is False: + failures.append(f"{key}: diverge at step {c.get('first_diff_step')}") + if strict: + for key in ("text_logit_comparison", "asr_logit_comparison"): + c = report.get(key, {}) + if c.get("match") is False and c.get("max_abs_diff", 0) > atol: + failures.append(f"{key}: max_abs_diff={c['max_abs_diff']:.2e} > atol={atol:.2e}") + assert not failures, "Parity failed:\n " + "\n ".join(failures) + + +# Parity requires deterministic, float32, no caches, greedy decoding. +_PARITY_DEFAULTS = { + "s2s": { + "engine_type": "native", + "compute_dtype": "float32", + "deterministic": True, + "decode_audio": False, + "use_perception_cache": False, + "use_perception_cudagraph": False, + "use_llm_cache": False, + "top_p": 1.0, + "repetition_penalty": 1.0, + "temperature": 1.0, + }, +} + + +def _build_parity_pipeline( + model_path: str, + audio_path: str, + output_dir: str, + *overrides: dict[str, Any], +) -> StreamingS2SPipeline: + """Build a :class:`StreamingS2SPipeline` configured for strict parity testing. + + Loads ``s2s_streaming.yaml`` as the base config, applies + ``_PARITY_DEFAULTS``, then merges each dict in *overrides* + on top (in order). + + The chunk size is set to cover the full audio in one step so that + offline and incremental paths see identical input. + """ + import librosa + + audio_np, _ = librosa.load(audio_path, sr=SAMPLE_RATE) + total_frames = math.ceil(len(audio_np) / FRAME_SIZE_SAMPLES) + chunk_secs = total_frames * FRAME_SIZE_SAMPLES / SAMPLE_RATE + + cfg = OmegaConf.load(_CONF_YAML) + cfg = OmegaConf.merge( + cfg, + _PARITY_DEFAULTS, + { + "audio_file": audio_path, + "output_dir": output_dir, + "s2s": {"model_path": model_path}, + "streaming": { + "chunk_size_in_secs": chunk_secs, + "buffer_size_in_secs": max(71 * 0.08, chunk_secs), + }, + }, + ) + for overrides in overrides: + if overrides: + cfg = OmegaConf.merge(cfg, overrides) + return S2SPipelineBuilder.build_pipeline(cfg) + + +# --------------------------------------------------------------------------- +# Parity test -- tiny model (uses tiny_model_artifacts from conftest.py) +# --------------------------------------------------------------------------- + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU") +def test_parity_tiny_model(tiny_model_artifacts): + """Offline/incremental parity with a tiny random-weight model. + + Loads the model through the real ``S2SPipelineBuilder`` so the test + exercises the same code path as ``test_parity_real_checkpoint``. + """ + model_dir, audio_path, _ = tiny_model_artifacts + output_dir = tempfile.mkdtemp(prefix="parity-tiny-") + + pipeline = _build_parity_pipeline( + model_dir, audio_path, output_dir, + {"s2s": {"system_prompt": _MOCK_SYSTEM_PROMPT}}, + ) + report = run_parity_check(pipeline, audio_path, system_prompt=_MOCK_SYSTEM_PROMPT) + assert_parity(report, strict=True, atol=0.0) + + +# --------------------------------------------------------------------------- +# Integration test -- real checkpoint (skipped when env vars are not set) +# --------------------------------------------------------------------------- + + +def _real_checkpoint_available() -> bool: + path = os.environ.get("PARITY_CHECKPOINT_PATH", "") + return bool(path) and os.path.isdir(path) + + +@pytest.mark.skipif(not _real_checkpoint_available(), reason="set PARITY_CHECKPOINT_PATH") +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU") +def test_parity_real_checkpoint(): + """Parity check using a real exported checkpoint. + + Configure via environment variables:: + + PARITY_CHECKPOINT_PATH=/path/to/exported/checkpoint + PARITY_AUDIO_PATH=/path/to/test.wav # optional, defaults to force_align_test.mp3 + PARITY_SPEAKER_NAME= # optional + """ + ckpt = os.environ["PARITY_CHECKPOINT_PATH"] + audio = os.environ.get("PARITY_AUDIO_PATH") or _FORCE_ALIGN_AUDIO + speaker = os.environ.get("PARITY_SPEAKER_NAME") + + parity_overrides: dict[str, Any] = {"s2s": {"system_prompt": _MOCK_SYSTEM_PROMPT}} + if speaker: + parity_overrides["s2s"]["speaker_name"] = speaker + pipeline = _build_parity_pipeline( + ckpt, audio, tempfile.mkdtemp(prefix="parity-"), + parity_overrides, + ) + report = run_parity_check(pipeline, audio, system_prompt=_MOCK_SYSTEM_PROMPT) + assert_parity(report, strict=True, atol=0.0)