Skip to content

Commit 3006ab5

Browse files
committed
Implement FA4 varlen based on Paddle.
1 parent df46cc8 commit 3006ab5

23 files changed

Lines changed: 5086 additions & 1197 deletions

flashmask/flash_mask/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,15 @@
6969
except ImportError:
7070
pass # cute module not installed or dependencies missing
7171

72+
# ============================================================
73+
# FA4 varlen / standard interface (framework-routed)
74+
# ============================================================
75+
try:
76+
from .interface import flash_attn_func, flash_attn_varlen_func, flash_attn_combine
77+
__all__ += ["flash_attn_func", "flash_attn_varlen_func", "flash_attn_combine"]
78+
except ImportError:
79+
pass
80+
7281
if not _fa3_available and not _fa4_available:
7382
print("[WARNING] flash_mask: neither FA3 nor FA4 is available. "
7483
"Check your installation.")

flashmask/flash_mask/flash_attn_v4/__init__.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,21 @@
99

1010
import cutlass.cute as cute
1111

12-
from .interface import (
13-
flash_attn_func,
14-
flash_attn_varlen_func,
15-
)
16-
17-
from flash_mask.flash_attn_v4.cute_dsl_utils import cute_compile_patched
12+
# Auto-detect framework: prefer torch, fall back to paddle.
13+
try:
14+
import torch # noqa: F401
15+
16+
from flash_mask.flash_attn_v4.torch.interface import (
17+
flash_attn_func,
18+
flash_attn_varlen_func,
19+
)
20+
from flash_mask.flash_attn_v4.torch.cute_dsl_utils import cute_compile_patched
21+
except ImportError:
22+
from flash_mask.flash_attn_v4.paddle.interface import (
23+
flash_attn_func,
24+
flash_attn_varlen_func,
25+
)
26+
from flash_mask.flash_attn_v4.paddle.cute_dsl_utils import cute_compile_patched
1827

1928
# Patch cute.compile to optionally dump SASS
2029
cute.compile = cute_compile_patched

0 commit comments

Comments
 (0)