Skip to content

Commit 729fe09

Browse files
committed
add usual prop for augmented graph with default value
1 parent cd4ae43 commit 729fe09

5 files changed

Lines changed: 198 additions & 97 deletions

File tree

chebai_graph/preprocessing/properties/__init__.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,22 @@
2121
)
2222

2323
from .augmented_properties import (
24-
AtomNodeLevel,
25-
AtomFunctionalGroup,
26-
AtomRingSize,
27-
BondLevel,
24+
AugAtomNodeLevel,
25+
AugAtomFunctionalGroup,
26+
AugAtomRingSize,
27+
AugBondLevel,
28+
AugAtomType,
29+
AugNumAtomBonds,
30+
AugAtomCharge,
31+
AugAtomChirality,
32+
AugAtomHybridization,
33+
AugAtomNumHs,
34+
AugAtomAromaticity,
35+
AugBondAromaticity,
36+
AugBondType,
37+
AugBondInRing,
38+
AugMoleculeNumRings,
39+
AugRDKit2DNormalized,
2840
)
2941

3042
# isort: on
@@ -46,8 +58,20 @@
4658
"MoleculeNumRings",
4759
"RDKit2DNormalized",
4860
# -------- Augmented Molecular Properties --------
49-
"AtomNodeLevel",
50-
"AtomFunctionalGroup",
51-
"AtomRingSize",
52-
"BondLevel",
61+
"AugAtomNodeLevel",
62+
"AugAtomFunctionalGroup",
63+
"AugAtomRingSize",
64+
"AugBondLevel",
65+
"AugAtomType",
66+
"AugNumAtomBonds",
67+
"AugAtomCharge",
68+
"AugAtomChirality",
69+
"AugAtomHybridization",
70+
"AugAtomNumHs",
71+
"AugAtomAromaticity",
72+
"AugBondAromaticity",
73+
"AugBondType",
74+
"AugBondInRing",
75+
"AugMoleculeNumRings",
76+
"AugRDKit2DNormalized",
5377
]

chebai_graph/preprocessing/properties/augmented_properties.py

Lines changed: 148 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -5,83 +5,12 @@
55

66
from chebai_graph.preprocessing.property_encoder import OneHotEncoder, PropertyEncoder
77

8+
from . import properties as pr
89
from .constants import *
9-
from .properties import AtomProperty, BondProperty
1010

1111

12-
class AugmentedBondProperty(BondProperty, ABC):
13-
MAIN_KEY = "edges"
14-
15-
def get_property_value(self, augmented_mol: Dict) -> List:
16-
if self.MAIN_KEY not in augmented_mol:
17-
raise KeyError(
18-
f"Key `{self.MAIN_KEY}` should be present in augmented molecule dict"
19-
)
20-
21-
missing_keys = EDGE_LEVELS - augmented_mol[self.MAIN_KEY].keys()
22-
if missing_keys:
23-
raise KeyError(f"Missing keys {missing_keys} in augmented molecule nodes")
24-
25-
atom_molecule: Chem.Mol = augmented_mol[self.MAIN_KEY][WITHIN_ATOMS_EDGE]
26-
if not isinstance(atom_molecule, Chem.Mol):
27-
raise TypeError(
28-
f'augmented_mol["{self.MAIN_KEY}"]["{WITHIN_ATOMS_EDGE}"] must be an instance of rdkit.Chem.Mol'
29-
)
30-
31-
prop_list = [self.get_bond_value(bond) for bond in atom_molecule.GetBonds()]
32-
33-
fg_atom_edges = augmented_mol[self.MAIN_KEY][ATOM_FG_EDGE]
34-
fg_edges = augmented_mol[self.MAIN_KEY][WITHIN_FG_EDGE]
35-
fg_graph_node_edges = augmented_mol[self.MAIN_KEY][FG_GRAPHNODE_EDGE]
36-
37-
if (
38-
not isinstance(fg_atom_edges, dict)
39-
or not isinstance(fg_edges, dict)
40-
or not isinstance(fg_graph_node_edges, dict)
41-
):
42-
raise TypeError(
43-
f'augmented_mol["{self.MAIN_KEY}"](["{ATOM_FG_EDGE}"]/["{WITHIN_FG_EDGE}"]/["{FG_GRAPHNODE_EDGE}"]) '
44-
f"must be an instance of dict containing its properties"
45-
)
46-
47-
# For python 3.7+, the standard dict type preserves insertion order, and is iterated over in same order
48-
# https://docs.python.org/3/whatsnew/3.7.html#summary-release-highlights
49-
# https://mail.python.org/pipermail/python-dev/2017-December/151283.html
50-
prop_list.extend([self.get_bond_value(bond) for bond in fg_atom_edges.values()])
51-
prop_list.extend([self.get_bond_value(bond) for bond in fg_edges.values()])
52-
prop_list.extend(
53-
[self.get_bond_value(bond) for bond in fg_graph_node_edges.values()]
54-
)
55-
56-
num_directed_edges = augmented_mol[self.MAIN_KEY][NUM_EDGES] // 2
57-
assert (
58-
len(prop_list) == num_directed_edges
59-
), f"Number of property values ({len(prop_list)}) should be equal to number of half the number of undirected edges i.e. must be equal to {num_directed_edges} "
60-
61-
return prop_list
62-
63-
@abstractmethod
64-
def get_bond_value(self, bond: Chem.rdchem.Bond | Dict):
65-
pass
66-
67-
def _check_modify_bond_prop_value(self, bond: Chem.rdchem.Bond | Dict, prop: str):
68-
value = self._get_bond_prop_value(bond, prop)
69-
if not value:
70-
# Every atom/node should have given value
71-
raise ValueError(f"'{prop}' is set but empty.")
72-
return value
73-
74-
@staticmethod
75-
def _get_bond_prop_value(bond: Chem.rdchem.Bond | Dict, prop: str):
76-
if isinstance(bond, Chem.rdchem.Bond):
77-
return bond.GetProp(prop)
78-
elif isinstance(bond, dict):
79-
return bond[prop]
80-
else:
81-
raise TypeError("Bond/Edge should be of type `Chem.rdchem.Bond` or `dict`.")
82-
83-
84-
class AugmentedAtomProperty(AtomProperty, ABC):
12+
# --------------------- Atom Properties -----------------------------
13+
class AugmentedAtomProperty(pr.AtomProperty, ABC):
8514
MAIN_KEY = "nodes"
8615

