-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathopen_clip_patch.py
More file actions
37 lines (30 loc) · 1.53 KB
/
open_clip_patch.py
File metadata and controls
37 lines (30 loc) · 1.53 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from transformers.models.bert import modeling_bert
from open_clip import CustomTextCLIP
from open_clip.hf_model import HFTextEncoder
import torch.nn.functional as F
from torch import TensorType
def patch_encode_text():
def encode_text_patched(self, text, normalize: bool = False, output_attentions = False, output_tokens = False):
if output_attentions:
features, attn_scores = self.text(text, output_attentions = output_attentions, output_tokens = output_tokens)
features = F.normalize(features, dim=-1) if normalize else features
return features, attn_scores
else:
features = self.text(text, output_attentions = output_attentions, output_tokens = output_tokens)
return F.normalize(features, dim=-1) if normalize else features
def HFText_encoder_patched(self, x: TensorType, output_attentions=False, output_tokens=False):
self.output_tokens = output_tokens
attn_mask = (x != self.config.pad_token_id).long()
out = self.transformer(input_ids=x, attention_mask=attn_mask, output_attentions=output_attentions)
if self.output_tokens:
tokens = self.proj(out[0])
if output_attentions:
return tokens, out[1]
else:
return tokens
else:
pooled_out = self.pooler(out, attn_mask)
projected = self.proj(pooled_out)
return projected
CustomTextCLIP.encode_text = encode_text_patched
HFTextEncoder.forward = HFText_encoder_patched