From 3fb84420d6c91372e4904f1ff87ea04de03c4f6d Mon Sep 17 00:00:00 2001 From: sfluegel Date: Tue, 24 Feb 2026 12:10:42 +0100 Subject: [PATCH] fix checkpoint loading for electra, return attentions --- chebai/models/electra.py | 1 + chebai/result/prediction.py | 15 ++++++++++++--- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/chebai/models/electra.py b/chebai/models/electra.py index 94e9de69..36430773 100644 --- a/chebai/models/electra.py +++ b/chebai/models/electra.py @@ -203,6 +203,7 @@ def _process_batch(self, batch: Dict[str, Any], batch_idx: int) -> Dict[str, Any ) * CLS_TOKEN ) + model_kwargs["output_attentions"] = True return dict( features=torch.cat((cls_tokens, batch.x), dim=1), labels=batch.y, diff --git a/chebai/result/prediction.py b/chebai/result/prediction.py index 60548c5e..c80cdc34 100644 --- a/chebai/result/prediction.py +++ b/chebai/result/prediction.py @@ -58,9 +58,18 @@ def __init__( self._model_hparams = ckpt_file["hyper_parameters"] self._model_hparams.pop("_instantiator", None) self._model_hparams.pop("classes_txt_file_path", None) - self._model = ChebaiBaseNet.load_from_checkpoint( - checkpoint_path, map_location=self.device - ) + try: + self._model = ChebaiBaseNet.load_from_checkpoint( + checkpoint_path, + map_location=self.device, + ) + except Exception: + # models trained on a pretrained checkpoint have an additional path argument that we need to set to None + self._model = ChebaiBaseNet.load_from_checkpoint( + checkpoint_path, + map_location=self.device, + pretrained_checkpoint=None, + ) assert ( isinstance(self._model, ChebaiBaseNet) and type(self._model) is not ChebaiBaseNet