diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 837d573d8c4d..9443dc4440c3 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -1523,15 +1523,18 @@ def _maybe_modify_attn_mask_npu(query: torch.Tensor, key: torch.Tensor, attn_mas # Reshape Attention Mask: [batch_size, seq_len_k] -> [batch_size, 1, sqe_len_q, seq_len_k] # https://www.hiascend.com/document/detail/zh/Pytorch/730/apiref/torchnpuCustomsapi/docs/context/torch_npu-npu_fusion_attention.md - if ( - attn_mask is not None - and attn_mask.ndim == 2 - and attn_mask.shape[0] == query.shape[0] - and attn_mask.shape[1] == key.shape[1] - ): - B, Sq, Skv = attn_mask.shape[0], query.shape[1], key.shape[1] + if attn_mask is not None: + if ( + attn_mask.ndim == 2 + and attn_mask.shape[0] == query.shape[0] + and attn_mask.shape[1] == key.shape[1] + ): + B, Sq, Skv = attn_mask.shape[0], query.shape[1], key.shape[1] + attn_mask = attn_mask.unsqueeze(1).expand(B, Sq, Skv).unsqueeze(1).contiguous() + elif attn_mask.ndim == 4 and attn_mask.shape[1:3] == (1, 1): + attn_mask = attn_mask.expand(-1, -1, query.shape[1], -1) + attn_mask = ~attn_mask.to(torch.bool) - attn_mask = attn_mask.unsqueeze(1).expand(B, Sq, Skv).unsqueeze(1).contiguous() return attn_mask diff --git a/src/diffusers/models/transformers/transformer_ernie_image.py b/src/diffusers/models/transformers/transformer_ernie_image.py index 4bf00f749f25..57f97a808a7c 100644 --- a/src/diffusers/models/transformers/transformer_ernie_image.py +++ b/src/diffusers/models/transformers/transformer_ernie_image.py @@ -127,8 +127,14 @@ def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tenso query, key = query.to(dtype), key.to(dtype) # From [batch, seq_len] to [batch, 1, 1, seq_len] -> broadcast to [batch, heads, seq_len, seq_len] - if attention_mask is not None and attention_mask.ndim == 2: - attention_mask = attention_mask[:, None, None, :] + if attention_mask is not None: + if attention_mask.ndim == 2: + attention_mask = attention_mask[:, None, None, :] + + if attention_mask.ndim == 4: + # NPU does not support automatic broadcasting for this type; the mask must be expanded. + if attention_mask.device.type == 'npu' and attention_mask.shape[1:3] == (1, 1): + attention_mask = attention_mask.expand(-1, attn.heads, query.shape[1], -1) # Compute joint attention hidden_states = dispatch_attention_fn(