diff --git a/model/resnet_cbam.py b/model/resnet_cbam.py index f196e34..67f3228 100644 --- a/model/resnet_cbam.py +++ b/model/resnet_cbam.py @@ -28,9 +28,9 @@ def __init__(self, in_planes, ratio=16): self.avg_pool = nn.AdaptiveAvgPool2d(1) self.max_pool = nn.AdaptiveMaxPool2d(1) - self.fc = nn.Sequential(nn.Conv2d(in_planes, in_planes // 16, 1, bias=False), + self.fc = nn.Sequential(nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False), nn.ReLU(), - nn.Conv2d(in_planes // 16, in_planes, 1, bias=False)) + nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)) self.sigmoid = nn.Sigmoid() def forward(self, x):