-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathopenclip_encoder.py
More file actions
133 lines (115 loc) · 5.41 KB
/
openclip_encoder.py
File metadata and controls
133 lines (115 loc) · 5.41 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
#!/usr/bin/env python
import torch
import torchvision
import open_clip
class OpenCLIPNetwork:
def __init__(self, device):
self.process = torchvision.transforms.Compose(
[
torchvision.transforms.Resize((224, 224)),
torchvision.transforms.Normalize(
mean=[0.48145466, 0.4578275, 0.40821073],
std=[0.26862954, 0.26130258, 0.27577711],
),
]
)
self.clip_model_type = "ViT-B-16"
self.clip_model_pretrained = 'laion2b_s34b_b88k'
self.clip_n_dims = 512
model, _, _ = open_clip.create_model_and_transforms(
self.clip_model_type,
pretrained=self.clip_model_pretrained,
precision="fp16",
)
model.eval()
self.tokenizer = open_clip.get_tokenizer(self.clip_model_type)
self.model = model.to(device)
self.negatives = ("object", "things", "stuff", "texture")
self.positives = (" ",)
with torch.no_grad():
tok_phrases = torch.cat([self.tokenizer(phrase) for phrase in self.positives]).to(device)
self.pos_embeds = model.encode_text(tok_phrases)
tok_phrases = torch.cat([self.tokenizer(phrase) for phrase in self.negatives]).to(device)
self.neg_embeds = model.encode_text(tok_phrases)
self.pos_embeds /= self.pos_embeds.norm(dim=-1, keepdim=True)
self.neg_embeds /= self.neg_embeds.norm(dim=-1, keepdim=True)
@torch.no_grad()
def get_relevancy(self, embed: torch.Tensor, positive_id: int) -> torch.Tensor:
# embed: 32768x512
phrases_embeds = torch.cat([self.pos_embeds, self.neg_embeds], dim=0)
p = phrases_embeds.to(embed.dtype)
output = torch.mm(embed, p.T)
positive_vals = output[..., positive_id : positive_id + 1]
negative_vals = output[..., len(self.positives) :]
repeated_pos = positive_vals.repeat(1, len(self.negatives))
sims = torch.stack((repeated_pos, negative_vals), dim=-1)
softmax = torch.softmax(10 * sims, dim=-1)
best_id = softmax[..., 0].argmin(dim=1)
return torch.gather(softmax, 1, best_id[..., None, None].expand(best_id.shape[0], len(self.negatives), 2))[
:, 0, :
]
def encode_image(self, input, mask=None):
processed_input = self.process(input).half()
return self.model.encode_image(processed_input, mask=mask)
def encode_text(self, text_list, device):
text = self.tokenizer(text_list).to(device)
return self.model.encode_text(text)
def set_positives(self, text_list):
self.positives = text_list
with torch.no_grad():
tok_phrases = torch.cat(
[self.tokenizer(phrase) for phrase in self.positives]
).to(self.neg_embeds.device)
self.pos_embeds = self.model.encode_text(tok_phrases)
self.pos_embeds /= self.pos_embeds.norm(dim=-1, keepdim=True)
def set_semantics(self, text_list):
self.semantic_labels = text_list
with torch.no_grad():
tok_phrases = torch.cat([self.tokenizer(phrase) for phrase in self.semantic_labels]).to("cuda")
self.semantic_embeds = self.model.encode_text(tok_phrases)
self.semantic_embeds /= self.semantic_embeds.norm(dim=-1, keepdim=True)
def get_semantic_map(self, sem_map: torch.Tensor) -> torch.Tensor:
# embed: 3xhxwx512
n_levels, h, w, c = sem_map.shape
pos_num = self.semantic_embeds.shape[0]
phrases_embeds = torch.cat([self.semantic_embeds, self.neg_embeds], dim=0)
p = phrases_embeds.to(sem_map.dtype)
sem_pred = torch.zeros(n_levels, h, w)
for i in range(n_levels):
output = torch.mm(sem_map[i].view(-1, c), p.T)
softmax = torch.softmax(10 * output, dim=-1)
sem_pred[i] = torch.argmax(softmax, dim=-1).view(h, w)
sem_pred[i][sem_pred[i] >= pos_num] = -1
return sem_pred.long()
def get_max_across(self, sem_map):
n_phrases = len(self.positives)
n_phrases_sims = [None for _ in range(n_phrases)]
n_levels, h, w, _ = sem_map.shape
clip_output = sem_map.permute(1, 2, 0, 3).flatten(0, 1)
n_levels_sims = [None for _ in range(n_levels)]
for i in range(n_levels):
for j in range(n_phrases):
probs = self.get_relevancy(clip_output[..., i, :], j)
pos_prob = probs[..., 0:1]
n_phrases_sims[j] = pos_prob
n_levels_sims[i] = torch.stack(n_phrases_sims)
relev_map = torch.stack(n_levels_sims).view(n_levels, n_phrases, h, w)
return relev_map
class OpenCLIPTextEncoder:
def __init__(self, device="cuda"):
self.clip_model_type = "ViT-B-16"
self.clip_model_pretrained = "laion2b_s34b_b88k"
model, _, _ = open_clip.create_model_and_transforms(
self.clip_model_type,
pretrained=self.clip_model_pretrained,
precision="fp16",
)
model.eval()
self.model = model.to(device)
self.tokenizer = open_clip.get_tokenizer(self.clip_model_type)
self.device = device
@torch.no_grad()
def encode(self, texts) -> torch.Tensor:
tokens = self.tokenizer(texts).to(self.device)
embeds = self.model.encode_text(tokens)
return embeds / embeds.norm(dim=-1, keepdim=True)