-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
195 lines (185 loc) · 10.1 KB
/
utils.py
File metadata and controls
195 lines (185 loc) · 10.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
import string
import numpy as np
from torch.nn import functional as F
import torch
import os
BASE_DIR = '..'
model_to_path_dict={
'Llama-2-13b-chat-hf':{
'hf_name':"meta-llama/Llama-2-13b-chat-hf",
'save_dir_name':"llama2-chat",
'initial_char':'▁',
},
'gemma-7b':{
'hf_name':"google/gemma-7b",
'save_dir_name':'gemma-7b',
'initial_char':None,
},
'Llama3-8b-instruct':{
'hf_name':"meta-llama/Meta-Llama-3-8B-Instruct",
'save_dir_name':'llama3-8b-instruct',
'initial_char':'Ġ',
},
'mistral-7b-instruct':{
'hf_name':"mistralai/Mistral-7B-Instruct-v0.3",
'save_dir_name':'mistral-7b-instruct',
'initial_char':'▁',
},
'Llama3.2-3b-instruct':{
'hf_name':"meta-llama/Llama-3.2-3B-Instruct",
'save_dir_name':'llama3_2-3b-instruct',
'initial_char':'Ġ',
},
'Llama3-8b':{
'hf_name':"meta-llama/Meta-Llama-3-8B",
'save_dir_name':'llama3-8b',
'initial_char':'Ġ',
},
'gemma-2-9b-it':{
'hf_name':"google/gemma-2-9b-it",
'save_dir_name':'gemma-2-9b-it',
'initial_char':'▁',
},
'Llama3.2-3b-instruct_finetuned':{
'hf_name':os.path.join(BASE_DIR,'training','results','checkpoint-375'),
'save_dir_name':'Llama3.2-3b-instruct_finetuned',
'initial_char':'Ġ',
},
}
def check_word_count(original_txt,tokenized_txt):
# check word count
count = 0
for t in tokenized_txt:
if '▁' in t:
count+=1
assert count == len(original_txt.split(' ')),'Word count in original transcript and tokenized txt must be the same'
def calculate_cross_entropy(tokens,logits,base2 = False):
# torch CE uses natural log, equivalent to -log(softmax(logit of target token))
device = 'cuda'
ce = F.cross_entropy(logits.to(device), tokens.to(device),reduction = 'none')
if base2:
ce = ce/torch.log(torch.Tensor([2])).to(device)
ce = ce.to('cpu')
return ce
def normalize_entropy(tokens_lower,entropy_lower,mean_token_entropy,verbose = True):
# normalize the entropy of each token by mean entropy
assert tokens_lower.shape==entropy_lower.shape
normalized_entropy = np.zeros(entropy_lower.shape[0])-1
for i,uncorrected_entropy in enumerate(entropy_lower):
curr_token = int(tokens_lower[i])
if curr_token in mean_token_entropy.keys():
mean_curr_token_entropy = mean_token_entropy[curr_token]
normalized_entropy[i] = uncorrected_entropy/mean_curr_token_entropy
if normalized_entropy[i] == np.inf or np.isnan(normalized_entropy[i]):
print(i) # this hopefully shouldn't print
normalized_entropy[i] = 1
else:
# fill in 1, ie treat it as the mean entropy of all occurrences of this token
normalized_entropy[i] = 1
if verbose:
print('%d - \'%s\' not in mean_token_entropy dict'%(curr_token,tokenizer.decode(curr_token)))
assert sum(normalized_entropy==-1)==0
return normalized_entropy
def segmentation_to_word_list(human_output):
'''
Input: list of strings, each string is an event
Output: list of strings, each string is a word. The string has a \n at the end of the word if the human segmented after that point
'''
human_output_newline = [s[:-1] if s[-1] == ' ' else s for s in human_output] # takes care of whitespace
human_output_newline = [s[1:] if s[0] == ' ' else s for s in human_output_newline] # takes care of whitespace
human_output_newline = [s+'\n' for s in human_output_newline] # add \n to end of each segmentation
human_output_split = ' '.join(human_output_newline).split(' ')
return human_output_split
# task: need to know where these new lines are in the tokenized text
# want: indices of tokens where a segmentation happened immediately after that token
def get_segmentation_indices(tokenized_txt,segmented_word_list,original_txt,punctuated = False,initial_char='▁'):
'''
task: need to know where these new lines are in the tokenized text
want: indices of tokens where a segmentation happened immediately after that token
input:
tokenized_txt: list generated by tokenizer.tokenize
segmented_word_list: list of words. output from segmentation_to_word_list
output:
segmentation_indices_in_tokens: list of integer. humans segmented right after each of these tokens in tokenized_txt
(ie this token index is still in the previous event)
initial_char: character used by tokenizers in front of word initials. 'Ġ' for llama-3, '▁' for llama-2
'''
words_in_tokenized_txt = [] # put tokenized text back to word form
curr_word = ''
i = 0
segmentation_indices_in_tokens = [] # indices of tokens. humans segmented right after each of these tokens
# strip chars to check the punctuated transcript
strip_chars = string.punctuation + '-–'
strip_chars = strip_chars.translate(str.maketrans('', '', '\'')) # don't strip ' (quote)
for i,curr_token in enumerate(tokenized_txt):
if initial_char in tokenized_txt[i]: # this must be the start of a new word
curr_word = curr_token
else:
curr_word += curr_token
if i+1 < len(tokenized_txt):
# condition 1: this word is finished
# condition 2: gemma does this weird thing, where if there are 2 consecutive spaces, it tokenize the space with leading __ but no _ in front of the next word
if initial_char in tokenized_txt[i+1] or (initial_char=='▁' and curr_token=='▁▁'):
words_in_tokenized_txt.append(curr_word)
corresponding_word = segmented_word_list[len(words_in_tokenized_txt)-1]
if not punctuated:
if curr_word.translate(str.maketrans('', '', initial_char)) != corresponding_word.translate(str.maketrans('', '', '\n')):
print(curr_word,corresponding_word)
print(segmented_word_list[:len(words_in_tokenized_txt)-1])
assert curr_word.translate(str.maketrans('', '', initial_char)) == corresponding_word.translate(str.maketrans('', '', '\n'))
else:
corresponding_word_strip = corresponding_word.translate(str.maketrans('', '', '\n'))
corresponding_word_strip = corresponding_word_strip.translate(str.maketrans('', '', strip_chars)).lower()
if curr_word.translate(str.maketrans('', '', initial_char)).lower().translate(str.maketrans('', '', strip_chars)) != corresponding_word_strip:
print(curr_word,corresponding_word)
print(curr_word.translate(str.maketrans('', '', initial_char)), corresponding_word_strip)
assert curr_word[1:].translate(str.maketrans('', '', initial_char)).lower().translate(str.maketrans('', '', strip_chars)) == corresponding_word_strip
if '\n' in corresponding_word:
segmentation_indices_in_tokens.append(i)
#print(tokenized_txt[i],curr_word,corresponding_word)
else:
words_in_tokenized_txt.append(curr_word)
corresponding_word = segmented_word_list[len(words_in_tokenized_txt)-1]
if not punctuated:
if curr_word.translate(str.maketrans('', '', initial_char)) != corresponding_word.translate(str.maketrans('', '', '\n')):
print(curr_word,corresponding_word)
assert curr_word.translate(str.maketrans('', '', initial_char)) == corresponding_word.translate(str.maketrans('', '', '\n'))
else:
corresponding_word_strip = corresponding_word.translate(str.maketrans('', '', '\n'))
corresponding_word_strip = corresponding_word_strip.translate(str.maketrans('', '', strip_chars)).lower()
if curr_word.translate(str.maketrans('', '', initial_char)).lower().translate(str.maketrans('', '', strip_chars)) != corresponding_word_strip:
print(curr_word,corresponding_word)
assert curr_word.translate(str.maketrans('', '', initial_char)).lower().translate(str.maketrans('', '', strip_chars)) == corresponding_word_strip
if '\n' in corresponding_word:
segmentation_indices_in_tokens.append(i)
#print(tokenized_txt[i],curr_word,corresponding_word)
len_original_txt = len(original_txt.split(' '))
if original_txt.split(' ')[0]== '':
len_original_txt-=1 # for stories with a leading space, tokenizers won't tokenize the leading space separately, but the .split(' ') will produce an empty string
if len(words_in_tokenized_txt) != len_original_txt:
print(len(words_in_tokenized_txt),len_original_txt)
for i in range(len(words_in_tokenized_txt)):
tokenized_txt_word = words_in_tokenized_txt[i]
original_word = original_txt.split(' ')[i]
if tokenized_txt_word.translate(str.maketrans('', '', initial_char))!=original_word.translate(str.maketrans('', '', '\n')):
print(i,tokenized_txt_word,original_word)
print(words_in_tokenized_txt[-1],original_txt.split(' ')[-1])
assert len(words_in_tokenized_txt) == len_original_txt
return segmentation_indices_in_tokens
def find_subtensor_indices(a, b):
# Create a sliding window to find where tensor a occurs in tensor b
for i in range(b.size(0) - a.size(0) + 1):
if torch.equal(b[i:i + a.size(0)], a):
start_index = i
end_index = i + a.size(0)
return start_index, end_index
return None # If a is not a sub-tensor of b
def find_event(tokenizer,story_tokens,event_txt,start_idx):
'''find the token indices of an event by decoding the tokens until you find it'''
search_tokens = story_tokens[start_idx:]
i = 1
decoded = ''
while i < search_tokens.shape[0] and decoded!=event_txt:
decoded = tokenizer.decode(search_tokens[:i])
i+=1
return i+start_idx-1