Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion flash_attn/flash_attn_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -921,6 +921,7 @@ def forward(
return_softmax,
block_table,
is_grad_enabled,
zero_tensors: bool = False,
):
is_grad = is_grad_enabled and any(
x.requires_grad for x in [q, k, v]
Expand Down Expand Up @@ -949,6 +950,7 @@ def forward(
alibi_slopes=alibi_slopes,
return_softmax=return_softmax and dropout_p > 0,
block_table=block_table,
zero_tensors = zero_tensors,
)
if is_grad:
ctx.save_for_backward(
Expand All @@ -963,6 +965,7 @@ def forward(
ctx.softcap = softcap
ctx.alibi_slopes = alibi_slopes
ctx.deterministic = deterministic
ctx.zero_tensors = zero_tensors

out = out_padded[..., :head_size_og]
return out if not return_softmax else (out, softmax_lse, S_dmask)
Expand Down Expand Up @@ -998,11 +1001,12 @@ def backward(ctx, dout, *args):
ctx.alibi_slopes,
ctx.deterministic,
rng_state=rng_state,
zero_tensors = ctx.zero_tensors
)
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dk = dk[..., : dout.shape[-1]]
dv = dv[..., : dout.shape[-1]]
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None


def flash_attn_qkvpacked_func(
Expand Down Expand Up @@ -1394,6 +1398,7 @@ def flash_attn_varlen_func(
deterministic=False,
return_attn_probs=False,
block_table=None,
zero_tensors = False,
):
"""dropout_p should be set to 0.0 during evaluation
Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
Expand Down Expand Up @@ -1468,6 +1473,7 @@ def flash_attn_varlen_func(
return_attn_probs,
block_table,
torch.is_grad_enabled(),
zero_tensors,
)


Expand Down
Loading