-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_2.py
More file actions
121 lines (95 loc) · 4.95 KB
/
test_2.py
File metadata and controls
121 lines (95 loc) · 4.95 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import math
import torch.nn
import torch.optim
import torchvision
import numpy as np
from model import *
from imp_subnet import *
import config as c
import datasets
import modules.Unet_common as common
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def load(name, net, optim):
state_dicts = torch.load(name)
network_state_dict = {k: v for k, v in state_dicts['net'].items() if 'tmp_var' not in k}
net.load_state_dict(network_state_dict)
try:
optim.load_state_dict(state_dicts['opt'])
except:
print('Cannot load optimizer for some reason or other')
def gauss_noise(shape):
noise = torch.zeros(shape).cuda()
for i in range(noise.shape[0]):
noise[i] = torch.randn(noise[i].shape).cuda()
return noise
def computePSNR(origin, pred):
origin = np.array(origin)
origin = origin.astype(np.float32)
pred = np.array(pred)
pred = pred.astype(np.float32)
mse = np.mean((origin / 1.0 - pred / 1.0) ** 2)
if mse < 1.0e-10:
return 100
return 10 * math.log10(255.0 ** 2 / mse)
net1 = Model_1()
net2 = Model_2()
net1.cuda()
net2.cuda()
init_model(net1)
init_model(net2)
net1 = torch.nn.DataParallel(net1, device_ids=c.device_ids)
net2 = torch.nn.DataParallel(net2, device_ids=c.device_ids)
params_trainable1 = (list(filter(lambda p: p.requires_grad, net1.parameters())))
params_trainable2 = (list(filter(lambda p: p.requires_grad, net2.parameters())))
optim1 = torch.optim.Adam(params_trainable1, lr=c.lr, betas=c.betas, eps=1e-6, weight_decay=c.weight_decay)
optim2 = torch.optim.Adam(params_trainable2, lr=c.lr, betas=c.betas, eps=1e-6, weight_decay=c.weight_decay)
weight_scheduler1 = torch.optim.lr_scheduler.StepLR(optim1, c.weight_step, gamma=c.gamma)
weight_scheduler2 = torch.optim.lr_scheduler.StepLR(optim2, c.weight_step, gamma=c.gamma)
dwt = common.DWT()
iwt = common.IWT()
if c.pretrain:
load(c.PRETRAIN_PATH2 + c.suffix_pretrain + '_1.pt', net1, optim1)
load(c.PRETRAIN_PATH2 + c.suffix_pretrain + '_2.pt', net2, optim2)
with torch.no_grad():
net1.eval()
net2.eval()
for i, x in enumerate(datasets.testloader):
x = x.to(device)
cover = x[:x.shape[0] // 3] # channels = 3
secret_1 = x[x.shape[0] // 3: 2 * (x.shape[0] // 3)]
secret_2 = x[2 * (x.shape[0] // 3): 3 * (x.shape[0] // 3)]
cover_dwt = dwt(cover) # channels = 12
secret_dwt_1 = dwt(secret_1)
secret_dwt_2 = dwt(secret_2)
input_dwt_1 = torch.cat((cover_dwt, secret_dwt_1, secret_dwt_2), 1) # channels = 36
output_dwt_1 = net1(input_dwt_1) # channels = 72 [stego, z, z_key, global_key, local_key, key_input]
output_steg_dwt_1 = output_dwt_1.narrow(1, 0, 4 * c.channels_in) # channels = 12
z_dwt_1 = output_dwt_1.narrow(1, 4 * c.channels_in, 4 * c.channels_in)
global_key_dwt = output_dwt_1.narrow(1, 12 * c.channels_in, 4 * c.channels_in)
local_key_dwt_1 = output_dwt_1.narrow(1, 16 * c.channels_in, 4 * c.channels_in)
key_dwt_1 = output_dwt_1.narrow(1, 20 * c.channels_in, 4 * c.channels_in)
output_steg_1 = iwt(output_steg_dwt_1) # channels = 3
key_1 = iwt(key_dwt_1)
z_1 = iwt(z_dwt_1)
input_dwt_2 = torch.cat((output_steg_dwt_1, secret_dwt_2, global_key_dwt), 1) # channels = 36
output_dwt_2 = net2(input_dwt_2) # channels = 48 [stego, z, z_key, local_key]
output_steg_dwt_2 = output_dwt_2.narrow(1, 0, 4 * c.channels_in) # channels = 12
local_key_dwt_2 = output_dwt_2.narrow(1, 12 * c.channels_in, 4 * c.channels_in)
output_steg_2 = iwt(output_steg_dwt_2) # channels = 3
output_rev_dwt_2 = output_steg_dwt_2 # channels = 12
rev_dwt_2 = net2(output_rev_dwt_2, rev=True) # channels = 48 [stego, secret, key, rev_z]
rev_steg_dwt_1 = rev_dwt_2.narrow(1, 0, 4 * c.channels_in) # channels = 12
rev_secret_dwt_2 = rev_dwt_2.narrow(1, 4 * c.channels_in, 4 * c.channels_in) # channels = 12
rev_steg_1 = iwt(rev_steg_dwt_1) # channels = 3
rev_secret_2 = iwt(rev_secret_dwt_2) # channels = 3
output_rev_dwt_1 = rev_steg_dwt_1 # channels = 12
rev_dwt_1 = net1(output_rev_dwt_1, rev=True) # channels = 48 [cover, secret, key, rev_z]
rev_secret_dwt = rev_dwt_1.narrow(1, 4 * c.channels_in, 4 * c.channels_in) # channels = 12
rev_secret_1 = iwt(rev_secret_dwt)
torchvision.utils.save_image(cover, c.TEST_PATH2_cover + '%.5d.png' % i)
torchvision.utils.save_image(secret_1, c.TEST_PATH2_secret_1 + '%.5d.png' % i)
torchvision.utils.save_image(secret_2, c.TEST_PATH2_secret_2 + '%.5d.png' % i)
torchvision.utils.save_image(output_steg_1, c.TEST_PATH2_steg_1 + '%.5d.png' % i)
torchvision.utils.save_image(rev_secret_1, c.TEST_PATH2_secret_rev_1 + '%.5d.png' % i)
torchvision.utils.save_image(output_steg_2, c.TEST_PATH2_steg_2 + '%.5d.png' % i)
torchvision.utils.save_image(rev_secret_2, c.TEST_PATH2_secret_rev_2 + '%.5d.png' % i)