From bac01ea8c8bda926d76097b4d98e87f604e8defd Mon Sep 17 00:00:00 2001 From: Your Name Date: Sun, 1 Feb 2026 12:54:48 +0530 Subject: [PATCH] Fix: Apply pre_attn_layout in torch/vanilla attention paths --- wan/modules/animate/face_blocks.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/wan/modules/animate/face_blocks.py b/wan/modules/animate/face_blocks.py index 69c04150..02b9e32e 100644 --- a/wan/modules/animate/face_blocks.py +++ b/wan/modules/animate/face_blocks.py @@ -65,6 +65,10 @@ def attention( pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode] if mode == "torch": + # Apply the layout transformation + q = pre_attn_layout(q) + k = pre_attn_layout(k) + v = pre_attn_layout(v) if attn_mask is not None and attn_mask.dtype != torch.bool: attn_mask = attn_mask.to(q.dtype) x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal) @@ -77,6 +81,10 @@ def attention( ) x = x.view(batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]) # reshape x to [b, s, a, d] elif mode == "vanilla": + # Apply the layout transformation + q = pre_attn_layout(q) + k = pre_attn_layout(k) + v = pre_attn_layout(v) scale_factor = 1 / math.sqrt(q.size(-1)) b, a, s, _ = q.shape