diff --git a/decode_beam.py b/decode_beam.py index c2a71ca..48f67c0 100644 --- a/decode_beam.py +++ b/decode_beam.py @@ -64,16 +64,20 @@ def eval(self, alpha=1.0): decoder = DecoderRNN() -def beam_decode(target_tensor, decoder_hiddens, encoder_outputs=None): +def beam_decode(target_tensor, + decoder_hiddens, + encoder_outputs=None, + topk=1, + beam_width=10, + SOS_token=0, + EOS_token=10): ''' :param target_tensor: target indexes tensor of shape [B, T] where B is the batch size and T is the maximum length of the output sentence :param decoder_hidden: input tensor of shape [1, B, H] for start of the decoding :param encoder_outputs: if you are using attention mechanism you can pass encoder outputs, [T, B, H] where T is the maximum length of input sentence + :param topk: how many sentence do you want to generate :return: decoded_batch ''' - - beam_width = 10 - topk = 1 # how many sentence do you want to generate decoded_batch = [] # decoding goes sentence by sentence