-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest.py
More file actions
executable file
·201 lines (172 loc) · 7.93 KB
/
test.py
File metadata and controls
executable file
·201 lines (172 loc) · 7.93 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
import unittest
from util import *
from util_connectivity import *
from util_virtual_resection import *
with open('../data/TEST_DATA.json') as json_data_file:
data = json.load(json_data_file)
class CorrespondNamesTest(unittest.TestCase):
'''
This unit test checks conversion between EEG labels from IEEG.org and cartoon map labels (e.g. LG64 and LG064-Ref).
'''
def test(self):
patient_id = 'Study029'
dilate_radius = 0
data_dir = os.path.expanduser(data['REAL_DATA_DIR'])
for event_type, events in data['PATIENTS'][patient_id]['Events'].items():
for event_id in events.keys():
fn = os.path.join(data_dir, patient_id, 'eeg', events[event_id]['FILE'])
eeg_channel_labels = []
# Get channels, ECoG Data, Fsx
with h5py.File(fn) as f:
evData = f['evData'].value
Fs = f['Fs'].value
for column in f['channels']:
row_data = []
for row_number in range(len(column)):
row_data.append(''.join(map(unichr, f[column[row_number]][:])))
eeg_channel_labels.append(row_data)
Fs = int(Fs[0][0])
eeg_channel_labels = eeg_channel_labels[0]
# evData = scipy.stats.zscore(evData,axis=1)
T = evData.shape[0]
break
break
cartoon_map_labels = map(lambda x: x.split(',')[4].replace('\n',''), open(os.path.expanduser(
data['PATIENTS'][patient_id]['ELECTRODE_LABELS']
),'r').readlines())
res_dict = correspond_label_names(eeg_channel_labels, cartoon_map_labels)
for k,v in sorted(res_dict.items(),key=lambda x: x[0]):
print k,v[1]
self.assertTrue(True)
class DataTest(unittest.TestCase):
'''
This unit test checks compatibility of datasets as defined by the TEST_DATA.json config file.
'''
def test(self):
for event_type, events in data['PATIENTS']['TEST1']['Events'].items():
fn = os.path.join(os.path.expanduser(data['DATA_DIR']),
'TEST1',
'eeg',
events['1']['FILE']
)
channels = []
# Get channels, ECoG Data, Fsx
with h5py.File(fn) as f:
evData = f['evData'].value
Fs = f['Fs'].value
for column in f['channels']:
row_data = []
for row_number in range(len(column)):
row_data.append(''.join(map(unichr, f[column[row_number]][:])))
channels.append(row_data)
Fs = int(Fs[0][0])
channels = list(np.squeeze(np.array(channels)))
# evData = scipy.stats.zscore(evData,axis=1)
T = evData.shape[0]
assert evData.shape == (450000,60)
assert Fs == 500
class ResectionZoneTest(unittest.TestCase):
'''
This unit test computes and prints electrodes in the resection zone for a given patient ID.
'''
def test(self):
patient_id = 'Study029'
data_dir = os.path.expanduser(data['REAL_DATA_DIR'])
labels = map(lambda x: x.split(',')[4].replace('\n',''), open(os.path.expanduser(
data['PATIENTS'][patient_id]['ELECTRODE_LABELS']
),'r').readlines())
# Load ignored node labels
ignored_node_labels = data['PATIENTS'][patient_id]['IGNORE_ELECTRODES']
for ignored_node_label in ignored_node_labels:
if(ignored_node_label not in labels):
labels.append(ignored_node_label)
for event_type, events in data['PATIENTS'][patient_id]['Events'].items():
for event_id in events.keys():
fn = os.path.join(data_dir, patient_id, 'eeg', events[event_id]['FILE'])
channels = []
# Get channels, ECoG Data, Fsx
with h5py.File(fn) as f:
evData = f['evData'].value
Fs = f['Fs'].value
for column in f['channels']:
row_data = []
for row_number in range(len(column)):
row_data.append(''.join(map(unichr, f[column[row_number]][:])))
channels.append(row_data)
Fs = int(Fs[0][0])
channels = channels[0]
# evData = scipy.stats.zscore(evData,axis=1)
T = evData.shape[0]
# Correspond lable names
labels_dict = correspond_label_names(channels, labels)
# Load electrodes to ignore
ignored_node_idx = map(lambda x: labels_dict[x][0], ignored_node_labels)
for ii,node_id in enumerate(ignored_node_idx):
print 'Ignoring node label: %s because label %s is in IGNORE_ELECTRODES'%(channels[node_id],ignored_node_labels[ii])
channels = list(np.delete(np.array(channels),ignored_node_idx))
# Recorrespond label names
labels_dict = correspond_label_names(channels, labels)
break
break
dilate_radius = -5
print 'Printing resected electrodes with erosion of 5% of network'
print get_resected_electrodes(patient_id, dilate_radius, data, labels_dict)
dilate_radius = 0
print 'Printing resected electrodes with no dilation'
print get_resected_electrodes(patient_id, dilate_radius, data, labels_dict)
dilate_radius = 5
print 'Printing resected electrodes with dilation of 5% of network'
print get_resected_electrodes(patient_id, dilate_radius, data, labels_dict)
self.assertTrue(True)
class ConnectivityTest(unittest.TestCase):
'''
This unit test creates adjacency multiband connectivity matrices and checks correctness of output.
'''
def test(self):
compute_multiband_connectivity('TEST1', 1, data)
self.assertTrue(True)
class VirtualResectionTest(unittest.TestCase):
'''
This unit test computes c_res(t) based on multiband connectivity matrices and checks correctness of output.
'''
def test(self):
# Run virtual resection on all files
unique_idx = virtual_resection('TEST1',0,data)
self.assertTrue(True)
class NullVirtualResectionTest(unittest.TestCase):
'''
This unit test computes c_null(t) based on multiband connectivity matrices and checks correctness of output.
'''
def test(self):
comp_dir = os.path.join(os.path.expanduser(data['COMP_DIR']),'TEST1','aim3')
unique_idx = []
for fn in os.listdir(comp_dir):
try:
print fn
match = re.match(r'[A-Za-z0-9]+.([A-Za-z]+).([0-9]+).cres.([0-9a-zA-Z-]+).npz',fn)
unique_idx.append((match.group(3),match.group(1),match.group(2)))
except AttributeError:
continue
# Run virtual resection on all files
for unique_id, event_type, event_id in unique_idx:
null_virtual_resection('TEST1', unique_id, event_type, event_id, 0,data)
self.assertTrue(True)
class PlotVirtualResectionTest(unittest.TestCase):
'''
This unit test generates the plot figures.
'''
def test(self):
comp_dir = os.path.join(os.path.expanduser(data['COMP_DIR']),'TEST1','aim3')
unique_idx = []
for fn in os.listdir(comp_dir):
try:
match = re.match(r'[A-Za-z0-9]+.([A-Za-z]+).([0-9]+).cres.([0-9a-zA-Z-]+).npz',fn)
unique_idx.append((match.group(3),match.group(1),match.group(2)))
except AttributeError:
continue
# Plot virtual resection results on all files
for unique_id, event_type, event_id in unique_idx:
plot_experiment('TEST1', unique_id, data=data)
self.assertTrue(True)
if __name__ == '__main__':
unittest.main()