Skip to content

Commit 609e0ec

Browse files
emotionorCopilot
andauthored
Yn patch0411 (#339)
* [feat] add params `save_sdf`: save conformers to sdf * [style] black --skip-string-normalization * [docs] add the annotations * [feat] save_sdf : optional [if_not_exists, always, never] * [docs] update the annotations * [feat] update save_sdf judge order * [feat] remove `save_sdf`, add `conf_cache_level` for saveing conformers to sdf files. `conf_cache_level` optional [0, 1, 2] * Apply suggestions from code review Revise several grammatical error by copilot. Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 7092bf4 commit 609e0ec

13 files changed

Lines changed: 228 additions & 71 deletions

File tree

unimol_tools/setup.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
setup(
77
name="unimol_tools",
88
version="0.1.3.post1",
9-
description=("unimol_tools is a Python package for property prediction with Uni-Mol in molecule, materials and protein."),
9+
description=(
10+
"unimol_tools is a Python package for property prediction with Uni-Mol in molecule, materials and protein."
11+
),
1012
long_description=open('README.md').read(),
1113
long_description_content_type='text/markdown',
1214
author="DP Technology",
@@ -20,16 +22,18 @@
2022
"dist",
2123
],
2224
),
23-
install_requires=["numpy<2.0.0,>=1.22.4",
24-
"pandas<2.0.0",
25-
"torch",
26-
"joblib",
27-
"rdkit",
28-
"pyyaml",
29-
"addict",
30-
"scikit-learn",
31-
"numba",
32-
"tqdm"],
25+
install_requires=[
26+
"numpy<2.0.0,>=1.22.4",
27+
"pandas<2.0.0",
28+
"torch",
29+
"joblib",
30+
"rdkit",
31+
"pyyaml",
32+
"addict",
33+
"scikit-learn",
34+
"numba",
35+
"tqdm",
36+
],
3337
python_requires=">=3.6",
3438
include_package_data=True,
3539
classifiers=[
@@ -43,4 +47,4 @@
4347
"Programming Language :: Python :: 3.10",
4448
"Topic :: Scientific/Engineering :: Artificial Intelligence",
4549
],
46-
)
50+
)

unimol_tools/unimol_tools/data/conformer.py

