diff --git a/torchvision/ops/_utils.py b/torchvision/ops/_utils.py index 40bae605d02..85abd459961 100644 --- a/torchvision/ops/_utils.py +++ b/torchvision/ops/_utils.py @@ -98,9 +98,7 @@ def _loss_inter_union( xkis2 = torch.min(x2, x2g) ykis2 = torch.min(y2, y2g) - intsctk = torch.zeros_like(x1) - mask = (ykis2 > ykis1) & (xkis2 > xkis1) - intsctk[mask] = (xkis2[mask] - xkis1[mask]) * (ykis2[mask] - ykis1[mask]) + intsctk = (xkis2 - xkis1).clamp(min=0) * (ykis2 - ykis1).clamp(min=0) unionk = (x2 - x1) * (y2 - y1) + (x2g - x1g) * (y2g - y1g) - intsctk return intsctk, unionk