55
66from chebai_graph .preprocessing .property_encoder import OneHotEncoder , PropertyEncoder
77
8+ from . import properties as pr
89from .constants import *
9- from .properties import AtomProperty , BondProperty
1010
1111
12- class AugmentedBondProperty (BondProperty , ABC ):
13- MAIN_KEY = "edges"
14-
15- def get_property_value (self , augmented_mol : Dict ) -> List :
16- if self .MAIN_KEY not in augmented_mol :
17- raise KeyError (
18- f"Key `{ self .MAIN_KEY } ` should be present in augmented molecule dict"
19- )
20-
21- missing_keys = EDGE_LEVELS - augmented_mol [self .MAIN_KEY ].keys ()
22- if missing_keys :
23- raise KeyError (f"Missing keys { missing_keys } in augmented molecule nodes" )
24-
25- atom_molecule : Chem .Mol = augmented_mol [self .MAIN_KEY ][WITHIN_ATOMS_EDGE ]
26- if not isinstance (atom_molecule , Chem .Mol ):
27- raise TypeError (
28- f'augmented_mol["{ self .MAIN_KEY } "]["{ WITHIN_ATOMS_EDGE } "] must be an instance of rdkit.Chem.Mol'
29- )
30-
31- prop_list = [self .get_bond_value (bond ) for bond in atom_molecule .GetBonds ()]
32-
33- fg_atom_edges = augmented_mol [self .MAIN_KEY ][ATOM_FG_EDGE ]
34- fg_edges = augmented_mol [self .MAIN_KEY ][WITHIN_FG_EDGE ]
35- fg_graph_node_edges = augmented_mol [self .MAIN_KEY ][FG_GRAPHNODE_EDGE ]
36-
37- if (
38- not isinstance (fg_atom_edges , dict )
39- or not isinstance (fg_edges , dict )
40- or not isinstance (fg_graph_node_edges , dict )
41- ):
42- raise TypeError (
43- f'augmented_mol["{ self .MAIN_KEY } "](["{ ATOM_FG_EDGE } "]/["{ WITHIN_FG_EDGE } "]/["{ FG_GRAPHNODE_EDGE } "]) '
44- f"must be an instance of dict containing its properties"
45- )
46-
47- # For python 3.7+, the standard dict type preserves insertion order, and is iterated over in same order
48- # https://docs.python.org/3/whatsnew/3.7.html#summary-release-highlights
49- # https://mail.python.org/pipermail/python-dev/2017-December/151283.html
50- prop_list .extend ([self .get_bond_value (bond ) for bond in fg_atom_edges .values ()])
51- prop_list .extend ([self .get_bond_value (bond ) for bond in fg_edges .values ()])
52- prop_list .extend (
53- [self .get_bond_value (bond ) for bond in fg_graph_node_edges .values ()]
54- )
55-
56- num_directed_edges = augmented_mol [self .MAIN_KEY ][NUM_EDGES ] // 2
57- assert (
58- len (prop_list ) == num_directed_edges
59- ), f"Number of property values ({ len (prop_list )} ) should be equal to number of half the number of undirected edges i.e. must be equal to { num_directed_edges } "
60-
61- return prop_list
62-
63- @abstractmethod
64- def get_bond_value (self , bond : Chem .rdchem .Bond | Dict ):
65- pass
66-
67- def _check_modify_bond_prop_value (self , bond : Chem .rdchem .Bond | Dict , prop : str ):
68- value = self ._get_bond_prop_value (bond , prop )
69- if not value :
70- # Every atom/node should have given value
71- raise ValueError (f"'{ prop } ' is set but empty." )
72- return value
73-
74- @staticmethod
75- def _get_bond_prop_value (bond : Chem .rdchem .Bond | Dict , prop : str ):
76- if isinstance (bond , Chem .rdchem .Bond ):
77- return bond .GetProp (prop )
78- elif isinstance (bond , dict ):
79- return bond [prop ]
80- else :
81- raise TypeError ("Bond/Edge should be of type `Chem.rdchem.Bond` or `dict`." )
82-
83-
84- class AugmentedAtomProperty (AtomProperty , ABC ):
12+ # --------------------- Atom Properties -----------------------------
13+ class AugmentedAtomProperty (pr .AtomProperty , ABC ):
8514 MAIN_KEY = "nodes"
8615
8716 def get_property_value (self , augmented_mol : Dict ):
@@ -145,31 +74,23 @@ def _get_atom_prop_value(self, atom: Chem.rdchem.Atom | Dict, prop: str):
14574 )
14675
14776
148- class AtomNodeLevel (AugmentedAtomProperty ):
77+ class AugAtomNodeLevel (AugmentedAtomProperty ):
14978 def __init__ (self , encoder : Optional [PropertyEncoder ] = None ):
15079 super ().__init__ (encoder or OneHotEncoder (self ))
15180
15281 def get_atom_value (self , atom : Chem .rdchem .Atom | Dict ):
15382 return self ._check_modify_atom_prop_value (atom , NODE_LEVEL )
15483
15584
156- class AtomFunctionalGroup (AugmentedAtomProperty ):
85+ class AugAtomFunctionalGroup (AugmentedAtomProperty ):
15786 def __init__ (self , encoder : Optional [PropertyEncoder ] = None ):
15887 super ().__init__ (encoder or OneHotEncoder (self ))
15988
16089 def get_atom_value (self , atom : Chem .rdchem .Atom | Dict ):
16190 return self ._check_modify_atom_prop_value (atom , "FG" )
16291
163- def _get_atom_prop_value (self , atom : Chem .rdchem .Atom | Dict , prop : str ):
164- if isinstance (atom , Chem .rdchem .Atom ):
165- return atom .GetProp (prop )
166- elif isinstance (atom , dict ):
167- return atom [prop ]
168- else :
169- raise TypeError ("Atom/Node should be of type `Chem.rdchem.Atom` or `dict`." )
170-
17192
172- class AtomRingSize (AugmentedAtomProperty ):
93+ class AugAtomRingSize (AugmentedAtomProperty ):
17394 def __init__ (self , encoder : Optional [PropertyEncoder ] = None ):
17495 super ().__init__ (encoder or OneHotEncoder (self ))
17596
@@ -186,9 +107,150 @@ def _check_modify_atom_prop_value(self, atom: Chem.rdchem.Atom | Dict, prop: str
186107 return 0
187108
188109
189- class BondLevel (AugmentedBondProperty ):
110+ class AugNodeValueDefaulter (AugmentedAtomProperty , ABC ):
111+ def get_atom_value (self , atom : Chem .rdchem .Atom | Dict ):
112+ if isinstance (atom , Chem .rdchem .Atom ):
113+ # Delegate to superclass method for atom
114+ return super ().get_atom_value (atom )
115+ elif isinstance (atom , dict ):
116+ return 0
117+ else :
118+ raise TypeError (
119+ f"Expected Chem.rdchem.Atom or dict, got { type (atom ).__name__ } "
120+ )
121+
122+
123+ class AugAtomType (AugNodeValueDefaulter , pr .AtomType ): ...
124+
125+
126+ class AugNumAtomBonds (AugNodeValueDefaulter , pr .NumAtomBonds ): ...
127+
128+
129+ class AugAtomCharge (AugNodeValueDefaulter , pr .AtomCharge ): ...
130+
131+
132+ class AugAtomChirality (AugNodeValueDefaulter , pr .AtomChirality ): ...
133+
134+
135+ class AugAtomHybridization (AugNodeValueDefaulter , pr .AtomHybridization ): ...
136+
137+
138+ class AugAtomNumHs (AugNodeValueDefaulter , pr .AtomNumHs ): ...
139+
140+
141+ class AugAtomAromaticity (AugNodeValueDefaulter , pr .AtomAromaticity ): ...
142+
143+
144+ # --------------------- Bond Properties ------------------------------
145+ class AugmentedBondProperty (pr .BondProperty , ABC ):
146+ MAIN_KEY = "edges"
147+
148+ def get_property_value (self , augmented_mol : Dict ) -> List :
149+ if self .MAIN_KEY not in augmented_mol :
150+ raise KeyError (
151+ f"Key `{ self .MAIN_KEY } ` should be present in augmented molecule dict"
152+ )
153+
154+ missing_keys = EDGE_LEVELS - augmented_mol [self .MAIN_KEY ].keys ()
155+ if missing_keys :
156+ raise KeyError (f"Missing keys { missing_keys } in augmented molecule nodes" )
157+
158+ atom_molecule : Chem .Mol = augmented_mol [self .MAIN_KEY ][WITHIN_ATOMS_EDGE ]
159+ if not isinstance (atom_molecule , Chem .Mol ):
160+ raise TypeError (
161+ f'augmented_mol["{ self .MAIN_KEY } "]["{ WITHIN_ATOMS_EDGE } "] must be an instance of rdkit.Chem.Mol'
162+ )
163+
164+ prop_list = [self .get_bond_value (bond ) for bond in atom_molecule .GetBonds ()]
165+
166+ fg_atom_edges = augmented_mol [self .MAIN_KEY ][ATOM_FG_EDGE ]
167+ fg_edges = augmented_mol [self .MAIN_KEY ][WITHIN_FG_EDGE ]
168+ fg_graph_node_edges = augmented_mol [self .MAIN_KEY ][FG_GRAPHNODE_EDGE ]
169+
170+ if (
171+ not isinstance (fg_atom_edges , dict )
172+ or not isinstance (fg_edges , dict )
173+ or not isinstance (fg_graph_node_edges , dict )
174+ ):
175+ raise TypeError (
176+ f'augmented_mol["{ self .MAIN_KEY } "](["{ ATOM_FG_EDGE } "]/["{ WITHIN_FG_EDGE } "]/["{ FG_GRAPHNODE_EDGE } "]) '
177+ f"must be an instance of dict containing its properties"
178+ )
179+
180+ # For python 3.7+, the standard dict type preserves insertion order, and is iterated over in same order
181+ # https://docs.python.org/3/whatsnew/3.7.html#summary-release-highlights
182+ # https://mail.python.org/pipermail/python-dev/2017-December/151283.html
183+ prop_list .extend ([self .get_bond_value (bond ) for bond in fg_atom_edges .values ()])
184+ prop_list .extend ([self .get_bond_value (bond ) for bond in fg_edges .values ()])
185+ prop_list .extend (
186+ [self .get_bond_value (bond ) for bond in fg_graph_node_edges .values ()]
187+ )
188+
189+ num_directed_edges = augmented_mol [self .MAIN_KEY ][NUM_EDGES ] // 2
190+ assert (
191+ len (prop_list ) == num_directed_edges
192+ ), f"Number of property values ({ len (prop_list )} ) should be equal to number of half the number of undirected edges i.e. must be equal to { num_directed_edges } "
193+
194+ return prop_list
195+
196+ @abstractmethod
197+ def get_bond_value (self , bond : Chem .rdchem .Bond | Dict ):
198+ pass
199+
200+ def _check_modify_bond_prop_value (self , bond : Chem .rdchem .Bond | Dict , prop : str ):
201+ value = self ._get_bond_prop_value (bond , prop )
202+ if not value :
203+ # Every atom/node should have given value
204+ raise ValueError (f"'{ prop } ' is set but empty." )
205+ return value
206+
207+ @staticmethod
208+ def _get_bond_prop_value (bond : Chem .rdchem .Bond | Dict , prop : str ):
209+ if isinstance (bond , Chem .rdchem .Bond ):
210+ return bond .GetProp (prop )
211+ elif isinstance (bond , dict ):
212+ return bond [prop ]
213+ else :
214+ raise TypeError ("Bond/Edge should be of type `Chem.rdchem.Bond` or `dict`." )
215+
216+
217+ class AugBondLevel (AugmentedBondProperty ):
190218 def __init__ (self , encoder : Optional [PropertyEncoder ] = None ):
191219 super ().__init__ (encoder or OneHotEncoder (self ))
192220
193221 def get_bond_value (self , bond : Chem .rdchem .Bond | Dict ):
194222 return self ._check_modify_bond_prop_value (bond , EDGE_LEVEL )
223+
224+
225+ class AugBondValueDefaulter (AugmentedBondProperty , ABC ):
226+ def get_bond_value (self , bond : Chem .rdchem .Bond | Dict ):
227+ if isinstance (bond , Chem .rdchem .Bond ):
228+ # Delegate to superclass method for bond
229+ return super ().get_bond_value (bond )
230+ elif isinstance (bond , dict ):
231+ return 0
232+ else :
233+ raise TypeError ("Bond/Edge should be of type `Chem.rdchem.Bond` or `dict`." )
234+
235+
236+ class AugBondAromaticity (AugBondValueDefaulter , pr .BondAromaticity ): ...
237+
238+
239+ class AugBondType (AugBondValueDefaulter , pr .BondType ): ...
240+
241+
242+ class AugBondInRing (AugBondValueDefaulter , pr .BondInRing ): ...
243+
244+
245+ # --------------------- Molecular Properties ------------------------------
246+ class AugmentedMolecularProperty (pr .MolecularProperty , ABC ):
247+ def get_property_value (self , augmented_mol : Dict ) -> list :
248+ mol : Chem .Mol = augmented_mol [self .MAIN_KEY ]["atom_nodes" ]
249+ assert isinstance (mol , Chem .Mol ), "Molecule should be instance of `Chem.Mol`"
250+ return super ().get_property_value (mol )
251+
252+
253+ class AugMoleculeNumRings (AugmentedMolecularProperty , pr .MoleculeNumRings ): ...
254+
255+
256+ class AugRDKit2DNormalized (AugmentedMolecularProperty , pr .RDKit2DNormalized ): ...
0 commit comments