-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdrawfig_venn.py
More file actions
173 lines (151 loc) · 7.8 KB
/
drawfig_venn.py
File metadata and controls
173 lines (151 loc) · 7.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
from pylab import *
import scipy.io
from os.path import exists
import mytools
import myvenn
NsampPerGridFile = 100
Nperpop = 40
stimAmp1 = [120,125,130,140,150]
gAMPA1 = [12.5, 15.0, 17.5, 20.0, 22.5]
gAMPA2 = [25.0, 30.0, 35.0, 40.0, 45.0]
gAMPA3 = [80.0, 90.0, 100.0, 110.0, 120.0, 130.0, 140.0]
nmdaAmpaRatio = [0.5, 0.333]
gGABA = [10.0, 15.0, 20.0, 25.0, 30.0, 35.0]
pvs = [0.9, 0.95]
min_gluts = [0.0]
tauNeur1s = [10.0]
tauNeur2s = [200.0, 250.0]
simIDs = list(range(0,len(stimAmp1)*len(gAMPA1)*len(gAMPA2)*len(gAMPA3)*len(nmdaAmpaRatio)*len(gGABA)*len(pvs)*len(min_gluts)*len(tauNeur1s)*len(tauNeur2s)))
iparAttrs = [x+1 for x in [3,4,5,6,7,8,9,10,11,14,15,16,17,19]]
parNames = ['stimAmp1','stimAmp2','gAMPA1', 'gAMPA2', 'gAMPA3', 'gNMDA1', 'gNMDA2', 'gNMDA3', 'gGABA', 'min_gluts', 'pvs', 'tauNeur1s', 'tauNeur2s']
if True:
close('all')
#for iax in range(0,len(axarr)):
# axarr[iax].tick_params(axis='both', which='major', labelsize=4)
# for axis in ['top','bottom','left','right']:
# axarr[iax].spines[axis].set_linewidth(0.5)
filenames_all = []
strs_all = []
inperouts_all = []
NoutputInsides_all = []
NoutputOutsides_all = []
for igrid in range(0,900):
print('Working on '+str(igrid*NsampPerGridFile)+'-'+str(igrid*NsampPerGridFile+NsampPerGridFile))
if exists('gridsearch'+str(Nperpop)+'_2pm_sep7_'+str(igrid*NsampPerGridFile)+'-'+str(igrid*NsampPerGridFile+NsampPerGridFile)+'.mat'):
A = scipy.io.loadmat('gridsearch'+str(Nperpop)+'_2pm_sep7_'+str(igrid*NsampPerGridFile)+'-'+str(igrid*NsampPerGridFile+NsampPerGridFile)+'.mat')
else:
print('gridsearch'+str(Nperpop)+'_2pm_sep7_'+str(igrid*NsampPerGridFile)+'-'+str(igrid*NsampPerGridFile+NsampPerGridFile)+'.mat does not exist')
continue
inperouts_this = []
NoutputInsides_this = []
NoutputOutsides_this = []
for isamp in range(0,NsampPerGridFile):
inperouts_thisMMN = []
NoutputInsides_thisMMN = []
NoutputOutsides_thisMMN = []
for iMMN in range(0,4):
Nbetweens = A['NbetweensUnique_all'][isamp][iMMN]
NoutputInsides = A['NoutputInsidesUnique_all'][isamp][iMMN]
NoutputOutsides = [A['NoutputOutsidesUnique_all'][isamp][iMMN][i] for i in [1,2,3,4,5]] #Exclude the first one since it's allowed to give a "deviant-like" output. Consider leaving out indices 1 and 4?
inperout = 0 if NoutputInsides == 0 else (NoutputInsides/sum(NoutputOutsides) if sum(NoutputOutsides) > 0 else 10+NoutputInsides)
inperouts_thisMMN.append(inperout)
NoutputInsides_thisMMN.append(NoutputInsides)
NoutputOutsides_thisMMN.append(sum(NoutputOutsides))
inperouts_this.append(inperouts_thisMMN[:])
NoutputInsides_this.append(NoutputInsides_thisMMN[:])
NoutputOutsides_this.append(NoutputOutsides_thisMMN[:])
inperouts_all.append(inperouts_this[:])
NoutputInsides_all.append(NoutputInsides_this[:])
NoutputOutsides_all.append(NoutputOutsides_this[:])
filenames_all = r_[filenames_all,A['filenames_all'][:]]
strs_all = r_[strs_all,A['strs_all'][:]]
inperouts_vec = []
NoutputInsides_vec = []
NoutputOutsides_vec = []
for i in range(0,len(inperouts_all)):
print('Concatenating '+str(i)+'/'+str(len(inperouts_all)))
inperouts_vec = inperouts_vec+inperouts_all[i]
NoutputInsides_vec = NoutputInsides_vec + NoutputInsides_all[i]
NoutputOutsides_vec = NoutputOutsides_vec + NoutputOutsides_all[i]
goodones = [i for i in range(0,len(inperouts_vec)) if inperouts_vec[i][0] > 0.4 and inperouts_vec[i][1] > 0.4 and inperouts_vec[i][2] > 0.4 and NoutputInsides_vec[i][0] >= 7 and NoutputInsides_vec[i][1] >= 7 and NoutputInsides_vec[i][2] >= 7]
cols_good = mytools.colorsredtolila(len(goodones),0.7)
def mystr(x):
s = str(x)
if '00000' in s:
s = s[0:s.find('00000')+1] #Doesn't work always (e.g. try mystr(480000001)), but good enough here
return s
if '99999' in s:
i = s.find('99999')
if i < 2:
return s
if s[i-1] == '.':
myi = i-2
issuccess = 0
n9s = 0
while myi >= 0:
if s[myi] != '9':
issuccess = 1
break
myi = myi-1
n9s = n9s + 1
if issuccess:
return s[0:myi]+str(int(s[myi])+1)+'0'*(n9s)
return s
if s.find('.') == -1:
return s[0:i-1]+str(int(s[i-1])+1)+'0'*(len(s)-i)
if s.find('.') > i:
return s[0:i-1]+str(int(s[i-1])+1)+'0'*(s.find('.')-i)
return s[0:i-1]+str(int(s[i-1])+1)
return s
def changeNperpopAndScale(command_str, NperpopNew, AMPA_scale = 1.0, NMDA_scale = 1.0, GABA_scale = 1.0):
splitted = command_str.split(' ')
if int(splitted[2]) != Nperpop:
print('Error: splitted[2] != Nperpop')
return ''
splitted[2] = str(NperpopNew)
splitted[4] = mystr(float(splitted[4])*AMPA_scale)
splitted[5] = mystr(float(splitted[5])*AMPA_scale)
splitted[6] = mystr(float(splitted[6])*NMDA_scale)
splitted[7] = mystr(float(splitted[7])*NMDA_scale)
splitted[8] = mystr(float(splitted[8])*GABA_scale)
return ' '.join(splitted)
region_counts = {
"0000": 0,
"1000": 0,
"0100": 0,
"0010": 0,
"0001": 0,
"1100": 0,
"1010": 0,
"1001": 0,
"0110": 0,
"0101": 0,
"0011": 0,
"1110": 0,
"1101": 0,
"1011": 0,
"0111": 0,
"1111": 0
}
labels = ["Frequency deviant", "Omission", "Duration deviant", "Inverse duration deviant"]
for M1 in [0,1]:
for M2 in [0,1]:
for M3 in [0,1]:
for M4 in [0,1]:
#region_counts[str(M1)+str(M2)+str(M3)+str(M4)] = len([strs_all[i] for i in range(0,len(inperouts_vec)) if (inperouts_vec[i][0] >= 1 and NoutputInsides_vec[i][0] >= 32 or not M1) and
# (inperouts_vec[i][1] >= 1 and NoutputInsides_vec[i][1] >= 32 or not M2) and
# (inperouts_vec[i][2] >= 1 and NoutputInsides_vec[i][2] >= 32 or not M3) and
# (inperouts_vec[i][3] >= 1 and NoutputInsides_vec[i][3] >= 32 or not M4)])
region_counts[str(M1)+str(M2)+str(M3)+str(M4)] = len([strs_all[i] for i in range(0,len(inperouts_vec)) if (inperouts_vec[i][0] >= 1 and NoutputInsides_vec[i][0] >= 32) == M1 and
(inperouts_vec[i][1] >= 1 and NoutputInsides_vec[i][1] >= 32) == M2 and
(inperouts_vec[i][2] >= 1 and NoutputInsides_vec[i][2] >= 32) == M3 and
(inperouts_vec[i][3] >= 1 and NoutputInsides_vec[i][3] >= 32) == M4])
print(str(region_counts))
vennfig,vennax = vennfromchatgpt.draw_venn_from_counts(region_counts, labels)
print('#parameters with correct detection of all protocols: '+str(len([print(strs_all[i]) for i in range(0,len(inperouts_vec)) if inperouts_vec[i][0] >= 1 and inperouts_vec[i][1] >= 1 and inperouts_vec[i][2] >= 1 and inperouts_vec[i][3] >= 1 and NoutputInsides_vec[i][0] >= 32 and NoutputInsides_vec[i][1] >= 32 and NoutputInsides_vec[i][2] >= 32 and NoutputInsides_vec[i][3] >= 32 ])))
print('#parameters with correct detection of all protocols: '+str(len([print(filenames_all[i]) for i in range(0,len(inperouts_vec)) if inperouts_vec[i][0] >= 1 and inperouts_vec[i][1] >= 1 and inperouts_vec[i][2] >= 1 and inperouts_vec[i][3] >= 1 and NoutputInsides_vec[i][0] >= 32 and NoutputInsides_vec[i][1] >= 32 and NoutputInsides_vec[i][2] >= 32 and NoutputInsides_vec[i][3] >= 32 ])))
print('Saving fig_venn.pdf')
vennax.set_position([0.1,0.45,0.8,0.5])
pos = vennax.get_position()
vennfig.text(pos.x0 + 0.03, pos.y1 - 0.155, 'A', fontsize=11)
vennfig.savefig('fig_venn.pdf')