-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodels.py
More file actions
100 lines (85 loc) · 4.15 KB
/
models.py
File metadata and controls
100 lines (85 loc) · 4.15 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
# this file contains various models used in the training and inference scripts
import torch
from transformers import AutoModelForSequenceClassification
class HuggingfaceModel(torch.nn.Module):
def __init__(self, model_huggingface):
super(HuggingfaceModel, self).__init__()
self.model_huggingface = model_huggingface
def forward(self, x):
outputs_huggingface = self.model_huggingface(**x)
return outputs_huggingface["logits"]
class BartModelForEnsemble(torch.nn.Module):
def __init__(self, model_state_dict, tokenizer=None):
super(BartModelForEnsemble, self).__init__()
model_huggingface = AutoModelForSequenceClassification.from_pretrained("facebook/bart-base", num_labels=3)
if tokenizer is not None:
model_huggingface.resize_token_embeddings(len(tokenizer))
model = HuggingfaceModel(model_huggingface)
model.load_state_dict(model_state_dict)
self.model = model.model_huggingface.model
# self.model = HuggingfaceModel(model_huggingface).model_huggingface.model
def forward(self, x):
outputs = self.model(**x)
mask_eos = x["input_ids"].eq(self.model.config.eos_token_id)
hidden_state_eos = outputs["last_hidden_state"][mask_eos]
return hidden_state_eos
class BertModelForEnsemble(torch.nn.Module):
def __init__(self, model_state_dict, tokenizer=None):
super(BertModelForEnsemble, self).__init__()
model_huggingface = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)
if tokenizer is not None:
model_huggingface.resize_token_embeddings(len(tokenizer))
model = HuggingfaceModel(model_huggingface)
model.load_state_dict(model_state_dict)
self.model = model.model_huggingface.bert
def forward(self, x):
outputs = self.model(**x)
hidden_state_cls = outputs["last_hidden_state"][:,0]
return hidden_state_cls
class BertweetModelForEnsemble(torch.nn.Module):
def __init__(self, model_state_dict, tokenizer=None):
super(BertweetModelForEnsemble, self).__init__()
model_huggingface = AutoModelForSequenceClassification.from_pretrained("vinai/bertweet-base", num_labels=2)
if tokenizer is not None:
model_huggingface.resize_token_embeddings(len(tokenizer))
model = HuggingfaceModel(model_huggingface)
model.load_state_dict(model_state_dict)
self.model = model.model_huggingface.roberta
def forward(self, x):
outputs = self.model(**x)
hidden_state_cls = outputs["last_hidden_state"][:,0]
return hidden_state_cls
class XLNetModelForEnsemble(torch.nn.Module):
def __init__(self, model_state_dict, tokenizer=None):
super(XLNetModelForEnsemble, self).__init__()
model_huggingface = AutoModelForSequenceClassification.from_pretrained("xlnet-base-cased", num_labels=2)
if tokenizer is not None:
model_huggingface.resize_token_embeddings(len(tokenizer))
model = HuggingfaceModel(model_huggingface)
model.load_state_dict(model_state_dict)
self.model = model.model_huggingface.transformer
def forward(self, x):
outputs = self.model(**x)
hidden_state_cls = outputs["last_hidden_state"][:,-1]
return hidden_state_cls
class EnsembleModel(torch.nn.Module):
def __init__(self, list_models, freeze_models=False, size_hidden_state=2):
super(EnsembleModel, self).__init__()
self.list_models = torch.nn.ModuleList(list_models)
self.layer_linear = torch.nn.Linear(
in_features=len(list_models) * size_hidden_state,
out_features=2,
)
if freeze_models:
for model in self.list_models:
for param in model.parameters():
param.requires_grad = False
def forward(self, x):
list_logits = []
for i in range(len(self.list_models)):
model = self.list_models[i]
logits = model(x[i])
list_logits.append(logits)
tmp = torch.cat(list_logits, axis=1)
logits = self.layer_linear(tmp)
return logits