Hi, thanks for your work. I'm looking at your implementation,
and maybe I'm missing something, but when you do
one_hot.scatter_(
1,
input_ids[0][:control_length].unsqueeze(1),
torch.ones(one_hot.shape[0], 1, device=model.device, dtype=embed_weights.dtype)
)
You access the CLS (<|startoftext|>) token and then run the gradient over it.
If I understand correctly, you should do something like
one_hot.scatter_(1, input_ids[0][1:control_length+1].unsqueeze(1), ....)
and then
full_embeds = torch.cat([embeds[:, 0:1], input_embeds, embeds[:, control_length+1:]], dim=1).
What do you think about this?
Hi, thanks for your work. I'm looking at your implementation,
MMA-Diffusion/src/textual_attack.py
Line 81 in 563091a
and maybe I'm missing something, but when you do
one_hot.scatter_(
1,
input_ids[0][:control_length].unsqueeze(1),
torch.ones(one_hot.shape[0], 1, device=model.device, dtype=embed_weights.dtype)
)
You access the CLS (<|startoftext|>) token and then run the gradient over it.
If I understand correctly, you should do something like
one_hot.scatter_(1, input_ids[0][1:control_length+1].unsqueeze(1), ....)
and then
full_embeds = torch.cat([embeds[:, 0:1], input_embeds, embeds[:, control_length+1:]], dim=1).
What do you think about this?