|
56 | 56 | _REQUIRED_XLA_VERSION = "2.2" |
57 | 57 | _REQUIRED_XFORMERS_VERSION = "0.0.29" |
58 | 58 |
|
| 59 | +logger = get_logger(__name__) # pylint: disable=invalid-name |
| 60 | + |
59 | 61 | _CAN_USE_FLASH_ATTN = is_flash_attn_available() and is_flash_attn_version(">=", _REQUIRED_FLASH_VERSION) |
60 | 62 | _CAN_USE_FLASH_ATTN_3 = is_flash_attn_3_available() |
61 | 63 | _CAN_USE_AITER_ATTN = is_aiter_available() and is_aiter_version(">=", _REQUIRED_AITER_VERSION) |
|
70 | 72 | try: |
71 | 73 | from flash_attn import flash_attn_func, flash_attn_varlen_func |
72 | 74 | from flash_attn.flash_attn_interface import _wrapped_flash_attn_backward, _wrapped_flash_attn_forward |
73 | | - except (ImportError, OSError) as e: |
| 75 | + except (ImportError, OSError, RuntimeError) as e: |
74 | 76 | # Handle ABI mismatch or other import failures gracefully. |
75 | 77 | # This can happen when flash_attn was compiled against a different PyTorch version. |
76 | | - _flash_attn_logger = get_logger(__name__) |
77 | | - _flash_attn_logger.warning( |
| 78 | + logger.warning( |
78 | 79 | f"flash_attn is installed but failed to import: {e}. " |
79 | 80 | f"Falling back to native PyTorch attention." |
80 | 81 | ) |
|
94 | 95 | try: |
95 | 96 | from flash_attn_interface import flash_attn_func as flash_attn_3_func |
96 | 97 | from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func |
97 | | - except (ImportError, OSError) as e: |
98 | | - _flash_attn_3_logger = get_logger(__name__) |
99 | | - _flash_attn_3_logger.warning(f"flash_attn_3 failed to import: {e}. Falling back to native attention.") |
| 98 | + except (ImportError, OSError, RuntimeError) as e: |
| 99 | + logger.warning(f"flash_attn_3 failed to import: {e}. Falling back to native attention.") |
100 | 100 | _CAN_USE_FLASH_ATTN_3 = False |
101 | 101 | flash_attn_3_func = None |
102 | 102 | flash_attn_3_varlen_func = None |
|
107 | 107 | if _CAN_USE_AITER_ATTN: |
108 | 108 | try: |
109 | 109 | from aiter import flash_attn_func as aiter_flash_attn_func |
110 | | - except (ImportError, OSError) as e: |
111 | | - _aiter_logger = get_logger(__name__) |
112 | | - _aiter_logger.warning(f"aiter failed to import: {e}. Falling back to native attention.") |
| 110 | + except (ImportError, OSError, RuntimeError) as e: |
| 111 | + logger.warning(f"aiter failed to import: {e}. Falling back to native attention.") |
113 | 112 | _CAN_USE_AITER_ATTN = False |
114 | 113 | aiter_flash_attn_func = None |
115 | 114 | else: |
|
125 | 124 | sageattn_qk_int8_pv_fp16_triton, |
126 | 125 | sageattn_varlen, |
127 | 126 | ) |
128 | | - except (ImportError, OSError) as e: |
129 | | - _sage_logger = get_logger(__name__) |
130 | | - _sage_logger.warning(f"sageattention failed to import: {e}. Falling back to native attention.") |
| 127 | + except (ImportError, OSError, RuntimeError) as e: |
| 128 | + logger.warning(f"sageattention failed to import: {e}. Falling back to native attention.") |
131 | 129 | _CAN_USE_SAGE_ATTN = False |
132 | 130 | sageattn = None |
133 | 131 | sageattn_qk_int8_pv_fp8_cuda = None |
|
150 | 148 | # pytorch documentation) that the user may compile it. If we import directly, we will not have access to the |
151 | 149 | # compiled function. |
152 | 150 | import torch.nn.attention.flex_attention as flex_attention |
153 | | - except (ImportError, OSError) as e: |
154 | | - _flex_logger = get_logger(__name__) |
155 | | - _flex_logger.warning(f"flex_attention failed to import: {e}. Falling back to native attention.") |
| 151 | + except (ImportError, OSError, RuntimeError) as e: |
| 152 | + logger.warning(f"flex_attention failed to import: {e}. Falling back to native attention.") |
156 | 153 | _CAN_USE_FLEX_ATTN = False |
157 | 154 | flex_attention = None |
158 | 155 | else: |
|
162 | 159 | if _CAN_USE_NPU_ATTN: |
163 | 160 | try: |
164 | 161 | from torch_npu import npu_fusion_attention |
165 | | - except (ImportError, OSError) as e: |
166 | | - _npu_logger = get_logger(__name__) |
167 | | - _npu_logger.warning(f"torch_npu failed to import: {e}. Falling back to native attention.") |
| 162 | + except (ImportError, OSError, RuntimeError) as e: |
| 163 | + logger.warning(f"torch_npu failed to import: {e}. Falling back to native attention.") |
168 | 164 | _CAN_USE_NPU_ATTN = False |
169 | 165 | npu_fusion_attention = None |
170 | 166 | else: |
|
174 | 170 | if _CAN_USE_XLA_ATTN: |
175 | 171 | try: |
176 | 172 | from torch_xla.experimental.custom_kernel import flash_attention as xla_flash_attention |
177 | | - except (ImportError, OSError) as e: |
178 | | - _xla_logger = get_logger(__name__) |
179 | | - _xla_logger.warning(f"torch_xla failed to import: {e}. Falling back to native attention.") |
| 173 | + except (ImportError, OSError, RuntimeError) as e: |
| 174 | + logger.warning(f"torch_xla failed to import: {e}. Falling back to native attention.") |
180 | 175 | _CAN_USE_XLA_ATTN = False |
181 | 176 | xla_flash_attention = None |
182 | 177 | else: |
|
186 | 181 | if _CAN_USE_XFORMERS_ATTN: |
187 | 182 | try: |
188 | 183 | import xformers.ops as xops |
189 | | - except (ImportError, OSError) as e: |
190 | | - _xformers_logger = get_logger(__name__) |
191 | | - _xformers_logger.warning(f"xformers failed to import: {e}. Falling back to native attention.") |
| 184 | + except (ImportError, OSError, RuntimeError) as e: |
| 185 | + logger.warning(f"xformers failed to import: {e}. Falling back to native attention.") |
192 | 186 | _CAN_USE_XFORMERS_ATTN = False |
193 | 187 | xops = None |
194 | 188 | else: |
@@ -216,8 +210,6 @@ def wrap(func): |
216 | 210 | _register_fake = register_fake_no_op |
217 | 211 |
|
218 | 212 |
|
219 | | -logger = get_logger(__name__) # pylint: disable=invalid-name |
220 | | - |
221 | 213 | # TODO(aryan): Add support for the following: |
222 | 214 | # - Sage Attention++ |
223 | 215 | # - block sparse, radial and other attention methods |
|
0 commit comments