-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmap_utils.py
More file actions
297 lines (228 loc) · 8.81 KB
/
map_utils.py
File metadata and controls
297 lines (228 loc) · 8.81 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
import mrcfile
import numpy as np
def is_map_empty(mrc_file):
# check if the map is empty
with mrcfile.open(mrc_file, permissive=True) as mrc:
assert mrc.data is not None # type: ignore[union-attr]
if np.allclose(mrc.data, 0):
return True
return False
def calc_map_ccc(input_mrc, input_pred, center=True, overlap_only=False):
"""
Calculate the Concordance Correlation Coefficient (CCC) and overlap percentage of two input MRC files.
Parameters:
input_mrc (str): Path to the MRC file.
input_pred (str): Path to the prediction MRC file.
center (bool, optional): If True, center the data. Defaults to True.
Returns:
float: The calculated CCC.
float: The overlap percentage.
"""
# Open the MRC files and copy their data
with mrcfile.open(input_mrc) as mrc:
assert mrc.data is not None # type: ignore[union-attr]
mrc_data = mrc.data.copy()
with mrcfile.open(input_pred) as mrc:
assert mrc.data is not None # type: ignore[union-attr]
pred_data = mrc.data.copy()
# mrc_data = np.where(mrc_data > 1e-8, mrc_data, 0.0)
# pred_data = np.where(pred_data > 1e-8, pred_data, 0.0)
# Determine the minimum count of non-zero values
min_count = np.min([np.count_nonzero(mrc_data), np.count_nonzero(pred_data)])
# Calculate the overlap of non-zero values
overlap = mrc_data * pred_data > 0.0
if overlap_only:
mrc_data = mrc_data[overlap]
pred_data = pred_data[overlap]
# Center the data if specified
if center:
mrc_data = mrc_data - np.mean(mrc_data)
pred_data = pred_data - np.mean(pred_data)
# Calculate the overlap percentage
overlap_percent = np.sum(overlap) / min_count
# Calculate the CCC
ccc = np.sum(mrc_data * pred_data) / np.sqrt(
np.sum(mrc_data**2) * np.sum(pred_data**2)
)
return ccc, overlap_percent
"""Compute FSC between two volumes, adapted from cryodrgn"""
import numpy as np
import torch # type: ignore[import-not-found]
from torch.fft import fftshift, ifftshift, fft2, fftn, ifftn # type: ignore[import-not-found]
def normalize(img, mean=0, std=None, std_n=None):
if std is None:
# Since std is a memory consuming process, use the first std_n samples for std determination
std = torch.std(img[:std_n, ...])
# logger.info(f"Normalized by {mean} +/- {std}")
return (img - mean) / std
def fft2_center(img):
return fftshift(fft2(fftshift(img, dim=(-1, -2))), dim=(-1, -2))
def fftn_center(img):
return fftshift(fftn(fftshift(img)))
def ifftn_center(img):
if isinstance(img, np.ndarray):
# Note: We can't just typecast a complex ndarray using torch.Tensor(array) !
img = torch.complex(torch.Tensor(img.real), torch.Tensor(img.imag))
x = ifftshift(img)
y = ifftn(x)
z = ifftshift(y)
return z
def ht2_center(img):
_img = fft2_center(img)
return _img.real - _img.imag
def htn_center(img):
_img = fftshift(fftn(fftshift(img)))
return _img.real - _img.imag
def iht2_center(img):
img = fft2_center(img)
img /= img.shape[-1] * img.shape[-2]
return img.real - img.imag
def ihtn_center(img):
img = fftshift(img)
img = fftn(img)
img = fftshift(img)
img /= torch.prod(torch.tensor(img.shape, device=img.device))
return img.real - img.imag
def symmetrize_ht(ht):
if ht.ndim == 2:
ht = ht[np.newaxis, ...]
assert ht.ndim == 3
n = ht.shape[0]
D = ht.shape[-1]
sym_ht = torch.empty((n, D + 1, D + 1), dtype=ht.dtype, device=ht.device)
sym_ht[:, 0:-1, 0:-1] = ht
assert D % 2 == 0
sym_ht[:, -1, :] = sym_ht[:, 0, :] # last row is the first row
sym_ht[:, :, -1] = sym_ht[:, :, 0] # last col is the first col
sym_ht[:, -1, -1] = sym_ht[:, 0, 0] # last corner is first corner
if n == 1:
sym_ht = sym_ht[0, ...]
return sym_ht
def calculate_fsc(vol1_f, vol2_f, output_f, Apix=1.0, plot=True):
import mrcfile
with mrcfile.open(vol1_f, permissive=True) as v1:
assert v1.data is not None # type: ignore[union-attr]
vol1 = v1.data.copy()
with mrcfile.open(vol2_f, permissive=True) as v2:
assert v2.data is not None # type: ignore[union-attr]
vol2 = v2.data.copy()
assert vol1.shape == vol2.shape
# pad if non-cubic
padding_xyz = np.max(vol1.shape) - vol1.shape
vol1 = np.pad(
vol1,
((0, padding_xyz[0]), (0, padding_xyz[1]), (0, padding_xyz[2])),
mode="constant",
)
vol2 = np.pad(
vol2,
((0, padding_xyz[0]), (0, padding_xyz[1]), (0, padding_xyz[2])),
mode="constant",
)
if vol1.shape[0] % 2 != 0:
vol1 = np.pad(vol1, ((0, 1), (0, 1), (0, 1)), mode="constant")
vol2 = np.pad(vol2, ((0, 1), (0, 1), (0, 1)), mode="constant")
vol1 = torch.from_numpy(vol1).to(torch.float32)
vol2 = torch.from_numpy(vol2).to(torch.float32)
D = vol1.shape[0]
x = np.arange(-D // 2, D // 2)
x2, x1, x0 = np.meshgrid(x, x, x, indexing="ij")
coords = np.stack((x0, x1, x2), -1)
r = (coords**2).sum(-1) ** 0.5
assert r[D // 2, D // 2, D // 2] == 0.0, r[D // 2, D // 2, D // 2]
vol1 = fftn_center(vol1)
vol2 = fftn_center(vol2)
prev_mask = np.zeros((D, D, D), dtype=bool)
fsc = [1.0]
for i in range(1, D // 2):
mask = r < i
shell = np.where(mask & np.logical_not(prev_mask))
v1 = vol1[shell]
v2 = vol2[shell]
p = np.vdot(v1, v2) / (np.vdot(v1, v1) * np.vdot(v2, v2)) ** 0.5
fsc.append(float(p.real))
prev_mask = mask
fsc = np.asarray(fsc)
x = np.arange(D // 2) / D
res = np.stack((x, fsc), 1)
if output_f:
np.savetxt(output_f, res)
else:
# logger.info(res)
pass
w = np.where(fsc < 0.5)
cutoff_05 = 1 / x[w[0][0]] * Apix if len(w) > 0 and len(w[0]) > 0 else None
if cutoff_05 is None:
cutoff_05 = 0.0
w = np.where(fsc < 0.143)
cutoff_0143 = 1 / x[w[0][0]] * Apix if len(w) > 0 and len(w[0]) > 0 else None
if cutoff_0143 is None:
cutoff_0143 = 0.0
# Visualization
if plot:
import matplotlib.pyplot as plt
# Convert spatial frequency to resolution in Angstroms
resolution = 1 / (
x * Apix + 1e-10
) # Add small epsilon to avoid division by zero
resolution[0] = np.inf # Set DC component to infinity
plt.figure(figsize=(10, 6))
plt.plot(x[1:], fsc[1:], "b-", linewidth=2, label="FSC")
plt.axhline(y=0.5, color="r", linestyle="--", alpha=0.7, label="FSC = 0.5") # type: ignore[arg-type]
plt.axhline(
y=0.143, # type: ignore
color="orange",
linestyle="--",
alpha=0.7,
label="FSC = 0.143",
)
# Mark resolution cutoffs
if cutoff_05 > 0:
cutoff_05_freq = float(1 / (cutoff_05 / Apix))
plt.axvline(x=cutoff_05_freq, color="r", linestyle=":", alpha=0.7) # type: ignore
plt.text(
float(1 / (cutoff_05 / Apix)),
0.52,
f"{cutoff_05:.1f}Å",
rotation=90,
verticalalignment="bottom",
color="r",
) # type: ignore[arg-type]
if cutoff_0143 > 0:
cutoff_0143_freq = float(1 / (cutoff_0143 / Apix))
plt.axvline(x=cutoff_0143_freq, color="orange", linestyle=":", alpha=0.7) # type: ignore[arg-type]
plt.text(
float(1 / (cutoff_0143 / Apix)),
0.15,
f"{cutoff_0143:.1f}Å",
rotation=90,
verticalalignment="bottom",
color="orange",
)
plt.xlabel("Spatial Frequency")
plt.ylabel("Fourier Shell Correlation")
plt.title("FSC Curve")
plt.grid(True, alpha=0.3)
plt.legend()
plt.ylim(-0.1, 1.1)
plt.xlim(0, x[-1])
# Add secondary x-axis for resolution
ax1 = plt.gca()
ax2 = ax1.twiny() # type: ignore[attr-defined]
# Select reasonable tick positions
freq_ticks = np.array([0.1, 0.2, 0.3, 0.4, 0.5])
freq_ticks = freq_ticks[freq_ticks <= x[-1]]
res_ticks = 1 / (freq_ticks * Apix)
ax2.set_xlim(ax1.get_xlim()) # type: ignore[attr-defined]
ax2.set_xticks(freq_ticks)
ax2.set_xticklabels([f"{r:.1f}" for r in res_ticks])
ax2.set_xlabel("Resolution (Å)")
plt.tight_layout()
# Generate plot filename based on output_f
import os
base_name = os.path.splitext(output_f)[0]
plot_filename = base_name + "_plot.png"
plt.savefig(plot_filename, dpi=300, bbox_inches="tight")
plt.close() # Close the figure to free memory
print(f"FSC plot saved to: {plot_filename}")
return x, fsc, cutoff_05, cutoff_0143