I try to replace the Attention layer in a typical Transformer with lightning attention, however, when I run the code, some unexpected error happens.
Here is what I test to do:
- The original code snippets:
class FullAttention(nn.Module):
def __init__(self, mask_flag=True, scale=None, attention_dropout=0.1, output_attention=False):
super(FullAttention, self).__init__()
self.scale = scale
self.mask_flag = mask_flag
self.output_attention = output_attention
self.dropout = nn.Dropout(attention_dropout)
def forward(self, queries, keys, values, attn_mask, n_vars=None, n_tokens=None, tau=None, delta=None):
B, L, H, E = queries.shape
_, S, _, D = values.shape
scale = self.scale or 1. / sqrt(E)
scores = torch.einsum("blhe,bshe->bhls", queries, keys)
if self.mask_flag:
if attn_mask is None:
attn_mask = TriangularCausalMask(B, L, device=queries.device)
scores.masked_fill_(attn_mask.mask, -np.inf)
A = self.dropout(torch.softmax(scale * scores, dim=-1))
V = torch.einsum("bhls,bshd->blhd", A, values)
if self.output_attention:
return V.contiguous(), A
else:
return V.contiguous(), None
- The replaced code snippets:
class LightningAttention(nn.Module):
def __init__(self, mask_flag=True, scale=None, attention_dropout=0.1, output_attention=False):
super(LightningAttention, self).__init__()
self.output_attention = output_attention
self.s = _build_slope_tensor(self.n_heads).to(torch.float32)
self.s.requires_grad = False
def forward(self, queries, keys, values, attn_mask, n_vars=None, n_tokens=None, tau=None, delta=None):
B, L, H, E = queries.shape
_, S, _, D = values.shape
queries = queries.permute(0, 2, 1, 3)
keys = keys.permute(0, 2, 1, 3)
values = values.permute(0, 2, 1, 3)
V = lightning_attn_func(queries, keys, values, self.s.to(queries.device))
return V.contiguous(), None
I intended to simply test whether lightning attention could process data in a causal manner. However, after the replacement, although the model could normally complete the training process in one epoch, it will encounter segment faults or LLVM ERROR: pthread_join failed: Invalid argument on the valid Dataloader. Specifically, the error would happen at such a line:
for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(vali_loader):
And all the above would not happen with the original FullAttention layer.
I search for the solutions, but get confused. Some recommand me set num_workers to 0 and persistent_workers to False to check whether it's because of errors related to parallelization. And, well, it works. I have noticed that lightning attention uses thread block operations via triton, do these steps cause improper resource competition when multiple Dataloaders work together? (I don't know much about this, but from the current situation I can only solve this issue by canceling parallel data loading, which greatly slows down the data loading speed T_T)
Could the author help figure out where the problem might be?
I try to replace the Attention layer in a typical Transformer with lightning attention, however, when I run the code, some unexpected error happens.
Here is what I test to do:
I intended to simply test whether lightning attention could process data in a causal manner. However, after the replacement, although the model could normally complete the training process in one epoch, it will encounter
segment faultsorLLVM ERROR: pthread_join failed: Invalid argumenton the valid Dataloader. Specifically, the error would happen at such a line:And all the above would not happen with the original
FullAttentionlayer.I search for the solutions, but get confused. Some recommand me set
num_workersto0andpersistent_workerstoFalseto check whether it's because of errors related to parallelization. And, well, it works. I have noticed that lightning attention uses thread block operations viatriton, do these steps cause improper resource competition when multiple Dataloaders work together? (I don't know much about this, but from the current situation I can only solve this issue by canceling parallel data loading, which greatly slows down the data loading speed T_T)Could the author help figure out where the problem might be?