-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathAutoEncoder.py
More file actions
67 lines (54 loc) · 2.07 KB
/
AutoEncoder.py
File metadata and controls
67 lines (54 loc) · 2.07 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch
from src.network_architectures import encoder_layers, decoder_layers
class AutoEncoder(torch.nn.Module):
"""
Autoencoder network
The autoencoder network aims at encoding the image to a latent space and decoding the latent space to an image.
The latent space should be invariant to the attributes.
"""
def __init__(self):
super(AutoEncoder, self).__init__()
self.encoder = encoder_layers
self.decoder = decoder_layers
self.optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
def forward(self, x, y, display: bool = False) -> (torch.Tensor, torch.Tensor):
"""
x: image
y: attributes
"""
# first convert x and y to float
input_x = x.float()
input_y = y.float()
# encode x to latent space
latent = self.encoder(input_x)
# expand y to match latent space dimensions (2, 2)
input_y = input_y.unsqueeze(2).unsqueeze(3)
input_y = input_y.expand(y.shape[0], y.shape[1], 2, 2)
# concatenate latent and input_y along the channel dimension for the decoder to be able to process the y (attributes)
latent_y = torch.cat((latent, input_y), dim=1)
# decode latent_y to output (image)
decoded = self.decoder(latent_y)
if display:
if display:
print(
f"input_x {input_x.shape} {input_x.dtype}\n"
f"input_y {input_y.shape} {input_y.dtype}\n"
f"latent {latent.shape} {latent.dtype}\n"
f"latent_y {latent_y.shape} {latent_y.dtype}\n"
f"decoded {decoded.shape} {decoded.dtype}"
)
return latent, decoded
# %% TESTS
latent, decoded = AutoEncoder()(torch.rand((1, 3, 256, 256)), torch.rand((1, 40)))
assert latent.shape == (
1,
512,
2,
2,
), "The inference function does not work properly. Shape issue for latent"
assert decoded.shape == (
1,
3,
256,
256,
), "The inference function does not work properly. Shape issue for decoded"