diff --git a/src/libtorchaudio/cuctc/src/ctc_prefix_decoder_kernel_v2.cu b/src/libtorchaudio/cuctc/src/ctc_prefix_decoder_kernel_v2.cu index 92edcc83c1..155a25fa9b 100644 --- a/src/libtorchaudio/cuctc/src/ctc_prefix_decoder_kernel_v2.cu +++ b/src/libtorchaudio/cuctc/src/ctc_prefix_decoder_kernel_v2.cu @@ -294,8 +294,7 @@ __global__ __launch_bounds__(BLOCK_SIZE) void first_matrix__bitonic_topk_kernel( queue.done(); float* block_topk_key = reinterpret_cast(smem_buf_bytes + smem_result_byte_offset); - int* block_topk_value = - reinterpret_cast(block_topk_key + sizeof(float) * beam); + int* block_topk_value = reinterpret_cast(block_topk_key + beam); queue.store(block_topk_key, block_topk_value); for (int idx = tx; idx < beam; idx += BLOCK_SIZE) { diff --git a/test/torchaudio_unittest/models/decoder/cuda_ctc_decoder_test.py b/test/torchaudio_unittest/models/decoder/cuda_ctc_decoder_test.py index 88f0216f33..1d22fd56b9 100644 --- a/test/torchaudio_unittest/models/decoder/cuda_ctc_decoder_test.py +++ b/test/torchaudio_unittest/models/decoder/cuda_ctc_decoder_test.py @@ -13,7 +13,7 @@ @skipIfNoCuda @skipIfNoCuCtcDecoder class CUCTCDecoderTest(TempDirMixin, TorchaudioTestCase): - def _get_decoder(self, tokens=None, **kwargs): + def _get_decoder(self, tokens=None, beam_size=5, **kwargs): from torchaudio.models.decoder import cuda_ctc_decoder if tokens is None: @@ -21,14 +21,12 @@ def _get_decoder(self, tokens=None, **kwargs): return cuda_ctc_decoder( tokens=tokens, - beam_size=5, + beam_size=beam_size, **kwargs, ) - def _get_emissions(self): - B, T, N = 4, 15, NUM_TOKENS - - emissions = torch.rand(B, T, N).cuda() + def _get_emissions(self, batch_size=4, num_frames=15, num_tokens=NUM_TOKENS): + emissions = torch.rand(batch_size, num_frames, num_tokens).cuda() emissions = torch.nn.functional.log_softmax(emissions, -1) return emissions @@ -47,3 +45,20 @@ def test_shape(self): decoder = self._get_decoder() results = decoder(log_probs, encoder_out_lens) self.assertEqual(len(results), log_probs.shape[0]) + + def test_large_beam_decode(self): + torch.manual_seed(0) + num_tokens = 129 + beam_size = 128 + # Keep the vocabulary larger than beam_size so the CUDA decoder + # exercises the full top-k shared-memory result buffers. + tokens = ["-"] + [f"token_{idx}" for idx in range(1, num_tokens)] + log_probs = self._get_emissions(batch_size=2, num_frames=4, num_tokens=num_tokens) + encoder_out_lens = torch.full((log_probs.shape[0],), log_probs.shape[1], dtype=torch.int32, device="cuda") + + decoder = self._get_decoder(tokens=tokens, beam_size=beam_size) + self.assertEqual(decoder.beam_size, beam_size) + results = decoder(log_probs, encoder_out_lens) + torch.cuda.synchronize() + + self.assertEqual(len(results), log_probs.shape[0])