-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdecoder_implementations.py
More file actions
58 lines (43 loc) · 2.38 KB
/
decoder_implementations.py
File metadata and controls
58 lines (43 loc) · 2.38 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
# SPDX-License-Identifier: Apache-2.0
# © (2023) ETH Zurich and other contributors, see AUTHORS.txt for details
from typing import Dict
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
from entropy_based_sampling.decoder_base import BaseDynamicEnsembleDecoder
class BartDynamicEnsembleDecoder(BaseDynamicEnsembleDecoder):
"""
A dynamic ensemble decoder class for the BART and MBART models.
Inherits from the BaseDynamicEnsembleDecoder, implementing methods specific
to BART and MBART models. It is responsible for computing encoder outputs, getting
encoder outputs batch size, and running forward pass using a state.
Methods:
__init__: Initialize the BartDynamicEnsembleDecoder object.
compute_encoder_outputs: Compute the encoder outputs for the given input_ids and attention_mask.
get_encoder_outputs_batch_size: Get the batch size of the encoder outputs.
outputs_from_state: Run forward pass using a state, specifically for states with a 'model' attribute.
"""
def __init__(self, model: AutoModelForSeq2SeqLM, tokenizer: AutoTokenizer, decoding_hyperparams: Dict,
score_reduce_type: str, **kwargs):
super().__init__(model, tokenizer, decoding_hyperparams, score_reduce_type, **kwargs)
def compute_encoder_outputs(self, input_ids: torch.Tensor,
attention_mask: torch.Tensor) -> BaseModelOutput:
encoder_outputs = self.encoder(input_ids, attention_mask=attention_mask, return_dict=True)
return encoder_outputs
def get_encoder_outputs_batch_size(self, encoder_outputs: BaseModelOutput) -> int:
return encoder_outputs.last_hidden_state.shape[0]
def outputs_from_state(self, state: Dict) -> Seq2SeqLMOutput:
if len(state['past']) == 1:
encoder_outputs, past_key_values = state['past'], None
else:
encoder_outputs = state['past'][0]
past_key_values = state['past'][1]
model_inputs = state['model'].prepare_inputs_for_generation(
state['input_ids'],
past=past_key_values,
attention_mask=state['attention_mask'],
use_cache=True,
encoder_outputs=encoder_outputs
)
outputs = state['model'](**model_inputs)
return outputs