diff --git a/.gitignore b/.gitignore index de347daa..3ba7a70b 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,12 @@ __pycache__/ tmp_examples* new_checkpoint* batch_test* -nohup* \ No newline at end of file +nohup* +wan.egg-info +build +Wan2.2-T2V-A14B +2.4.0 +1.23.5 +Wan2.2-TI2V-5B +input +out diff --git a/requirements.txt b/requirements.txt index 59274655..f03c3dfc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,5 +12,5 @@ easydict ftfy dashscope imageio-ffmpeg -flash_attn +#flash_attn numpy>=1.23.5,<2 \ No newline at end of file diff --git a/wan/modules/attention.py b/wan/modules/attention.py index 4dbbe03f..315db0d9 100644 --- a/wan/modules/attention.py +++ b/wan/modules/attention.py @@ -1,5 +1,6 @@ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import torch +import torch.nn.functional as F try: import flash_attn_interface @@ -109,25 +110,43 @@ def half(x): causal=causal, deterministic=deterministic)[0].unflatten(0, (b, lq)) else: - assert FLASH_ATTN_2_AVAILABLE - x = flash_attn.flash_attn_varlen_func( - q=q, - k=k, - v=v, - cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum( - 0, dtype=torch.int32).to(q.device, non_blocking=True), - cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum( - 0, dtype=torch.int32).to(q.device, non_blocking=True), - max_seqlen_q=lq, - max_seqlen_k=lk, - dropout_p=dropout_p, - softmax_scale=softmax_scale, - causal=causal, - window_size=window_size, - deterministic=deterministic).unflatten(0, (b, lq)) - - # output - return x.type(out_dtype) + # ---- SDPA Fallback (ohne Flash-Attn v2) ---- + # q, k, v: (b*lq, h, d), (b*lk, h, d) + bld_q, h, d = q.shape + bld_k = k.shape[0] + lq = bld_q // b + lk = bld_k // b + + # -> (b, h, L, d) + q_ = q.unflatten(0, (b, lq)).transpose(1, 2) # (b, h, lq, d) + k_ = k.unflatten(0, (b, lk)).transpose(1, 2) # (b, h, lk, d) + v_ = v.unflatten(0, (b, lk)).transpose(1, 2) # (b, h, lk, d) + + if softmax_scale is not None: + q_ = q_ * softmax_scale + + # bool-Maske: True = ignorieren + attn_mask = None + if (q_lens is not None) or (k_lens is not None): + attn_mask = torch.zeros((b, 1, lq, lk), dtype=torch.bool, device=q_.device) + if q_lens is not None: + for i in range(b): + if q_lens[i] < lq: + attn_mask[i, 0, q_lens[i]:, :] = True + if k_lens is not None: + for i in range(b): + if k_lens[i] < lk: + attn_mask[i, 0, :, k_lens[i]:] = True + + x_sdpa = F.scaled_dot_product_attention( + q_, k_, v_, + attn_mask=attn_mask, + dropout_p=0.0, + is_causal=bool(causal), + ) # (b, h, lq, d) + + x = x_sdpa.permute(0, 2, 1, 3).contiguous() + return x def attention(