Skip to content

FlopCountAnalysis issues #153

@cavalleria

Description

@cavalleria

i use bellow code to test mha flops. but it's same flops and params when nhead=4 or 8

import torch
import torch.nn as nn


class MHAModel(nn.Module):
    def __init__(self, dim, nhead, dropout):
        super(MHAModel, self).__init__()

        self.mha = nn.MultiheadAttention(dim_out, nhead, dropout=dropout, batch_first=True)

    def forward(self, x):
        x = self.mha(x, x, x)[0]
        return x


from fvcore.nn import FlopCountAnalysis, flop_count_table

dim_out = 448
seq_len = 300
nhead = 4
dropout = 0.1

net = MHAModel(dim=dim_out, nhead=nhead, dropout=dropout)
net.eval()
data = torch.randn((1, seq_len, dim_out))
flops = FlopCountAnalysis(net, (data))
print(flop_count_table(flops, max_depth=4))

image

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions