diff --git a/flash_attn/flash_attn_interface.py b/flash_attn/flash_attn_interface.py index a53b4a3108a..818c013bbdd 100644 --- a/flash_attn/flash_attn_interface.py +++ b/flash_attn/flash_attn_interface.py @@ -921,6 +921,7 @@ def forward( return_softmax, block_table, is_grad_enabled, + zero_tensors: bool = False, ): is_grad = is_grad_enabled and any( x.requires_grad for x in [q, k, v] @@ -949,6 +950,7 @@ def forward( alibi_slopes=alibi_slopes, return_softmax=return_softmax and dropout_p > 0, block_table=block_table, + zero_tensors = zero_tensors, ) if is_grad: ctx.save_for_backward( @@ -963,6 +965,7 @@ def forward( ctx.softcap = softcap ctx.alibi_slopes = alibi_slopes ctx.deterministic = deterministic + ctx.zero_tensors = zero_tensors out = out_padded[..., :head_size_og] return out if not return_softmax else (out, softmax_lse, S_dmask) @@ -998,11 +1001,12 @@ def backward(ctx, dout, *args): ctx.alibi_slopes, ctx.deterministic, rng_state=rng_state, + zero_tensors = ctx.zero_tensors ) dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dk = dk[..., : dout.shape[-1]] dv = dv[..., : dout.shape[-1]] - return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None + return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None def flash_attn_qkvpacked_func( @@ -1394,6 +1398,7 @@ def flash_attn_varlen_func( deterministic=False, return_attn_probs=False, block_table=None, + zero_tensors = False, ): """dropout_p should be set to 0.0 during evaluation Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads @@ -1468,6 +1473,7 @@ def flash_attn_varlen_func( return_attn_probs, block_table, torch.is_grad_enabled(), + zero_tensors, )