11import abc
22import os
3- import torch
43from typing import Optional
54
5+ import torch
6+ import sys
7+ from itertools import islice
8+ import inspect
9+
610
711class PropertyEncoder (abc .ABC ):
812 def __init__ (self , property , ** kwargs ):
@@ -36,11 +40,13 @@ class IndexEncoder(PropertyEncoder):
3640 def __init__ (self , property , indices_dir = None , ** kwargs ):
3741 super ().__init__ (property , ** kwargs )
3842 if indices_dir is None :
39- indices_dir = os .path .dirname (__file__ )
43+ indices_dir = os .path .dirname (inspect . getfile ( self . __class__ ) )
4044 self .dirname = indices_dir
4145 # load already existing cache
4246 with open (self .index_path , "r" ) as pk :
43- self .cache = [x .strip () for x in pk ]
47+ self .cache : dict [str , int ] = {
48+ token .strip (): idx for idx , token in enumerate (pk )
49+ }
4450 self .index_length_start = len (self .cache )
4551 self .offset = 0
4652
@@ -64,19 +70,33 @@ def index_path(self):
6470
6571 def on_finish (self ):
6672 """Save cache"""
67- with open (self .index_path , "w" ) as pk :
68- new_length = len (self .cache ) - self .index_length_start
69- pk .writelines ([f"{ c } \n " for c in self .cache ])
70- print (
71- f"saved index of property { self .property .name } to { self .index_path } , "
72- f"index length: { len (self .cache )} (new: { new_length } )"
73- )
73+ total_tokens = len (self .cache )
74+ if total_tokens > self .index_length_start :
75+ print ("New tokens added to the cache, Saving them to index token file....." )
76+
77+ assert sys .version_info >= (
78+ 3 ,
79+ 7 ,
80+ ), "This code requires Python 3.7 or higher."
81+ # For python 3.7+, the standard dict type preserves insertion order, and is iterated over in same order
82+ # https://docs.python.org/3/whatsnew/3.7.html#summary-release-highlights
83+ # https://mail.python.org/pipermail/python-dev/2017-December/151283.html
84+ new_tokens = list (islice (self .cache , self .index_length_start , total_tokens ))
85+
86+ with open (self .index_path , "a" ) as pk :
87+ pk .writelines ([f"{ c } \n " for c in new_tokens ])
88+ print (
89+ f"New { len (new_tokens )} tokens append to index of property { self .property .name } to { self .index_path } ..."
90+ )
91+ print (
92+ f"Now, the total length of the index of property { self .property .name } is { total_tokens } "
93+ )
7494
7595 def encode (self , token ):
7696 """Returns a unique number for each token, automatically adds new tokens to the cache."""
7797 if not str (token ) in self .cache :
78- self .cache . append (str (token ))
79- return torch .tensor ([self .cache . index ( str (token )) + self .offset ])
98+ self .cache [ (str (token ))] = len ( self . cache )
99+ return torch .tensor ([self .cache [ str (token )] + self .offset ])
80100
81101
82102class OneHotEncoder (IndexEncoder ):
0 commit comments