Skip to content

Matmul Flops #108

@breuera

Description

@breuera

The matmul flop counts seem to be off by 2x.

I tested the code on a simple MLP which reads as:

import torch.nn

## @package eml.mlp.Module
#  Simple MultiLayer Perceptron (MLP) with fixed dimensions.
#
#  The MLP is assumes a 28^2 input-image and 10 output classes.
#  These are the dimensions of the Fashion MNIST dataset.
class Model( torch.nn.Module ):
  ## Initializes the class.
  #  @param self object pointer.
  def __init__( self ):
    super( Model, self ).__init__()
    ## flattens the input
    self.m_flatten = torch.nn.Flatten()
    ## layers of the MLP: 3x(linear + relu)
    self.m_layers = torch.nn.Sequential( torch.nn.Linear( 28*28, 512 ),
                                         torch.nn.ReLU(),
                                         torch.nn.Linear( 512, 512 ),
                                         torch.nn.ReLU(),
                                         torch.nn.Linear( 512, 10 ) )

  ## Forward pass with the given input.
  #  @param self object pointer.
  #  @param i_input input for the forward pass.
  #  @return output of the MLP.
  def forward( self,
               i_input ):
    l_flatten = self.m_flatten( i_input )
    l_result = self.m_layers( l_flatten )
    return l_result

Embedded this in some code with the crucial piece here:

l_model = eml.mlp.model.Model()
[...]
print( l_model )

#
# flop count code
# https://github.com/facebookresearch/fvcore/blob/main/docs/flop_count.md
#
import fvcore.nn

l_x, l_y = next(iter(l_data_loader_train))

print( l_x.size() )

l_flops = fvcore.nn.FlopCountAnalysis( l_model,
                                       l_x )

print( l_flops.by_module_and_operator() )

print( fvcore.nn.flop_count_table( l_flops ) )

This returns:

Model(
  (m_flatten): Flatten(start_dim=1, end_dim=-1)
  (m_layers): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=10, bias=True)
  )
)
torch.Size([64, 1, 28, 28])
{'': Counter({'addmm': 42795008}), 'm_flatten': Counter(), 'm_layers': Counter({'addmm': 42795008}), 'm_layers.0': Counter({'addmm': 25690112}), 'm_layers.1': Counter(), 'm_layers.2': Counter({'addmm': 16777216}), 'm_layers.3': Counter(), 'm_layers.4': Counter({'addmm': 327680})}
| module     | #parameters or shape   | #flops   |
|:-----------|:-----------------------|:---------|
| m_layers   | 0.67M                  | 42.795M  |
|  0         |  0.402M                |  25.69M  |
|   0.weight |   (512, 784)           |          |
|   0.bias   |   (512,)               |          |
|  2         |  0.263M                |  16.777M |
|   2.weight |   (512, 512)           |          |
|   2.bias   |   (512,)               |          |
|  4         |  5.13K                 |  0.328M  |
|   4.weight |   (10, 512)            |          |
|   4.bias   |   (10,)                |          |

Let's take the first linear layer as an example: Matrix A in https://pytorch.org/docs/stable/generated/torch.nn.Linear.html has shape (512, 784).
Matrix x (since the example batched) has shape (64, 784).
Computing the result, C=xA^T requires 2*64*512*784 - 64*512 floating point operations.
However, in the example a bias is used, i.e., 64*512 additions on top -> 2*64*512*784=513,80,224 flops total; the tool reports 25,690,112 for the first layer. btw: I am not sure why the bias doesn't show up separately.

I believe that the code below is off since the number of ops of the op C+=AB using BLAS identifiers is 2*M*N*K not M*N*K:

flop = prod(input_shapes[0]) * input_shapes[-1][-1]

Metadata

Metadata

Assignees

No one assigned

    Labels

    questionFurther information is requested

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions