Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -294,8 +294,7 @@ __global__ __launch_bounds__(BLOCK_SIZE) void first_matrix__bitonic_topk_kernel(
queue.done();
float* block_topk_key =
reinterpret_cast<float*>(smem_buf_bytes + smem_result_byte_offset);
int* block_topk_value =
reinterpret_cast<int*>(block_topk_key + sizeof(float) * beam);
int* block_topk_value = reinterpret_cast<int*>(block_topk_key + beam);

queue.store(block_topk_key, block_topk_value);
for (int idx = tx; idx < beam; idx += BLOCK_SIZE) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,20 @@
@skipIfNoCuda
@skipIfNoCuCtcDecoder
class CUCTCDecoderTest(TempDirMixin, TorchaudioTestCase):
def _get_decoder(self, tokens=None, **kwargs):
def _get_decoder(self, tokens=None, beam_size=5, **kwargs):
from torchaudio.models.decoder import cuda_ctc_decoder

if tokens is None:
tokens = get_asset_path("decoder/tokens.txt")

return cuda_ctc_decoder(
tokens=tokens,
beam_size=5,
beam_size=beam_size,
**kwargs,
)

def _get_emissions(self):
B, T, N = 4, 15, NUM_TOKENS

emissions = torch.rand(B, T, N).cuda()
def _get_emissions(self, batch_size=4, num_frames=15, num_tokens=NUM_TOKENS):
emissions = torch.rand(batch_size, num_frames, num_tokens).cuda()
emissions = torch.nn.functional.log_softmax(emissions, -1)

return emissions
Expand All @@ -47,3 +45,20 @@ def test_shape(self):
decoder = self._get_decoder()
results = decoder(log_probs, encoder_out_lens)
self.assertEqual(len(results), log_probs.shape[0])

def test_large_beam_decode(self):
torch.manual_seed(0)
num_tokens = 129
beam_size = 128
# Keep the vocabulary larger than beam_size so the CUDA decoder
# exercises the full top-k shared-memory result buffers.
tokens = ["-"] + [f"token_{idx}" for idx in range(1, num_tokens)]
log_probs = self._get_emissions(batch_size=2, num_frames=4, num_tokens=num_tokens)
encoder_out_lens = torch.full((log_probs.shape[0],), log_probs.shape[1], dtype=torch.int32, device="cuda")

decoder = self._get_decoder(tokens=tokens, beam_size=beam_size)
self.assertEqual(decoder.beam_size, beam_size)
results = decoder(log_probs, encoder_out_lens)
torch.cuda.synchronize()

self.assertEqual(len(results), log_probs.shape[0])