44
55from __future__ import absolute_import , division , print_function
66
7+ import os
78import numpy as np
9+ from rdkit .Chem import PandasTools
810
911from ..utils import logger
1012from .conformer import ConformerGen , UniMolV2Feature
@@ -29,13 +31,14 @@ def __init__(self, data=None, is_train=True, save_path=None, **params):
2931 :param save_path: (str) Path to save any necessary files, like scalers.
3032 :param params: Additional parameters for data preprocessing and model configuration.
3133 """
32- self .data = data
34+ self .raw_data = data
3335 self .is_train = is_train
3436 self .save_path = save_path
3537 self .task = params .get ('task' , None )
3638 self .target_cols = params .get ('target_cols' , None )
3739 self .multiclass_cnt = params .get ('multiclass_cnt' , None )
3840 self .ss_method = params .get ('target_normalize' , 'none' )
41+ self .conf_cache_level = params .get ('conf_cache_level' , 1 )
3942 self ._init_data (** params )
4043 self ._init_split (** params )
4144
@@ -50,7 +53,7 @@ def _init_data(self, **params):
5053 :param params: Additional parameters for data processing.
5154 :raises ValueError: If the task type is unknown.
5255 """
53- self .data = MolDataReader ().read_data (self .data , self .is_train , ** params )
56+ self .data = MolDataReader ().read_data (self .raw_data , self .is_train , ** params )
5457 self .data ['target_scaler' ] = TargetScaler (
5558 self .ss_method , self .task , self .save_path
5659 )
@@ -93,24 +96,35 @@ def _init_data(self, **params):
9396 raise ValueError ('Unknown task: {}' .format (self .task ))
9497
9598 if params .get ('model_name' , None ) == 'unimolv1' :
96- if 'atoms' in self .data and 'coordinates' in self .data :
99+ if 'mols' in self .data :
100+ no_h_list = ConformerGen (** params ).transform_mols (self .data ['mols' ])
101+ mols = None
102+ elif 'atoms' in self .data and 'coordinates' in self .data :
97103 no_h_list = ConformerGen (** params ).transform_raw (
98104 self .data ['atoms' ], self .data ['coordinates' ]
99105 )
106+ mols = None
100107 else :
101108 smiles_list = self .data ["smiles" ]
102- no_h_list = ConformerGen (** params ).transform (smiles_list )
109+ no_h_list , mols = ConformerGen (** params ).transform (smiles_list )
103110 elif params .get ('model_name' , None ) == 'unimolv2' :
104- if 'atoms' in self .data and 'coordinates' in self .data :
111+ if 'mols' in self .data :
112+ no_h_list = UniMolV2Feature (** params ).transform_mols (self .data ['mols' ])
113+ mols = None
114+ elif 'atoms' in self .data and 'coordinates' in self .data :
105115 no_h_list = UniMolV2Feature (** params ).transform_raw (
106116 self .data ['atoms' ], self .data ['coordinates' ]
107117 )
118+ mols = None
108119 else :
109120 smiles_list = self .data ["smiles" ]
110- no_h_list = UniMolV2Feature (** params ).transform (smiles_list )
121+ no_h_list , mols = UniMolV2Feature (** params ).transform (smiles_list )
111122
112123 self .data ['unimol_input' ] = no_h_list
113124
125+ if mols is not None :
126+ self .save_mol2sdf (self .data ['raw_data' ], mols , params )
127+
114128 def _init_split (self , ** params ):
115129
116130 self .split_method = params .get ('split_method' , '5fold_random' )
@@ -135,3 +149,50 @@ def _init_split(self, **params):
135149 nfolds [te_idx ] = enu
136150 self .data ['split_nfolds' ] = split_nfolds
137151 return split_nfolds
152+
153+ def save_mol2sdf (self , data , mols , params ):
154+ """
155+ Save the conformers to a SDF file.
156+
157+ :param data: DataFrame containing the raw data.
158+ :param mols: List of RDKit molecule objects.
159+ """
160+ if isinstance (self .raw_data , str ):
161+ base_name = os .path .splitext (os .path .basename (self .raw_data ))[0 ]
162+ elif isinstance (self .raw_data , list ) or isinstance (self .raw_data , np .ndarray ):
163+ # If the raw_data is a list of smiles, we can use a default name.
164+ base_name = 'unimol_conformers'
165+ else :
166+ logger .warning ('Warning: raw_data is not a path or list, cannot save sdf.' )
167+ return
168+ if params .get ('sdf_save_path' ) is None :
169+ if self .save_path is not None :
170+ params ['sdf_save_path' ] = self .save_path
171+ else :
172+ return
173+ save_path = os .path .join (params .get ('sdf_save_path' ), f"{ base_name } .sdf" )
174+ if self .conf_cache_level == 0 :
175+ logger .warning (f"conf_cache_level is 0, do not save conformers." )
176+ return
177+ elif self .conf_cache_level == 1 and os .path .exists (save_path ):
178+ logger .warning (f"conf_cache_level is 1, but { save_path } exists, so do not save conformers." )
179+ return
180+ elif self .conf_cache_level == 2 or not os .path .exists (save_path ):
181+ logger .info (f"conf_cache_level is { self .conf_cache_level } , saving conformers to { save_path } ." )
182+ else :
183+ logger .warning (f"Unknown conf_cache_level: { self .conf_cache_level } , do not saving conformers." )
184+ return
185+ sdf_result = data .copy ()
186+ sdf_result ['ROMol' ] = mols
187+ os .makedirs (os .path .dirname (save_path ), exist_ok = True )
188+ try :
189+ PandasTools .WriteSDF (
190+ sdf_result ,
191+ save_path ,
192+ properties = list (sdf_result .columns ),
193+ idName = 'RowID' ,
194+ )
195+ logger .info (f"Successfully saved sdf file to { save_path } " )
196+ except Exception as e :
197+ logger .warning (f"Failed to write sdf file: { e } " )
198+ pass
0 commit comments