diff --git a/graphai/core/embedding/embedding.py b/graphai/core/embedding/embedding.py index 5a9ab042..faa0b121 100644 --- a/graphai/core/embedding/embedding.py +++ b/graphai/core/embedding/embedding.py @@ -236,6 +236,7 @@ def _get_model_output(self, model, text): def _embed(self, model, text, force_split): text_too_large = False result = self._get_model_output(model, text) + if result is None: if force_split: model_max_tokens = self._get_model_max_tokens(model) @@ -251,6 +252,12 @@ def _embed(self, model, text, force_split): result = results.sum(axis=0).flatten() else: text_too_large = True + + # Normalise for all vectors to have unit length + norm = np.linalg.norm(result) + if norm > 0: + result = result / norm + return result, text_too_large def embed(self, text, model_type='all-MiniLM-L12-v2', force_split=True):