Skip to content

Does lightning attention cause improper parallelization pthreads in memory? #16

@Leopold2333

Description

@Leopold2333

I try to replace the Attention layer in a typical Transformer with lightning attention, however, when I run the code, some unexpected error happens.

Here is what I test to do:

  1. The original code snippets:
class FullAttention(nn.Module):
    def __init__(self, mask_flag=True, scale=None, attention_dropout=0.1, output_attention=False):
        super(FullAttention, self).__init__()
        self.scale = scale
        self.mask_flag = mask_flag
        self.output_attention = output_attention
        self.dropout = nn.Dropout(attention_dropout)

    def forward(self, queries, keys, values, attn_mask, n_vars=None, n_tokens=None, tau=None, delta=None):
        B, L, H, E = queries.shape
        _, S, _, D = values.shape
        scale = self.scale or 1. / sqrt(E)

        scores = torch.einsum("blhe,bshe->bhls", queries, keys)

        if self.mask_flag:
            if attn_mask is None:
                attn_mask = TriangularCausalMask(B, L, device=queries.device)

            scores.masked_fill_(attn_mask.mask, -np.inf)

        A = self.dropout(torch.softmax(scale * scores, dim=-1))
        V = torch.einsum("bhls,bshd->blhd", A, values)

        if self.output_attention:
            return V.contiguous(), A
        else:
            return V.contiguous(), None
  1. The replaced code snippets:
class LightningAttention(nn.Module):
    def __init__(self, mask_flag=True, scale=None, attention_dropout=0.1, output_attention=False):
        super(LightningAttention, self).__init__()
        self.output_attention = output_attention
        self.s = _build_slope_tensor(self.n_heads).to(torch.float32)
        self.s.requires_grad = False

    def forward(self, queries, keys, values, attn_mask, n_vars=None, n_tokens=None, tau=None, delta=None):
        B, L, H, E = queries.shape
        _, S, _, D = values.shape
        queries = queries.permute(0, 2, 1, 3)
        keys = keys.permute(0, 2, 1, 3)
        values = values.permute(0, 2, 1, 3)
        V = lightning_attn_func(queries, keys, values, self.s.to(queries.device))

        return V.contiguous(), None

I intended to simply test whether lightning attention could process data in a causal manner. However, after the replacement, although the model could normally complete the training process in one epoch, it will encounter segment faults or LLVM ERROR: pthread_join failed: Invalid argument on the valid Dataloader. Specifically, the error would happen at such a line:

for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(vali_loader):

And all the above would not happen with the original FullAttention layer.

I search for the solutions, but get confused. Some recommand me set num_workers to 0 and persistent_workers to False to check whether it's because of errors related to parallelization. And, well, it works. I have noticed that lightning attention uses thread block operations via triton, do these steps cause improper resource competition when multiple Dataloaders work together? (I don't know much about this, but from the current situation I can only solve this issue by canceling parallel data loading, which greatly slows down the data loading speed T_T)

Could the author help figure out where the problem might be?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions