-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
127 lines (106 loc) · 3.84 KB
/
utils.py
File metadata and controls
127 lines (106 loc) · 3.84 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import pandas as pd
from tqdm import tqdm
import numpy as np
import seaborn as sn
import matplotlib.pyplot as plt
from constants1 import *
def infer_sentences(model, sentences, start):
"""
Args:
model (POSTagger): model used for inference
sentences (list[str]): list of sentences to infer by single process
start (int): index of first sentence in sentences in the original list of sentences
Returns:
dict: index, predicted tags for each sentence in sentences
"""
res = {}
for i in range(len(sentences)):
res[start+i] = model.inference(sentences[i])
return res
def compute_prob(model, sentences, tags, start):
"""
Args:
model (POSTagger): model used for inference
sentences (list[str]): list of sentences
sentences (list[str]): list of tags
start (int): index of first sentence in sentences in the original list of sentences
Returns:
dict: index, probability for each sentence,tag pair
"""
res = {}
for i in range(len(sentences)):
res[start+i] = model.sequence_probability(sentences[i], tags[i])
return res
#from https://stackoverflow.com/questions/6294179/how-to-find-all-occurrences-of-an-element-in-a-list
def indices(lst, element):
result = []
offset = -1
while True:
try:
offset = lst.index(element, offset+1)
except ValueError:
return result
result.append(offset)
def load_data(sentence_file, tag_file=None):
"""Loads data from two files: one containing sentences and one containing tags.
tag_file is optional, so this function can be used to load the test data.
Suggested to split the data by the document-start symbol.
"""
df_sentences = pd.read_csv(open(sentence_file))
doc_start_indexes = df_sentences.index[df_sentences['word'] == '-DOCSTART-'].tolist()
num_sentences = len(doc_start_indexes)
sentences = [] # each sentence is a list of tuples (index,word)
if tag_file:
df_tags = pd.read_csv(open(tag_file))
tags = []
for i in tqdm(range(num_sentences)):
index = doc_start_indexes[i]
if i == num_sentences-1:
# handle last sentence
next_index = len(df_sentences)
else:
next_index = doc_start_indexes[i+1]
sent = []
tag = []
for j in range(index, next_index):
word = df_sentences['word'][j].strip()
if not CAPITALIZATION or word == '-DOCSTART-':
word = word.lower()
sent.append(word)
if tag_file:
tag.append((df_tags['tag'][j]))
if STOP_WORD:
sent.append('<STOP>')
sentences.append(sent)
if tag_file:
if STOP_WORD:
tag.append('<STOP>')
tags.append(tag)
if tag_file:
return sentences, tags
return sentences
def confusion_matrix(tag2idx,idx2tag, pred, gt, fname):
"""Saves the confusion matrix
Args:
tag2idx (dict): tag to index dictionary
idx2tag (dict): index to tag dictionary
pred (list[list[str]]): list of predicted tags
gt (_type_): _description_
fname (str): filename to save confusion matrix
"""
matrix = np.zeros((len(tag2idx), len(tag2idx))) #-2 for start/end states
flat_pred = []
flat_y = []
for p in pred:
flat_pred.extend(p)
for true in gt:
flat_y.extend(true)
for i in range(len(flat_pred)):
idx_pred = tag2idx[flat_pred[i]]
idx_y = tag2idx[flat_y[i]]
matrix[idx_y][idx_pred] += 1
df_cm = pd.DataFrame(matrix, index = [idx2tag[i] for i in range(len(tag2idx))],
columns = [idx2tag[i] for i in range(len(tag2idx))])
plt.figure(figsize = (10,7))
sn.heatmap(df_cm, annot=False)
plt.savefig(fname)