8716
def get_property_value(self, augmented_mol: Dict):
@@ -145,31 +74,23 @@ def _get_atom_prop_value(self, atom: Chem.rdchem.Atom | Dict, prop: str):
14574
)
14675

14776

148-
class AtomNodeLevel(AugmentedAtomProperty):
77+
class AugAtomNodeLevel(AugmentedAtomProperty):
14978
def __init__(self, encoder: Optional[PropertyEncoder] = None):
15079
super().__init__(encoder or OneHotEncoder(self))
15180

15281
def get_atom_value(self, atom: Chem.rdchem.Atom | Dict):
15382
return self._check_modify_atom_prop_value(atom, NODE_LEVEL)
15483

15584

156-
class AtomFunctionalGroup(AugmentedAtomProperty):
85+
class AugAtomFunctionalGroup(AugmentedAtomProperty):
15786
def __init__(self, encoder: Optional[PropertyEncoder] = None):
15887
super().__init__(encoder or OneHotEncoder(self))
15988

16089
def get_atom_value(self, atom: Chem.rdchem.Atom | Dict):
16190
return self._check_modify_atom_prop_value(atom, "FG")
16291

163-
def _get_atom_prop_value(self, atom: Chem.rdchem.Atom | Dict, prop: str):
164-
if isinstance(atom, Chem.rdchem.Atom):
165-
return atom.GetProp(prop)
166-
elif isinstance(atom, dict):
167-
return atom[prop]
168-
else:
169-
raise TypeError("Atom/Node should be of type `Chem.rdchem.Atom` or `dict`.")
170-
17192

172-
class AtomRingSize(AugmentedAtomProperty):
93+
class AugAtomRingSize(AugmentedAtomProperty):
17394
def __init__(self, encoder: Optional[PropertyEncoder] = None):
17495
super().__init__(encoder or OneHotEncoder(self))
17596

