Skip to content

Commit d5ca17c

Browse files
authored
Merge pull request #15 from ChEB-AI/fix/optimize_index_encoder
Optimize Index Encoder for constant time search
2 parents b2a5c1d + 775c188 commit d5ca17c

File tree

1 file changed

+32
-12
lines changed

1 file changed

+32
-12
lines changed

chebai_graph/preprocessing/property_encoder.py

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
import abc
22
import os
3-
import torch
43
from typing import Optional
54

5+
import torch
6+
import sys
7+
from itertools import islice
8+
import inspect
9+
610

711
class PropertyEncoder(abc.ABC):
812
def __init__(self, property, **kwargs):
@@ -36,11 +40,13 @@ class IndexEncoder(PropertyEncoder):
3640
def __init__(self, property, indices_dir=None, **kwargs):
3741
super().__init__(property, **kwargs)
3842
if indices_dir is None:
39-
indices_dir = os.path.dirname(__file__)
43+
indices_dir = os.path.dirname(inspect.getfile(self.__class__))
4044
self.dirname = indices_dir
4145
# load already existing cache
4246
with open(self.index_path, "r") as pk:
43-
self.cache = [x.strip() for x in pk]
47+
self.cache: dict[str, int] = {
48+
token.strip(): idx for idx, token in enumerate(pk)
49+
}
4450
self.index_length_start = len(self.cache)
4551
self.offset = 0
4652

@@ -64,19 +70,33 @@ def index_path(self):
6470

6571
def on_finish(self):
6672
"""Save cache"""
67-
with open(self.index_path, "w") as pk:
68-
new_length = len(self.cache) - self.index_length_start
69-
pk.writelines([f"{c}\n" for c in self.cache])
70-
print(
71-
f"saved index of property {self.property.name} to {self.index_path}, "
72-
f"index length: {len(self.cache)} (new: {new_length})"
73-
)
73+
total_tokens = len(self.cache)
74+
if total_tokens > self.index_length_start:
75+
print("New tokens added to the cache, Saving them to index token file.....")
76+
77+
assert sys.version_info >= (
78+
3,
79+
7,
80+
), "This code requires Python 3.7 or higher."
81+
# For python 3.7+, the standard dict type preserves insertion order, and is iterated over in same order
82+
# https://docs.python.org/3/whatsnew/3.7.html#summary-release-highlights
83+
# https://mail.python.org/pipermail/python-dev/2017-December/151283.html
84+
new_tokens = list(islice(self.cache, self.index_length_start, total_tokens))
85+
86+
with open(self.index_path, "a") as pk:
87+
pk.writelines([f"{c}\n" for c in new_tokens])
88+
print(
89+
f"New {len(new_tokens)} tokens append to index of property {self.property.name} to {self.index_path}..."
90+
)
91+
print(
92+
f"Now, the total length of the index of property {self.property.name} is {total_tokens}"
93+
)
7494

7595
def encode(self, token):
7696
"""Returns a unique number for each token, automatically adds new tokens to the cache."""
7797
if not str(token) in self.cache:
78-
self.cache.append(str(token))
79-
return torch.tensor([self.cache.index(str(token)) + self.offset])
98+
self.cache[(str(token))] = len(self.cache)
99+
return torch.tensor([self.cache[str(token)] + self.offset])
80100

81101

82102
class OneHotEncoder(IndexEncoder):

0 commit comments

Comments
 (0)