-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
193 lines (165 loc) · 6.59 KB
/
utils.py
File metadata and controls
193 lines (165 loc) · 6.59 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
import re
import string
import pickle
import itertools
from zipfile import ZipFile
import matplotlib.pyplot as plt
import numpy as np
from nltk import tokenize
from sklearn.metrics import confusion_matrix
from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer
# Use a tokenizer that retains punctuation
tokenizer = tokenize.WordPunctTokenizer()
# Sentences data
zfile = ZipFile('data/sentences.zip')
# Vocab Index Dictionary
with open('data/processed_data/word_index.p', 'rb') as p:
word_index = pickle.load(p)
################################################################
# Load data functions
################################################################
def get_all_cases(caseids):
items = zfile.namelist()
caseids = list(caseids)
cases = [None] * len(caseids)
for item in items:
if 'contentMajOp' in item:
_,_year,fname = item.split('/')
_, year = _year.split('_')
caseid,_,_ = fname.split('_')
if caseid in caseids:
idx = caseids.index(caseid)
cases[idx] = item
return cases
def build_corpus(cases, labels=None, geniss=None, topic_filter=None):
"""
:return: List of corpus documents, list of labels
"""
if labels is not None:
assert len(labels) == len(geniss)
# Find NoneType in training major ops and remove from labels
labels = [label for i, label in enumerate(labels) if cases[i] is not None]
geniss = [topic for i, topic in enumerate(geniss) if cases[i] is not None]
cases = [item for item in cases if item is not None]
if topic_filter:
idx = [i for i, gen in enumerate(geniss) if geniss[i]==topic_filter]
labels = [label for i, label in enumerate(labels) if geniss[i]==topic_filter]
cases = [case for i, case in enumerate(cases) if geniss[i]==topic_filter]
print('Num Cases: {}, Num Labels: {}'.format(len(cases), len(labels)))
caseids = [item.split('/')[2].split('_')[0] for item in cases]
cases = [' '.join(doc) for doc in document_iterator(cases)]
if labels is not None:
assert len(cases) == len(labels)
return cases, labels, caseids
else:
return cases, caseids
def get_corpus(case):
"""
Return corpus only
"""
caseid = case.split('/')[2].split('_')[0]
corpus = [' '.join(doc) for doc in document_iterator([case])]
return corpus, caseid
def load_ngrams(corpus, ngram_range, tfidf=True, max_features=10000):
"""
Convert corpus into bag of words matrix
:param corpus: list of documents
:param ngram_range: tuple of min and max ngrams to use
:param tfidf: if True, return TF-IDF weighted matrix
:param max_features: maximum vocabulary size
:return: sparse matrix X, fitted vectorizer
"""
if tfidf:
vectorizer = TfidfVectorizer(ngram_range=ngram_range,
stop_words='english',
strip_accents='ascii',
max_df=0.95,
min_df=10,
max_features=max_features)
else:
vectorizer = CountVectorizer(ngram_range=ngram_range,
stop_words='english',
strip_accents='ascii',
max_df=0.95,
min_df=10,
max_features=max_features)
X = vectorizer.fit_transform(corpus)
return X, vectorizer
################################################################
# Visualization functions
################################################################
def plot_confusion_matrix(val_pred, val_truth, classes=[-1, 1],
normalize=True,
title='Confusion matrix',
cmap=plt.cm.Blues):
"""
From sklearn documentation
This function prints and plots the confusion matrix.
Normalization can be applied by setting `normalize=True`.
"""
cm = confusion_matrix(val_pred, val_truth)
if normalize:
cm = cm.astype('float') / cm.sum(axis=0)
print("Normalized confusion matrix")
else:
print('Confusion matrix, without normalization')
print(cm)
plt.imshow(cm, interpolation='nearest', cmap=cmap)
plt.title(title)
plt.colorbar()
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=45)
plt.yticks(tick_marks, classes)
fmt = '.2f' if normalize else 'd'
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
plt.text(j, i, format(cm[i, j], fmt),
horizontalalignment="center",
color="white" if cm[i, j] > .5 else "black")
plt.tight_layout()
plt.ylabel('Predicted label')
plt.xlabel('True label')
################################################################
# Helper functions
################################################################
def one_hot_labels(labels):
"""
Convert labels array to one hot vectors
"""
classes = len(set(labels))
Y_onehot = np.zeros((len(labels), classes))
Y_onehot[np.arange(len(labels)), labels] = 1
return Y_onehot
def trim_and_pad(sequence, max_seq_length):
"""
Trim off end if sequence > max_seq_length
Pad with 0 if sequence < max_seq_length
"""
if len(sequence) >= max_seq_length:
return sequence[:max_seq_length]
else:
diff = max_seq_length - len(sequence)
pad = np.zeros(diff, dtype=int)
return np.concatenate((sequence, pad))
def map_corpus_to_int(corpus):
int_corpus = [np.array([word_index[word] for word in doc.split(' ')
if word in word_index.keys()]) for doc in corpus]
return int_corpus
def document_iterator(items):
"""
Iterate through major opinion documents
"""
# Local path of all of our document files
for item in items:
if 'contentMajOp' not in item:
continue
_,_year,fname = item.split('/')
_, year = _year.split('_')
regex = re.compile('[%s]|\n' % re.escape(string.punctuation))
txt = zfile.open(item).read().decode().lower()
txt = regex.sub(' ', txt)
tokens = tokenizer.tokenize(txt)
# Keep only tokens greater than 1 characters
# Remove first 5 words from whole doc
tokens = [token for token in tokens if len(token)>1
and token not in tokens[:5]]
yield tokens