Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 83 additions & 93 deletions flashmask/flash_mask/cute/block_sparsity.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"""

from typing import Tuple, Optional, Callable, List, NamedTuple
import paddle
import torch
import cutlass.cute as cute
from cutlass.cute.runtime import from_dlpack

Expand All @@ -30,18 +30,18 @@ def __new_from_mlir_values__(self, values):
return BlockSparseTensors(*values)


class BlockSparseTensorsPaddle(NamedTuple):
mask_block_cnt: paddle.Tensor
mask_block_idx: paddle.Tensor
full_block_cnt: Optional[paddle.Tensor] = None
full_block_idx: Optional[paddle.Tensor] = None
class BlockSparseTensorsTorch(NamedTuple):
mask_block_cnt: torch.Tensor
mask_block_idx: torch.Tensor
full_block_cnt: Optional[torch.Tensor] = None
full_block_idx: Optional[torch.Tensor] = None


def _expand_sparsity_tensor(
tensor: paddle.Tensor,
tensor: torch.Tensor,
expected_shape: Tuple[int, ...],
tensor_name: str,
) -> paddle.Tensor:
) -> torch.Tensor:
"""Check if we need to expand the tensor to expected shape, and do so if possible."""
needs_expand = tensor.shape != expected_shape
if not needs_expand:
Expand All @@ -56,20 +56,20 @@ def _expand_sparsity_tensor(

def _check_and_expand_block(
name: str,
cnt: Optional[paddle.Tensor],
idx: Optional[paddle.Tensor],
cnt: Optional[torch.Tensor],
idx: Optional[torch.Tensor],
expected_count_shape: Tuple[int, int, int],
expected_index_shape: Tuple[int, int, int, int],
) -> Tuple[Optional[paddle.Tensor], Optional[paddle.Tensor]]:
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
if (cnt is None) != (idx is None):
raise ValueError(
f"{name}_block_cnt and {name}_block_idx must both be provided or both be None"
)
if cnt is None or idx is None:
return None, None
if cnt.dtype != paddle.int32 or idx.dtype != paddle.int32:
raise ValueError(f"{name}_block tensors must have dtype paddle.int32")
if cnt.place != idx.place:
if cnt.dtype != torch.int32 or idx.dtype != torch.int32:
raise ValueError(f"{name}_block tensors must have dtype torch.int32")
if cnt.device != idx.device:
raise ValueError(f"{name}_block_cnt and {name}_block_idx must be on the same device")
if not cnt.is_cuda or not idx.is_cuda:
raise ValueError(f"{name}_block tensors must live on CUDA")
Expand All @@ -79,11 +79,11 @@ def _check_and_expand_block(


def normalize_block_sparse_tensors(
tensors: BlockSparseTensorsPaddle,
tensors: BlockSparseTensorsTorch,
*,
expected_count_shape: Tuple[int, int, int],
expected_index_shape: Tuple[int, int, int, int],
) -> BlockSparseTensorsPaddle:
) -> BlockSparseTensorsTorch:
if tensors.mask_block_cnt is None or tensors.mask_block_idx is None:
raise ValueError("mask_block_cnt and mask_block_idx must be provided for block sparsity.")

Expand All @@ -104,45 +104,33 @@ def normalize_block_sparse_tensors(
expected_count_shape,
expected_index_shape,
)
if full_cnt is not None and mask_cnt.place != full_cnt.place:
if full_cnt is not None and mask_cnt.device != full_cnt.device:
raise ValueError("All block sparse tensors must be on the same device")

return BlockSparseTensorsPaddle(
return BlockSparseTensorsTorch(
mask_block_cnt=mask_cnt,
mask_block_idx=mask_idx,
full_block_cnt=full_cnt,
full_block_idx=full_idx,
)


def is_block_sparsity_enabled(tensors: BlockSparseTensorsPaddle) -> bool:
def is_block_sparsity_enabled(tensors: BlockSparseTensorsTorch) -> bool:
return any(t is not None for t in (tensors.full_block_cnt, tensors.mask_block_cnt))


def to_cute_block_sparse_tensors(tensors: BlockSparseTensorsPaddle) -> Optional[BlockSparseTensors]:
def to_cute_block_sparse_tensors(tensors: BlockSparseTensorsTorch) -> Optional[BlockSparseTensors]:
if not is_block_sparsity_enabled(tensors):
return None

mask_block_cnt_tensor = from_dlpack(
tensors.mask_block_cnt.detach(), assumed_align=4
).mark_layout_dynamic(leading_dim=2)
mask_block_idx_tensor = from_dlpack(
tensors.mask_block_idx.detach(), assumed_align=4
).mark_layout_dynamic(leading_dim=3)
full_block_cnt_tensor = (
from_dlpack(tensors.full_block_cnt.detach(), assumed_align=4).mark_layout_dynamic(
leading_dim=2
)
if tensors.full_block_cnt is not None
else None
)
full_block_idx_tensor = (
from_dlpack(tensors.full_block_idx.detach(), assumed_align=4).mark_layout_dynamic(
leading_dim=3
)
if tensors.full_block_idx is not None
else None
)
def _wrap(t, dim):
if t is None:
return None
return from_dlpack(t.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=dim)

mask_block_cnt_tensor = _wrap(tensors.mask_block_cnt, 2)
mask_block_idx_tensor = _wrap(tensors.mask_block_idx, 3)

full_block_cnt_tensor = _wrap(tensors.full_block_cnt, 2)
full_block_idx_tensor = _wrap(tensors.full_block_idx, 3)

return BlockSparseTensors(
mask_block_cnt_tensor,
Expand All @@ -156,14 +144,14 @@ def compute_block_sparsity(
config: Config,
mask_mod_flex: Optional[Callable],
device: str,
cu_seqlens_q: Optional[paddle.Tensor] = None,
cu_seqlens_k: Optional[paddle.Tensor] = None,
aux_tensors: Optional[List[paddle.Tensor]] = None,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_k: Optional[torch.Tensor] = None,
aux_tensors: Optional[List[torch.Tensor]] = None,
) -> Tuple[
Optional[paddle.Tensor],
Optional[paddle.Tensor],
Optional[paddle.Tensor],
Optional[paddle.Tensor],
Optional[torch.Tensor],
Optional[torch.Tensor],
Optional[torch.Tensor],
Optional[torch.Tensor],
]:
"""
Computes block sparsity tensors from a given masking function.
Expand Down Expand Up @@ -205,45 +193,45 @@ def compute_block_sparsity(


def _compute_sparsity(
config: Config, device: str, aux_tensors: Optional[List[paddle.Tensor]]
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]:
config: Config, device: str, aux_tensors: Optional[List[torch.Tensor]]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Computes block sparsity for fixed-length sequences."""
n_blocks_q = (config.seqlen_q + config.tile_m - 1) // config.tile_m
n_blocks_k = (config.seqlen_k + config.tile_n - 1) // config.tile_n

# Pre-allocate output tensors
full_block_cnt = paddle.zeros(
(config.batch_size, config.nheads, n_blocks_q), dtype=paddle.int32
full_block_cnt = torch.zeros(
(config.batch_size, config.nheads, n_blocks_q), dtype=torch.int32, device=device
)
mask_block_cnt = paddle.zeros(
(config.batch_size, config.nheads, n_blocks_q), dtype=paddle.int32
mask_block_cnt = torch.zeros(
(config.batch_size, config.nheads, n_blocks_q), dtype=torch.int32, device=device
)
full_block_idx = paddle.zeros(
(config.batch_size, config.nheads, n_blocks_q, n_blocks_k), dtype=paddle.int32
full_block_idx = torch.zeros(
(config.batch_size, config.nheads, n_blocks_q, n_blocks_k), dtype=torch.int32, device=device
)
mask_block_idx = paddle.zeros(
(config.batch_size, config.nheads, n_blocks_q, n_blocks_k), dtype=paddle.int32
mask_block_idx = torch.zeros(
(config.batch_size, config.nheads, n_blocks_q, n_blocks_k), dtype=torch.int32, device=device
)

# --- Identity Mask ---
# All blocks are fully computed.
if config.mask_mod_name == "identity":
k_blocks = paddle.arange(n_blocks_k)
k_blocks = torch.arange(n_blocks_k, dtype=torch.int32, device=device)
for q_block_idx in range(n_blocks_q):
full_block_cnt[:, :, q_block_idx] = n_blocks_k
full_block_idx[:, :, q_block_idx, :n_blocks_k] = k_blocks

# --- Identity Partial Mask ---
# All blocks are partially computed (masked).
elif config.mask_mod_name == "identity_partial":
k_blocks = paddle.arange(n_blocks_k)
k_blocks = torch.arange(n_blocks_k, dtype=torch.int32, device=device)
for q_block_idx in range(n_blocks_q):
mask_block_cnt[:, :, q_block_idx] = n_blocks_k
mask_block_idx[:, :, q_block_idx, :n_blocks_k] = k_blocks

# --- Block Causal Mask ---
elif config.mask_mod_name == "block_causal":
k_blocks = paddle.arange(n_blocks_k)
k_blocks = torch.arange(n_blocks_k, dtype=torch.int32, device=device)
for q_block_idx in range(n_blocks_q):
causal_indices = k_blocks[k_blocks <= q_block_idx]
num_causal_indices = len(causal_indices)
Expand All @@ -253,16 +241,18 @@ def _compute_sparsity(

# --- Causal and Sliding Window Masks ---
elif config.mask_mod_name in ["causal", "sliding_window"]:
q_block_indices = paddle.arange(n_blocks_q)
k_block_indices = paddle.arange(n_blocks_k)
q_block_indices = torch.arange(n_blocks_q, dtype=torch.int32, device=device)
k_block_indices = torch.arange(n_blocks_k, dtype=torch.int32, device=device)

q_starts = q_block_indices * config.tile_m
q_ends = paddle.minimum(
(q_block_indices + 1) * config.tile_m, paddle.to_tensor(config.seqlen_q)
q_ends = torch.minimum(
(q_block_indices + 1) * config.tile_m,
torch.tensor(config.seqlen_q, dtype=torch.int32, device=device)
)
k_starts = k_block_indices * config.tile_n
k_ends = paddle.minimum(
(k_block_indices + 1) * config.tile_n, paddle.to_tensor(config.seqlen_k)
k_ends = torch.minimum(
(k_block_indices + 1) * config.tile_n,
torch.tensor(config.seqlen_k, dtype=torch.int32, device=device)
)

# Expand dims for broadcasting: (n_blocks_q, 1) and (1, n_blocks_k)
Expand Down Expand Up @@ -315,9 +305,9 @@ def _compute_varlen_sparsity(
config: Config,
mask_mod_flex: Callable,
device: str,
cu_seqlens_q: paddle.Tensor,
cu_seqlens_k: paddle.Tensor,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]:
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Computes block sparsity for variable-length sequences."""
assert cu_seqlens_k is not None, "cu_seqlens_k is required for varlen attention"
assert cu_seqlens_q.shape[0] == config.batch_size + 1
Expand All @@ -336,19 +326,19 @@ def _compute_varlen_sparsity(
max_n_blocks = (total_k_len + config.tile_n - 1) // config.tile_n

# Pre-allocate padded output tensors
full_block_cnt = paddle.zeros(
(config.batch_size, config.nheads, max_m_blocks), dtype=paddle.int32
full_block_cnt = torch.zeros(
(config.batch_size, config.nheads, max_m_blocks), dtype=torch.int32, device=device
)
mask_block_cnt = paddle.zeros(
(config.batch_size, config.nheads, max_m_blocks), dtype=paddle.int32
mask_block_cnt = torch.zeros(
(config.batch_size, config.nheads, max_m_blocks), dtype=torch.int32, device=device
)
full_block_idx = paddle.zeros(
full_block_idx = torch.zeros(
(config.batch_size, config.nheads, max_m_blocks, max_n_blocks),
dtype=paddle.int32,
dtype=torch.int32, device=device
)
mask_block_idx = paddle.zeros(
mask_block_idx = torch.zeros(
(config.batch_size, config.nheads, max_m_blocks, max_n_blocks),
dtype=paddle.int32,
dtype=torch.int32, device=device
)

# Process each sequence in the batch individually
Expand Down Expand Up @@ -495,13 +485,13 @@ def _compute_causal_varlen_blocks(

if full_blocks:
full_block_cnt[seq_idx, :, m_local] = len(full_blocks)
full_block_idx[seq_idx, :, m_local, : len(full_blocks)] = paddle.to_tensor(
full_blocks,
full_block_idx[seq_idx, :, m_local, : len(full_blocks)] = torch.tensor(
full_blocks, dtype=torch.int32, device=device
)
if partial_blocks:
mask_block_cnt[seq_idx, :, m_local] = len(partial_blocks)
mask_block_idx[seq_idx, :, m_local, : len(partial_blocks)] = paddle.tensor(
partial_blocks,
mask_block_idx[seq_idx, :, m_local, : len(partial_blocks)] = torch.tensor(
partial_blocks, dtype=torch.int32, device=device
)


Expand Down Expand Up @@ -555,13 +545,13 @@ def _compute_sliding_window_varlen_blocks(

if full_blocks:
full_block_cnt[seq_idx, :, m_local] = len(full_blocks)
full_block_idx[seq_idx, :, m_local, : len(full_blocks)] = paddle.to_tensor(
full_blocks,
full_block_idx[seq_idx, :, m_local, : len(full_blocks)] = torch.tensor(
full_blocks, dtype=torch.int32, device=device
)
if partial_blocks:
mask_block_cnt[seq_idx, :, m_local] = len(partial_blocks)
mask_block_idx[seq_idx, :, m_local, : len(partial_blocks)] = paddle.to_tensor(
partial_blocks,
mask_block_idx[seq_idx, :, m_local, : len(partial_blocks)] = torch.tensor(
partial_blocks, dtype=torch.int32, device=device
)


Expand All @@ -576,8 +566,8 @@ def _compute_identity_varlen_blocks(
**kwargs,
):
"""Computes identity (all-attend) block sparsity for a single varlen sequence."""
n_blocks_global = paddle.arange(
first_n_block_global, first_n_block_global + n_blocks_k, dtype=paddle.int32
n_blocks_global = torch.arange(
first_n_block_global, first_n_block_global + n_blocks_k, dtype=torch.int32, device=device
)
for m_local in range(n_blocks_q):
full_block_cnt[seq_idx, :, m_local] = n_blocks_k
Expand Down Expand Up @@ -641,11 +631,11 @@ def _compute_generic_varlen_blocks(

if full_blocks:
full_block_cnt[seq_idx, h_q, m_local] = len(full_blocks)
full_block_idx[seq_idx, h_q, m_local, : len(full_blocks)] = paddle.to_tensor(
full_blocks,
full_block_idx[seq_idx, h_q, m_local, : len(full_blocks)] = torch.tensor(
full_blocks, dtype=torch.int32, device=device
)
if partial_blocks:
mask_block_cnt[seq_idx, h_q, m_local] = len(partial_blocks)
mask_block_idx[seq_idx, h_q, m_local, : len(partial_blocks)] = paddle.to_tensor(
partial_blocks,
mask_block_idx[seq_idx, h_q, m_local, : len(partial_blocks)] = torch.tensor(
partial_blocks, dtype=torch.int32, device=device
)
Loading