Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 57 additions & 53 deletions ip_adapter/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ def forward(

if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
ip_hidden_states = None
else:
# get encoder_hidden_states, ip_hidden_states
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
Expand All @@ -165,31 +166,32 @@ def forward(
hidden_states = attn.batch_to_head_dim(hidden_states)

# for ip-adapter
ip_key = self.to_k_ip(ip_hidden_states)
ip_value = self.to_v_ip(ip_hidden_states)

ip_key = attn.head_to_batch_dim(ip_key)
ip_value = attn.head_to_batch_dim(ip_value)

if xformers_available:
ip_hidden_states = self._memory_efficient_attention_xformers(query, ip_key, ip_value, None)
else:
ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)

# region control
if len(region_control.prompt_image_conditioning) == 1:
region_mask = region_control.prompt_image_conditioning[0].get('region_mask', None)
if region_mask is not None:
h, w = region_mask.shape[:2]
ratio = (h * w / query.shape[1]) ** 0.5
mask = F.interpolate(region_mask[None, None], scale_factor=1/ratio, mode='nearest').reshape([1, -1, 1])
else:
mask = torch.ones_like(ip_hidden_states)
ip_hidden_states = ip_hidden_states * mask
if ip_hidden_states is not None:
ip_key = self.to_k_ip(ip_hidden_states)
ip_value = self.to_v_ip(ip_hidden_states)

ip_key = attn.head_to_batch_dim(ip_key)
ip_value = attn.head_to_batch_dim(ip_value)

hidden_states = hidden_states + self.scale * ip_hidden_states
if xformers_available:
ip_hidden_states = self._memory_efficient_attention_xformers(query, ip_key, ip_value, None)
else:
ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)

# region control
if len(region_control.prompt_image_conditioning) == 1:
region_mask = region_control.prompt_image_conditioning[0].get('region_mask', None)
if region_mask is not None:
h, w = region_mask.shape[:2]
ratio = (h * w / query.shape[1]) ** 0.5
mask = F.interpolate(region_mask[None, None], scale_factor=1/ratio, mode='nearest').reshape([1, -1, 1])
else:
mask = torch.ones_like(ip_hidden_states)
ip_hidden_states = ip_hidden_states * mask

hidden_states = hidden_states + self.scale * ip_hidden_states

# linear proj
hidden_states = attn.to_out[0](hidden_states)
Expand Down Expand Up @@ -368,6 +370,7 @@ def forward(

if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
ip_hidden_states = None
else:
# get encoder_hidden_states, ip_hidden_states
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
Expand Down Expand Up @@ -399,37 +402,38 @@ def forward(
hidden_states = hidden_states.to(query.dtype)

# for ip-adapter
ip_key = self.to_k_ip(ip_hidden_states)
ip_value = self.to_v_ip(ip_hidden_states)
if ip_hidden_states is not None:
ip_key = self.to_k_ip(ip_hidden_states)
ip_value = self.to_v_ip(ip_hidden_states)

ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
ip_hidden_states = F.scaled_dot_product_attention(
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)
with torch.no_grad():
self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1)
#print(self.attn_map.shape)

ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
ip_hidden_states = ip_hidden_states.to(query.dtype)

# region control
if len(region_control.prompt_image_conditioning) == 1:
region_mask = region_control.prompt_image_conditioning[0].get('region_mask', None)
if region_mask is not None:
query = query.reshape([-1, query.shape[-2], query.shape[-1]])
h, w = region_mask.shape[:2]
ratio = (h * w / query.shape[1]) ** 0.5
mask = F.interpolate(region_mask[None, None], scale_factor=1/ratio, mode='nearest').reshape([1, -1, 1])
else:
mask = torch.ones_like(ip_hidden_states)
ip_hidden_states = ip_hidden_states * mask

hidden_states = hidden_states + self.scale * ip_hidden_states
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
ip_hidden_states = F.scaled_dot_product_attention(
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)
with torch.no_grad():
self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1)
#print(self.attn_map.shape)

ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
ip_hidden_states = ip_hidden_states.to(query.dtype)

# region control
if len(region_control.prompt_image_conditioning) == 1:
region_mask = region_control.prompt_image_conditioning[0].get('region_mask', None)
if region_mask is not None:
query = query.reshape([-1, query.shape[-2], query.shape[-1]])
h, w = region_mask.shape[:2]
ratio = (h * w / query.shape[1]) ** 0.5
mask = F.interpolate(region_mask[None, None], scale_factor=1/ratio, mode='nearest').reshape([1, -1, 1])
else:
mask = torch.ones_like(ip_hidden_states)
ip_hidden_states = ip_hidden_states * mask

hidden_states = hidden_states + self.scale * ip_hidden_states

# linear proj
hidden_states = attn.to_out[0](hidden_states)
Expand Down