From 544625277d2686cb0ae128437145ec1d37924773 Mon Sep 17 00:00:00 2001 From: Hazem Date: Thu, 18 Sep 2025 21:15:33 +0300 Subject: [PATCH] handled the case for T5 tokenizers --- evo_prot_grad/common/sampler.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/evo_prot_grad/common/sampler.py b/evo_prot_grad/common/sampler.py index 47d2f0b..32b938a 100644 --- a/evo_prot_grad/common/sampler.py +++ b/evo_prot_grad/common/sampler.py @@ -136,7 +136,6 @@ def _product_of_experts(self, inputs: List[str]) -> Tuple[List[torch.Tensor], to scores += [expert.temperature * score] # sum scores over experts return ohs, torch.stack(scores, dim=0).sum(dim=0) - def _compute_gradients(self, ohs: List[torch.Tensor], PoE: torch.Tensor) -> torch.Tensor: """Compute the gradients of the product of experts @@ -160,8 +159,17 @@ def _compute_gradients(self, ohs: List[torch.Tensor], PoE: torch.Tensor) -> torc # This checks whether the gradient sequence length # is exactly two more than the input sequence length. if oh_grad.shape[1] == self.chains_oh.shape[1] + 2: - oh_grad = oh_grad[:,1:-1] - summed_grads += [ oh_grad @ expert.expert_to_canonical_order ] + oh_grad = oh_grad[:, 1:-1] + + # some tokenizers add an token to the protein + # sequence like ProtT5 and Ankh, check and remove here + # if necessary. + # This checks whether the gradient sequence length + # is exactly one more than the input sequence length. + elif oh_grad.shape[1] == self.chains_oh.shape[1] + 1: + oh_grad = oh_grad[:, :-1] + + summed_grads += [oh_grad @ expert.expert_to_canonical_order] # sum over experts return torch.stack(summed_grads, dim=0).sum(dim=0)