Skip to content

Commit e678187

Browse files
committed
[update] use fvcore - temp3
1 parent f6fb58c commit e678187

1 file changed

Lines changed: 68 additions & 3 deletions

File tree

compressai_vision/utils/measure_complexity.py

Lines changed: 68 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -293,14 +293,17 @@ def _cast(x):
293293
"aten::conv2d": conv_flop_jit,
294294
"aten::_convolution": conv_flop_jit,
295295
"aten::cudnn_convolution": conv_flop_jit,
296-
# element-wise ops
296+
# element-wise ops (out-of-place)
297297
"aten::add": elemwise_flop_jit,
298+
"aten::add_": elemwise_flop_jit,
298299
"aten::mul": elemwise_flop_jit,
299300
"aten::div": elemwise_flop_jit,
300301
"aten::abs": elemwise_flop_jit,
301302
"aten::reciprocal": elemwise_flop_jit,
302303
"aten::round": elemwise_flop_jit,
303-
"aten::leaky_relu": elemwise_flop_jit,
304+
"aten::leaky_relu": elemwise_flop_jit,
305+
# pooling
306+
"aten::max_pool2d": max_pool2d_flop_jit,
304307
})
305308
total_flops = flops.total()
306309

@@ -557,4 +560,66 @@ def forward(self, x_dummy: torch.Tensor) -> torch.Tensor:
557560
def elemwise_flop_jit(inputs, outputs):
558561
# outputs can be Tensor or tuple/list of Tensors
559562
out = outputs[0] if isinstance(outputs, (tuple, list)) else outputs
560-
return prod(get_shape(out)) # 1 flop per output element (approx.)
563+
return prod(get_shape(out)) # 1 flop per output element (approx.)
564+
565+
def max_pool2d_flop_jit(inputs, outputs):
566+
"""
567+
Approximate FLOPs for max_pool2d.
568+
569+
Convention:
570+
- For each output element, max-pool performs (kH*kW - 1) comparisons.
571+
- We count comparisons as 1 FLOP each (approx).
572+
"""
573+
# aten::max_pool2d signature (typical):
574+
# inputs = [x, kernel_size, stride, padding, dilation, ceil_mode]
575+
x = inputs[0]
576+
y = outputs[0]
577+
578+
out_numel = _value_numel(y)
579+
if out_numel == 0:
580+
return 0
581+
582+
k = _to_ivalue(inputs[1], default=None) # could be int or (kH,kW) or list
583+
if isinstance(k, int):
584+
kH, kW = k, k
585+
elif isinstance(k, (list, tuple)) and len(k) == 2:
586+
kH, kW = int(k[0]), int(k[1])
587+
else:
588+
# Fallback: if kernel size is not statically available, assume 1x1
589+
kH, kW = 1, 1
590+
591+
# comparisons per output = kH*kW - 1
592+
return int(out_numel) * max(int(kH) * int(kW) - 1, 0)
593+
594+
def _value_sizes(v):
595+
"""
596+
Get static tensor sizes from torch._C.Value (JIT IR value).
597+
Returns a list like [N, C, H, W] or None if unknown.
598+
"""
599+
try:
600+
t = v.type()
601+
if hasattr(t, "sizes") and t.sizes() is not None:
602+
return list(t.sizes())
603+
except Exception:
604+
pass
605+
return None
606+
607+
def _value_numel(v):
608+
sizes = _value_sizes(v)
609+
if not sizes or any(s is None for s in sizes):
610+
return 0
611+
n = 1
612+
for s in sizes:
613+
n *= int(s)
614+
return n
615+
616+
def _to_ivalue(v, default=None):
617+
"""
618+
Try to materialize constant from torch._C.Value if it is a constant.
619+
Works for many prim::Constant-derived Values.
620+
"""
621+
try:
622+
return v.toIValue()
623+
except Exception:
624+
return default
625+

0 commit comments

Comments
 (0)