@@ -186,9 +107,150 @@ def _check_modify_atom_prop_value(self, atom: Chem.rdchem.Atom | Dict, prop: str
186107
return 0
187108

188109

189-
class BondLevel(AugmentedBondProperty):
110+
class AugNodeValueDefaulter(AugmentedAtomProperty, ABC):
111+
def get_atom_value(self, atom: Chem.rdchem.Atom | Dict):
112+
if isinstance(atom, Chem.rdchem.Atom):
113+
# Delegate to superclass method for atom
114+
return super().get_atom_value(atom)
115+
elif isinstance(atom, dict):
116+
return 0
117+
else:
118+
raise TypeError(
119+
f"Expected Chem.rdchem.Atom or dict, got {type(atom).__name__}"
120+
)
121+
122+
123+
class AugAtomType(AugNodeValueDefaulter, pr.AtomType): ...
124+
125+
126+
class AugNumAtomBonds(AugNodeValueDefaulter, pr.NumAtomBonds): ...
127+
128+
129+
class AugAtomCharge(AugNodeValueDefaulter, pr.AtomCharge): ...
130+
131+
132+
class AugAtomChirality(AugNodeValueDefaulter, pr.AtomChirality): ...
133+
134+
135+
class AugAtomHybridization(AugNodeValueDefaulter, pr.AtomHybridization): ...
136+
137+
138+
class AugAtomNumHs(AugNodeValueDefaulter, pr.AtomNumHs): ...
139+
140+
141+
class AugAtomAromaticity(AugNodeValueDefaulter, pr.AtomAromaticity): ...
142+
143+
144+
# --------------------- Bond Properties ------------------------------
145+
class AugmentedBondProperty(pr.BondProperty, ABC):
146+
MAIN_KEY = "edges"
147+
148+
def get_property_value(self, augmented_mol: Dict) -> List:
149+
if self.MAIN_KEY not in augmented_mol:
150+
raise KeyError(
151+
f"Key `{self.MAIN_KEY}` should be present in augmented molecule dict"
152+
)
153+
154+
missing_keys = EDGE_LEVELS - augmented_mol[self.MAIN_KEY].keys()
155+
if missing_keys:
156+
raise KeyError(f"Missing keys {missing_keys} in augmented molecule nodes")
157+
158+
atom_molecule: Chem.Mol = augmented_mol[self.MAIN_KEY][WITHIN_ATOMS_EDGE]
159+
if not isinstance(atom_molecule, Chem.Mol):
160+
raise TypeError(
161+
f'augmented_mol["{self.MAIN_KEY}"]["{WITHIN_ATOMS_EDGE}"] must be an instance of rdkit.Chem.Mol'
162+
)
163+
164+
prop_list = [self.get_bond_value(bond) for bond in atom_molecule.GetBonds()]
165+
166+
fg_atom_edges = augmented_mol[self.MAIN_KEY][ATOM_FG_EDGE]
167+
fg_edges = augmented_mol[self.MAIN_KEY][WITHIN_FG_EDGE]
168+
fg_graph_node_edges = augmented_mol[self.MAIN_KEY][FG_GRAPHNODE_EDGE]
169+
170+
if (
171+
not isinstance(fg_atom_edges, dict)
172+
or not isinstance(fg_edges, dict)
173+
or not isinstance(fg_graph_node_edges, dict)
174+
):
175+
raise TypeError(
176+
f'augmented_mol["{self.MAIN_KEY}"](["{ATOM_FG_EDGE}"]/["{WITHIN_FG_EDGE}"]/["{FG_GRAPHNODE_EDGE}"]) '
177+
f"must be an instance of dict containing its properties"
178+
)
179+
180+
# For python 3.7+, the standard dict type preserves insertion order, and is iterated over in same order
181+
# https://docs.python.org/3/whatsnew/3.7.html#summary-release-highlights
182+
# https://mail.python.org/pipermail/python-dev/2017-December/151283.html
183+
prop_list.extend([self.get_bond_value(bond) for bond in fg_atom_edges.values()])
184+
prop_list.extend([self.get_bond_value(bond) for bond in fg_edges.values()])
185+
prop_list.extend(
186+
[self.get_bond_value(bond) for bond in fg_graph_node_edges.values()]
187+
)
188+
189+
num_directed_edges = augmented_mol[self.MAIN_KEY][NUM_EDGES] // 2
190+
assert (
191+
len(prop_list) == num_directed_edges
192+
), f"Number of property values ({len(prop_list)}) should be equal to number of half the number of undirected edges i.e. must be equal to {num_directed_edges} "
193+
194+
return prop_list
195+
196+
@abstractmethod
197+
def get_bond_value(self, bond: Chem.rdchem.Bond | Dict):
198+
pass
199+
200+
def _check_modify_bond_prop_value(self, bond: Chem.rdchem.Bond | Dict, prop: str):
201+
value = self._get_bond_prop_value(bond, prop)
202+
if not value:
203+
# Every atom/node should have given value
204+
raise ValueError(f"'{prop}' is set but empty.")
205+
return value
206+
207+
@staticmethod
208+
def _get_bond_prop_value(bond: Chem.rdchem.Bond | Dict, prop: str):
209+
if isinstance(bond, Chem.rdchem.Bond):
210+
return bond.GetProp(prop)
211+
elif isinstance(bond, dict):
212+
return bond[prop]
213+
else:
214+
raise TypeError("Bond/Edge should be of type `Chem.rdchem.Bond` or `dict`.")
215+
216+
217+
class AugBondLevel(AugmentedBondProperty):
190218
def __init__(self, encoder: Optional[PropertyEncoder] = None):
191219
super().__init__(encoder or OneHotEncoder(self))
192220

193221
def get_bond_value(self, bond: Chem.rdchem.Bond | Dict):
194222
return self._check_modify_bond_prop_value(bond, EDGE_LEVEL)
223+
224+
225+
class AugBondValueDefaulter(AugmentedBondProperty, ABC):
226+
def get_bond_value(self, bond: Chem.rdchem.Bond | Dict):
227+
if isinstance(bond, Chem.rdchem.Bond):
228+
# Delegate to superclass method for bond
229+
return super().get_bond_value(bond)
230+
elif isinstance(bond, dict):
231+
return 0
232+
else:
233+
raise TypeError("Bond/Edge should be of type `Chem.rdchem.Bond` or `dict`.")
234+
235+
236+
class AugBondAromaticity(AugBondValueDefaulter, pr.BondAromaticity): ...
237+
238+
239+
class AugBondType(AugBondValueDefaulter, pr.BondType): ...
240+
241+
242+
class AugBondInRing(AugBondValueDefaulter, pr.BondInRing): ...
243+
244+
245+
# --------------------- Molecular Properties ------------------------------
246+
class AugmentedMolecularProperty(pr.MolecularProperty, ABC):
247+
def get_property_value(self, augmented_mol: Dict) -> list:
248+
mol: Chem.Mol = augmented_mol[self.MAIN_KEY]["atom_nodes"]
249+
assert isinstance(mol, Chem.Mol), "Molecule should be instance of `Chem.Mol`"
250+
return super().get_property_value(mol)
251+
252+
253+
class AugMoleculeNumRings(AugmentedMolecularProperty, pr.MoleculeNumRings): ...
254+
255+
256+
class AugRDKit2DNormalized(AugmentedMolecularProperty, pr.RDKit2DNormalized): ...

chebai_graph/preprocessing/properties/properties.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ def on_finish(self):
3232
def __str__(self):
3333
return self.name
3434

35-
def get_property_value(self, mol: Chem.rdchem.Mol | Dict):
36-
raise NotImplementedError
35+
@abstractmethod
36+
def get_property_value(self, mol: Chem.rdchem.Mol | Dict): ...
3737

3838

3939
class AtomProperty(MolecularProperty, ABC):

chebai_graph/preprocessing/reader/augmented_reader.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,6 @@ def _set_ring_fg_prop(self, connected_atoms, fg_nodes):
396396
ring_size = len(connected_atoms)
397397
fg_nodes[self._num_of_nodes] = {
398398
NODE_LEVEL: FG_NODE_LEVEL,
399-
# E.g., Fused Ring has size "5-6", indicating size of each connected ring in fused ring
400399
"FG": f"RING_{ring_size}",
401400
"RING": ring_size,
402401
}
@@ -405,6 +404,8 @@ def _set_ring_fg_prop(self, connected_atoms, fg_nodes):
405404
ring_prop = atom.GetProp("RING")
406405
if not ring_prop:
407406
raise ValueError("Atom does not have a ring size set")
407+
# TODO: discuss the case, should it be max or average
408+
# An atom belonging to multiple rings in fused Ring has size "5-6", indicating size of each ring
408409
max_ring_size = max(list(map(int, ring_prop.split("-"))))
409410
atom.SetProp("FG", f"RING_{max_ring_size}")
410411

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
class_path: chebai_graph.preprocessing.datasets.ChEBI50GraphFGAugmentorReader
2+
init_args:
3+
properties:
4+
- chebai_graph.preprocessing.properties.AugAtomType
5+
- chebai_graph.preprocessing.properties.AugNumAtomBonds
6+
- chebai_graph.preprocessing.properties.AugAtomCharge
7+
- chebai_graph.preprocessing.properties.AugAtomAromaticity
8+
- chebai_graph.preprocessing.properties.AugAtomHybridization
9+
- chebai_graph.preprocessing.properties.AugAtomNumHs
10+
- chebai_graph.preprocessing.properties.AugBondType
11+
- chebai_graph.preprocessing.properties.AugBondInRing
12+
- chebai_graph.preprocessing.properties.AugBondAromaticity
13+
#- chebai_graph.preprocessing.properties.AugMoleculeNumRings
14+
- chebai_graph.preprocessing.properties.AugRDKit2DNormalized

0 commit comments

Comments
 (0)