Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
97ed8c8
first commit for a canary streaming prototype
andrusenkoau Mar 12, 2025
c356476
add waitk decoding policy
andrusenkoau Mar 13, 2025
292425f
add laal computation for waitk and alignatt
andrusenkoau Mar 13, 2025
593a2b0
replace args with config
andrusenkoau Mar 13, 2025
823d408
add AST task and BLEU calculation
andrusenkoau Mar 14, 2025
bcb7502
fix laal computation for empty predictions
andrusenkoau Mar 14, 2025
537634b
fix is_last_speech_chunk bug taking into acount to the downsampling v…
andrusenkoau Mar 18, 2025
7810b97
fix bleu calculation
andrusenkoau Mar 18, 2025
ca4c820
first commit for batched waitk
andrusenkoau Apr 1, 2025
36c374c
fix batched waitk decoding
andrusenkoau Apr 2, 2025
543ea0f
fix decoding speed for batched waitk
andrusenkoau Apr 2, 2025
0cfd8dc
add first working version of batched alignatt
andrusenkoau Apr 4, 2025
606594e
add token alignment for alignatt laal
andrusenkoau Apr 7, 2025
144aa63
add hallucination detector for alignatt
andrusenkoau Apr 8, 2025
19decad
fixes
andrusenkoau Apr 8, 2025
21eb3f8
add streaming bugger for offline model
andrusenkoau Apr 11, 2025
ea23a62
add first version of waitk for offline models
andrusenkoau Apr 11, 2025
23553ed
add fixes
andrusenkoau Apr 11, 2025
529ff97
fix current code
andrusenkoau Apr 11, 2025
3eb9948
fix offline model decoding with alignatt
andrusenkoau Apr 15, 2025
d0fba27
add waitk for alignatt
andrusenkoau Apr 15, 2025
f93908b
add chunked streaming buffer
andrusenkoau Apr 18, 2025
4072a67
fixes
andrusenkoau Apr 21, 2025
45411bc
bug fix
andrusenkoau Apr 23, 2025
3b2144e
Merge branch 'main' of github.com:andrusenkoau/NeMo into canary_strea…
andrusenkoau May 23, 2025
8ff964a
Apply isort and black reformatting
andrusenkoau May 23, 2025
e43ac8d
Merge branch 'main' of github.com:andrusenkoau/NeMo into canary_strea…
andrusenkoau Jul 21, 2025
364d645
first attempt
andrusenkoau Jul 25, 2025
704777f
integration of alignatt for aed models
andrusenkoau Jul 29, 2025
fe10a1a
add laal support
andrusenkoau Jul 29, 2025
922c306
minor fixes
andrusenkoau Jul 29, 2025
6afdc6b
before code reviewing
andrusenkoau Jul 31, 2025
1a60744
add chunk shift for accurate laal computation
andrusenkoau Aug 8, 2025
1402894
zero sink token for long audio recognition
andrusenkoau Aug 8, 2025
9db1c3c
minor fixes
andrusenkoau Aug 8, 2025
4993172
before exps with old buffer
andrusenkoau Aug 8, 2025
5559d9b
fix buffer
andrusenkoau Aug 11, 2025
904c726
add text normalization before wer calculation
andrusenkoau Aug 12, 2025
a1737b6
fix laal computation for new tokenizer
andrusenkoau Aug 12, 2025
cef6441
fix tokenizer for laal
andrusenkoau Aug 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

Large diffs are not rendered by default.

10 changes: 7 additions & 3 deletions examples/asr/speech_to_text_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
import torch
import transcribe_speech
from omegaconf import MISSING, OmegaConf, open_dict
from sacrebleu import corpus_bleu

