Skip to content

Commit c1b47c0

Browse files
sym-botclaude
andcommitted
address review: use single logger and catch RuntimeError
- Move logger to module level instead of creating per-backend loggers - Add RuntimeError to exception list alongside ImportError and OSError Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 21b221e commit c1b47c0

1 file changed

Lines changed: 18 additions & 26 deletions

File tree

src/diffusers/models/attention_dispatch.py

Lines changed: 18 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@
5656
_REQUIRED_XLA_VERSION = "2.2"
5757
_REQUIRED_XFORMERS_VERSION = "0.0.29"
5858

59+
logger = get_logger(__name__) # pylint: disable=invalid-name
60+
5961
_CAN_USE_FLASH_ATTN = is_flash_attn_available() and is_flash_attn_version(">=", _REQUIRED_FLASH_VERSION)
6062
_CAN_USE_FLASH_ATTN_3 = is_flash_attn_3_available()
6163
_CAN_USE_AITER_ATTN = is_aiter_available() and is_aiter_version(">=", _REQUIRED_AITER_VERSION)
@@ -70,11 +72,10 @@
7072
try:
7173
from flash_attn import flash_attn_func, flash_attn_varlen_func
7274
from flash_attn.flash_attn_interface import _wrapped_flash_attn_backward, _wrapped_flash_attn_forward
73-
except (ImportError, OSError) as e:
75+
except (ImportError, OSError, RuntimeError) as e:
7476
# Handle ABI mismatch or other import failures gracefully.
7577
# This can happen when flash_attn was compiled against a different PyTorch version.
76-
_flash_attn_logger = get_logger(__name__)
77-
_flash_attn_logger.warning(
78+
logger.warning(
7879
f"flash_attn is installed but failed to import: {e}. "
7980
f"Falling back to native PyTorch attention."
8081
)
@@ -94,9 +95,8 @@
9495
try:
9596
from flash_attn_interface import flash_attn_func as flash_attn_3_func
9697
from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func
97-
except (ImportError, OSError) as e:
98-
_flash_attn_3_logger = get_logger(__name__)
99-
_flash_attn_3_logger.warning(f"flash_attn_3 failed to import: {e}. Falling back to native attention.")
98+
except (ImportError, OSError, RuntimeError) as e:
99+
logger.warning(f"flash_attn_3 failed to import: {e}. Falling back to native attention.")
100100
_CAN_USE_FLASH_ATTN_3 = False
101101
flash_attn_3_func = None
102102
flash_attn_3_varlen_func = None
@@ -107,9 +107,8 @@
107107
if _CAN_USE_AITER_ATTN:
108108
try:
109109
from aiter import flash_attn_func as aiter_flash_attn_func
110-
except (ImportError, OSError) as e:
111-
_aiter_logger = get_logger(__name__)
112-
_aiter_logger.warning(f"aiter failed to import: {e}. Falling back to native attention.")
110+
except (ImportError, OSError, RuntimeError) as e:
111+
logger.warning(f"aiter failed to import: {e}. Falling back to native attention.")
113112
_CAN_USE_AITER_ATTN = False
114113
aiter_flash_attn_func = None
115114
else:
@@ -125,9 +124,8 @@
125124
sageattn_qk_int8_pv_fp16_triton,
126125
sageattn_varlen,
127126
)
128-
except (ImportError, OSError) as e:
129-
_sage_logger = get_logger(__name__)
130-
_sage_logger.warning(f"sageattention failed to import: {e}. Falling back to native attention.")
127+
except (ImportError, OSError, RuntimeError) as e:
128+
logger.warning(f"sageattention failed to import: {e}. Falling back to native attention.")
131129
_CAN_USE_SAGE_ATTN = False
132130
sageattn = None
133131
sageattn_qk_int8_pv_fp8_cuda = None
@@ -150,9 +148,8 @@
150148
# pytorch documentation) that the user may compile it. If we import directly, we will not have access to the
151149
# compiled function.
152150
import torch.nn.attention.flex_attention as flex_attention
153-
except (ImportError, OSError) as e:
154-
_flex_logger = get_logger(__name__)
155-
_flex_logger.warning(f"flex_attention failed to import: {e}. Falling back to native attention.")
151+
except (ImportError, OSError, RuntimeError) as e:
152+
logger.warning(f"flex_attention failed to import: {e}. Falling back to native attention.")
156153
_CAN_USE_FLEX_ATTN = False
157154
flex_attention = None
158155
else:
@@ -162,9 +159,8 @@
162159
if _CAN_USE_NPU_ATTN:
163160
try:
164161
from torch_npu import npu_fusion_attention
165-
except (ImportError, OSError) as e:
166-
_npu_logger = get_logger(__name__)
167-
_npu_logger.warning(f"torch_npu failed to import: {e}. Falling back to native attention.")
162+
except (ImportError, OSError, RuntimeError) as e:
163+
logger.warning(f"torch_npu failed to import: {e}. Falling back to native attention.")
168164
_CAN_USE_NPU_ATTN = False
169165
npu_fusion_attention = None
170166
else:
@@ -174,9 +170,8 @@
174170
if _CAN_USE_XLA_ATTN:
175171
try:
176172
from torch_xla.experimental.custom_kernel import flash_attention as xla_flash_attention
177-
except (ImportError, OSError) as e:
178-
_xla_logger = get_logger(__name__)
179-
_xla_logger.warning(f"torch_xla failed to import: {e}. Falling back to native attention.")
173+
except (ImportError, OSError, RuntimeError) as e:
174+
logger.warning(f"torch_xla failed to import: {e}. Falling back to native attention.")
180175
_CAN_USE_XLA_ATTN = False
181176
xla_flash_attention = None
182177
else:
@@ -186,9 +181,8 @@
186181
if _CAN_USE_XFORMERS_ATTN:
187182
try:
188183
import xformers.ops as xops
189-
except (ImportError, OSError) as e:
190-
_xformers_logger = get_logger(__name__)
191-
_xformers_logger.warning(f"xformers failed to import: {e}. Falling back to native attention.")
184+
except (ImportError, OSError, RuntimeError) as e:
185+
logger.warning(f"xformers failed to import: {e}. Falling back to native attention.")
192186
_CAN_USE_XFORMERS_ATTN = False
193187
xops = None
194188
else:
@@ -216,8 +210,6 @@ def wrap(func):
216210
_register_fake = register_fake_no_op
217211

218212

219-
logger = get_logger(__name__) # pylint: disable=invalid-name
220-
221213
# TODO(aryan): Add support for the following:
222214
# - Sage Attention++
223215
# - block sparse, radial and other attention methods

0 commit comments

Comments
 (0)