-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathheatmap_coh.py
More file actions
90 lines (77 loc) · 2.95 KB
/
heatmap_coh.py
File metadata and controls
90 lines (77 loc) · 2.95 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
# Plots heat map of channel combinations across frequencies, with color coded coh
import scipy.io as scio
import seaborn as sns
import csv
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
import simple_GUI as sg
from tkinter import Tk
from tkinter.filedialog import askopenfilename
# Choose and open pre file
Tk().withdraw()
precoh_name = askopenfilename()
print("File: ", precoh_name)
precoh_mat = scio.loadmat(precoh_name)
# Choose and open post file
Tk().withdraw()
postcoh_name = askopenfilename()
print("File: ", postcoh_name)
postcoh_mat = scio.loadmat(postcoh_name)
# Pull coh, freq, and chan data
pre_coh = precoh_mat['coh_spect']
post_coh = postcoh_mat['coh_spect']
freq = precoh_mat['freq'][0]
num_chan = len(precoh_mat['cmb_labels'])
# Write pre data into CSV
filename = 'precohdata.csv'
i = 0
with open(filename, 'w') as csvfile:
csvwriter = csv.writer(csvfile)
csvwriter.writerow(freq)
for i in range(len(pre_coh)):
csvwriter.writerow(pre_coh[i])
# Write post data into CSV
filename = 'postcohdata.csv'
i = 0
with open(filename, 'w') as csvfile:
csvwriter = csv.writer(csvfile)
csvwriter.writerow(freq)
for i in range(len(post_coh)):
csvwriter.writerow(post_coh[i])
# Load into pandas data frame
pre_df = pd.read_csv('precohdata.csv',index_col=False)
post_df = pd.read_csv('postcohdata.csv', index_col= False)
# Create delta dataframe
delta_df = post_df.subtract(pre_df)
# Label x and y axis
y_label = np.arange(1, num_chan + 1, 1)
x_ticks = np.arange(0,61,10)
x_labels = range(0,31,5)
# Plot all three graphs
fig = plt.figure(figsize = (9,8), constrained_layout = True)
ax1 = fig.add_subplot(2, 2, 1)
ax2 = fig.add_subplot(2, 2, 2)
ax3 = fig.add_subplot(2, 2, 3)
pre_plot = sns.heatmap(data = pre_df, ax = ax1, yticklabels = y_label, cbar_kws={'label': 'Coherence'}, cmap="viridis", vmin = -0.2, vmax = 0.8)
pre_plot.set_xticks(x_ticks)
pre_plot.set_xticklabels(x_labels)
pre_plot.set(title="Pre Coherence Spectra", xlabel="Frequency (HZ)", ylabel="Channel cmb #")
pre_plot.axvline(x=7, color = 'black')
pre_plot.axvline(x=16, color = 'black')
post_plot = sns.heatmap(data = post_df, ax = ax2, yticklabels = y_label, cbar_kws={'label': 'Coherence'}, cmap="viridis", vmin = -0.2, vmax = 0.8)
post_plot.set_xticks(x_ticks)
post_plot.set_xticklabels(x_labels)
post_plot.set(title="Post Coherence Spectra", xlabel="Frequency (HZ)", ylabel="Channel cmb #")
post_plot.axvline(x=7, color = 'black')
post_plot.axvline(x=16, color = 'black')
d_plot = sns.heatmap(data = delta_df, ax = ax3, yticklabels = y_label, cbar_kws={'label': 'Coherence'}, cmap="viridis", vmin = -0.2, vmax = 0.2)
d_plot.set_xticks(x_ticks)
d_plot.set_xticklabels(x_labels)
d_plot.set(title="Change in Coherence Spectra", xlabel="Frequency (HZ)", ylabel="Channel cmb #")
d_plot.axvline(x=7, color = 'black')
d_plot.axvline(x=16, color = 'black')
plt.savefig(sg.path_name + '\\' + sg.ani_num + '_' + sg.rec_day + '_coh_heatmap.png')
plt.show()
print('done')