forked from uber-research/PPLM
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpredict_adapted.py
More file actions
102 lines (83 loc) · 4.02 KB
/
predict_adapted.py
File metadata and controls
102 lines (83 loc) · 4.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from transformers import MarianTokenizer, GenerationConfig, LogitsProcessorList
from perturbable_marianmt_model import PerturbableMarianMTModel
from bag_of_words_processor import get_bag_of_words_vectors
from perturb_past import PerturbationArgs
from multiprocessing import Pool
from forced_prefix_logits_processor import ForcedPrefixLogitsProcessor
def _get_bags_of_words(hyperparameters, device, tokenizer):
positive_bags_of_words = hyperparameters.pop("bag_of_words", None)
if type(positive_bags_of_words) is str:
positive_bags_of_words_paths = positive_bags_of_words.split(";")
positive_bags_of_words = None
else:
positive_bags_of_words_paths = None
negative_bags_of_words = hyperparameters.pop("negative_bag_of_words", None)
if type(negative_bags_of_words) is str:
negative_bags_of_words_paths = negative_bags_of_words.split(";")
negative_bags_of_words = None
else:
negative_bags_of_words_paths = None
# set up perturbation args
positive_bow = get_bag_of_words_vectors(
tokenizer,
bag_of_words=positive_bags_of_words,
bag_of_words_paths=positive_bags_of_words_paths,
device=device
) if positive_bags_of_words is not None or positive_bags_of_words_paths is not None else None
negative_bow = get_bag_of_words_vectors(
tokenizer,
bag_of_words=negative_bags_of_words,
bag_of_words_paths=negative_bags_of_words_paths,
device=device
) if negative_bags_of_words is not None or negative_bags_of_words_paths is not None else None
return positive_bow,negative_bow
def chunk(l, n):
for i in range(0, len(l), n):
yield l[i:i + n]
def load_model(hyperparameters, device="cpu"):
pretrained_model = hyperparameters["translation_model"]
model = PerturbableMarianMTModel.from_pretrained(
pretrained_model,
output_hidden_states=True
)
model.to(device)
model.eval()
# load tokenizer
tokenizer = MarianTokenizer.from_pretrained(pretrained_model)
return model, tokenizer
def _make_adapted_predictions(inputs):
source_texts, hyperparameters, device = inputs
print(f"Using device {device}")
model, tokenizer = load_model(hyperparameters, device)
positive_bow, negative_bow = _get_bags_of_words(hyperparameters, device, tokenizer)
args = PerturbationArgs(
positive_bag_of_words=positive_bow,
negative_bag_of_words=negative_bow,
**hyperparameters
)
max_length = hyperparameters.pop("length", 100)
warmup_steps = hyperparameters.pop("warmup_steps", 0)
if warmup_steps > 0:
logits_processor = LogitsProcessorList([ForcedPrefixLogitsProcessor([tokenizer.pad_token_id] * warmup_steps)])
else:
logits_processor = None
generation_config = GenerationConfig(num_beams=1, do_sample=False, max_new_tokens=max_length-1)
predictions = []
for texts in source_texts:
encoded_texts = tokenizer(texts, padding=True, return_tensors="pt")
input_ids = encoded_texts.input_ids.to(device) # [batch_size, max_seq_len]
attention_mask = encoded_texts.attention_mask.to(device) # [batch_size, max_seq_len]
results = model.generate(args, input_ids, attention_mask=attention_mask, generation_config=generation_config, logits_processor=logits_processor)
decoded_results = tokenizer.batch_decode(results, skip_special_tokens=True)
predictions.extend(decoded_results)
print(decoded_results)
return predictions
def make_adapted_predictions(source_texts, hyperparameters, batch_size=50, worker_count=4, device="cpu"):
if worker_count == 1:
return _make_adapted_predictions((source_texts, hyperparameters, device))
batches = list(chunk(source_texts, batch_size))
with Pool(processes=worker_count) as pool:
inputs = [(batch, hyperparameters, device) for batch in batches]
results = pool.map(_make_adapted_predictions, inputs)
print("FINAL RESULTS:", results)
return [item for sublist in results for item in sublist]