|
1 | 1 | from abc import ABC, abstractmethod |
| 2 | +from types import MappingProxyType |
2 | 3 |
|
3 | 4 | import rdkit.Chem as Chem |
4 | 5 |
|
@@ -49,3 +50,59 @@ def get_bond_value(self, bond: Chem.rdchem.Bond): |
49 | 50 |
|
50 | 51 | class MoleculeProperty(MolecularProperty): |
51 | 52 | """Global property of a molecule.""" |
| 53 | + |
| 54 | + |
| 55 | +class FrozenPropertyAlias(MolecularProperty, ABC): |
| 56 | + """ |
| 57 | + Wrapper base class for augmented graph properties that want to reuse existing molecular properties. |
| 58 | +
|
| 59 | + This class allows augmented graph property classes to inherit both from this wrapper and a standard |
| 60 | + molecular property (from `.properties`), enabling reuse of their encoders and index files without |
| 61 | + modifying them. |
| 62 | +
|
| 63 | + Key Features: |
| 64 | + - Prevents new tokens from being added to the encoder cache by freezing it. |
| 65 | + - Automatically aligns the property name (used for encoder/index resolution) with the inherited |
| 66 | + base property by removing the "Aug" prefix from the class name. |
| 67 | +
|
| 68 | + Usage: |
| 69 | + The derived class should: |
| 70 | + - Inherit from `FrozenPropertyAlias` **and** a valid base molecular property class. |
| 71 | + - Have a name starting with "Aug" (e.g., `AugAtomType`), which will be resolved to `AtomType`. |
| 72 | +
|
| 73 | + Example: |
| 74 | + ```python |
| 75 | + class AugAtomType(FrozenPropertyAlias, AtomType): |
| 76 | + ... |
| 77 | + ``` |
| 78 | + Note: |
| 79 | + Subclass name of this class should with prefix "Aug" for above effect to take place. |
| 80 | +
|
| 81 | + This allows `AugAtomType` to reuse the encoder, index files, and logic of `AtomType` while |
| 82 | + integrating into augmented graph pipelines. |
| 83 | + """ |
| 84 | + |
| 85 | + def __init__(self, encoder: PropertyEncoder | None = None): |
| 86 | + super().__init__(encoder) |
| 87 | + # Lock the encoder's cache to prevent adding new tokens |
| 88 | + if hasattr(self.encoder, "cache") and isinstance(self.encoder.cache, dict): |
| 89 | + self.encoder.cache = MappingProxyType(self.encoder.cache) |
| 90 | + |
| 91 | + @property |
| 92 | + def name(self): |
| 93 | + """ |
| 94 | + Unique identifier for this property, with 'Aug' prefix removed if present. |
| 95 | + This allows the encoder to reuse index files of the corresponding base property. |
| 96 | + """ |
| 97 | + class_name = self.__class__.__name__ |
| 98 | + return class_name[3:] if class_name.startswith("Aug") else class_name |
| 99 | + |
| 100 | + def on_finish(self): |
| 101 | + if ( |
| 102 | + hasattr(self.encoder, "cache") |
| 103 | + and len(self.encoder.cache) > self.encoder.index_length_start |
| 104 | + ): |
| 105 | + raise ValueError( |
| 106 | + f"{self.__class__.__name__} attempted to add new tokens to a {self.encoder.index_path}" |
| 107 | + ) |
| 108 | + super().on_finish() |
0 commit comments