Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion chebifier/prediction_models/electra_predictor.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

import numpy as np

from .nn_predictor import NNPredictor
Expand Down Expand Up @@ -40,7 +42,7 @@ def __init__(self, model_name: str, ckpt_path: str, **kwargs):
f"Initialised Electra model {self.model_name} (device: {self.predictor.device})"
)

def explain_smiles(self, smiles) -> dict:
def explain_smiles(self, smiles) -> Optional[dict]:
from chebai.preprocessing.reader import EMBEDDING_OFFSET

# Add dummy labels because the collate function requires them.
Expand Down Expand Up @@ -69,4 +71,6 @@ def explain_smiles(self, smiles) -> dict:
]
for a in result["attentions"]
]
if len(graphs) == 0:
return None
return {"graphs": graphs}
8 changes: 7 additions & 1 deletion chebifier/prediction_models/nn_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,13 @@ def __init__(
):
super().__init__(model_name, **kwargs)
self.batch_size = kwargs.get("batch_size", None)
# compile_model will run the model in eager mode, which gives better performance, but does not return intermediate states
# such as attention weights. Therfore, ELECTRA attention graphs will only work with compile_model=False.
compile_model = kwargs.get("compile_model", True)
# If batch_size is not provided, it will be set to default batch size used during training in Predictor
self.predictor: Predictor = Predictor(ckpt_path, self.batch_size)
self.predictor: Predictor = Predictor(
ckpt_path, self.batch_size, compile_model=compile_model
)

@modelwise_smiles_lru_cache.batch_decorator
def predict_smiles_list(self, smiles_list: list[str]) -> list:
Expand Down Expand Up @@ -51,4 +56,5 @@ def calculate_results(self, batch):
dat = self.predictor._model._process_batch(
collator(batch).to(self.predictor.device), 0
)

return self.predictor._model(dat, **dat["model_kwargs"])