-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathmain.py
More file actions
289 lines (246 loc) · 11.5 KB
/
main.py
File metadata and controls
289 lines (246 loc) · 11.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
import os
import shutil
from data_processing.Unify_Map import Unify_Map
from data_processing.Resize_Map import Resize_Map
from predict.infer_diffusion import infer_diffem
from modeling.pdb_utils import swap_cif_occupancy_bfactor
from modeling.map_utils import segment_map
from modeling.fit_structure_chain import fit_structure_chain
from modeling.assemble_structure import assemble_structure
from ops.argparser import argparser
from ops.domain_utils import prepare_domain_input
from ops.fasta2pool import fasta2pool
from ops.fasta_searchdb import fasta_searchdb
from ops.fasta_utils import read_fasta, refine_fasta_input
from ops.io_utils import read_structure_txt
from ops.map_utils import increase_map_density
from ops.os_operation import mkdir, extract_compressed_file, copy_directory, unzip_gz
from ops.pdb_utils import cif2pdb, clean_pdb_template
def init_save_path(origin_map_path):
save_path = os.path.join(os.getcwd(), "Predict_Result")
mkdir(save_path)
map_name = os.path.split(origin_map_path)[1].replace(".mrc", "")
map_name = map_name.replace(".map", "")
map_name = map_name.replace("(", "").replace(")", "")
save_path = os.path.join(save_path, map_name)
mkdir(save_path)
return save_path, map_name
def set_up_envrionment(params):
if params["resolution"] > 20:
print(
"maps with %.2f resolution is not supported! We only support maps with resolution 0-20A!"
% params["resolution"]
)
exit()
gpu_id = params["gpu"]
if gpu_id is not None:
os.environ["CUDA_VISIBLE_DEVICES"] = gpu_id
cur_map_path = os.path.abspath(params["F"])
if cur_map_path.endswith(".gz"):
cur_map_path = unzip_gz(cur_map_path)
if params["output"] is None:
save_path, map_name = init_save_path(cur_map_path)
else:
save_path = params["output"]
map_name = "input_diffmodeler" # to avoid server same name bugs
mkdir(save_path)
try:
print("pre-compile VESPER to accelerate!")
running_dir = os.path.dirname(os.path.abspath(__file__))
os.system(f"cd {running_dir}; python -O -m compileall VESPER_CUDA")
except:
print("pre-compile VESPER failed! No impact to main scripts!")
save_path = os.path.abspath(save_path)
cur_map_path = Unify_Map(
cur_map_path, os.path.join(save_path, map_name + "_unified.mrc")
)
cur_map_path = Resize_Map(cur_map_path, os.path.join(save_path, map_name + ".mrc"))
if params["contour"] < 0:
# change contour level to 0 and increase all the density
cur_map_path = increase_map_density(
cur_map_path,
os.path.join(save_path, map_name + "_increase.mrc"),
params["contour"],
)
params["contour"] = 0
new_map_path = os.path.join(save_path, map_name + "_segment.mrc")
segment_map(cur_map_path, new_map_path, contour=0)
return save_path, new_map_path
def diffusion_trace_map(save_path, cur_map_path, params):
if params["resolution"] >= 2:
diffusion_dir = os.path.join(save_path, "infer_diffusion")
diff_trace_map = infer_diffem(cur_map_path, diffusion_dir, params)
else:
print("skip diffusion with very high resolution map %f" % params["resolution"])
diff_trace_map = cur_map_path
print("Diffusion process finished! Traced map saved here %s" % diff_trace_map)
# segment this difftrace map to save time
diff_new_trace_map = os.path.join(save_path, "diffusion.mrc")
segment_map(diff_trace_map, diff_new_trace_map, contour=0)
return diff_new_trace_map
def construct_single_chain_candidate(params, save_path):
# first build a dict from the input text configure file
single_chain_pdb_input = os.path.abspath(params["P"])
single_chain_pdb_dir = os.path.join(save_path, "single_chain_pdb")
# if os.path.exists(single_chain_pdb_dir):
# shutil.rmtree(single_chain_pdb_dir)
# delete_dir(single_chain_pdb_dir)
if not os.path.isdir(single_chain_pdb_input):
os.makedirs(single_chain_pdb_dir, exist_ok=True)
single_chain_pdb_dir = extract_compressed_file(
single_chain_pdb_input, single_chain_pdb_dir
)
else:
if os.path.exists(single_chain_pdb_dir):
shutil.rmtree(single_chain_pdb_dir)
single_chain_pdb_dir = copy_directory(
single_chain_pdb_input, single_chain_pdb_dir
)
# for every .cif files in the single_chain_pdb_dir, convert them to pdb
for file in os.listdir(single_chain_pdb_dir):
if file.endswith(".cif"):
cur_cif_path = os.path.join(single_chain_pdb_dir, file)
cur_pdb_path = os.path.join(
single_chain_pdb_dir, file.replace(".cif", ".pdb")
)
cif2pdb(cur_cif_path, cur_pdb_path)
fitting_dict = read_structure_txt(
single_chain_pdb_dir, os.path.abspath(params["M"])
)
return fitting_dict
def fix_cif_for_coot(input_cif, output_cif):
"""
Reads a CIF file and ensures:
- '_atom_site.auth_asym_id' is copied from '_atom_site.label_asym_id'.
- '_atom_site.auth_seq_id' is copied from '_atom_site.label_seq_id'.
- All atom entries align properly in the CIF format for Coot.
"""
with open(input_cif, "r") as infile:
lines = infile.readlines()
new_lines = []
in_atom_site = False
headers = []
modified_headers = False
for line in lines:
stripped_line = line.strip()
# Detect start of _atom_site loop
if stripped_line.startswith("loop_"):
in_atom_site = False # Reset detection
if stripped_line.startswith("_atom_site."):
in_atom_site = True
headers.append(stripped_line)
# Modify header to include '_atom_site.auth_asym_id' and '_atom_site.auth_seq_id'
if (
in_atom_site
and not modified_headers
and "_atom_site.label_seq_id" in stripped_line
):
if "_atom_site.auth_asym_id" not in headers:
headers.append("_atom_site.auth_asym_id")
if "_atom_site.auth_seq_id" not in headers:
headers.append("_atom_site.auth_seq_id")
modified_headers = True
continue # Skip writing this line since we'll rewrite the headers later
# Ensure atom data has correct column count
if in_atom_site and stripped_line.startswith("ATOM"):
parts = stripped_line.split()
if len(parts) == len(headers) - 2: # Missing two columns
chain_idx = headers.index("_atom_site.label_asym_id")
seq_idx = headers.index("_atom_site.label_seq_id")
label_asym_id = parts[chain_idx] # Chain ID
label_seq_id = parts[seq_idx] # Residue number
insert_idx = seq_idx + 1
parts.insert(
insert_idx, label_asym_id
) # Insert _atom_site.auth_asym_id
parts.insert(
insert_idx + 1, label_seq_id
) # Insert _atom_site.auth_seq_id
elif len(parts) != len(headers): # If still inconsistent, print warning
print(f"WARNING: Inconsistent CIF loop at line:\n {stripped_line}")
new_line = " ".join(parts) + "\n"
new_lines.append(new_line)
else:
new_lines.append(line)
# Write fixed headers at the correct place
with open(output_cif, "w") as outfile:
for line in new_lines:
if line.strip().startswith("_atom_site."):
# Write all headers once at the correct place
if headers:
outfile.write("\n".join(headers) + "\n")
headers = [] # Clear headers so they aren't written again
continue
outfile.write(line)
print(f"Reformatted CIF file saved as: {output_cif}")
if __name__ == "__main__":
params = argparser()
save_path, cur_map_path = set_up_envrionment(params)
running_dir = os.path.dirname(os.path.abspath(__file__))
os.chdir(running_dir)
if not os.path.isabs(params["model"]["path"]):
params["model"]["path"] = os.path.join(running_dir, params["model"]["path"])
if not os.path.isabs(params["db_exp_path"]):
params["db_exp_path"] = os.path.join(running_dir, params["db_exp_path"])
if not os.path.isabs(params["db_path"]):
params["db_path"] = os.path.join(running_dir, params["db_path"])
if params["mode"] == 0:
# first build a dict from the input text configure file
fitting_dict = construct_single_chain_candidate(params, save_path)
elif params["mode"] == 1:
fitting_dict = fasta2pool(params, save_path)
elif params["mode"] == 2:
fitting_dict = fasta_searchdb(params, save_path)
if params["seq_search"]:
print("sequence search finished!")
exit()
elif params["mode"] == 3:
# support fasta+template mixed mode. Will first check if template is provided, will use the provided one first. Then for remained results, do db search
fitting_dict = construct_single_chain_candidate(params, save_path)
params["fasta_path"] = os.path.abspath(params["fasta_path"])
chain_dict = read_fasta(params["fasta_path"])
refined_fasta_path = os.path.join(save_path, "refined_input.fasta")
fitting_dict = refine_fasta_input(chain_dict, fitting_dict, refined_fasta_path)
print("updated fitting_dict:", fitting_dict)
params["P"] = refined_fasta_path
additional_fitting_dict = fasta_searchdb(params, save_path)
# merge two fitting dict
for key in additional_fitting_dict:
fitting_dict[key] = additional_fitting_dict[key]
print("final fitting dict:", fitting_dict)
else:
print("mode %d is not supported!" % params["mode"])
exit()
if len(fitting_dict) == 0:
print("Empty Template candiate, DiffModeler can not run!!!")
exit()
# clean fitting dict to avoid strange pdb cause entire program fail
final_template_dir = os.path.join(save_path, "final_template_input")
final_fitting_dict = clean_pdb_template(fitting_dict, final_template_dir)
if params["domain"]:
domain_template_dir = os.path.join(save_path, "domain_template_input")
final_fitting_dict = prepare_domain_input(
final_fitting_dict, domain_template_dir, num_cpu=params["SWORD_thread"]
)
print("Domain split finished!", final_fitting_dict)
# diffusion inference
diff_trace_map = diffusion_trace_map(save_path, cur_map_path, params)
# VESPER singl-chain fitting process
fitting_dir = os.path.join(save_path, "structure_modeling")
fit_structure_chain(diff_trace_map, final_fitting_dict, fitting_dir, params)
# VESPER assembling
modeling_dir = os.path.join(save_path, "structure_assembling")
source_cif = assemble_structure(
diff_trace_map, final_fitting_dict, fitting_dir, modeling_dir, params
)
output_cif = os.path.join(save_path, "DiffModeler_alpha.cif")
shutil.copy(source_cif, output_cif)
# convert the cif format
input_cif = output_cif
output_cif = os.path.join(save_path, "DiffModeler.cif")
fix_cif_for_coot(input_cif, output_cif)
# generate a cif file to save the fitting score to b-factor field for easier visualization
# for server visualization on server
score_specific_path = os.path.join(save_path, "DiffModeler_fitscore.cif")
swap_cif_occupancy_bfactor(output_cif, score_specific_path)
print(f"Please check DiffModeler's output structure in {output_cif}")