diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index d9920a877112..80d925ea0843 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -352,8 +352,8 @@ class _HubKernelConfig: AttentionBackendName.FLASH_VARLEN_HUB: _HubKernelConfig( repo_id="kernels-community/flash-attn2", function_attr="flash_attn_varlen_func", - wrapped_forward_attr="flash_attn_interface._wrapped_flash_attn_varlen_forward", - wrapped_backward_attr="flash_attn_interface._wrapped_flash_attn_varlen_backward", + wrapped_forward_attr="flash_attn_interface._flash_attn_varlen_forward", + wrapped_backward_attr="flash_attn_interface._flash_attn_varlen_backward", version=1, ), AttentionBackendName.SAGE_HUB: _HubKernelConfig( @@ -1325,8 +1325,8 @@ def _flash_varlen_attention_hub_forward_op( wrapped_backward_fn = config.wrapped_backward_fn if wrapped_forward_fn is None or wrapped_backward_fn is None: raise RuntimeError( - "Flash attention varlen hub kernels must expose `_wrapped_flash_attn_varlen_forward` and " - "`_wrapped_flash_attn_varlen_backward` for context parallel execution." + "Flash attention varlen hub kernels must expose `_flash_attn_varlen_forward` and " + "`_flash_attn_varlen_backward` for context parallel execution." ) if scale is None: @@ -1419,7 +1419,7 @@ def _flash_varlen_attention_hub_backward_op( wrapped_backward_fn = config.wrapped_backward_fn if wrapped_backward_fn is None: raise RuntimeError( - "Flash attention varlen hub kernels must expose `_wrapped_flash_attn_varlen_backward` " + "Flash attention varlen hub kernels must expose `_flash_attn_varlen_backward` " "for context parallel execution." )