diff --git a/examples/speechlm2/conf/duplex_eartts.yaml b/examples/speechlm2/conf/duplex_eartts.yaml index a2765fa36364..2aa377461f2e 100644 --- a/examples/speechlm2/conf/duplex_eartts.yaml +++ b/examples/speechlm2/conf/duplex_eartts.yaml @@ -127,7 +127,7 @@ trainer: strategy: _target_: lightning.pytorch.strategies.DDPStrategy gradient_as_bucket_view: true - find_unused_parameters: true + find_unused_parameters: false data: # data loader configs diff --git a/examples/speechlm2/duplex_eartts_eval.py b/examples/speechlm2/duplex_eartts_eval.py index 582a9efd1024..eb6245b7a88b 100644 --- a/examples/speechlm2/duplex_eartts_eval.py +++ b/examples/speechlm2/duplex_eartts_eval.py @@ -83,24 +83,77 @@ out_dir (str): Directory where generated audio samples will be saved. + inference_dtype (str, optional): + Target dtype used during inference. This controls the precision + of model weights and operations. + + Supported values: + - "float32" (default) + - "float16" + - "bfloat16" + + Notes: + - If set to a lower precision (e.g., float16), the model weights + and/or execution dtype will be adjusted accordingly. + - Internally mapped via `getattr(torch, inference_dtype)`. + + keep_codec_original_dtype (bool, optional): + Controls whether the audio codec module keeps its original dtype + when `inference_dtype` is not float32. + + If True (default): + - Only the TTS backbone (`model.tts_model`) is cast to the target dtype. + - The codec remains in its original precision (typically float32). + - Useful to isolate precision effects and avoid degradation from + codec quantization. + + If False: + - The entire model (including codec) is cast to `inference_dtype`. + - `model.audio_codec_run_dtype` is also set accordingly. + + debug_dtype (bool, optional): + Enables runtime inspection of tensor dtypes flowing through the model. + + If True: + - Forward hooks are attached to all leaf modules. + - During the first batch, dtype usage statistics are collected + and logged. + - Outputs include: + - Per-module-group dtype distribution + - Example module names per dtype + Usage: - python duplex_eartts_eval.py \ - --config-path=conf/ \ - --config-name=duplex_eartts.yaml \ - ++checkpoint_path=duplex_eartts_results/duplex_eartts/model.ckpt \ - ++datasets_json_path=/path/to/evalset_config.jsonl \ - ++out_dir=duplex_eartts_results/duplex_eartts/audio_samples/dummy_dataset + # Example with fp32 inference + python duplex_eartts_eval.py \ + --config-path=conf/ \ + --config-name=duplex_eartts.yaml \ + ++checkpoint_path=duplex_eartts_results/duplex_eartts/model.ckpt \ + ++datasets_json_path=/path/to/evalset_config.jsonl \ + ++out_dir=duplex_eartts_results/duplex_eartts/audio_samples/dummy_dataset + + # Example with fp16 inference and dtype debugging + python duplex_eartts_eval.py \ + --config-path=conf/ \ + --config-name=duplex_eartts.yaml \ + ++checkpoint_path=duplex_eartts_results/duplex_eartts/model.ckpt \ + ++datasets_json_path=/path/to/evalset_config.jsonl \ + ++out_dir=uplex_eartts_results/duplex_eartts/audio_samples/dummy_dataset \ + ++inference_dtype=float16 \ + ++keep_codec_original_dtype=True \ + ++debug_dtype=True """ import json import os +from functools import partial import librosa import soundfile as sf import torch from torch.nn.utils.rnn import pad_sequence +from torch.utils.data import DataLoader, Dataset -from nemo.collections.audio.parts.utils.resampling import resample +from nemo.collections.audio.parts.utils.transforms import resample torch.set_float32_matmul_precision("medium") torch.backends.cudnn.allow_tf32 = True @@ -111,54 +164,190 @@ from nemo.collections.speechlm2.models.duplex_ear_tts import DuplexEARTTS from nemo.collections.speechlm2.parts.metrics.asr_cer_wer import Intelligibility +from nemo.collections.speechlm2.parts.metrics.secs import SECS +from nemo.collections.speechlm2.parts.precision import fp32_precision from nemo.core.config import hydra_runner +from nemo.utils import logging -torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) +# Use .get() to avoid crashing when running a single GPU without torchrun +if torch.cuda.is_available(): + torch.cuda.set_device(int(os.environ.get("LOCAL_RANK", 0))) -def read_jsonl_batches( - file_path, - batch_size, - drop_last=False, - max_batches=None, # <-- DEBUG OPTION -): +def attach_dtype_counter(model): + """ + Attaches forward hooks to all leaf modules of a model to track the dtype + of their outputs during inference. + + This utility is designed for debugging precision behavior, especially when + using mixed precision or reduced precision (fp16 / bf16). + + Behavior: + - Registers a forward hook on each leaf module (modules with no children). + - For each forward pass, records the dtype of the module output. + - Aggregates statistics grouped by top-level module name. + - Stores a few example module class names per dtype. + + Returns: + handles (List[RemovableHandle]): + List of hook handles. These must be removed manually to avoid + memory leaks or performance degradation. + + stats (Dict[str, Dict[str, int]]): + Nested dictionary containing dtype counts per module group. + Structure: + stats[module_group][dtype] = count + + Example: + { + "tts_model": { + "torch.float16": 120, + "torch.float32": 0, + "torch.bfloat16": 0, + "other": 2 + } + } + + examples (Dict[str, Dict[str, List[str]]]): + Stores up to 3 example module class names per dtype per group. + Useful for quickly identifying which layers are running in + unexpected precision. + + Notes: + - Only inspects outputs (not inputs or parameters). + - Dtype is inferred from the first tensor found in the output. + - Non-floating dtypes are categorized as "other". + - Grouping is based on the top-level module name (prefix before first dot). + + Typical usage: + handles, stats, examples = attach_dtype_counter(model) + + # Run inference ... + + for h in handles: + h.remove() """ - Reads a JSONL file and yields batches of size batch_size. + handles = [] + + # structure: stats[module_group][dtype] = count + stats = {} + examples = {} + + def is_leaf(module): + return len(list(module.children())) == 0 + + def get_dtype(x): + if torch.is_tensor(x): + return str(x.dtype) + elif isinstance(x, (list, tuple)): + for t in x: + if torch.is_tensor(t): + return str(t.dtype) + return "other" + + def get_module_group(name): + # top-level module (before first dot) + return name.split(".")[0] if "." in name else name + + def hook_fn(name): + def fn(module, inputs, outputs): + dtype = get_dtype(outputs) + if dtype not in ["torch.float16", "torch.bfloat16", "torch.float32"]: + dtype = "other" + + group = get_module_group(name) + + if group not in stats: + stats[group] = { + "torch.float16": 0, + "torch.bfloat16": 0, + "torch.float32": 0, + "other": 0, + } + examples[group] = { + "torch.float16": [], + "torch.bfloat16": [], + "torch.float32": [], + "other": [], + } + + stats[group][dtype] += 1 + + # store a few examples per dtype per group + if len(examples[group][dtype]) < 3: + examples[group][dtype].append(module.__class__.__name__) + + return fn + + for name, module in model.named_modules(): + if is_leaf(module): + handles.append(module.register_forward_hook(hook_fn(name))) + + return handles, stats, examples + + +def report_dtype_stats(handles, stats, examples): + """ + Cleans up monitoring hooks and logs a detailed report of the tensor precisions + (dtypes) observed during the model forward pass. - Args: - file_path (str): Path to the JSONL file - batch_size (int): Number of samples per batch - drop_last (bool): If True, drop the last incomplete batch - max_batches (int or None): If set, only yield this many batches (debug mode) + This function should be called after at least one inference iteration has + completed while hooks are attached. It removes the hooks to prevent + performance overhead and prints a structured summary of which module groups + executed in which dtypes. - Yields: - List[dict]: A batch of samples + Args: + handles (List[torch.utils.hooks.RemovableHandle]): The list of hooks + returned by `attach_dtype_counter`. + stats (Dict): Nested dictionary containing dtype counts per module group. + examples (Dict): Dictionary containing example module names for each + observed dtype. """ - batch = [] - num_batches = 0 + for h in handles: + h.remove() + + logging.info("\n=== DTYPE USAGE PER MODULE ===") + + for group, group_stats in stats.items(): + total = sum(group_stats.values()) + if total == 0: + continue + + logging.info(f"\n--- {group} ---") + for dtype, count in group_stats.items(): + if count > 0: + logging.info(f"{dtype}: {count} ({100*count/total:.2f}%)") + + logging.info("\n=== EXAMPLES ===") + for group, group_examples in examples.items(): + logging.info(f"\n--- {group} ---") + for dtype, mods in group_examples.items(): + if mods: + logging.info(f"{dtype}: {mods}") - with open(file_path, "r", encoding="utf-8") as f: - for line_idx, line in enumerate(f, 1): - line = line.strip() - if not line: - continue - try: - sample = json.loads(line) - except json.JSONDecodeError as e: - raise ValueError(f"Invalid JSON on line {line_idx}: {e}") +class EvalJSONLDataset(Dataset): + """ + Standard PyTorch Dataset for reading JSONL evaluation files. + """ - batch.append(sample) + def __init__(self, file_path): + self.samples = [] + with open(file_path, "r", encoding="utf-8") as f: + for line_idx, line in enumerate(f, 1): + line = line.strip() + if not line: + continue + try: + self.samples.append(json.loads(line)) + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JSON on line {line_idx}: {e}") - if len(batch) == batch_size: - yield batch - batch = [] - num_batches += 1 - if max_batches is not None and num_batches >= max_batches: - return + def __len__(self): + return len(self.samples) - if batch and not drop_last: - yield batch + def __getitem__(self, idx): + return self.samples[idx] def collate_and_tokenize_custom( @@ -195,6 +384,7 @@ def collate_and_tokenize_custom( # Construct: text + 4x pads # We extend the list with the tokens and then the pad tokens pad_ids = [model.text_pad_id] * pad_len + if force_interruption: fname = s["audio_filepath"] no_ext = fname.split(".")[0] @@ -229,22 +419,17 @@ def collate_and_tokenize_custom( full_ids.extend(seg_ids) full_ids.extend(pad_ids) - # Convert to tensor - tokenized_list.append(torch.as_tensor(full_ids, dtype=torch.long, device=model.device)) + tokenized_list.append(torch.as_tensor(full_ids, dtype=torch.long)) else: # Standard String Handling tokenized_list.append( - torch.as_tensor( - [model.tokenizer.bos] + model.tokenizer.text_to_ids(text_data), - dtype=torch.long, - device=model.device, - ) + torch.as_tensor([model.tokenizer.bos] + model.tokenizer.text_to_ids(text_data), dtype=torch.long) ) if add_beginning_pad_tokens: pad_len = 25 - prefix = torch.full((pad_len,), model.text_pad_id, dtype=torch.long, device=model.device) + prefix = torch.full((pad_len,), model.text_pad_id, dtype=torch.long) for i in range(len(tokenized_list)): tokenized_list[i] = torch.cat([prefix, tokenized_list[i]]) @@ -257,7 +442,7 @@ def collate_and_tokenize_custom( target_num_frames = [] for i, s in enumerate(batch): - # 1. Load Context Audio (Conditioning) + # Load Context Audio audio_path = s["context_audio_filepath"] if root_path is not None: audio_path = os.path.join(root_path, audio_path) @@ -273,7 +458,7 @@ def collate_and_tokenize_custom( audio_list.append(wav) audio_lengths.append(len(wav)) - # 2. Handle Target Audio / Duration + # Handle Target Audio / Duration tdur_audio_path = s["audio_filepath"] if root_path is not None: tdur_audio_path = os.path.join(root_path, tdur_audio_path) @@ -310,7 +495,7 @@ def collate_and_tokenize_custom( for i, wav in enumerate(audio_list): padded_audio[i, : len(wav)] = wav - padded_audio = padded_audio.to(model.device) + # Keep on CPU audio_lengths = torch.tensor(audio_lengths, dtype=torch.long) # Expand text length to match expected output speech duration @@ -321,9 +506,7 @@ def collate_and_tokenize_custom( # (prevents truncation if calc was slightly off) target_len = max(target_len, L) - padded_input_ids = torch.full( - (B, target_len), fill_value=model.text_pad_id, dtype=input_ids.dtype, device=input_ids.device - ) + padded_input_ids = torch.full((B, target_len), fill_value=model.text_pad_id, dtype=input_ids.dtype) # Copy the actual tokens (which might already contain list-based padding) padded_input_ids[:, :L] = input_ids @@ -349,48 +532,87 @@ def inference(cfg): if distributed and not torch.distributed.is_initialized(): torch.distributed.init_process_group(backend="nccl") + # Dynamically determine the correct GPU for this process + if torch.cuda.is_available(): + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + target_device = torch.device(f"cuda:{local_rank}") + else: + target_device = torch.device("cpu") + torch.set_float32_matmul_precision("medium") torch.backends.cudnn.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True + target_dtype = getattr(torch, cfg.get("inference_dtype", "float32")) + if target_dtype != torch.float32: + torch.set_default_dtype(target_dtype) + if cfg.get("checkpoint_path", None): model = DuplexEARTTS.load_from_checkpoint( - cfg.checkpoint_path, - cfg=OmegaConf.to_container(cfg, resolve=True), + cfg.checkpoint_path, cfg=OmegaConf.to_container(cfg, resolve=True), map_location=target_device ).eval() else: raise ValueError("For evaluation, you must provide `cfg.checkpoint_path`.") - target_dtype = getattr(torch, cfg.get("inference_dtype", "float32")) - # Move and cast if target_dtype != torch.float32: - model.to(dtype=target_dtype) - - intelligibility = Intelligibility("stt_en_fastconformer_transducer_large", reuse_asr_hyps=False).reset() - - for batch_id, batch in enumerate(read_jsonl_batches(cfg.datasets_json_path, cfg.batch_size, max_batches=None)): - inputs = collate_and_tokenize_custom( - batch, - model, - extra_duration_thrshould=1.5, - sample_rate=model.target_sample_rate, - root_path=cfg.audio_dir, - add_beginning_pad_tokens=cfg.get("add_beginning_pad_tokens", True), - add_eos=cfg.get("add_eos", True), - pad_factor_text_speech=cfg.get("pad_factor_text_speech", 10), - force_interruption=cfg.get("force_interruption", False), - ) + if cfg.get("keep_codec_original_dtype", True): + model.tts_model.to(dtype=target_dtype) + model.ensures_codec_target_dtype() # ensures that codec is in the right precision + else: + model.audio_codec_run_dtype = target_dtype + model.to(dtype=target_dtype) + + if cfg.get("debug_dtype", False): + handles, stats, examples = attach_dtype_counter(model) + + with fp32_precision(): + intelligibility = Intelligibility("stt_en_fastconformer_transducer_large", reuse_asr_hyps=False).reset() + secs_metric = SECS("titanet_large").reset() + + # Initialize the Dataset + eval_dataset = EvalJSONLDataset(cfg.datasets_json_path) + + # Use partial to bind the model and config parameters to the collate function + collate_fn = partial( + collate_and_tokenize_custom, + model=model, + extra_duration_thrshould=1.5, + sample_rate=model.target_sample_rate, + root_path=cfg.audio_dir, + add_beginning_pad_tokens=cfg.get("add_beginning_pad_tokens", True), + add_eos=cfg.get("add_eos", True), + pad_factor_text_speech=cfg.get("pad_factor_text_speech", 10), + force_interruption=cfg.get("force_interruption", False), + ) + + # Initialize the DataLoader + dataloader = DataLoader( + dataset=eval_dataset, + batch_size=cfg.batch_size, + collate_fn=collate_fn, + num_workers=cfg.get("num_workers", 4), + pin_memory=True, + shuffle=False, + drop_last=False, + ) + + if cfg.get("user_custom_speaker_reference", None): + wav, sr = librosa.load(cfg.model.inference_speaker_reference, sr=model.target_sample_rate, mono=True) + speaker_wav = torch.as_tensor(wav, dtype=target_dtype).unsqueeze(0).to(model.device) + + # Iterate over the DataLoader + for batch_id, inputs in enumerate(dataloader): + + # Move required tensors to the GPU immediately + inputs["input_ids"] = inputs["input_ids"].to(model.device) + inputs["context_audio"] = inputs["context_audio"].to(model.device) + inputs["context_audio_lengths"] = inputs["context_audio_lengths"].to(model.device) + if cfg.get("user_custom_speaker_reference", None): - wav, sr = librosa.load(cfg.model.inference_speaker_reference, sr=model.target_sample_rate, mono=True) - wav = torch.as_tensor(wav, dtype=target_dtype).unsqueeze(0) - inputs["context_audio"] = wav.expand(inputs["input_ids"].size(0), *wav.shape[1:]) - inputs["context_audio_lengths"][:] = wav.size(-1) - inputs["context_audio"] = inputs["context_audio"].to(model.device) - inputs["context_audio_lengths"] = inputs["context_audio_lengths"].to(model.device).long() - - use_autocast = target_dtype != torch.float32 - autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=target_dtype) if use_autocast else nullcontext() - with torch.no_grad(), autocast_ctx: + inputs["context_audio"] = speaker_wav.expand(inputs["input_ids"].size(0), *speaker_wav.shape[1:]) + inputs["context_audio_lengths"][:] = speaker_wav.size(-1) + + with torch.no_grad(): model.set_init_inputs( speaker_audio=inputs["context_audio"], speaker_audio_lens=inputs["context_audio_lengths"], @@ -400,50 +622,70 @@ def inference(cfg): audio, audio_len = model.offline_inference( next_subword_ids=inputs["input_ids"], - formatter="custom", + task="custom", init_inputs=init_inputs, ) - audio = audio.float() - # reset audio len to the actual size removing extra long audio padding - audio_len = (torch.tensor(inputs["target_num_frames"]) * model.target_samples_per_frame).int() - - # resample audio to the asr sampling rate - metric_audio_pred = resample(audio, model.target_sample_rate, 16000) - metric_audio_pred_lens = (audio_len / model.target_sample_rate * 16000).to(torch.long) - - intelligibility.update( - name="dataset", - refs=inputs["raw_text"], - pred_audio=metric_audio_pred, - pred_audio_lens=metric_audio_pred_lens, - asr_hyps=None, - ) - - # save audio to cfg.out_dir - os.makedirs(cfg.out_dir, exist_ok=True) - - audio = audio.detach().cpu().float() - audio_len = audio_len.cpu() - - for i in range(audio.size(0)): - wav = audio[i, : audio_len[i]].numpy() - # Use original target audio filename - target_path = inputs["target_audio_paths"][i] - base_name = os.path.basename(target_path) - out_path = os.path.join(cfg.out_dir, base_name) - - sf.write( - out_path, - wav, - samplerate=model.target_sample_rate, + if cfg.get("debug_dtype", False) and batch_id == 0: + report_dtype_stats(handles, stats, examples) + + with fp32_precision(): + audio = audio.float() + + # reset audio len to the actual size removing extra long audio padding + audio_len = ( + torch.tensor(inputs["target_num_frames"], device=audio.device) * model.target_samples_per_frame + ).int() + + # resample audio to the asr sampling rate + metric_audio_pred = resample(audio, model.target_sample_rate, 16000) + metric_audio_pred_lens = (audio_len / model.target_sample_rate * 16000).to(torch.long) + + intelligibility.update( + name="dataset", + refs=inputs["raw_text"], + pred_audio=metric_audio_pred, + pred_audio_lens=metric_audio_pred_lens, + asr_hyps=None, ) - print(f"Saved: {out_path}") + secs_metric.update( + name="dataset", + target_audio=resample(inputs["context_audio"], model.target_sample_rate, 16000), + target_audio_lens=(inputs["context_audio_lengths"] / model.target_sample_rate * 16000).to(torch.long), + pred_audio=metric_audio_pred, + pred_audio_lens=metric_audio_pred_lens, + ) + + # save audio to cfg.out_dir + os.makedirs(cfg.out_dir, exist_ok=True) + audio = audio.detach().cpu().float() + audio_len = audio_len.cpu() + + for i in range(audio.size(0)): + wav = audio[i, : audio_len[i]].numpy() + # Use original target audio filename + target_path = inputs["target_audio_paths"][i] + base_name = os.path.basename(target_path) + out_path = os.path.join(cfg.out_dir, base_name) + + sf.write( + out_path, + wav, + samplerate=model.target_sample_rate, + ) + + logging.info(f"Saved: {out_path}") + + with fp32_precision(): + logging.info("\n--- Evaluation Metrics ---") + cer_wer = intelligibility.compute() + for k, m in cer_wer.items(): + logging.info(f"Intelligibility - {k}: {m}") - cer_wer = intelligibility.compute() - for k, m in cer_wer.items(): - print(k, m) + secs_scores = secs_metric.compute() + for k, m in secs_scores.items(): + logging.info(f"SECS - {k}: {m}") if __name__ == "__main__": diff --git a/nemo/collections/common/data/lhotse/cutset.py b/nemo/collections/common/data/lhotse/cutset.py index d4e7ec482921..acec0d9fbfee 100644 --- a/nemo/collections/common/data/lhotse/cutset.py +++ b/nemo/collections/common/data/lhotse/cutset.py @@ -18,6 +18,7 @@ import random import re import warnings +from copy import deepcopy from functools import partial from itertools import repeat from pathlib import Path @@ -925,7 +926,7 @@ def create_recording_from_array(samples: np.ndarray, sampling_rate: int, recordi def convert_cut_fn(cut: Cut) -> Cut: """Convert a single cut into the continuation format.""" - orig_agent_sup = fastcopy(cut.supervisions[0]) + orig_agent_sup = deepcopy(cut.supervisions[0]) target_audio_orig_dur = cut.target_audio.duration # Resample audios @@ -1014,6 +1015,104 @@ def filter_target_speaker_fn(cut: Cut) -> bool: return cuts, is_tarred +@data_type_parser(["s2s_duplex_reverse_role"]) +def read_s2s_duplex_reverse_role(config) -> Tuple[CutSet, bool]: + """ + Reverse the speaker roles and swap the source/target audio streams in a Duplex S2S CutSet. + + This parser takes an existing conversational dataset and inverts the perspective + by swapping the "user" and "agent" supervision labels. It also swaps the primary + `recording` (usually source audio) with the `target_audio` to fully simulate the + conversation from the opposite participant's point of view. + + Args: + config: Dictionary containing parser options: + - agent_roles (List[str], optional): List of role strings to be identified as the agent. + Defaults to ["agent", "Agent", "Assistant", "assistant"]. + - user_roles (List[str], optional): List of role strings to be identified as the user. + Defaults to ["user", "User"]. + - target_agent_name (str, optional): The canonical name to assign to former user roles. + Defaults to "agent". + - target_user_name (str, optional): The canonical name to assign to former agent roles. + Defaults to "user". + + Returns: + Tuple[CutSet, bool]: Converted cuts with swapped roles and audio streams, + along with a flag indicating if the data was tarred. + """ + cuts, is_tarred = read_cutset_from_config(config) + + # Roles coming from config + agent_roles = config.get("agent_roles", ["agent", "Agent", "Assistant", "assistant"]) + user_roles = config.get("user_roles", ["user", "User"]) + + # Normalize for robust matching + agent_roles_set = {r.lower() for r in agent_roles} + user_roles_set = {r.lower() for r in user_roles} + + # Canonical names you want after swapping + target_agent_name = config.get("target_agent_name", "agent") + target_user_name = config.get("target_user_name", "user") + + def swap_speaker(role: str) -> str: + """Swap a given role based on the configured user/agent sets.""" + if role is None: + return role + + role_l = role.lower() + + # user -> agent + if role_l in user_roles_set: + return target_agent_name + + # agent -> user + if role_l in agent_roles_set: + return target_user_name + + # untouched roles (e.g., narrator, system, etc.) + return role + + def convert_cut_fn(cut: Cut) -> Cut: + """Convert a single cut by swapping supervisions and audio streams.""" + new_cut = deepcopy(cut) + + # swap supervisions + if getattr(new_cut, "supervisions", None): + new_sups = [] + for s in new_cut.supervisions: + s2 = deepcopy(s) + s2.speaker = swap_speaker(getattr(s2, "speaker", None)) + new_sups.append(s2) + new_cut.supervisions = new_sups + + # swap audio streams + old_recording = new_cut.recording + old_target_audio = new_cut.target_audio + old_rec_id = old_recording.id + old_tar_id = old_target_audio.id + + new_cut.recording = old_target_audio + new_cut.target_audio = old_recording + + # keep duration consistent + if hasattr(new_cut, "duration"): + new_cut.duration = new_cut.recording.duration + + # Debug assertions + assert new_cut.target_audio.id == old_rec_id, f"{new_cut.id}: recording swap failed" + assert new_cut.recording.id == old_tar_id, f"{new_cut.id}: target_audio swap failed" + + # Optional stronger assertions (object identity) + assert new_cut.recording is old_target_audio, f"{new_cut.id}: recording object not swapped" + assert new_cut.target_audio is old_recording, f"{new_cut.id}: target_audio object not swapped" + + new_cut.task = "s2s_duplex_reverse_role" + return new_cut + + cuts = cuts.map(convert_cut_fn) + return cuts, is_tarred + + @data_type_parser(["lhotse_as_conversation"]) def read_lhotse_as_conversation(config) -> tuple[CutSet, bool]: """ diff --git a/nemo/collections/speechlm2/models/duplex_ear_tts.py b/nemo/collections/speechlm2/models/duplex_ear_tts.py index e1e982f71a63..bab8bbf8810f 100644 --- a/nemo/collections/speechlm2/models/duplex_ear_tts.py +++ b/nemo/collections/speechlm2/models/duplex_ear_tts.py @@ -131,7 +131,9 @@ def get_codec_silence_frame_last_one(self): audio, audio_len = self.pad_audio_to_factor(audio, audio_len, self.target_samples_per_frame) with ensures_target_precision(self.audio_codec_run_dtype), torch.no_grad(): - sil_codes, sil_codes_lens = self.audio_codec.encode(audio.unsqueeze(1), audio_len) + sil_codes, sil_codes_lens = self.audio_codec.encode( + audio.unsqueeze(1).to(self.audio_codec_run_dtype), audio_len + ) return sil_codes[0, -1] def get_codec_silence_frame(self): @@ -142,7 +144,9 @@ def get_codec_silence_frame(self): audio, audio_len = self.pad_audio_to_factor(audio, audio_len, self.target_samples_per_frame) with ensures_target_precision(self.audio_codec_run_dtype), torch.no_grad(): - sil_codes, _ = self.audio_codec.encode(audio.unsqueeze(1), audio_len) # [1, T, C] + sil_codes, _ = self.audio_codec.encode( + audio.unsqueeze(1).to(self.audio_codec_run_dtype), audio_len + ) # [1, T, C] sil_codes = sil_codes[0] # [T, C] # Convert each frame (C tokens) into a tuple @@ -328,7 +332,9 @@ def prepare_inputs(self, batch: dict): target_audio, target_audio_lens, self.target_samples_per_frame, 1 ) with ensures_target_precision(self.audio_codec_run_dtype), torch.no_grad(): - target_codes, target_codes_lens = self.audio_codec.encode(target_audio.unsqueeze(1), target_audio_lens) + target_codes, target_codes_lens = self.audio_codec.encode( + target_audio.unsqueeze(1).to(self.audio_codec_run_dtype), target_audio_lens + ) with fp32_precision(): target_len = target_codes.shape[1] @@ -546,10 +552,27 @@ def training_step(self, batch: dict, batch_idx: int): self.log_dict(ans, on_step=True) return ans + def ensures_codec_target_dtype(self) -> None: + """ + Ensures the audio codec is instantiated with the target dtype. + + This method checks whether `self.audio_codec` exists and whether its + parameters match `self.audio_codec_run_dtype`. If the codec is missing + or is running with the wrong dtype (e.g., due to PTL auto-downcasting), + the codec is reloaded by calling `setup_audio_codec()`. + + Intended to be called at runtime boundaries such as: + - `on_train_epoch_start` + - `on_validation_epoch_start` + """ + if hasattr(self, "audio_codec") and next(self.audio_codec.parameters()).dtype == self.audio_codec_run_dtype: + self.audio_codec.eval() + return # already correct precision → no-op + + setup_audio_codec(self) + def on_train_epoch_start(self) -> None: - ensures_codec_target_dtype( - self - ) # potentially reloads the audio codec to make sure it's in target codec precision + self.ensures_codec_target_dtype() # potentially reloads the audio codec to make sure it's in target codec precision def on_train_epoch_end(self) -> None: # log model stats to debug gradient weights issues @@ -600,9 +623,7 @@ def on_validation_epoch_start(self) -> None: if torch.distributed.is_initialized(): self.trainer.strategy.model.require_backward_grad_sync = False - ensures_codec_target_dtype( - self - ) # potentially reloads the audio codec to make sure it's in target codec precision + self.ensures_codec_target_dtype() # potentially reloads the audio codec to make sure it's in target codec precision self.results_logger = ResultsLogger(self.validation_save_path).reset() self.asr_bleu = ASRBLEU(self.cfg.scoring_asr).reset() @@ -1013,7 +1034,9 @@ def set_init_inputs(self, speaker_audio=None, speaker_audio_lens=None, system_pr [target_audio.size(-1)] * target_audio.size(0), dtype=torch.long, device=self.device ) with ensures_target_precision(self.audio_codec_run_dtype), torch.no_grad(): - code, _ = self.audio_codec.encode(target_audio.unsqueeze(1), target_audio_len) + code, _ = self.audio_codec.encode( + target_audio.unsqueeze(1).to(self.audio_codec_run_dtype), target_audio_len + ) # get context hidden if self.cfg.tts_config.context_hidden_size is not None: @@ -1205,7 +1228,7 @@ def decode_one_audio_step(self, gen_audio_codes_history, number_prev_tokens=None - audio_pred_cur_step: Latest decoded waveform chunk, shape (B, wav_to_token_ratio). - audio_len: Lengths (number of samples), shape (B,). """ - with fp32_precision(), torch.no_grad(): + with ensures_target_precision(self.audio_codec_run_dtype), torch.no_grad(): if number_prev_tokens: gen_audio_codes_history = gen_audio_codes_history[:, -number_prev_tokens:] @@ -1365,7 +1388,9 @@ def offline_inference( logging.info(f"Autoregressive inference step: {i} of {max_steps} take around {step_time}s") if not incremental_audio_decoding: - gen_audio_codes_lens = torch.tensor([gen_audio_codes.shape[1]] * gen_audio_codes.shape[0]).to(self.device) + gen_audio_codes_lens = torch.tensor( + [gen_audio_codes.shape[1]] * gen_audio_codes.shape[0], dtype=torch.long + ).to(self.device) # decode audio. Note that it is not necessary because the prompt is removed, so no special token should be on the output, but lets do it for safety gen_audio_codes = replace_control_speech_codes( gen_audio_codes, self._control_codes, self.codec_silence_tokens @@ -1634,30 +1659,6 @@ def replace_control_speech_codes( return torch.where(torch.isin(speech_codes, control_codes), speech_codes[:, :1], speech_codes) -def ensures_codec_target_dtype(model): - """ - Ensures the audio codec is instantiated with the target dtype. - - This function checks whether `model.audio_codec` exists and whether its - parameters match `model.audio_codec_run_dtype`. If the codec is missing - or is running with the wrong dtype (e.g., due to PTL auto-downcasting), - the codec is reloaded by calling `setup_audio_codec()`. - - Intended to be called at runtime boundaries such as: - - `on_train_epoch_start` - - `on_validation_epoch_start` - - Args: - model: Model instance of DuplexEARTTS - - """ - if hasattr(model, "audio_codec") and next(model.audio_codec.parameters()).dtype == model.audio_codec_run_dtype: - model.audio_codec.eval() - return # already correct precision → no-op - - setup_audio_codec(model) - - def setup_audio_codec(model): """ Instantiates the RVQ audio codec and injects codec embeddings into the TTS model. @@ -1683,6 +1684,7 @@ def setup_audio_codec(model): p.requires_grad = False model.audio_codec.eval() + model.audio_codec.to(model.device) # force codec to run in the same device as the main model assert callable(model.tts_model.set_rvq_embs) diff --git a/nemo/collections/speechlm2/modules/ear_tts_model.py b/nemo/collections/speechlm2/modules/ear_tts_model.py index 879416915cc7..9ea1b238620a 100644 --- a/nemo/collections/speechlm2/modules/ear_tts_model.py +++ b/nemo/collections/speechlm2/modules/ear_tts_model.py @@ -93,123 +93,26 @@ def forward(self, x: Tensor) -> Tensor: # ============================================================================== -# Triton-accelerated and Fallback Functions +# Core Mathematical and Masking Functions # ============================================================================== -TRITON_IMPORTED = False -try: - import triton - import triton.language as tl - - TRITON_IMPORTED = True -except ImportError: - TRITON_IMPORTED = False - -USE_TRITON = TRITON_IMPORTED and torch.cuda.is_available() - -if USE_TRITON: - logging.info("Triton available & CUDA detected. Using Triton kernel for batch_matmul.") - - @triton.jit - def batch_matmul_kernel( - x_ptr, - w_ptr, - y_ptr, - result_ptr, - b, - d_in, - d_out, - n, - BLOCK_SIZE_DIN: tl.constexpr, - BLOCK_SIZE_DOUT: tl.constexpr, - ): - batch_id = tl.program_id(axis=0) - dout_block_id = tl.program_id(axis=1) - - if batch_id >= b: - return - - idx = tl.load(y_ptr + batch_id) - - x_offset = x_ptr + batch_id * d_in - w_offset = w_ptr + idx * d_out * d_in - - dout_offsets = dout_block_id * BLOCK_SIZE_DOUT + tl.arange(0, BLOCK_SIZE_DOUT) - dout_mask = dout_offsets < d_out - - result_block = tl.zeros([BLOCK_SIZE_DOUT], dtype=tl.float32) - - for din_start in range(0, d_in, BLOCK_SIZE_DIN): - din_offsets = din_start + tl.arange(0, BLOCK_SIZE_DIN) - din_mask = din_offsets < d_in - - x_i = tl.load(x_offset + din_offsets, mask=din_mask, other=0.0) - w_i_block = tl.load( - w_offset + dout_offsets[:, None] * d_in + din_offsets[None, :], - mask=(dout_mask[:, None] & din_mask[None, :]), - other=0.0, - ) - - result_block += tl.sum(w_i_block * x_i[None, :], axis=1) - - result_offset = result_ptr + batch_id * d_out + dout_offsets - tl.store(result_offset, result_block, mask=dout_mask) - - def batch_matmul_triton(x, w, y, BLOCK_SIZE_DIN: int = 16, BLOCK_SIZE_DOUT: int = 64): - assert x.is_contiguous() and w.is_contiguous() and y.is_contiguous() - - b, d_in = x.shape - n, d_out, _ = w.shape - result = torch.empty(b, d_out, device=x.device, dtype=torch.float32) - - batch_matmul_kernel[lambda meta: (b, triton.cdiv(d_out, meta["BLOCK_SIZE_DOUT"]))]( - x.float(), - w.float(), - y, - result, - b, - d_in, - d_out, - n, - BLOCK_SIZE_DIN=BLOCK_SIZE_DIN, - BLOCK_SIZE_DOUT=BLOCK_SIZE_DOUT, - ) - - return result.to(dtype=x.dtype) - - batch_matmul = batch_matmul_triton - -else: - logging.info("Using PyTorch fallback (Triton unavailable or no CUDA).") - - # Fallback to PyTorch implementation if Triton is not available - def batch_matmul_pytorch(x: Tensor, w: Tensor, y: Tensor, *args, **kwargs) -> Tensor: - """ - Performs a batched matrix multiplication using PyTorch's native functions. - - This function serves as a fallback when Triton is not available. It achieves - the same result by gathering the appropriate weight matrices and using `torch.bmm`. - - Args: - x (Tensor): The input tensor of shape `[batch_size, d_in]`. - w (Tensor): The weight tensor of shape `[num_weights, d_out, d_in]`. - y (Tensor): The index tensor of shape `[batch_size]`. - - Returns: - Tensor: The result of the multiplication, shape `[batch_size, d_out]`. - """ - # w[y] gathers the weight matrices for each item in the batch. - # x.unsqueeze(2) reshapes x to [batch_size, d_in, 1] for bmm. - # The result is squeezed to remove the trailing dimension of size 1. - return torch.bmm(w[y], x.unsqueeze(2)).squeeze(2) - - batch_matmul = batch_matmul_pytorch +def batch_matmul(x: Tensor, w: Tensor, y: Tensor, *args, **kwargs) -> Tensor: + """ + Performs a batched matrix multiplication using PyTorch's native functions. + Args: + x (Tensor): The input tensor of shape `[batch_size, d_in]`. + w (Tensor): The weight tensor of shape `[num_weights, d_out, d_in]`. + y (Tensor): The index tensor of shape `[batch_size]`. -# ============================================================================== -# Core Mathematical and Masking Functions -# ============================================================================== + Returns: + Tensor: The result of the multiplication, shape `[batch_size, d_out]`. + """ + # w[y] gathers the weight matrices for each item in the batch. + # x.unsqueeze(2) reshapes x to [batch_size, d_in, 1] for bmm. + # The result is squeezed to remove the trailing dimension of size 1. + return torch.bmm(w[y], x.unsqueeze(2)).squeeze(2) def gumbel_like(tensor: Tensor, eps: float = 1e-8) -> Tensor: @@ -581,6 +484,8 @@ def infer(self, x: Tensor, guidance_scale: float = 0.0, top_p_or_k: float | int ).view_as(logits) ) + logits = logits.to(x.dtype) + # Sample a mixture component using the Gumbel-Max trick with fp32_precision(): mixture_indices = (F.log_softmax(logits, dim=-1) + gumbel_like(logits)).argmax(-1) @@ -606,10 +511,10 @@ def infer(self, x: Tensor, guidance_scale: float = 0.0, top_p_or_k: float | int mu_res = self.proj_else(x) else: - mu_res = torch.zeros((b, t, d), device=x.device) + mu_res = torch.zeros((b, t, d), device=x.device, dtype=x.dtype) logs = self.proj_logs(x).clamp_min(self.min_log_std) - return mu * torch.exp(logs) + mu_res, logs + return mu * torch.exp(logs.float()).to(logs.dtype) + mu_res, logs def forward(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor]: """ @@ -950,7 +855,6 @@ def forward(self, subword_ids: Tensor, subword_mask: Tensor | None = None) -> Te # 1. Convert subword IDs to character IDs char_ids, char_lengths = self.prepare_inputs(subword_ids, subword_mask) - # char_mask = sequence_mask(char_lengths).float() char_mask = sequence_mask(char_lengths) # 2. Get character embeddings and pass them through the backbone @@ -1015,7 +919,7 @@ def forward(self, audio_emb, text_emb): h = gate.to(dtype) * audio_h + (1 - gate).to(dtype) * text_h h = res.to(dtype) * h - h = self.final_norm(h.float()).to(dtype) + h = self.final_norm(h).to(dtype) return h @@ -1149,7 +1053,7 @@ def depthsum_embedding(self, code: Tensor) -> Tensor: _, v, h = self.rvq_embs.size() device = code.device - ret = torch.zeros((b, t, h), device=device) + ret = torch.zeros((b, t, h), device=device, dtype=self.rvq_embs.dtype) embs = F.pad(self.rvq_embs, [0, 0, 0, 1]) for i in range(d): emb = embs[i] @@ -1203,7 +1107,7 @@ def _prepare_conditioning( asr_speech_tokens_emb: Tensor | None, ) -> Tensor: """Computes the final conditioning tensor by combining all sources.""" - cond = torch.zeros((1, 1, self.hidden_size), device=uncond_dec_flag.device) + cond = torch.zeros((1, 1, self.hidden_size), device=uncond_dec_flag.device, dtype=self.rvq_embs.dtype) if self.embed_context is not None and context_hidden_state is not None: cond = cond + self.embed_context(context_hidden_state) diff --git a/tests/collections/common/test_lhotse_dataloading_duplex.py b/tests/collections/common/test_lhotse_dataloading_duplex.py new file mode 100644 index 000000000000..fedeb83e9fdd --- /dev/null +++ b/tests/collections/common/test_lhotse_dataloading_duplex.py @@ -0,0 +1,294 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from itertools import islice +from pathlib import Path + +import lhotse +import numpy as np +import pytest +import soundfile as sf +import torch +from lhotse import CutSet, SupervisionSegment +from omegaconf import OmegaConf + +from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config + + +class Identity(torch.utils.data.Dataset): + """Dummy dataset class to return the raw CutSet for testing dataloader output.""" + + def __getitem__(self, cuts: lhotse.CutSet) -> lhotse.CutSet: + return cuts + + +def create_wav_file(path: Path, duration: float, sample_rate: int = 16000): + """Helper to create a valid, silent WAV file on disk to bypass memory bytes serialization.""" + samples = np.zeros((1, int(duration * sample_rate)), dtype=np.float32) + sf.write(str(path), samples.T, sample_rate, format='WAV') + + +@pytest.fixture(scope="session") +def cutset_shar_s2s_overlap_path(tmp_path_factory) -> Path: + """5 utterances representing conversational overlap data as Lhotse Shar.""" + tmp_dir = tmp_path_factory.mktemp("overlap_audio") + cuts = [] + + for i in range(5): + main_path = tmp_dir / f"ov_main_{i}.wav" + create_wav_file(main_path, duration=5.0) + + c = lhotse.MonoCut( + id=f"ov_cut_{i}", + start=0.0, + duration=5.0, + channel=0, + recording=lhotse.Recording.from_file(main_path, recording_id=f"ov_main_{i}"), + ) + # Add custom overlapping segments + c.supervisions = [] + c.custom = { + "agent_segments": [{"start": 0.5, "end": 2.0, "text": "agent speaking"}], + "user_segments": [{"start": 1.0, "end": 3.0, "text": "user speaking"}], + } + cuts.append(c) + + cuts = CutSet.from_cuts(cuts) + p = tmp_path_factory.mktemp("overlap_shar") + cuts.to_shar(p, fields={"recording": "wav"}, shard_size=5) + return p + + +@pytest.fixture(scope="session") +def cutset_shar_magpietts_path(tmp_path_factory) -> Path: + """5 utterances representing MagpieTTS data with target and context audio.""" + tmp_dir = tmp_path_factory.mktemp("magpie_audio") + cuts = [] + + for i in range(5): + main_path = tmp_dir / f"mag_main_{i}.wav" + tgt_path = tmp_dir / f"mag_target_{i}.wav" + ctx_path = tmp_dir / f"mag_context_{i}.wav" + + create_wav_file(main_path, duration=2.0) + create_wav_file(tgt_path, duration=2.0) + create_wav_file(ctx_path, duration=1.0) + + c = lhotse.MonoCut( + id=f"mag_cut_{i}", + start=0.0, + duration=2.0, + channel=0, + recording=lhotse.Recording.from_file(main_path, recording_id=f"mag_main_{i}"), + ) + + c.custom = { + "target_audio": lhotse.Recording.from_file(tgt_path, recording_id=f"mag_target_{i}"), + "context_audio": lhotse.Recording.from_file(ctx_path, recording_id=f"mag_context_{i}"), + } + + c.supervisions = [ + SupervisionSegment( + id=f"sup_{i}", + recording_id=c.recording.id, + start=0.0, + duration=2.0, + text="hello", + speaker="agent", + custom={"cer": 0.01, "context_speaker_similarity": 0.9, "validation_status": "pass"}, + ) + ] + cuts.append(c) + + cuts = CutSet.from_cuts(cuts) + p = tmp_path_factory.mktemp("magpie_shar") + cuts.to_shar(p, fields={"recording": "wav"}, shard_size=5) + return p + + +@pytest.fixture(scope="session") +def regular_duplex_s2s_format(tmp_path_factory) -> Path: + """5 utterances representing duplex conversational data for role reversal.""" + tmp_dir = tmp_path_factory.mktemp("reverse_role_audio") + cuts = [] + + for i in range(5): + main_path = tmp_dir / f"rr_main_{i}.wav" + tgt_path = tmp_dir / f"rr_target_{i}.wav" + + create_wav_file(main_path, duration=3.0) + create_wav_file(tgt_path, duration=3.0) + + c = lhotse.MonoCut( + id=f"rr_cut_{i}", + start=0.0, + duration=3.0, + channel=0, + recording=lhotse.Recording.from_file(main_path, recording_id=f"rr_main_{i}"), + ) + + # Store an alternative target recording in the custom field + c.custom = {"target_audio": lhotse.Recording.from_file(tgt_path, recording_id=f"rr_target_{i}")} + + c.supervisions = [ + SupervisionSegment( + id=f"sup_{i}_1", recording_id=c.recording.id, start=0.0, duration=1.0, speaker="user", text="hello" + ), + SupervisionSegment( + id=f"sup_{i}_2", recording_id=c.recording.id, start=1.5, duration=1.0, speaker="agent", text="hi" + ), + ] + cuts.append(c) + + cuts = CutSet.from_cuts(cuts) + p = tmp_path_factory.mktemp("reverse_role_shar") + cuts.to_shar(p, fields={"recording": "wav"}, shard_size=5) + return p + + +def test_data_input_cfg_s2s_overlap(cutset_shar_s2s_overlap_path): + config = OmegaConf.create( + { + "input_cfg": [ + { + "type": "s2s_duplex_overlap_as_s2s_duplex", + "shar_path": str(cutset_shar_s2s_overlap_path), + "weight": 1.0, + "move_agent_text_back_by": 0.1, + "filter_samples_starting_with_agent": False, + "tags": { + "dataset_name": "OverlapData", + }, + }, + ], + "sample_rate": 16000, + "shuffle": True, + "num_workers": 0, + "batch_size": 2, + "seed": 0, + "shard_seed": 0, + } + ) + + dl = get_lhotse_dataloader_from_config(config=config, global_rank=0, world_size=1, dataset=Identity()) + + # Verify dataloader and transformations + batches = [batch for batch in islice(dl, 1)] + assert len(batches) == 1 + + b = batches[0] + assert isinstance(b, lhotse.CutSet) + assert all(c.custom["dataset_name"] == "OverlapData" for c in b) + + for cut in b: + assert cut.task == "s2s_duplex_overlap_as_s2s_duplex" + assert len(cut.supervisions) == 2 + + # Verify chronological sorting and offsets applied correctly + sups = sorted(cut.supervisions, key=lambda s: s.start) + assert sups[0].speaker == "agent" # agent starts at 0.5 - 0.1 = 0.4 + assert sups[0].start == pytest.approx(0.4) + assert sups[1].speaker == "user" # user starts at 1.0 + + +def test_data_input_cfg_magpietts(cutset_shar_magpietts_path): + config = OmegaConf.create( + { + "input_cfg": [ + { + "type": "lhotse_magpietts_data_as_continuation", + "shar_path": str(cutset_shar_magpietts_path), + "weight": 1.0, + "sample_rate": 22050, + "add_extra_end_silence": False, + "tags": { + "dataset_name": "MagpieData", + }, + }, + ], + "sample_rate": 22050, + "shuffle": True, + "num_workers": 0, + "batch_size": 2, + "seed": 0, + "shard_seed": 0, + } + ) + + dl = get_lhotse_dataloader_from_config(config=config, global_rank=0, world_size=1, dataset=Identity()) + + batches = [batch for batch in islice(dl, 1)] + assert len(batches) == 1 + + b = batches[0] + assert isinstance(b, lhotse.CutSet) + assert all(c.custom["dataset_name"] == "MagpieData" for c in b) + + for cut in b: + assert cut.task == "lhotse_magpietts_data_as_continuation" + assert hasattr(cut, "target_audio") + assert hasattr(cut, "context_audio") + assert hasattr(cut, "recording") + assert len(cut.supervisions) == 2 + + # Verify synthetic user/agent split behavior + assert cut.supervisions[0].speaker == "user" + assert cut.supervisions[0].duration == pytest.approx(0.08) + assert cut.supervisions[1].speaker == "agent" + + +def test_data_input_cfg_reverse_role(regular_duplex_s2s_format): + config = OmegaConf.create( + { + "input_cfg": [ + { + "type": "s2s_duplex_reverse_role", + "shar_path": str(regular_duplex_s2s_format), + "weight": 1.0, + "target_agent_name": "swapped_agent", + "target_user_name": "swapped_user", + "tags": { + "dataset_name": "ReverseRoleData", + }, + }, + ], + "sample_rate": 16000, + "shuffle": True, + "num_workers": 0, + "batch_size": 2, + "seed": 0, + "shard_seed": 0, + } + ) + + dl = get_lhotse_dataloader_from_config(config=config, global_rank=0, world_size=1, dataset=Identity()) + + batches = [batch for batch in islice(dl, 1)] + assert len(batches) == 1 + + b = batches[0] + assert isinstance(b, lhotse.CutSet) + assert all(c.custom["dataset_name"] == "ReverseRoleData" for c in b) + + for cut in b: + assert cut.task == "s2s_duplex_reverse_role" + + # Verify the roles have been inverted according to configuration overrides + sups = sorted(cut.supervisions, key=lambda s: s.start) + assert sups[0].speaker == "swapped_agent" # Originally "user" + assert sups[1].speaker == "swapped_user" # Originally "agent" + + # Ensure the recording streams were swapped + assert cut.recording.id.startswith("rr_target") + assert cut.target_audio.id.startswith("rr_main")