From 421adce52df859877c52c9fa2d4c9491eec57282 Mon Sep 17 00:00:00 2001 From: Chet <282654568@qq.com> Date: Thu, 6 Jun 2024 15:57:30 +0800 Subject: [PATCH] Update utils.py The vanilla version want to broadcast the **L** dim, but actually broadcast the **N** dim --- analyze/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/analyze/utils.py b/analyze/utils.py index f78be1f6..7f90d834 100644 --- a/analyze/utils.py +++ b/analyze/utils.py @@ -374,7 +374,7 @@ def attnmap_mamba(regs, mode="CB", ret="all", absnorm=0, scale=1, verbose=False, mask = torch.tril(dts.new_ones((L, L))) dts = torch.nn.functional.softplus(dts + delta_bias[:, None]).view(B, G, D, L) - dw_logs = As.view(G, D, N)[None, :, :, None] * dts[:,:,:,None,:] # (B, G, D, N, L) + dw_logs = As.view(G, D, N)[None, :, :][:,:,:,:,None] * dts[:,:,:,None,:] # (B, G, D, N, L) ws = torch.cumsum(dw_logs, dim=-1).exp() if mode == "CB":