diff --git a/lightx2v/common/ops/attn/ulysses_attn.py b/lightx2v/common/ops/attn/ulysses_attn.py index 125267cc..6d473eaf 100755 --- a/lightx2v/common/ops/attn/ulysses_attn.py +++ b/lightx2v/common/ops/attn/ulysses_attn.py @@ -26,6 +26,7 @@ def apply( attention_type="flash_attn2", seq_p_group=None, use_fp8_comm=False, + use_tensor_fusion=False, enable_head_parallel=False, img_first=True, **kwargs, @@ -44,6 +45,8 @@ def apply( 返回: torch.Tensor: 计算得到的注意力结果 """ + use_qkv_fusion = use_tensor_fusion + if len(q.shape) == 4: q = q.reshape(-1, q.shape[-2], q.shape[-1]) k = k.reshape(-1, k.shape[-2], k.shape[-1]) @@ -94,48 +97,125 @@ def apply( txt_q, txt_k, txt_v = q[:txt_qkv_len, :, :].contiguous(), k[:txt_qkv_len, :, :].contiguous(), v[:txt_qkv_len, :, :].contiguous() img_q, img_k, img_v = q[txt_qkv_len:, :, :].contiguous(), k[txt_qkv_len:, :, :].contiguous(), v[txt_qkv_len:, :, :].contiguous() - img_qkv = torch.stack([img_q, img_k, img_v], dim=0).reshape(3, img_qkv_len, world_size, shard_heads, hidden_dims) - original_dtype = img_qkv.dtype + if use_qkv_fusion: + img_qkv = torch.stack([img_q, img_k, img_v], dim=0).reshape(3, img_qkv_len, world_size, shard_heads, hidden_dims) + original_dtype = img_qkv.dtype + else: + img_q = img_q.reshape(img_qkv_len, world_size, shard_heads, hidden_dims) + img_k = img_k.reshape(img_qkv_len, world_size, shard_heads, hidden_dims) + img_v = img_v.reshape(img_qkv_len, world_size, shard_heads, hidden_dims) + original_dtype = img_q.dtype if enable_head_parallel: - img_qkv = img_qkv.permute(3, 2, 1, 0, 4).contiguous() # (shard_heads, world_size, img_qkv_len, 3, hidden_dims) - output_qkv = torch.empty_like(img_qkv) + if use_qkv_fusion: + img_qkv = img_qkv.permute(3, 2, 1, 0, 4).contiguous() # (shard_heads, world_size, img_qkv_len, 3, hidden_dims) + output_qkv = torch.empty_like(img_qkv) + else: + img_q = img_q.permute(2, 1, 0, 3).contiguous() # (shard_heads, world_size, img_qkv_len, hidden_dims) + img_k = img_k.permute(2, 1, 0, 3).contiguous() + img_v = img_v.permute(2, 1, 0, 3).contiguous() + output_q = torch.empty_like(img_q) + output_k = torch.empty_like(img_k) + output_v = torch.empty_like(img_v) # 通信图像的查询、键和值 if use_fp8_comm: - img_qkv_fp8, img_qkv_scale = quant_fp8_vllm(img_qkv.reshape(-1, hidden_dims)) - img_qkv_fp8 = img_qkv_fp8.reshape(shard_heads, world_size, img_qkv_len, 3, hidden_dims) - img_qkv_scale = img_qkv_scale.reshape(shard_heads, world_size, img_qkv_len, 3, 1) - output_qkv_fp8 = torch.empty_like(img_qkv_fp8) - output_qkv_scale = torch.empty_like(img_qkv_scale) - comm_fp8_works = [] - comm_scale_works = [] - for h in range(shard_heads): - work_fp8 = dist.all_to_all_single(output_qkv_fp8[h], img_qkv_fp8[h], group=seq_p_group, async_op=True) - work_scale = dist.all_to_all_single(output_qkv_scale[h], img_qkv_scale[h], group=seq_p_group, async_op=True) - comm_fp8_works.append(work_fp8) - comm_scale_works.append(work_scale) + if use_qkv_fusion: + img_qkv_fp8, img_qkv_scale = quant_fp8_vllm(img_qkv.reshape(-1, hidden_dims)) + img_qkv_fp8 = img_qkv_fp8.reshape(shard_heads, world_size, img_qkv_len, 3, hidden_dims) + img_qkv_scale = img_qkv_scale.reshape(shard_heads, world_size, img_qkv_len, 3, 1) + output_qkv_fp8 = torch.empty_like(img_qkv_fp8) + output_qkv_scale = torch.empty_like(img_qkv_scale) + comm_fp8_works = [] + comm_scale_works = [] + for h in range(shard_heads): + work_fp8 = dist.all_to_all_single(output_qkv_fp8[h], img_qkv_fp8[h], group=seq_p_group, async_op=True) + work_scale = dist.all_to_all_single(output_qkv_scale[h], img_qkv_scale[h], group=seq_p_group, async_op=True) + comm_fp8_works.append(work_fp8) + comm_scale_works.append(work_scale) + else: + img_q_fp8, img_q_scale = quant_fp8_vllm(img_q.reshape(-1, hidden_dims)) + img_k_fp8, img_k_scale = quant_fp8_vllm(img_k.reshape(-1, hidden_dims)) + img_v_fp8, img_v_scale = quant_fp8_vllm(img_v.reshape(-1, hidden_dims)) + img_q_fp8 = img_q_fp8.reshape(shard_heads, world_size, img_qkv_len, hidden_dims) + img_k_fp8 = img_k_fp8.reshape(shard_heads, world_size, img_qkv_len, hidden_dims) + img_v_fp8 = img_v_fp8.reshape(shard_heads, world_size, img_qkv_len, hidden_dims) + img_q_scale = img_q_scale.reshape(shard_heads, world_size, img_qkv_len, 1) + img_k_scale = img_k_scale.reshape(shard_heads, world_size, img_qkv_len, 1) + img_v_scale = img_v_scale.reshape(shard_heads, world_size, img_qkv_len, 1) + output_q_fp8 = torch.empty_like(img_q_fp8) + output_k_fp8 = torch.empty_like(img_k_fp8) + output_v_fp8 = torch.empty_like(img_v_fp8) + output_q_scale = torch.empty_like(img_q_scale) + output_k_scale = torch.empty_like(img_k_scale) + output_v_scale = torch.empty_like(img_v_scale) + comm_fp8_works = [] + comm_scale_works = [] + for h in range(shard_heads): + work_q_fp8 = dist.all_to_all_single(output_q_fp8[h], img_q_fp8[h], group=seq_p_group, async_op=True) + work_k_fp8 = dist.all_to_all_single(output_k_fp8[h], img_k_fp8[h], group=seq_p_group, async_op=True) + work_v_fp8 = dist.all_to_all_single(output_v_fp8[h], img_v_fp8[h], group=seq_p_group, async_op=True) + work_q_scale = dist.all_to_all_single(output_q_scale[h], img_q_scale[h], group=seq_p_group, async_op=True) + work_k_scale = dist.all_to_all_single(output_k_scale[h], img_k_scale[h], group=seq_p_group, async_op=True) + work_v_scale = dist.all_to_all_single(output_v_scale[h], img_v_scale[h], group=seq_p_group, async_op=True) + comm_fp8_works.append(work_q_fp8) + comm_fp8_works.append(work_k_fp8) + comm_fp8_works.append(work_v_fp8) + comm_scale_works.append(work_q_scale) + comm_scale_works.append(work_k_scale) + comm_scale_works.append(work_v_scale) else: - comm_works = [] - for h in range(shard_heads): - work = dist.all_to_all_single(output_qkv[h], img_qkv[h], group=seq_p_group, async_op=True) - comm_works.append(work) + if use_qkv_fusion: + comm_works = [] + for h in range(shard_heads): + work = dist.all_to_all_single(output_qkv[h], img_qkv[h], group=seq_p_group, async_op=True) + comm_works.append(work) + else: + comm_works = [] + for h in range(shard_heads): + work_q = dist.all_to_all_single(output_q[h], img_q[h], group=seq_p_group, async_op=True) + work_k = dist.all_to_all_single(output_k[h], img_k[h], group=seq_p_group, async_op=True) + work_v = dist.all_to_all_single(output_v[h], img_v[h], group=seq_p_group, async_op=True) + comm_works.append(work_q) + comm_works.append(work_k) + comm_works.append(work_v) # 逐个head完成Attention计算 single_head = 1 head_attns = [] for h in range(shard_heads): if use_fp8_comm: - comm_fp8_works[h].wait() - comm_scale_works[h].wait() - output_qkv[h] = dequant_fp8_vllm(output_qkv_fp8[h], output_qkv_scale[h], original_dtype) + if use_qkv_fusion: + comm_fp8_works[h].wait() + comm_scale_works[h].wait() + output_qkv[h] = dequant_fp8_vllm(output_qkv_fp8[h], output_qkv_scale[h], original_dtype) + else: + comm_fp8_works[3 * h].wait() + comm_fp8_works[3 * h + 1].wait() + comm_fp8_works[3 * h + 2].wait() + comm_scale_works[3 * h].wait() + comm_scale_works[3 * h + 1].wait() + comm_scale_works[3 * h + 2].wait() + output_q[h] = dequant_fp8_vllm(output_q_fp8[h], output_q_scale[h], original_dtype) + output_k[h] = dequant_fp8_vllm(output_k_fp8[h], output_k_scale[h], original_dtype) + output_v[h] = dequant_fp8_vllm(output_v_fp8[h], output_v_scale[h], original_dtype) else: - comm_works[h].wait() - - qkv = output_qkv[h].reshape(global_img_seqlen, 3, single_head, hidden_dims).transpose(0, 1) - shard_img_q = qkv[0] # (global_img_seqlen, single_head, hidden_dims) - shard_img_k = qkv[1] - shard_img_v = qkv[2] + if use_qkv_fusion: + comm_works[h].wait() + else: + comm_works[3 * h].wait() + comm_works[3 * h + 1].wait() + comm_works[3 * h + 2].wait() + + if use_qkv_fusion: + qkv = output_qkv[h].reshape(global_img_seqlen, 3, single_head, hidden_dims).transpose(0, 1) + shard_img_q = qkv[0] # (global_img_seqlen, single_head, hidden_dims) + shard_img_k = qkv[1] + shard_img_v = qkv[2] + else: + shard_img_q = output_q[h].reshape(global_img_seqlen, single_head, hidden_dims) + shard_img_k = output_k[h].reshape(global_img_seqlen, single_head, hidden_dims) + shard_img_v = output_v[h].reshape(global_img_seqlen, single_head, hidden_dims) # 处理文本的查询、键和值,选择当前进程的当前头 shard_txt_q = txt_q[:, (cur_rank * shard_heads + h) : (cur_rank * shard_heads + h + 1), :] @@ -160,27 +240,71 @@ def apply( attn = torch.cat(head_attns, dim=1) else: - img_qkv = img_qkv.permute(2, 1, 0, 3, 4).contiguous() # (world_size, img_qkv_len, 3, shard_heads, hidden_dims) + if use_qkv_fusion: + img_qkv = img_qkv.permute(2, 1, 0, 3, 4).contiguous() # (world_size, img_qkv_len, 3, shard_heads, hidden_dims) + else: + img_q = img_q.permute(1, 0, 2, 3).contiguous() # (world_size, img_q_len, shard_heads, hidden_dims) + img_k = img_k.permute(1, 0, 2, 3).contiguous() # (world_size, img_k_len, shard_heads, hidden_dims) + img_v = img_v.permute(1, 0, 2, 3).contiguous() # (world_size, img_v_len, shard_heads, hidden_dims) # 通信图像的查询、键和值 if use_fp8_comm: - img_qkv_fp8, img_qkv_scale = quant_fp8_vllm(img_qkv.reshape(-1, hidden_dims)) - img_qkv_fp8 = img_qkv_fp8.reshape(world_size, img_qkv_len, shard_heads, 3, hidden_dims) - img_qkv_scale = img_qkv_scale.reshape(world_size, img_qkv_len, shard_heads, 3, 1) - output_qkv_fp8 = torch.empty_like(img_qkv_fp8) - output_qkv_scale = torch.empty_like(img_qkv_scale) - dist.all_to_all_single(output_qkv_fp8, img_qkv_fp8, group=seq_p_group) - dist.all_to_all_single(output_qkv_scale, img_qkv_scale, group=seq_p_group) - output_qkv = dequant_fp8_vllm(output_qkv_fp8, output_qkv_scale, original_dtype) + if use_qkv_fusion: + img_qkv_fp8, img_qkv_scale = quant_fp8_vllm(img_qkv.reshape(-1, hidden_dims)) + img_qkv_fp8 = img_qkv_fp8.reshape(world_size, img_qkv_len, shard_heads, 3, hidden_dims) + img_qkv_scale = img_qkv_scale.reshape(world_size, img_qkv_len, shard_heads, 3, 1) + output_qkv_fp8 = torch.empty_like(img_qkv_fp8) + output_qkv_scale = torch.empty_like(img_qkv_scale) + dist.all_to_all_single(output_qkv_fp8, img_qkv_fp8, group=seq_p_group) + dist.all_to_all_single(output_qkv_scale, img_qkv_scale, group=seq_p_group) + output_qkv = dequant_fp8_vllm(output_qkv_fp8, output_qkv_scale, original_dtype) + else: + img_q_fp8, img_q_scale = quant_fp8_vllm(img_q.reshape(-1, hidden_dims)) + img_k_fp8, img_k_scale = quant_fp8_vllm(img_k.reshape(-1, hidden_dims)) + img_v_fp8, img_v_scale = quant_fp8_vllm(img_v.reshape(-1, hidden_dims)) + img_q_fp8 = img_q_fp8.reshape(world_size, img_qkv_len, shard_heads, hidden_dims) + img_k_fp8 = img_k_fp8.reshape(world_size, img_qkv_len, shard_heads, hidden_dims) + img_v_fp8 = img_v_fp8.reshape(world_size, img_qkv_len, shard_heads, hidden_dims) + img_q_scale = img_q_scale.reshape(world_size, img_qkv_len, shard_heads, 1) + img_k_scale = img_k_scale.reshape(world_size, img_qkv_len, shard_heads, 1) + img_v_scale = img_v_scale.reshape(world_size, img_qkv_len, shard_heads, 1) + output_q_fp8 = torch.empty_like(img_q_fp8) + output_k_fp8 = torch.empty_like(img_k_fp8) + output_v_fp8 = torch.empty_like(img_v_fp8) + output_q_scale = torch.empty_like(img_q_scale) + output_k_scale = torch.empty_like(img_k_scale) + output_v_scale = torch.empty_like(img_v_scale) + dist.all_to_all_single(output_q_fp8, img_q_fp8, group=seq_p_group) + dist.all_to_all_single(output_k_fp8, img_k_fp8, group=seq_p_group) + dist.all_to_all_single(output_v_fp8, img_v_fp8, group=seq_p_group) + dist.all_to_all_single(output_q_scale, img_q_scale, group=seq_p_group) + dist.all_to_all_single(output_k_scale, img_k_scale, group=seq_p_group) + dist.all_to_all_single(output_v_scale, img_v_scale, group=seq_p_group) + output_q = dequant_fp8_vllm(output_q_fp8, output_q_scale, original_dtype) + output_k = dequant_fp8_vllm(output_k_fp8, output_k_scale, original_dtype) + output_v = dequant_fp8_vllm(output_v_fp8, output_v_scale, original_dtype) else: - output_qkv = torch.empty_like(img_qkv) - dist.all_to_all_single(output_qkv, img_qkv, group=seq_p_group) + if use_qkv_fusion: + output_qkv = torch.empty_like(img_qkv) + dist.all_to_all_single(output_qkv, img_qkv, group=seq_p_group) + else: + output_q = torch.empty_like(img_q) + output_k = torch.empty_like(img_k) + output_v = torch.empty_like(img_v) + dist.all_to_all_single(output_q, img_q, group=seq_p_group) + dist.all_to_all_single(output_k, img_k, group=seq_p_group) + dist.all_to_all_single(output_v, img_v, group=seq_p_group) # 完成Attention计算 - qkv = output_qkv.reshape(global_img_seqlen, 3, shard_heads, hidden_dims).transpose(0, 1) - shard_img_q = qkv[0] # (global_img_seqlen, shard_head, hidden_dims) - shard_img_k = qkv[1] - shard_img_v = qkv[2] + if use_qkv_fusion: + qkv = output_qkv.reshape(global_img_seqlen, 3, shard_heads, hidden_dims).transpose(0, 1) + shard_img_q = qkv[0] # (global_img_seqlen, shard_head, hidden_dims) + shard_img_k = qkv[1] + shard_img_v = qkv[2] + else: + shard_img_q = output_q.reshape(global_img_seqlen, shard_heads, hidden_dims) + shard_img_k = output_k.reshape(global_img_seqlen, shard_heads, hidden_dims) + shard_img_v = output_v.reshape(global_img_seqlen, shard_heads, hidden_dims) # 处理文本的查询、键和值,选择当前进程的当前头 shard_txt_q = txt_q[:, cur_rank * shard_heads : (cur_rank + 1) * shard_heads, :]