@@ -293,14 +293,17 @@ def _cast(x):
293293 "aten::conv2d" : conv_flop_jit ,
294294 "aten::_convolution" : conv_flop_jit ,
295295 "aten::cudnn_convolution" : conv_flop_jit ,
296- # element-wise ops
296+ # element-wise ops (out-of-place)
297297 "aten::add" : elemwise_flop_jit ,
298+ "aten::add_" : elemwise_flop_jit ,
298299 "aten::mul" : elemwise_flop_jit ,
299300 "aten::div" : elemwise_flop_jit ,
300301 "aten::abs" : elemwise_flop_jit ,
301302 "aten::reciprocal" : elemwise_flop_jit ,
302303 "aten::round" : elemwise_flop_jit ,
303- "aten::leaky_relu" : elemwise_flop_jit ,
304+ "aten::leaky_relu" : elemwise_flop_jit ,
305+ # pooling
306+ "aten::max_pool2d" : max_pool2d_flop_jit ,
304307 })
305308 total_flops = flops .total ()
306309
@@ -557,4 +560,66 @@ def forward(self, x_dummy: torch.Tensor) -> torch.Tensor:
557560def elemwise_flop_jit (inputs , outputs ):
558561 # outputs can be Tensor or tuple/list of Tensors
559562 out = outputs [0 ] if isinstance (outputs , (tuple , list )) else outputs
560- return prod (get_shape (out )) # 1 flop per output element (approx.)
563+ return prod (get_shape (out )) # 1 flop per output element (approx.)
564+
565+ def max_pool2d_flop_jit (inputs , outputs ):
566+ """
567+ Approximate FLOPs for max_pool2d.
568+
569+ Convention:
570+ - For each output element, max-pool performs (kH*kW - 1) comparisons.
571+ - We count comparisons as 1 FLOP each (approx).
572+ """
573+ # aten::max_pool2d signature (typical):
574+ # inputs = [x, kernel_size, stride, padding, dilation, ceil_mode]
575+ x = inputs [0 ]
576+ y = outputs [0 ]
577+
578+ out_numel = _value_numel (y )
579+ if out_numel == 0 :
580+ return 0
581+
582+ k = _to_ivalue (inputs [1 ], default = None ) # could be int or (kH,kW) or list
583+ if isinstance (k , int ):
584+ kH , kW = k , k
585+ elif isinstance (k , (list , tuple )) and len (k ) == 2 :
586+ kH , kW = int (k [0 ]), int (k [1 ])
587+ else :
588+ # Fallback: if kernel size is not statically available, assume 1x1
589+ kH , kW = 1 , 1
590+
591+ # comparisons per output = kH*kW - 1
592+ return int (out_numel ) * max (int (kH ) * int (kW ) - 1 , 0 )
593+
594+ def _value_sizes (v ):
595+ """
596+ Get static tensor sizes from torch._C.Value (JIT IR value).
597+ Returns a list like [N, C, H, W] or None if unknown.
598+ """
599+ try :
600+ t = v .type ()
601+ if hasattr (t , "sizes" ) and t .sizes () is not None :
602+ return list (t .sizes ())
603+ except Exception :
604+ pass
605+ return None
606+
607+ def _value_numel (v ):
608+ sizes = _value_sizes (v )
609+ if not sizes or any (s is None for s in sizes ):
610+ return 0
611+ n = 1
612+ for s in sizes :
613+ n *= int (s )
614+ return n
615+
616+ def _to_ivalue (v , default = None ):
617+ """
618+ Try to materialize constant from torch._C.Value if it is a constant.
619+ Works for many prim::Constant-derived Values.
620+ """
621+ try :
622+ return v .toIValue ()
623+ except Exception :
624+ return default
625+
0 commit comments