From 85a56758681b3c8fd864af18cc8aaba855f7c077 Mon Sep 17 00:00:00 2001 From: Fan Yang Date: Fri, 20 Jan 2023 16:33:30 +0800 Subject: [PATCH] add arguments for beam_search func --- decode_beam.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/decode_beam.py b/decode_beam.py index c2a71ca..48f67c0 100644 --- a/decode_beam.py +++ b/decode_beam.py @@ -64,16 +64,20 @@ def eval(self, alpha=1.0): decoder = DecoderRNN() -def beam_decode(target_tensor, decoder_hiddens, encoder_outputs=None): +def beam_decode(target_tensor, + decoder_hiddens, + encoder_outputs=None, + topk=1, + beam_width=10, + SOS_token=0, + EOS_token=10): ''' :param target_tensor: target indexes tensor of shape [B, T] where B is the batch size and T is the maximum length of the output sentence :param decoder_hidden: input tensor of shape [1, B, H] for start of the decoding :param encoder_outputs: if you are using attention mechanism you can pass encoder outputs, [T, B, H] where T is the maximum length of input sentence + :param topk: how many sentence do you want to generate :return: decoded_batch ''' - - beam_width = 10 - topk = 1 # how many sentence do you want to generate decoded_batch = [] # decoding goes sentence by sentence