from nemo.collections.asr.metrics.wer import word_error_rate
from nemo.collections.asr.parts.utils.transcribe_utils import (
Expand Down Expand Up @@ -140,19 +141,19 @@ def main(cfg: EvaluationConfig):
transcription_cfg = cfg

ground_truth_text = []
answers_text = []
predicted_text = []
invalid_manifest = False
with open(transcription_cfg.output_filename, 'r') as f:
for line in f:
data = json.loads(line)

if "pred_text" not in data:
invalid_manifest = True
break

ground_truth_text.append(data[cfg.gt_text_attr_name])

predicted_text.append(data["pred_text"])
if "answer" in data:
answers_text.append(data["answer"])

pc = PunctuationCapitalization(cfg.text_processing.punctuation_marks)
if cfg.text_processing.separate_punctuation:
Expand Down Expand Up @@ -213,6 +214,9 @@ def main(cfg: EvaluationConfig):
logging.info(f'Got {metric_name} of {metric_value}. Tolerance was {cfg.tolerance}')

logging.info(f"Dataset WER/CER {wer:.2%}/{cer:.2%}")
if answers_text:
bleu = corpus_bleu(predicted_text, [answers_text]).score
logging.info(f"Dataset BLEU {bleu:.2f}")

if cfg.use_punct_er:
dper_obj.print()
Expand Down
2 changes: 2 additions & 0 deletions nemo/collections/asr/models/aed_multitask_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -910,6 +910,8 @@ def _transcribe_forward(

log_probs, encoded_len, enc_states, enc_mask = self.forward(input_signal=audio, input_signal_length=audio_lens)

# import pdb; pdb.set_trace()

if decoder_input_ids is None:
# The dataloader provided only audio + audio_lens, so we
# are constructing the prompt dynamically using TranscribeConfig.
Expand Down
33 changes: 23 additions & 10 deletions nemo/collections/asr/modules/transformer/transformer_decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,11 @@ def __init__(
)
self.layer_norm_2 = nn.LayerNorm(hidden_size, eps=1e-5)
self.second_sub_layer = MultiHeadAttention(
hidden_size, num_attention_heads, attn_score_dropout, attn_layer_dropout
hidden_size,
num_attention_heads,
attn_score_dropout,
attn_layer_dropout,
return_xatt_scores=True,
)
self.layer_norm_3 = nn.LayerNorm(hidden_size, eps=1e-5)
self.third_sub_layer = PositionWiseFF(hidden_size, inner_size, ffn_dropout, hidden_act)
Expand All @@ -79,7 +83,7 @@ def forward_preln(self, decoder_query, decoder_mask, decoder_keys, encoder_state
residual = decoder_query
decoder_query = self.layer_norm_1(decoder_query)
decoder_keys = self.layer_norm_1(decoder_keys)
self_attn_output = self.first_sub_layer(decoder_query, decoder_keys, decoder_keys, decoder_mask)
self_attn_output, _ = self.first_sub_layer(decoder_query, decoder_keys, decoder_keys, decoder_mask)
self_attn_output += residual

if self.is_adapter_available():
Expand All @@ -95,7 +99,9 @@ def forward_preln(self, decoder_query, decoder_mask, decoder_keys, encoder_state

residual = self_attn_output
self_attn_output = self.layer_norm_2(self_attn_output)
enc_dec_attn_output = self.second_sub_layer(self_attn_output, encoder_states, encoder_states, encoder_mask)
enc_dec_attn_output, extra_output = self.second_sub_layer(
self_attn_output, encoder_states, encoder_states, encoder_mask
)
enc_dec_attn_output += residual

residual = enc_dec_attn_output
Expand All @@ -112,14 +118,14 @@ def forward_preln(self, decoder_query, decoder_mask, decoder_keys, encoder_state
pack_input = self.forward_enabled_adapters(pack_input)
output_states = pack_input['x']

return output_states
return output_states, extra_output

def forward_postln(self, decoder_query, decoder_mask, decoder_keys, encoder_states, encoder_mask):
"""
Post-LayerNorm block
Order of operations: Self-Attn -> Residual -> LN -> Cross-Attn -> Residual -> LN -> FFN -> Residual -> LN
"""
self_attn_output = self.first_sub_layer(decoder_query, decoder_keys, decoder_keys, decoder_mask)
self_attn_output, _ = self.first_sub_layer(decoder_query, decoder_keys, decoder_keys, decoder_mask)
self_attn_output += decoder_query

if self.is_adapter_available():
Expand All @@ -135,7 +141,9 @@ def forward_postln(self, decoder_query, decoder_mask, decoder_keys, encoder_stat

self_attn_output = self.layer_norm_1(self_attn_output)

enc_dec_attn_output = self.second_sub_layer(self_attn_output, encoder_states, encoder_states, encoder_mask)
enc_dec_attn_output, extra_output = self.second_sub_layer(
self_attn_output, encoder_states, encoder_states, encoder_mask
)
enc_dec_attn_output += self_attn_output
enc_dec_attn_output = self.layer_norm_2(enc_dec_attn_output)

Expand All @@ -151,7 +159,7 @@ def forward_postln(self, decoder_query, decoder_mask, decoder_keys, encoder_stat
pack_ip = self.forward_enabled_adapters(pack_ip)
output_states = pack_ip['x']

return self.layer_norm_3(output_states)
return self.layer_norm_3(output_states), xatt_scores

def forward(self, decoder_query, decoder_mask, decoder_keys, encoder_states, encoder_mask):
if self.pre_ln:
Expand Down Expand Up @@ -251,9 +259,14 @@ def forward(
else:
cached_mems_list = memory_states.unsqueeze(0)

xatt_scores_list = []

for i, layer in enumerate(self.layers):
decoder_states = layer(decoder_states, decoder_attn_mask, memory_states, encoder_states, encoder_attn_mask)
decoder_states, extra_output = layer(
decoder_states, decoder_attn_mask, memory_states, encoder_states, encoder_attn_mask
)
memory_states = self._get_memory_states(decoder_states, decoder_mems_list, i + 1)
xatt_scores_list.append(extra_output['xatt_scores'])
if return_mems:
if return_mems_as_list:
cached_mems_list.append(memory_states)
Expand All @@ -270,9 +283,9 @@ def forward(
cached_mems_list = torch.cat((cached_mems_list, memory_states.unsqueeze(0)), dim=0)

if return_mems:
return cached_mems_list
return cached_mems_list, xatt_scores_list
else:
return memory_states
return memory_states, xatt_scores_list

def input_example(self, max_batch=1, max_dim=256):
"""
Expand Down
36 changes: 33 additions & 3 deletions nemo/collections/asr/modules/transformer/transformer_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from nemo.collections.asr.parts.utils.asr_confidence_utils import ConfidenceMethodMixin
from nemo.collections.common.parts import NEG_INF, mask_padded_tokens

from nemo.utils import logging

__all__ = [
"GreedySequenceGenerator",
"TopKSequenceGenerator",
Expand Down Expand Up @@ -128,6 +130,7 @@ def _one_step_forward(
decoder_mems_list=None,
pos=0,
return_scores: bool = True,
return_xatt_scores: bool = False,
):
"""
One step of autoregressive output generation.
Expand All @@ -148,21 +151,25 @@ def _one_step_forward(
decoder_input_mask = mask_padded_tokens(decoder_input_ids, self.pad).float()

if encoder_hidden_states is not None:
decoder_mems_list = self.decoder.forward(
decoder_mems_list, xatt_scores_list = self.decoder.forward(
decoder_hidden_states,
decoder_input_mask,
encoder_hidden_states,
encoder_input_mask,
decoder_mems_list,
return_mems=True,
)
# import pdb; pdb.set_trace()
else:
decoder_mems_list = self.decoder.forward(
decoder_mems_list, _ = self.decoder.forward(
decoder_hidden_states, decoder_input_mask, decoder_mems_list, return_mems=True
)
with self.classifier.with_log_softmax_enabled(return_scores) as clf:
logits = clf.forward(hidden_states=decoder_mems_list[-1][:, -1:])
return logits, decoder_mems_list
if return_xatt_scores:
return logits, decoder_mems_list, xatt_scores_list
else:
return logits, decoder_mems_list

def _prepare_for_search(self, decoder_input_ids=None, encoder_hidden_states=None):
"""
Expand Down Expand Up @@ -202,6 +209,7 @@ def _forward(
is_sampling = self.temperature is not None and self.n_samples > 1

tgt, batch_size, max_generation_length = self._prepare_for_search(decoder_input_ids, encoder_hidden_states)
tgt_len = tgt.size(-1)
if is_sampling:
tgt = torch.repeat_interleave(tgt, self.n_samples, dim=0)
encoder_hidden_states = torch.repeat_interleave(encoder_hidden_states, self.n_samples, dim=0)
Expand All @@ -228,8 +236,15 @@ def _forward(
if i == 0:
input_ids = tgt
else:
i += tgt_len - 1
input_ids = tgt[:, -1:]

# logging.warning(f"Step {i}")
# logging.warning(f"tgt: {tgt}")
# logging.warning(f"input_ids: {input_ids}")
# if i == 14:
# raise ValueError("Stop here")

logits, decoder_mems_list = self._one_step_forward(
input_ids,
encoder_hidden_states,
Expand Down Expand Up @@ -413,6 +428,15 @@ def _forward(
scores, prefixes = torch.topk(log_probs.permute(0, 2, 1), self.beam_size, dim=1)
scores, prefixes = scores.view(-1, 1), prefixes.view(-1, 1)

# logging.warning(f"Step {0}")
# logging.warning(f"decoder_input_ids: {decoder_input_ids}")
# logging.warning(f"tgt: {tgt}")
# logging.warning(f"prefixes[:, -1:]: {prefixes[:, -1:]}")
# logging.warning("**********"*100)
# logging.warning(f"encoder_hidden_states.shape {encoder_hidden_states.shape}")
# logging.warning(f"encoder_hidden_states[0,35] {encoder_hidden_states[0,:20] }")
# raise ValueError("Stop here")

# repeat init target prefixes and cached memory states beam_size times
prefixes = torch.cat((tgt.repeat(1, self.beam_size).view(-1, tgt.shape[1]), prefixes), dim=1)
for j in range(len(decoder_mems_list)):
Expand All @@ -439,6 +463,10 @@ def _forward(
tgt_len = tgt.size(-1)
for i in range(tgt_len, max_generation_length + tgt_len):

# import pdb; pdb.set_trace()
# logging.warning(f"Step {i}")
# logging.warning(f"prefixes[:, -1:]: {prefixes[:, -1:]}")
# raise ValueError("Stop here")
# mask all finished hypotheses to exclude them from beam
pad_mask = pad_profile.repeat(1, self.beam_size)

Expand All @@ -448,6 +476,8 @@ def _forward(
)
scores_i, prefixes_i = torch.topk(log_probs[:, -1, :], self.beam_size, dim=-1)

# logging.warning(f"prefixes_i: {prefixes_i}")

# for all prefixes ending with <eos> or <pad> replace generated
# continuations with <pad>
prefixes_i = self.pad * pad_mask + prefixes_i * (1 - pad_mask)
Expand Down
38 changes: 36 additions & 2 deletions nemo/collections/asr/modules/transformer/transformer_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,28 @@ def forward(self, input_ids, token_type_ids=None, start_pos=0):
f"Input sequence is longer than maximum allowed sequence length for positional encoding. "
f"Got {seq_length} and {self.max_sequence_length}"
)

# prepare position embedding for asynchronius decoding (canary streaming)
if torch.is_tensor(start_pos):
shift_pos = start_pos.unsqueeze(-1)
start_pos = 0
else:
shift_pos = None

position_ids = torch.arange(
start=start_pos, end=start_pos + seq_length, dtype=torch.long, device=input_ids.device
)
position_ids = position_ids.unsqueeze(0).repeat(input_ids.size(0), 1)

# import pdb; pdb.set_trace()

if torch.is_tensor(shift_pos):
# shift_pos is a tensor, so we need to add it to the position_ids
# and make sure that the resulting position_ids are within the
# range of the positional embedding
position_ids = position_ids + shift_pos
# position_ids = torch.clamp(position_ids, 0, self.max_sequence_length - 1)

token_embeddings = self.token_embedding(input_ids)
position_embeddings = self.position_embedding(position_ids)
embeddings = token_embeddings + position_embeddings
Expand All @@ -140,7 +157,14 @@ class MultiHeadAttention(nn.Module):
whole layer, but before layer normalization
"""

def __init__(self, hidden_size, num_attention_heads, attn_score_dropout=0.0, attn_layer_dropout=0.0):
def __init__(
self,
hidden_size,
num_attention_heads,
attn_score_dropout=0.0,
attn_layer_dropout=0.0,
return_xatt_scores=False,
):
super().__init__()
if hidden_size % num_attention_heads != 0:
raise ValueError(
Expand All @@ -160,6 +184,8 @@ def __init__(self, hidden_size, num_attention_heads, attn_score_dropout=0.0, att
self.attn_dropout = nn.Dropout(attn_score_dropout)
self.layer_dropout = nn.Dropout(attn_layer_dropout)

self.return_xatt_scores = return_xatt_scores

def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attn_head_size)
x = x.view(*new_x_shape)
Expand All @@ -179,6 +205,9 @@ def forward(self, queries, keys, values, attention_mask):

# for numerical stability we pre-divide query and key by sqrt(sqrt(d))
attention_scores = torch.matmul(query, key.transpose(-1, -2))

# import pdb; pdb.set_trace()

if attention_mask is not None:
attention_scores = attention_scores + attention_mask.to(attention_scores.dtype)
attention_probs = torch.softmax(attention_scores, dim=-1)
Expand All @@ -193,7 +222,12 @@ def forward(self, queries, keys, values, attention_mask):
# output projection
output_states = self.out_projection(context)
output_states = self.layer_dropout(output_states)
return output_states

extra_output = {}
if self.return_xatt_scores:
extra_output['xatt_scores'] = attention_probs

return output_states, extra_output


class PositionWiseFF(nn.Module):
Expand Down
Loading