-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodel_interface.py
More file actions
231 lines (217 loc) · 10.1 KB
/
model_interface.py
File metadata and controls
231 lines (217 loc) · 10.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
import random
import torch
import torch.nn.functional as F
class ModelInterface:
"""
All queries to classifiers/models to should happend via this class.
It is a wrapper over a set of models that:
- tracks model calls
- implements the logic to pick a model
- implements the definition of an adversarial example
"""
def __init__(self, models, bounds=(0, 1), n_classes=None, slack=0.10, noise='deterministic',
new_adv_def=False, device=None, flip_prob=0.0, smoothing_noise=0., crop_size=None):
self.models = models
self.bounds = bounds
self.n_classes = n_classes
self.model_calls = 0
self.slack_prop = slack
self.noise = noise
self.new_adversarial_def = new_adv_def
self.device = device
self.send_models_to_device()
self.flip_prob = flip_prob
self.smoothing_noise = smoothing_noise
self.crop_size = crop_size
def send_models_to_device(self):
for model in self.models:
model.model = model.model.to(self.device)
def sample_bernoulli(self, probs):
self.model_calls += probs.numel()
return torch.bernoulli(probs)
def decision(self, batch, label, num_queries=1, targeted=False):
N = batch.shape[0] * num_queries
self.model_calls += batch.shape[0] * num_queries
# if N <= 100*1000:
if batch.ndim == 3:
new_batch = batch.repeat(num_queries, 1, 1)
else:
new_batch = batch.repeat(num_queries, 1, 1, 1)
decisions = self._decision(new_batch, label, targeted)
decisions = decisions.view(-1, len(batch)).transpose(0, 1)
# elif num_queries <= 100*1000:
# decisions = torch.zeros(len(batch), num_queries, device=batch.device)
# for b in range(len(batch)):
# if batch.ndim == 3:
# new_batch = batch[b].view(-1, 1, 1).repeat(num_queries, 1, 1)
# else:
# new_batch = batch[b].view(-1, 1, 1, 1).repeat(num_queries, 1, 1, 1)
# decisions[b] = self._decision(new_batch, label, targeted)
# else:
# decisions = torch.zeros(len(batch), num_queries, device=batch.device)
# for q in range(num_queries):
# decisions[:, q] = self._decision(batch, label, targeted)
return decisions
def _decision(self, batch, label, targeted=False):
"""
:param label: True/Targeted labels of the original image being attacked
:param num_queries: Number of times to query each image
:param batch: A batch of images
:param targeted: if targeted is true, label=targeted_label else label=true_label
:return: decisions of shape = (len(batch), num_queries)
"""
if self.noise == 'deterministic':
probs = self.get_probs_(images=batch)
prediction = probs.argmax(dim=1)
if targeted:
return (prediction == label) * 1.0
else:
return (prediction != label) * 1.0
elif self.noise == 'dropout':
probs = self.get_probs_(images=batch)
prediction = probs.argmax(dim=1)
if targeted:
return (prediction == label) * 1.0
else:
return (prediction != label) * 1.0
elif self.noise == 'smoothing':
rv = torch.randn(size=batch.shape, device=self.device)
batch_ = batch + self.smoothing_noise * rv
batch_ = torch.clamp(batch_, self.bounds[0], self.bounds[1])
probs = self.get_probs_(images=batch_)
prediction = probs.argmax(dim=1)
if targeted:
return (prediction == label) * 1.0
else:
return (prediction != label) * 1.0
elif self.noise == 'cropping':
size = batch.shape[1]
x_start = torch.randint(low=0, high=size+1-self.crop_size, size=(1, len(batch)))[0]
x_end = x_start + self.crop_size
y_start = torch.randint(low=0, high=size+1-self.crop_size, size=(1, len(batch)))[0]
y_end = y_start + self.crop_size
cropped = [b[x_start[i]:x_end[i], y_start[i]:y_end[i]] for i, b in enumerate(batch)]
cropped_batch = torch.stack(cropped)
if cropped_batch.ndim == 4:
resized = F.interpolate(cropped_batch.permute(0, 3, 1, 2), size, mode='bilinear')
resized = resized.permute(0, 2, 3, 1)
else:
resized = F.interpolate(cropped_batch.unsqueeze(dim=1), size, mode='bilinear')
resized = resized.squeeze(dim=1)
probs = self.get_probs_(images=resized)
prediction = probs.argmax(dim=1)
if targeted:
return (prediction == label) * 1.0
else:
return (prediction != label) * 1.0
elif self.noise == 'stochastic':
num_queries = 1 # TODO: this should be removed. num_queries is not supported by this function now
probs = self.get_probs_(images=batch)
rand_pred = torch.randint(self.n_classes-1, size=(len(batch), num_queries), device=self.device)
# TODO: Review this step carefully. I think it is assumed that prediction = label
rand_pred[rand_pred == label] = self.n_classes - 1
prediction = probs.argmax(dim=1).view(-1, 1).repeat(1, num_queries)
indices_to_flip = torch.rand(size=(len(batch), num_queries), device=self.device) < self.flip_prob
prediction[indices_to_flip] = rand_pred[indices_to_flip]
if targeted:
return (prediction == label) * 1.0
else:
return (prediction != label) * 1.0
elif self.noise == 'bayesian':
probs = self.get_probs_(images=batch)
probs = probs[:, label]
# probs = probs.view(-1, 1).repeat(1, num_queries)
if targeted:
decisions = torch.bernoulli(probs)
else:
decisions = torch.bernoulli(1 - probs)
return decisions
else:
raise RuntimeError(f'Unknown Noise type: {self.noise}')
def decision_with_logits(self, batch, true_label):
"""
Same as decision() but insteas of decision it returns logit vectors. Used for white-box attacks
:return: decisions of shape = (len(batch), num_classes)
"""
probs = self.get_probs_(images=batch)
self.model_calls += batch.shape[0]
if self.noise == 'deterministic':
ans = torch.zeros_like(probs)
ans[torch.arange(len(probs)), probs.argmax(axis=1)] = 1
return ans
elif self.noise == 'stochastic':
ans = torch.ones_like(probs) * self.flip_prob / (self.n_classes - 1)
ans[torch.arange(len(probs)), probs.argmax(axis=1)] = 1 - self.flip_prob
return ans
elif self.noise == 'bayesian':
return probs
else:
raise RuntimeError(f'Unknown Noise type: {self.noise}')
def get_probs_(self, images):
"""
WARNING
This function should only be used for capturing statistics.
It should not be a part of a decision based attack.
"""
m_id = random.choice(list(range(len(self.models))))
outs = self.models[m_id].get_probs(images)
# m_ids = torch.randint(low=0, high=len(self.models), size=[len(images)])
# outs = torch.zeros((len(images), self.n_classes), device=self.device)
# for i, image in enumerate(images):
# outs[i] = self.models[m_ids[i]].get_probs(image[None])[0]
return outs
def get_probs(self, image):
"""
WARNING
This function should only be used for capturing statistics.
It should not be a part of a decision based attack.
"""
m_id = random.choice(list(range(len(self.models))))
outs = self.models[m_id].get_probs(image[None])
return outs
def get_grads(self, images, true_label):
"""
WARNING
This function should only be used for capturing statistics.
It should not be a part of a decision based attack.
"""
m_id = random.choice(list(range(len(self.models))))
outs = self.models[m_id].get_grads(images, true_label)
return outs
@DeprecationWarning
def forward(self, images, a, freq, average=False, remember=True):
if type(images) != torch.Tensor:
images = torch.tensor(images).to(self.device)
slack = self.slack_prop * freq
batch = torch.stack(tuple(images))
m_id = random.choice(list(range(len(self.models))))
if self.noise == 'deterministic':
labels = self.models[m_id].ask_model(batch)
ans = (labels != a.true_label) * 1
self.model_calls += len(images)
else:
inp_batch = batch.repeat(freq, 1, 1)
outs = self.models[m_id].ask_model(inp_batch).reshape(freq, len(images)).T
self.model_calls += (len(images) * freq)
N = self.n_classes
id = outs + (N * torch.arange(outs.shape[0]).to(self.device))[:, None]
freqs = torch.bincount(id.flatten(), minlength=N * outs.shape[0]).view(-1, N)
true_freqs = freqs[:, a.true_label]
r = list(range(self.n_classes))
false_freqs = torch.max(freqs[:, r[:a.true_label] + r[a.true_label + 1:]], dim=1)[0]
if self.new_adversarial_def:
ans = (true_freqs < 0.5 * freq) * 1
else:
ans = (false_freqs > true_freqs + slack) * 1
if remember:
for i in range(len(images)):
if ans[i] == 1:
distance = a.calculate_distance(images[i], self.bounds)
if a.distance > distance:
a.distance = distance
a.perturbed = images[i]
if average and self.noise != 'deterministic':
adv_prob = 1 - (true_freqs / freq)
return adv_prob
else:
return ans