diff --git a/monai/losses/ssim_loss.py b/monai/losses/ssim_loss.py index 8ee1da7267..3fa578da29 100644 --- a/monai/losses/ssim_loss.py +++ b/monai/losses/ssim_loss.py @@ -111,17 +111,17 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: # 2D data x = torch.ones([1,1,10,10])/2 y = torch.ones([1,1,10,10])/2 - print(1-SSIMLoss(spatial_dims=2)(x,y)) + print(SSIMLoss(spatial_dims=2)(x,y)) # pseudo-3D data x = torch.ones([1,5,10,10])/2 # 5 could represent number of slices y = torch.ones([1,5,10,10])/2 - print(1-SSIMLoss(spatial_dims=2)(x,y)) + print(SSIMLoss(spatial_dims=2)(x,y)) # 3D data x = torch.ones([1,1,10,10,10])/2 y = torch.ones([1,1,10,10,10])/2 - print(1-SSIMLoss(spatial_dims=3)(x,y)) + print(SSIMLoss(spatial_dims=3)(x,y)) """ ssim_value = self.ssim_metric._compute_tensor(input, target).view(-1, 1) loss: torch.Tensor = 1 - ssim_value