Skip to content

Commit ed6787d

Browse files
committed
add wrapper to use mol prop for augmented graph
1 parent 33d60a9 commit ed6787d

2 files changed

Lines changed: 62 additions & 3 deletions

File tree

chebai_graph/preprocessing/properties/augmented_properties.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1-
from abc import ABC, abstractmethod
1+
from abc import ABC
22
from typing import Dict, List, Optional
33

44
from rdkit import Chem
55

66
from chebai_graph.preprocessing.property_encoder import OneHotEncoder, PropertyEncoder
77

88
from . import properties as pr
9+
from .base import FrozenPropertyAlias
910
from .constants import *
1011

1112

@@ -103,7 +104,8 @@ def _check_modify_atom_prop_value(self, atom: Chem.rdchem.Atom | Dict, prop: str
103104
return 0
104105

105106

106-
class AugNodeValueDefaulter(AugmentedAtomProperty, ABC):
107+
class AugNodeValueDefaulter(AugmentedAtomProperty, FrozenPropertyAlias, ABC):
108+
107109
def get_atom_value(self, atom: Chem.rdchem.Atom | Dict):
108110
if isinstance(atom, Chem.rdchem.Atom):
109111
# Delegate to superclass method for atom
@@ -245,7 +247,7 @@ def get_bond_value(self, bond: Chem.rdchem.Bond | Dict):
245247
return self._check_modify_bond_prop_value(bond, EDGE_LEVEL)
246248

247249

248-
class AugBondValueDefaulter(AugmentedBondProperty, ABC):
250+
class AugBondValueDefaulter(AugmentedBondProperty, FrozenPropertyAlias, ABC):
249251
def get_bond_value(self, bond: Chem.rdchem.Bond | Dict):
250252
if isinstance(bond, Chem.rdchem.Bond):
251253
# Delegate to superclass method for bond

chebai_graph/preprocessing/properties/base.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from abc import ABC, abstractmethod
2+
from types import MappingProxyType
23

34
import rdkit.Chem as Chem
45

@@ -49,3 +50,59 @@ def get_bond_value(self, bond: Chem.rdchem.Bond):
4950

5051
class MoleculeProperty(MolecularProperty):
5152
"""Global property of a molecule."""
53+
54+
55+
class FrozenPropertyAlias(MolecularProperty, ABC):
56+
"""
57+
Wrapper base class for augmented graph properties that want to reuse existing molecular properties.
58+
59+
This class allows augmented graph property classes to inherit both from this wrapper and a standard
60+
molecular property (from `.properties`), enabling reuse of their encoders and index files without
61+
modifying them.
62+
63+
Key Features:
64+
- Prevents new tokens from being added to the encoder cache by freezing it.
65+
- Automatically aligns the property name (used for encoder/index resolution) with the inherited
66+
base property by removing the "Aug" prefix from the class name.
67+
68+
Usage:
69+
The derived class should:
70+
- Inherit from `FrozenPropertyAlias` **and** a valid base molecular property class.
71+
- Have a name starting with "Aug" (e.g., `AugAtomType`), which will be resolved to `AtomType`.
72+
73+
Example:
74+
```python
75+
class AugAtomType(FrozenPropertyAlias, AtomType):
76+
...
77+
```
78+
Note:
79+
Subclass name of this class should with prefix "Aug" for above effect to take place.
80+
81+
This allows `AugAtomType` to reuse the encoder, index files, and logic of `AtomType` while
82+
integrating into augmented graph pipelines.
83+
"""
84+
85+
def __init__(self, encoder: PropertyEncoder | None = None):
86+
super().__init__(encoder)
87+
# Lock the encoder's cache to prevent adding new tokens
88+
if hasattr(self.encoder, "cache") and isinstance(self.encoder.cache, dict):
89+
self.encoder.cache = MappingProxyType(self.encoder.cache)
90+
91+
@property
92+
def name(self):
93+
"""
94+
Unique identifier for this property, with 'Aug' prefix removed if present.
95+
This allows the encoder to reuse index files of the corresponding base property.
96+
"""
97+
class_name = self.__class__.__name__
98+
return class_name[3:] if class_name.startswith("Aug") else class_name
99+
100+
def on_finish(self):
101+
if (
102+
hasattr(self.encoder, "cache")
103+
and len(self.encoder.cache) > self.encoder.index_length_start
104+
):
105+
raise ValueError(
106+
f"{self.__class__.__name__} attempted to add new tokens to a {self.encoder.index_path}"
107+
)
108+
super().on_finish()

0 commit comments

Comments
 (0)