forked from TanGeeGo/Optical-Generative-models
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
73 lines (59 loc) · 2.37 KB
/
utils.py
File metadata and controls
73 lines (59 loc) · 2.37 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import math
import torch
import torch.nn.functional as F
def _extract_into_tensor(arr, timesteps, broadcast_shape):
"""
Extract values from a 1-D numpy array for a batch of indices.
:param arr: the 1-D numpy array.
:param timesteps: a tensor of indices into the array to extract.
:param broadcast_shape: a larger shape of K dimensions with the batch
dimension equal to the length of timesteps.
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
"""
if not isinstance(arr, torch.Tensor):
arr = torch.from_numpy(arr)
res = arr[timesteps].float().to(timesteps.device)
while len(res.shape) < len(broadcast_shape):
res = res[..., None]
return res.expand(broadcast_shape)
def roll_torch(tensor, shift, axis):
"""implements numpy roll() or Matlab circshift() functions for tensors"""
if shift == 0:
return tensor
if axis < 0:
axis += tensor.dim()
dim_size = tensor.size(axis)
after_start = dim_size - shift
if shift < 0:
after_start = -shift
shift = dim_size - abs(shift)
before = tensor.narrow(axis, 0, dim_size - shift)
after = tensor.narrow(axis, after_start, shift)
return torch.cat([after, before], axis)
def ifftshift(tensor):
"""ifftshift for tensors of dimensions [minibatch_size, num_channels, height, width, 2]
shifts the width and heights
"""
size = tensor.size()
tensor_shifted = roll_torch(tensor, -math.floor(size[2] / 2.0), 2)
tensor_shifted = roll_torch(tensor_shifted, -math.floor(size[3] / 2.0), 3)
return tensor_shifted
def fftshift(tensor):
"""fftshift for tensors of dimensions [minibatch_size, num_channels, height, width, 2]
shifts the width and heights
"""
size = tensor.size()
tensor_shifted = roll_torch(tensor, math.floor(size[2] / 2.0), 2)
tensor_shifted = roll_torch(tensor_shifted, math.floor(size[3] / 2.0), 3)
return tensor_shifted
def kl_divergence_loss(output: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
calculate KL divergence loss
"""
# norm
output_norm = output - output.mean()
target_norm = target - target.mean()
output_log_prob = F.log_softmax(output_norm, dim=1)
target_prob = F.softmax(target_norm, dim=1)
loss = F.kl_div(output_log_prob, target_prob, reduction='batchmean')
return loss