From 87e05f5f1bbe09c4c53ca5a5f1b1aefb7dfc1d95 Mon Sep 17 00:00:00 2001 From: Vitaliy Date: Sat, 20 Jun 2026 23:44:44 -0400 Subject: [PATCH] fix(attention_dispatch): use correct attr names for FLASH_VARLEN_HUB kernels-community/flash-attn2 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The `kernels-community/flash-attn2` hub kernel exposes: - `flash_attn_interface._flash_attn_varlen_forward` - `flash_attn_interface._flash_attn_varlen_backward` But `_HUB_KERNELS_REGISTRY[FLASH_VARLEN_HUB]` was referencing: - `flash_attn_interface._wrapped_flash_attn_varlen_forward` - `flash_attn_interface._wrapped_flash_attn_varlen_backward` Those `_wrapped_*` attributes do not exist in the hub kernel, causing `_flash_attention_varlen_hub` to fail with an AttributeError when trying to resolve them. Contrast with `FLASH_HUB` (non-varlen flash-attn2), which correctly uses `_wrapped_flash_attn_forward/backward` — the standard flash-attn2 library does expose those names, but the hub varlen kernel uses the unwrapped form. Also updates the RuntimeError message in `_flash_attention_varlen_hub` to match the corrected attribute names. Fixes #14012. --- src/diffusers/models/attention_dispatch.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index d9920a877112..80d925ea0843 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -352,8 +352,8 @@ class _HubKernelConfig: AttentionBackendName.FLASH_VARLEN_HUB: _HubKernelConfig( repo_id="kernels-community/flash-attn2", function_attr="flash_attn_varlen_func", - wrapped_forward_attr="flash_attn_interface._wrapped_flash_attn_varlen_forward", - wrapped_backward_attr="flash_attn_interface._wrapped_flash_attn_varlen_backward", + wrapped_forward_attr="flash_attn_interface._flash_attn_varlen_forward", + wrapped_backward_attr="flash_attn_interface._flash_attn_varlen_backward", version=1, ), AttentionBackendName.SAGE_HUB: _HubKernelConfig( @@ -1325,8 +1325,8 @@ def _flash_varlen_attention_hub_forward_op( wrapped_backward_fn = config.wrapped_backward_fn if wrapped_forward_fn is None or wrapped_backward_fn is None: raise RuntimeError( - "Flash attention varlen hub kernels must expose `_wrapped_flash_attn_varlen_forward` and " - "`_wrapped_flash_attn_varlen_backward` for context parallel execution." + "Flash attention varlen hub kernels must expose `_flash_attn_varlen_forward` and " + "`_flash_attn_varlen_backward` for context parallel execution." ) if scale is None: @@ -1419,7 +1419,7 @@ def _flash_varlen_attention_hub_backward_op( wrapped_backward_fn = config.wrapped_backward_fn if wrapped_backward_fn is None: raise RuntimeError( - "Flash attention varlen hub kernels must expose `_wrapped_flash_attn_varlen_backward` " + "Flash attention varlen hub kernels must expose `_flash_attn_varlen_backward` " "for context parallel execution." )