-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathpredict_testset.py
More file actions
375 lines (337 loc) · 16.9 KB
/
Copy pathpredict_testset.py
File metadata and controls
375 lines (337 loc) · 16.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
import torch
import rasterio
import numpy as np
from pathlib import Path
from glob import glob
import pickle
import torch.nn as nn
from torchvision.transforms import Normalize, Compose
from models import AttentionUNet, UNet, UNet3Plus, ResNext
from geoutils import save_np_array_to_img, file_meta
import datetime
import matplotlib.pyplot as plt
from scipy.stats import gaussian_kde
from scipy.interpolate import interpn
from geoutils import file_meta
from utils import (
load_checkpoint, most_recent_file,
parse_date, apply_checkpoint, nanmean,
generate_model_name
)
cfg = {
'unet_data': {
'opt_image_bands': [1, 2, 3, 4], #,5,6,7,8
'sar_image_polarizations': [1],
'labels_bands': [1], #, 2, 3, 4
'data_stats': {
'opt_mean': 1065.408447265625,
'opt_std': 1222.510009765625,
'sar_mean': 250.40419006347656,
'sar_std': 141.63661193847656,
# ['avg', 'cov', 'dns', 'p95']
'labels_mean': [ 3.10381], #2.6482544, 27.901573, 21.948093,
'labels_std': [ 5.437437] #4.6751084, 42.40032, 34.249233,
},
'pkl_path': 'data/datasets/pkl/dataset_unet_*.pkl',
},
'resnext_data': {
'opt_image_bands': [1, 2, 3, 4],
'sar_image_polarizations': [1],
'labels_bands': [1,2,3,4],
'data_stats': {
'opt_mean': 1065.408447265625,
'opt_std': 1222.510009765625,
'sar_mean': 250.40419006347656,
'sar_std': 141.63661193847656,
# ['avg', 'cov', 'dns', 'p95']
'labels_mean': [2.6482544, 27.901573, 21.948093, 3.10381],
'labels_std': [4.6751084, 42.40032, 34.249233, 5.437437]
},
'pkl_path': 'data/datasets/pkl/dataset_resnext_*.pkl'
}
}
"""
Unpickles the most recent dataset and makes prediction on all test (1) tiles
"""
class PredictTestset():
def __init__(self, model: nn.Module, data_stats: dict,
opt_image_bands: list[int] = [1, 2, 3, 4], sar_image_polarizations: list[int] = [1], labels_bands: list[int] = [1, 2, 3, 4],
opt_transforms: Compose = None, sar_transforms: Compose = None, output_activation=False, labels_transforms: Compose = None,
bayesian: bool = False, pkl_path: str = 'data/datasets/pkl/*.pkl', supervised=True, name=None):
assert not (model.opt_only and model.sar_only)
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.model = model
self.arch = model.__class__.__name__.lower()
self.stats = data_stats
self.bayesian = bayesian
self.output_activation = output_activation
self.supervised=supervised
if name is None:
self.name = generate_model_name(self.model, self.supervised)
else:
self.name = name
self.directory = self.name.split('_')[0].lower()
self.performed_predictions = False
pkl_files = glob(f'{pkl_path}')
pkl_file = Path(most_recent_file(pkl_files))
with pkl_file.open('rb') as pkl:
print(f"Loading dataset '{pkl_file}'...")
dataset = pickle.load(pkl)
self.labels = dataset['labels']
self.images = dataset['images']
self.patch_size = dataset['patch_size']
self.margin = dataset['margin']
# total number of pixels to predict in each patch
self.n_points = (self.patch_size-(2*self.margin))**2
# number of variables to predict
c = len(self.labels)
h, w = self.labels[0].shape
if bayesian:
self.mean_result = torch.zeros((c,h,w), dtype=float)
self.mean_result[self.labels == np.nan] = np.nan
self.variance_result = torch.zeros((c,h,w), dtype=float)
self.variance_result[self.labels == np.nan] = np.nan
else:
#self.result = torch.zeros((c,h,w), dtype=float)
self.result = torch.full((c,h,w), float('nan'), dtype=float)
self.result[torch.from_numpy(self.labels == np.nan)] = np.nan
self.mae_error = torch.zeros((c,h,w), dtype=float)
self.mse_error = torch.zeros((c,h,w), dtype=float)
self.total_me_error = torch.zeros((len(self.labels)), dtype=float)
self.total_mae_error = torch.zeros((len(self.labels)), dtype=float)
self.total_mse_error = torch.zeros((len(self.labels)), dtype=float)
#self.patch_size = 64
self.ds_split_mask = dataset['split_mask'].squeeze(0)
self.image_indices = dataset['image_indices']
self.locations = dataset['test']
opt_channels = np.array(opt_image_bands) - 1
sar_channels = np.array(sar_image_polarizations) - 1
labels_channels = np.array(labels_bands) - 1
self.opt_transforms = Compose([ToTensor(), SelectChannels(opt_channels), Normalize(self.stats['opt_mean'], self.stats['opt_std'])]) if opt_transforms == None else opt_transforms
self.sar_transforms = Compose([ToTensor(), SelectChannels(sar_channels), Normalize(self.stats['sar_mean'], self.stats['sar_std'])]) if sar_transforms == None else sar_transforms
self.labels_transforms = Compose([ToTensor(), SelectChannels(labels_channels)])
def predict_all(self):
print(f'Performing predictions with {self.margin} px margin..')
self.model.eval()
for i in range(len(self)):
self.predict(i)
self.performed_predictions = True
"""
The code is setup to expect data in batches. As a result there is some squeezing/unsqueezing in the following code to address this.
"""
def predict(self, index):
i, j = self.locations[index]
i_slice = slice(i, i+self.patch_size)
j_slice = slice(j, j+self.patch_size)
# consider only pixels within margin to allow to sufficient spatial context
i_margin_slice = slice(i+self.margin, i+self.patch_size-self.margin)
j_margin_slice = slice(j+self.margin, j+self.patch_size-self.margin)
i_patch_slice = slice(self.margin, -self.margin)
j_patch_slice = slice(self.margin, -self.margin)
opt_idx, sar_idx = self.image_indices['opt'], self.image_indices['sar']
opt_patch = self.opt_transforms(self.images[opt_idx][:, i_slice, j_slice])
sar_patch = self.sar_transforms(self.images[sar_idx][:, i_slice, j_slice])
# input patch
x = torch.cat([opt_patch, sar_patch], dim=0).unsqueeze(0).to(device=self.device)
if self.model.opt_only:
x = x[:,:self.model.in_channels]
if self.model.sar_only:
x = x[:,-self.model.sar_channels:]
# ground truth patch
y = self.labels_transforms(self.labels[:, i_slice, j_slice])#.to(device=self.device)
mask = torch.isnan(y)
y_pred = self.model(x)
#if self.output_activation:
# cov, dns
#y_pred[:,[1,2]] = torch.sigmoid(y_pred[:,[1,2]])
y_pred = y_pred.detach().cpu()
self.total_me_error += nanmean((y_pred-y.unsqueeze(dim=0)), mask.unsqueeze(dim=0), edge=self.margin).squeeze()
self.total_mae_error += nanmean((y_pred-y.unsqueeze(dim=0)).abs(), mask.unsqueeze(dim=0), edge=self.margin).squeeze()
self.total_mse_error += nanmean((y_pred-y.unsqueeze(dim=0))**2, mask.unsqueeze(dim=0), edge=self.margin).squeeze()
y_pred = y_pred.squeeze(0)
error_mae = (y_pred - y).abs()
error_mae[mask] = 0
error_mse = (y_pred - y)**2
error_mse[mask] = 0
y_pred[mask] = np.nan
# only save results within margin
self.result[:,i_margin_slice,j_margin_slice] = y_pred[:,i_patch_slice,j_patch_slice]
self.mae_error[:,i_margin_slice,j_margin_slice] = error_mae[:,i_patch_slice,j_patch_slice]
self.mse_error[:,i_margin_slice,j_margin_slice] = error_mse[:,i_patch_slice,j_patch_slice]
"""
Save prediction and errors as maps (.tif)
"""
def write_to_disk(self, gt_meta):
assert self.performed_predictions
print(f'Writing {self.name} predictions to disk..')
meta = gt_meta.copy()
meta['driver'] = 'GTiff'
meta['count'] = 1
meta['dtype'] = 'float32'
Path(f'result/{self.name}').mkdir(parents=True, exist_ok=True)
if self.bayesian:
for i, label in enumerate(['p95']): #'avg','cov','dns',
save_np_array_to_img(np.expand_dims(self.mean_result.numpy()[i], 0), meta, f'result/{self.name}/mean_{label}_{self.name}_{parse_date()}.tif')
save_np_array_to_img(np.expand_dims(self.variance_result.numpy()[i], 0), meta, f'result/{self.name}/variance_{label}_{self.name}_{parse_date()}.tif')
save_np_array_to_img(np.expand_dims(self.mae_error.numpy()[i], 0), meta, f'result/{self.name}/mae_{label}_{self.name}_{parse_date()}.tif')
save_np_array_to_img(np.expand_dims(self.mse_error.numpy()[i], 0), meta, f'result/{self.name}/mse_{label}_{self.name}_{parse_date()}.tif')
else:
for i, label in enumerate(['p95']): #'avg','cov','dns',
save_np_array_to_img(np.expand_dims(self.result.numpy()[i], 0), meta, f'result/{self.name}/{label}_{self.name}_{parse_date()}.tif')
save_np_array_to_img(np.expand_dims(self.mae_error.numpy()[i], 0), meta, f'result/{self.name}/mae_{label}_{self.name}_{parse_date()}.tif')
save_np_array_to_img(np.expand_dims(self.mse_error.numpy()[i], 0), meta, f'result/{self.name}/mse_{label}_{self.name}_{parse_date()}.tif')
"""
Suggested method over method below.
Not implemented yet: https://pypi.org/project/latextable/
"""
def write_error_to_latex_table(self):
pass
def write_error_to_disk(self):
assert self.performed_predictions
err = str(self.calculate_errors())
Path(f'result/{self.name}').mkdir(parents=True, exist_ok=True)
with open(f'result/{self.name}/{self.name}_{parse_date()}.txt', 'w') as fh:
fh.write(err)
def calculate_errors(self):
assert self.performed_predictions
y_mean = np.nanmean(self.labels[:,self.ds_split_mask == 1], axis=1)
y_max = np.nanmax(self.labels[:,self.ds_split_mask == 1], axis=1)
y_min = np.nanmin(self.labels[:,self.ds_split_mask == 1], axis=1)
# MAE
mae_nums = (self.total_mae_error / len(self.locations)).numpy()
mae = []
mae_norm = []
mae_range = []
# MSE
mse_nums = (self.total_mse_error / len(self.locations)).numpy()
mse = []
mse_norm = []
mse_range = []
# RMSE
rmse_nums = np.array([np.sqrt(x) for x in mse_nums])
rmse = []
rmse_norm = []
rmse_range = []
# MBE
me_nums = (self.total_me_error / len(self.locations)).numpy()
mbe = []
for i, label in enumerate(['p95']): #'avg','cov','dns',
mae.append(f'{label}: {mae_nums[i]:.2f}')
mae_norm.append(f'{label}: {100*(mae_nums[i]/y_mean[i])}')
mae_range.append(f'{label}: {100*(mae_nums[i]/(y_max[i]-y_min[i]))}')
mse.append(f'{label}: {mse_nums[i]}')
mse_norm.append(f'{label}: {100*(mse_nums[i]/y_mean[i])}')
mse_range.append(f'{label}: {100*(mse_nums[i]/(y_max[i]-y_min[i]))}')
rmse.append(f'{label}: {rmse_nums[i]}')
rmse_norm.append(f'{label}: {100*(rmse_nums[i]/y_mean[i])}')
rmse_range.append(f'{label}: {100*(rmse_nums[i]/(y_max[i]-y_min[i]))}')
mbe.append(f'{label}: {me_nums[i]}')
results = {
'mae': mae,
'mae_norm': mae_norm,
'mae_range': mae_range,
'mse': mse,
'mse_norm': mse_norm,
'mse_range': mse_range,
'rmse': rmse,
'rmse_norm': rmse_norm,
'rmse_range': rmse_range,
'mbe': mbe
}
return results
"""
Return a flattened vector for each structure variable
Each vector contain (y_pred,y) tuples excluding all nan values
"""
def prediction_label_matrix(self):
assert self.performed_predictions
predictions = torch.clone(self.result)
# predictions from the test set
predictions = predictions[:,self.ds_split_mask == 1]
y = torch.clone(torch.from_numpy(self.labels))
# labels from the test set
y = y[:,self.ds_split_mask == 1]
assert predictions.shape == y.shape
indices = predictions.isfinite() & y.isfinite()
matrix = torch.stack((predictions[indices],y[indices]), dim=1)
# ['avg','cov','dns','p95'] maps with tuple (y_pred,y) for each pixel
return matrix.reshape(len(self.labels),-1,2)
def __len__(self):
return len(self.locations)
class ToTensor(nn.Module):
"""
In contrast to torchvision.transforms.ToTensor, this class doesn't permute the images' dimensions.
"""
def forward(self, array):
return torch.from_numpy(array)
class SelectChannels(nn.Module):
"""
Selects specified channels from an [B, C, H, W] tensor, dropping all other channels.
"""
def __init__(self, channels: np.array):
super().__init__()
self.channels = channels
def forward(self, tensor):
if tensor.ndim == 3:
return tensor[self.channels, :, :]
elif tensor.ndim == 4:
return tensor[:, self.channels, :, :]
raise ValueError(f'tensor.ndim should be in [3,4], got {tensor.ndim}')
def main():
torch.cuda.empty_cache()
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
# All models
models = [ # PSE-UNet
UNet(in_channels=5, out_channels=1, features=[64, 128, 256, 512, 1024], upsample='transpose',
use_se_block=True, entry_block=True, partial_conv=True, dropout=True, init_weights=False
).to(device=DEVICE),
#AttentionUnet
AttentionUNet(in_channels=5, out_channels=1,features=[32, 64, 128, 256, 512], use_se_block=False,
upsample='bilinear', partial_conv=False, entry_block=True, dropout=True, sar_channels=1
).to(device=DEVICE),
#AttentionUnet opt only
AttentionUNet(in_channels=4, out_channels=1,features=[32, 64, 128, 256, 512], use_se_block=False,
upsample='bilinear', partial_conv=False, entry_block=True, dropout=True, opt_only=True
).to(device=DEVICE),
#AttentionUnet sar only
AttentionUNet(in_channels=1, out_channels=1,features=[32, 64, 128, 256, 512], use_se_block=False,
upsample='bilinear', partial_conv=False, entry_block=True, dropout=True, sar_only=True
).to(device=DEVICE),
# UNet
UNet(in_channels=5, out_channels=1,features=[32, 64, 128, 256, 512], use_se_block=False,
upsample='bilinear', partial_conv=False, entry_block=True, dropout=True, sar_channels=1
).to(device=DEVICE),
# SeUnet,
UNet(in_channels=5, out_channels=1, upsample='transpose', features=[128, 256, 512], use_se_block=True, partial_conv=False, entry_block=False, dropout=True, sar_channels=1, init_weights=False, sar_only=False, opt_only=False).to(device=DEVICE)
# UNet3+
#UNet3Plus(in_channels=5, out_channels=4, upsample='bilinear', features=[64, 128, 256, 512, 1024], partial_conv=False, entry_block=False, dropout=True, sar_channels=1, use_se_block=False, init_weights=True).to(device=DEVICE),
# ResNeXt
#ResNext(in_channels=5, out_channels=4, layers=[2,3,5,3], groups=32, width_per_group=4, use_pixel_shortcut=True, use_entry_block=True, use_se_block=False, num_sar_channels=1, bayesian=False).to(device=DEVICE)
]
checkpoints = [
'model/SeUNet_L5_TR_PC/12-Sep-2023/SeUNet_L5_TR_PC_E25_12-Sep-2023.pt',
'model/AttentionUNet/12-Sep-2023/AttentionUNet_E11_12-Sep-2023.pt',
'model/optonly_AttentionUNet_L5_BI/13-Sep-2023/optonly_AttentionUNet_L5_BI_E26_13-Sep-2023.pt',
'model/saronly_AttentionUNet_L5_BI/14-Sep-2023/saronly_AttentionUNet_L5_BI_E11_14-Sep-2023.pt',
'model/UNet_L5_BI/12-Sep-2023/UNet_L5_BI_E29_12-Sep-2023.pt',
'model/SeUNet_L3_TR/12-Sep-2023/SeUNet_L3_TR_E19_12-Sep-2023.pt'
#'model/UNet_L5_BI/12-Sep-2023/UNet_L5_BI_E29_12-Sep-2023.pt'
]
for model, c in zip(models,checkpoints):
ckpt = load_checkpoint(c)
apply_checkpoint(ckpt, model)
# Predictions
meta = file_meta()
for model in models:
ds_name = 'unet_data' if type(model) is not ResNext else 'resnext_data'
name = generate_model_name(model)
test_ds = PredictTestset(model, name=name, bayesian=False, output_activation=True, **cfg[ds_name])
test_ds.predict_all()
# write error and prediction tifs to disk
test_ds.write_to_disk(meta)
# save errors to a text file
test_ds.write_error_to_disk()
del(test_ds)
return
if __name__ == '__main__':
main()