diff --git a/chebai/models/electra.py b/chebai/models/electra.py index 94e9de69..36430773 100644 --- a/chebai/models/electra.py +++ b/chebai/models/electra.py @@ -203,6 +203,7 @@ def _process_batch(self, batch: Dict[str, Any], batch_idx: int) -> Dict[str, Any ) * CLS_TOKEN ) + model_kwargs["output_attentions"] = True return dict( features=torch.cat((cls_tokens, batch.x), dim=1), labels=batch.y, diff --git a/chebai/result/prediction.py b/chebai/result/prediction.py index 60548c5e..c80cdc34 100644 --- a/chebai/result/prediction.py +++ b/chebai/result/prediction.py @@ -58,9 +58,18 @@ def __init__( self._model_hparams = ckpt_file["hyper_parameters"] self._model_hparams.pop("_instantiator", None) self._model_hparams.pop("classes_txt_file_path", None) - self._model = ChebaiBaseNet.load_from_checkpoint( - checkpoint_path, map_location=self.device - ) + try: + self._model = ChebaiBaseNet.load_from_checkpoint( + checkpoint_path, + map_location=self.device, + ) + except Exception: + # models trained on a pretrained checkpoint have an additional path argument that we need to set to None + self._model = ChebaiBaseNet.load_from_checkpoint( + checkpoint_path, + map_location=self.device, + pretrained_checkpoint=None, + ) assert ( isinstance(self._model, ChebaiBaseNet) and type(self._model) is not ChebaiBaseNet