From 75ade76b867bcd72b5f05853461d20af5de0ce8a Mon Sep 17 00:00:00 2001 From: sfluegel Date: Tue, 24 Feb 2026 12:15:03 +0100 Subject: [PATCH] make compilation optional, return None for empty attention graph --- chebifier/prediction_models/electra_predictor.py | 6 +++++- chebifier/prediction_models/nn_predictor.py | 8 +++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/chebifier/prediction_models/electra_predictor.py b/chebifier/prediction_models/electra_predictor.py index a8355d1..bf9f7e3 100644 --- a/chebifier/prediction_models/electra_predictor.py +++ b/chebifier/prediction_models/electra_predictor.py @@ -1,3 +1,5 @@ +from typing import Optional + import numpy as np from .nn_predictor import NNPredictor @@ -40,7 +42,7 @@ def __init__(self, model_name: str, ckpt_path: str, **kwargs): f"Initialised Electra model {self.model_name} (device: {self.predictor.device})" ) - def explain_smiles(self, smiles) -> dict: + def explain_smiles(self, smiles) -> Optional[dict]: from chebai.preprocessing.reader import EMBEDDING_OFFSET # Add dummy labels because the collate function requires them. @@ -69,4 +71,6 @@ def explain_smiles(self, smiles) -> dict: ] for a in result["attentions"] ] + if len(graphs) == 0: + return None return {"graphs": graphs} diff --git a/chebifier/prediction_models/nn_predictor.py b/chebifier/prediction_models/nn_predictor.py index 92962f2..971a42d 100644 --- a/chebifier/prediction_models/nn_predictor.py +++ b/chebifier/prediction_models/nn_predictor.py @@ -20,8 +20,13 @@ def __init__( ): super().__init__(model_name, **kwargs) self.batch_size = kwargs.get("batch_size", None) + # compile_model will run the model in eager mode, which gives better performance, but does not return intermediate states + # such as attention weights. Therfore, ELECTRA attention graphs will only work with compile_model=False. + compile_model = kwargs.get("compile_model", True) # If batch_size is not provided, it will be set to default batch size used during training in Predictor - self.predictor: Predictor = Predictor(ckpt_path, self.batch_size) + self.predictor: Predictor = Predictor( + ckpt_path, self.batch_size, compile_model=compile_model + ) @modelwise_smiles_lru_cache.batch_decorator def predict_smiles_list(self, smiles_list: list[str]) -> list: @@ -51,4 +56,5 @@ def calculate_results(self, batch): dat = self.predictor._model._process_batch( collator(batch).to(self.predictor.device), 0 ) + return self.predictor._model(dat, **dat["model_kwargs"])