Lines changed: 43 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -116,16 +116,17 @@ def single_process(self, smiles):
116116
:raises ValueError: If the conformer generation method is unrecognized.
117117
"""
118118
if self.method == 'rdkit_random':
119-
atoms, coordinates = inner_smi2coords(
119+
atoms, coordinates, mol = inner_smi2coords(
120120
smiles, seed=self.seed, mode=self.mode, remove_hs=self.remove_hs
121121
)
122-
return coords2unimol(
122+
feat = coords2unimol(
123123
atoms,
124124
coordinates,
125125
self.dictionary,
126126
self.max_atoms,
127127
remove_hs=self.remove_hs,
128128
)
129+
return feat, mol
129130
else:
130131
raise ValueError(
131132
'Unknown conformer generation method: {}'.format(self.method)
@@ -146,16 +147,36 @@ def transform_raw(self, atoms_list, coordinates_list):
146147
)
147148
return inputs
148149

150+
def transform_mols(self, mols_list):
151+
inputs = []
152+
for mol in mols_list:
153+
atoms = np.array([atom.GetSymbol() for atom in mol.GetAtoms()])
154+
coordinates = mol.GetConformer().GetPositions().astype(np.float32)
155+
inputs.append(
156+
coords2unimol(
157+
atoms,
158+
coordinates,
159+
self.dictionary,
160+
self.max_atoms,
161+
remove_hs=self.remove_hs,
162+
)
163+
)
164+
return inputs
165+
149166
def transform(self, smiles_list):
150167
logger.info('Start generating conformers...')
151168
if self.multi_process:
152169
pool = Pool(processes=min(8, os.cpu_count()))
153-
inputs = [
170+
results = [
154171
item for item in tqdm(pool.imap(self.single_process, smiles_list))
155172
]
156173
pool.close()
157174
else:
158-
inputs = [self.single_process(smiles) for smiles in tqdm(smiles_list)]
175+
results = [self.single_process(smiles) for smiles in tqdm(smiles_list)]
176+
177+
inputs, mols = zip(*results)
178+
inputs = list(inputs)
179+
mols = list(mols)
159180

160181
failed_conf = [(item['src_coord'] == 0.0).all() for item in inputs]
161182
logger.info(
@@ -192,7 +213,7 @@ def transform(self, smiles_list):
192213
[smiles_list[index] for index in failed_conf_3d_indices]
193214
)
194215
)
195-
return inputs
216+
return inputs, mols
196217

197218

198219
def inner_smi2coords(smi, seed=42, mode='fast', remove_hs=True, return_mol=False):
@@ -253,9 +274,9 @@ def inner_smi2coords(smi, seed=42, mode='fast', remove_hs=True, return_mol=False
253274
assert len(atoms_no_h) == len(
254275
coordinates_no_h
255276
), "coordinates shape is not align with {}".format(smi)
256-
return atoms_no_h, coordinates_no_h
277+
return atoms_no_h, coordinates_no_h, mol
257278
else:
258-
return atoms, coordinates
279+
return atoms, coordinates, mol
259280

260281

261282
def inner_coords(atoms, coordinates, remove_hs=True):
@@ -391,7 +412,8 @@ def single_process(self, smiles):
391412
remove_hs=self.remove_hs,
392413
return_mol=True,
393414
)
394-
return mol2unimolv2(mol, self.max_atoms, remove_hs=self.remove_hs)
415+
feat = mol2unimolv2(mol, self.max_atoms, remove_hs=self.remove_hs)
416+
return feat, mol
395417
else:
396418
raise ValueError(
397419
'Unknown conformer generation method: {}'.format(self.method)
@@ -405,16 +427,26 @@ def transform_raw(self, atoms_list, coordinates_list):
405427
inputs.append(mol2unimolv2(mol, self.max_atoms, remove_hs=self.remove_hs))
406428
return inputs
407429

430+
def transform_mols(self, mols_list):
431+
inputs = []
432+
for mol in mols_list:
433+
inputs.append(mol2unimolv2(mol, self.max_atoms, remove_hs=self.remove_hs))
434+
return inputs
435+
408436
def transform(self, smiles_list):
409437
logger.info('Start generating conformers...')
410438
if self.multi_process:
411439
pool = Pool(processes=min(8, os.cpu_count()))
412-
inputs = [
440+
results = [
413441
item for item in tqdm(pool.imap(self.single_process, smiles_list))
414442
]
415443
pool.close()
416444
else:
417-
inputs = [self.single_process(smiles) for smiles in tqdm(smiles_list)]
445+
results = [self.single_process(smiles) for smiles in tqdm(smiles_list)]
446+
447+
inputs, mols = zip(*results)
448+
inputs = list(inputs)
449+
mols = list(mols)
418450

419451
failed_conf = [(item['src_coord'] == 0.0).all() for item in inputs]
420452
logger.info(
@@ -452,7 +484,7 @@ def transform(self, smiles_list):
452484
)
453485
)
454486

455-
return inputs
487+
return inputs, mols
456488

457489

458490
def create_mol_from_atoms_and_coords(atoms, coordinates):

unimol_tools/unimol_tools/data/datahub.py

Lines changed: 67 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44

55
from __future__ import absolute_import, division, print_function
66

7+
import os
78
import numpy as np
9+
from rdkit.Chem import PandasTools
810

911
from ..utils import logger
1012
from .conformer import ConformerGen, UniMolV2Feature
@@ -29,13 +31,14 @@ def __init__(self, data=None, is_train=True, save_path=None, **params):
2931
:param save_path: (str) Path to save any necessary files, like scalers.
3032
:param params: Additional parameters for data preprocessing and model configuration.
3133
"""
32-
self.data = data
34+
self.raw_data = data
3335
self.is_train = is_train
3436
self.save_path = save_path
3537
self.task = params.get('task', None)
3638
self.target_cols = params.get('target_cols', None)
3739
self.multiclass_cnt = params.get('multiclass_cnt', None)
3840
self.ss_method = params.get('target_normalize', 'none')
41+
self.conf_cache_level = params.get('conf_cache_level', 1)
3942
self._init_data(**params)
4043
self._init_split(**params)
4144

