diff --git a/analyze/utils.py b/analyze/utils.py index f78be1f6..7f90d834 100644 --- a/analyze/utils.py +++ b/analyze/utils.py @@ -374,7 +374,7 @@ def attnmap_mamba(regs, mode="CB", ret="all", absnorm=0, scale=1, verbose=False, mask = torch.tril(dts.new_ones((L, L))) dts = torch.nn.functional.softplus(dts + delta_bias[:, None]).view(B, G, D, L) - dw_logs = As.view(G, D, N)[None, :, :, None] * dts[:,:,:,None,:] # (B, G, D, N, L) + dw_logs = As.view(G, D, N)[None, :, :][:,:,:,:,None] * dts[:,:,:,None,:] # (B, G, D, N, L) ws = torch.cumsum(dw_logs, dim=-1).exp() if mode == "CB":