-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtoken_utils.py
More file actions
129 lines (107 loc) · 4.31 KB
/
token_utils.py
File metadata and controls
129 lines (107 loc) · 4.31 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import pickle
from dataclasses import dataclass
from typing import Optional
from playground.chair.chair import CHAIR, WordNetLemmatizer, nltk, wordnet
evaluator: CHAIR = pickle.load(open("./playground/chair/chair.pkl", "rb"))
assert type(evaluator) is CHAIR
@dataclass
class AlignedTokens:
start: list[Optional[int]]
end: list[Optional[int]]
tokens: list[str]
caption: str
def get_tokens_position(
input_ids, qs, tokenizer
) -> list[tuple[Optional[int], Optional[int]]]:
output: list[tuple[Optional[int], Optional[int]]] = []
end_pos = 0
for token in input_ids:
decoded_token = tokenizer.decode(
token, skip_special_tokens=False, clean_up_tokenization_spaces=False
)
if decoded_token == "<0x0A>": # LLAMA
decoded_token = "\n"
start_pos = qs.find(decoded_token, end_pos)
if start_pos == -1:
output.append((None, None))
else:
end_pos = start_pos + len(decoded_token)
output.append((start_pos, end_pos))
return output
def has_overlap(a: tuple[int, int], b: tuple[int, int]) -> bool:
return a[0] < b[1] and b[0] < a[1]
def get_overlap_tokens(
token_slices: list[tuple[Optional[int], Optional[int]]],
query_slices: list[tuple[int, int]],
) -> list[int]:
output = []
for i, (token_start, token_end) in enumerate(token_slices):
if token_start is None or token_end is None:
continue
for query_start, query_end in query_slices:
if has_overlap((token_start, token_end), (query_start, query_end)):
output.append(i)
break
return output
def get_token_indices(caption: str, tokens: list[str]) -> AlignedTokens:
start_poses: list[Optional[int]] = []
end_poses: list[Optional[int]] = []
end_pos = 0
for token in tokens:
start_pos = caption.find(token, end_pos)
if start_pos == -1:
start_poses.append(None)
end_poses.append(None)
else:
end_pos = start_pos + len(token)
start_poses.append(start_pos)
end_poses.append(end_pos)
assert len(start_poses) == len(end_poses) == len(tokens)
return AlignedTokens(
start=start_poses, end=end_poses, tokens=tokens, caption=caption
)
def new_caption_to_words(self: CHAIR, caption: str):
# Adapted from https://github.com/Maxlinn/CHAIR-metric-standalone/blob/main/chair.py
# standard preprocessing
words = nltk.word_tokenize(caption.lower())
tagged_sent = nltk.pos_tag(words)
lemmas_sent = []
wnl = WordNetLemmatizer()
for tag in tagged_sent:
wordnet_pos = self.get_wordnet_pos(tag[1]) or wordnet.NOUN
lemmas_sent.append(wnl.lemmatize(tag[0], pos=wordnet_pos))
# words = [singularize(w) for w in words]
origin_words = words
words = lemmas_sent
# replace double words
i = 0
double_words = []
origin_double_words = [] # AllPath: Added
idxs = []
while i < len(words):
idxs.append(i)
double_word = " ".join(words[i : i + 2])
origin_double_word = " ".join(origin_words[i : i + 2]) # AllPath: Added
if double_word in self.double_word_dict:
double_words.append(self.double_word_dict[double_word])
origin_double_words.append(origin_double_word) # AllPath: Added
i += 2
else:
double_words.append(words[i])
origin_double_words.append(origin_words[i]) # AllPath: Added
i += 1
words = double_words
double_words = origin_double_words # AllPath: Added
# toilet seat is not chair (sentences like "the seat of the toilet" will fire for "chair" if we do not include this line)
if ("toilet" in words) & ("seat" in words):
words = [word for word in words if word != "seat"]
# get synonyms for all words in the caption
# TODO: check what if then?
idxs = [idx for idx, word in enumerate(words) if word in set(self.mscoco_objects)]
# idxs = [idxs[idx] for idx, word in enumerate(words) if word in set(self.mscoco_objects)]
words = [word for word in words if word in set(self.mscoco_objects)]
node_words = []
for word in words:
node_words.append(self.inverse_synonym_dict[word])
# return all the MSCOCO objects in the caption
return words, node_words, idxs, double_words