@@ -50,7 +53,7 @@ def _init_data(self, **params):
5053
:param params: Additional parameters for data processing.
5154
:raises ValueError: If the task type is unknown.
5255
"""
53-
self.data = MolDataReader().read_data(self.data, self.is_train, **params)
56+
self.data = MolDataReader().read_data(self.raw_data, self.is_train, **params)
5457
self.data['target_scaler'] = TargetScaler(
5558
self.ss_method, self.task, self.save_path
5659
)
@@ -93,24 +96,35 @@ def _init_data(self, **params):
9396
raise ValueError('Unknown task: {}'.format(self.task))
9497

9598
if params.get('model_name', None) == 'unimolv1':
96-
if 'atoms' in self.data and 'coordinates' in self.data:
99+
if 'mols' in self.data:
100+
no_h_list = ConformerGen(**params).transform_mols(self.data['mols'])
101+
mols = None
102+
elif 'atoms' in self.data and 'coordinates' in self.data:
97103
no_h_list = ConformerGen(**params).transform_raw(
98104
self.data['atoms'], self.data['coordinates']
99105
)
106+
mols = None
100107
else:
101108
smiles_list = self.data["smiles"]
102-
no_h_list = ConformerGen(**params).transform(smiles_list)
109+
no_h_list, mols = ConformerGen(**params).transform(smiles_list)
103110
elif params.get('model_name', None) == 'unimolv2':
104-
if 'atoms' in self.data and 'coordinates' in self.data:
111+
if 'mols' in self.data:
112+
no_h_list = UniMolV2Feature(**params).transform_mols(self.data['mols'])
113+
mols = None
114+
elif 'atoms' in self.data and 'coordinates' in self.data:
105115
no_h_list = UniMolV2Feature(**params).transform_raw(
106116
self.data['atoms'], self.data['coordinates']
107117
)
118+
mols = None
108119
else:
109120
smiles_list = self.data["smiles"]
110-
no_h_list = UniMolV2Feature(**params).transform(smiles_list)
121+
no_h_list, mols = UniMolV2Feature(**params).transform(smiles_list)
111122

112123
self.data['unimol_input'] = no_h_list
113124

125+
if mols is not None:
126+
self.save_mol2sdf(self.data['raw_data'], mols, params)
127+
114128
def _init_split(self, **params):
115129

116130
self.split_method = params.get('split_method', '5fold_random')
@@ -135,3 +149,50 @@ def _init_split(self, **params):
135149
nfolds[te_idx] = enu
136150
self.data['split_nfolds'] = split_nfolds
137151
return split_nfolds
152+
153+
def save_mol2sdf(self, data, mols, params):
154+
"""
155+
Save the conformers to a SDF file.
156+
157+
:param data: DataFrame containing the raw data.
158+
:param mols: List of RDKit molecule objects.
159+
"""
160+
if isinstance(self.raw_data, str):
161+
base_name = os.path.splitext(os.path.basename(self.raw_data))[0]
162+
elif isinstance(self.raw_data, list) or isinstance(self.raw_data, np.ndarray):
163+
# If the raw_data is a list of smiles, we can use a default name.
164+
base_name = 'unimol_conformers'
165+
else:
166+
logger.warning('Warning: raw_data is not a path or list, cannot save sdf.')
167+
return
168+
if params.get('sdf_save_path') is None:
169+
if self.save_path is not None:
170+
params['sdf_save_path'] = self.save_path
171+
else:
172+
return
173+
save_path = os.path.join(params.get('sdf_save_path'), f"{base_name}.sdf")
174+
if self.conf_cache_level == 0:
175+
logger.warning(f"conf_cache_level is 0, do not save conformers.")
176+
return
177+
elif self.conf_cache_level == 1 and os.path.exists(save_path):
178+
logger.warning(f"conf_cache_level is 1, but {save_path} exists, so do not save conformers.")
179+
return
180+
elif self.conf_cache_level == 2 or not os.path.exists(save_path):
181+
logger.info(f"conf_cache_level is {self.conf_cache_level}, saving conformers to {save_path}.")
182+
else:
183+
logger.warning(f"Unknown conf_cache_level: {self.conf_cache_level}, do not saving conformers.")
184+
return
185+
sdf_result = data.copy()
186+
sdf_result['ROMol'] = mols
187+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
188+
try:
189+
PandasTools.WriteSDF(
190+
sdf_result,
191+
save_path,
192+
properties=list(sdf_result.columns),
193+
idName='RowID',
194+
)
195+
logger.info(f"Successfully saved sdf file to {save_path}")
196+
except Exception as e:
197+
logger.warning(f"Failed to write sdf file: {e}")
198+
pass

unimol_tools/unimol_tools/data/datareader.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import numpy as np
1111
import pandas as pd
1212
from rdkit import Chem
13+
from rdkit.Chem import PandasTools
1314
from rdkit.Chem.Scaffolds import MurckoScaffold
1415

1516
from ..utils import logger
@@ -49,7 +50,13 @@ def read_data(self, data=None, is_train=True, **params):
4950
if isinstance(data, str):
5051
# load from file
5152
self.data_path = data
52-
data = pd.read_csv(self.data_path)
53+
if data.endswith('.sdf'):
54+
# load sdf file
55+
data = PandasTools.LoadSDF(data)
56+
elif data.endswith('.csv'):
57+
data = pd.read_csv(self.data_path)
58+
else:
59+
raise ValueError('Unknown file type: {}'.format(data))
5360
elif isinstance(data, dict):
5461
# load from dict
5562
if 'target' in data:
@@ -137,6 +144,9 @@ def read_data(self, data=None, is_train=True, **params):
137144
dd['atoms'] = data['atoms'].tolist()
138145
dd['coordinates'] = data['coordinates'].tolist()
139146

147+
if 'ROMol' in data.columns:
148+
dd['mols'] = data['ROMol'].tolist()
149+
140150
return dd
141151

142152
def check_smiles(self, smi, is_train, smi_strict):

unimol_tools/unimol_tools/data/datascaler.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,20 @@
99
import joblib
1010
import numpy as np
1111
from scipy.stats import kurtosis, skew
12-
from sklearn.preprocessing import (FunctionTransformer, MaxAbsScaler,
13-
MinMaxScaler, Normalizer, PowerTransformer,
14-
QuantileTransformer, RobustScaler,
15-
StandardScaler)
12+
from sklearn.preprocessing import (
13+
FunctionTransformer,
14+
MaxAbsScaler,
15+
MinMaxScaler,
16+
Normalizer,
17+
PowerTransformer,
18+
QuantileTransformer,
19+
RobustScaler,
20+
StandardScaler,
21+
)
1622

1723
from ..utils import logger
1824

25+
1926
class TargetScaler(object):
2027
'''
2128
A class to scale the target.
@@ -80,7 +87,9 @@ def fit(self, target, dump_dir):
8087
elif self.ss_method == 'auto':
8188
if self.task == 'regression':
8289
if self.is_skewed(target):
83-
self.scaler = FunctionTransformer(func=np.log1p, inverse_func=np.expm1)
90+
self.scaler = FunctionTransformer(
91+
func=np.log1p, inverse_func=np.expm1
92+
)
8493
logger.info('Auto select robust transformer.')
8594
else:
8695
self.scaler = StandardScaler()
@@ -90,7 +99,9 @@ def fit(self, target, dump_dir):
9099
target = np.ma.masked_invalid(target) # mask NaN value
91100
for i in range(target.shape[1]):
92101
if self.is_skewed(target[:, i]):
93-
self.scaler.append(FunctionTransformer(func=np.log1p, inverse_func=np.expm1))
102+
self.scaler.append(
103+
FunctionTransformer(func=np.log1p, inverse_func=np.expm1)
104+
)
94105
logger.info('Auto select robust transformer.')
95106
else:
96107
self.scaler.append(StandardScaler())

unimol_tools/unimol_tools/models/nnmodel.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@
1414
from torch.utils.data import Dataset
1515

1616
from ..utils import logger
17-
from .loss import (FocalLossWithLogits, GHMC_Loss, MAEwithNan,
18-
myCrossEntropyLoss)
17+
from .loss import FocalLossWithLogits, GHMC_Loss, MAEwithNan, myCrossEntropyLoss
1918
from .unimol import UniMolModel
2019
from .unimolv2 import UniMolV2Model
2120

unimol_tools/unimol_tools/models/unimol.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from __future__ import absolute_import, division, print_function
66

77
import os
8+
89
# import argparse
910
import pathlib
1011

0 commit comments

Comments
 (0)