-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathutils.py
More file actions
248 lines (209 loc) · 8.92 KB
/
Copy pathutils.py
File metadata and controls
248 lines (209 loc) · 8.92 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
import time
from glob import glob
from pathlib import Path
import pickle
import datetime
import torch
import numpy as np
from scipy import ndimage
from pathlib import Path
from torchmetrics.classification import BinaryJaccardIndex
from torcheval.metrics.functional import binary_f1_score
import torch.nn as nn
# remove noise and isolated pixels (i.e. powerlines, birds, bushes, and more)
def erode_and_dilate(source, structure=np.ones((2,2)), it=1):
eroded = ndimage.binary_erosion(source, structure=structure, iterations=it).astype(source.dtype)
dilated = ndimage.binary_dilation(eroded, structure=structure, iterations=it*3).astype(eroded.dtype)
result = np.logical_and(source,dilated).astype(dilated.dtype)
return result
# sort files by date in filename
# return most recent file
# example file: model/ResNext_B2353_E1_14-Oct-2022.pt
# 'ResNext_B2353_E1_14-Oct-2022.pt' -> '14-Oct-2022.pt' -> '14-Oct-2022'
# sort by date and reconstruct filename with latest date
def most_recent_file(filenames: list[str]):
assert len(filenames)
suff = Path(filenames[0]).suffix
date = extract_latest_date(filenames)
latest_file = ('_').join(filenames[0].split('_')[:-1]) + '_' + date + suff
assert latest_file in filenames
return latest_file
def extract_latest_date(filenames: list[str]):
assert len(filenames)
extract_date = lambda fname : datetime.datetime.strptime(Path(fname).name.split('_')[-1].split('.')[0], '%d-%b-%Y').date()
dates = [extract_date(f) for f in filenames]
dates.sort()
dates = [d.strftime('%d-%b-%Y') for d in dates]
return dates[-1]
# sometimes the last epoch is not the best performer
# on the validation set but rather the most trained
# and just saved for testing
def latest_epoch(filenames: list[str], latest=True) -> str:
assert len(filenames)
suff = Path(filenames[0]).suffix
extract_epoch = lambda fname : int(Path(fname).name.split('_')[-2][1:])
epochs = [extract_epoch(f) for f in filenames]
epochs.sort()
tmp = filenames[0].split('_')
tmp[-2] = 'E' + str(epochs[-1])
if len(filenames) >= 2 and not latest:
tmp[-2] = 'E' + str(epochs[-2])
latest_file = '_'.join(tmp)
assert latest_file in filenames
return latest_file
def parse_date():
return datetime.datetime.now().strftime('%d-%b-%G')
#### Checkpoints ####
def latest_checkpoint(model_name, latest=True) -> Path:
checkpoints = glob(f'model/{model_name}_*.pt')
date = extract_latest_date(checkpoints)
checkpoints = glob(f'model/{model_name}_*_{date}.pt')
print(checkpoints)
filename = latest_epoch(checkpoints, latest=latest)
#assert filename in glob(f'model/{model_name}_*.pt')
return filename
def load_checkpoint(path):
print('=> Loading checkpoint..')
return torch.load(path)
def apply_checkpoint(checkpoint, model):
print(f'=> Applying checkpoint to {model.__class__.__name__} model..')
model.load_state_dict(checkpoint)
def checkpoint_by_date_epoch(modelname, date, epoch):
return load_checkpoint(f'model/{modelname}/{date}/{modelname}_{epoch}_{date}.pt')
"""
Save model weights by name, epoch and date.
Optionally, saves the error outputs aswell.
"""
def save_checkpoint(model, epoch, name=None, log=None, save_dir=None, local_model=None):
if save_dir is None:
save_dir = f'model/{generate_model_name(model,supervised=True,local_model=local_model)}/{parse_date()}'
Path(save_dir).mkdir(parents=True, exist_ok=True)
name = generate_model_name(model, supervised=True) if name is None else name
state = model.state_dict()
path = Path(save_dir + '/' + name + '_E' + str(epoch+1) + '_' + parse_date() + '.pt')
print('=> Saving checkpoint..')
torch.save(state, path)
if log != None:
txtpath = Path(save_dir + '/' + name + '_' + parse_date() + '.txt')
with open(txtpath, 'w') as fh:
for l in log:
fh.write(l + '\n')
def log_time_elapsed(model, elapsed_time, save_dir='model/'):
txtpath = Path(save_dir + generate_model_name(model) + '_' + parse_date() + '.txt')
elapsed_time = str(datetime.timedelta(seconds=elapsed_time))
with open(txtpath, 'a') as fh:
fh.write(f'Time elapsed {elapsed_time}\n')
### Helper functions ###
def elapsed_time_from(start):
now = time.time()
return str(datetime.timedelta(seconds=now-start)).split('.')[0]
def std_mean(arr, dim=(1,2)):
assert len(arr.shape) == 3
if isinstance(arr, torch.Tensor):
arr = arr.numpy()
elif isinstance(arr, np.ndarray):
pass
else:
raise TypeError
return np.nanstd(arr, dim), np.nanmean(arr, dim)
# example: UNet -> SeUNet_L5_TR_PC
def generate_model_name(model, supervised=True, local_model=True):
name = type(model).__name__
if local_model:
if name.lower() in ['unet', 'unet3plus', 'attentionunet']:
name = name + '_L' + str(len(model.features))
name = name + '_' + model.upsample[:2].upper()
name = name + '_PC' if model.partial_conv else name
if name.lower() == 'resnext':
name = name + '_B' + ''.join([str(b) for b in model.layers])
name = 'Se' + name if model.use_se_block else name
name = name if supervised else name + '_semi'
name = 'saronly_' + name if model.sar_only else name
name = 'optonly_' + name if model.opt_only else name
name = 'seg_' + name if model.seg_only else name
name = 'reg_' + name if model.reg_only else name
else:
name = name
return name
def pickle_object(object, path):
with open(path, 'wb') as fh:
pickle.dump(object, fh)
def unpickle_object(path):
with open(path, 'rb') as pkl:
print(f"Loading dataset '{pkl}'...")
# dataset['name'] = (locations, loc_to_images_map, offsets)
object = pickle.load(pkl)
return object
def write_to_disk(text: list[str], path):
with open(path,'w') as fh:
fh.writelines('\n'.join(text))
"""
tensor: A pytorch tensor representing a patch
mask: A mask of NaN (1) values in the ground truth (exluded from mean)
edge: The border of a given tensor/patch to exclude from the mean
e.g. when only considering the center pixels for the loss
"""
def nanmean(tensor, mask, edge=0):
#assert len(tensor.shape) == 4 and len(mask.shape) == 4
tensor = tensor.clone()
#tensor[mask == 1] = -1
tensor[mask == 1] = 0
if edge > 0:
N = torch.sum(mask[:,:,edge:-edge,edge:-edge] == 0, dim=(0,2,3))
N[N < 1] = 1
mean = torch.sum(tensor[:,:,edge:-edge,edge:-edge], dim=(0,2,3)) / N
return mean
N = torch.sum(mask == 0, dim=(0,1,2,3))
N[N < 1] = 1
mean = torch.sum(tensor, dim=(0,1,2,3)) / N
return mean
def nanIoU(pred,target, mask, test=False, device= 'cpu'):
#print(mask)
target = torch.nan_to_num(target[:,:1], nan=0.0)
pred = torch.nan_to_num(pred[:,:1], nan=0.0)
#print(target.shape)
if not test:
m = torch.nn.Sigmoid()
pred = m(pred)
iou_calc = BinaryJaccardIndex(threshold=0.5).to(device)
iou = iou_calc(pred[mask==0], target[mask==0])
#print(iou,torch.isnan(pred[mask==0]), torch.isnan(target[mask==0] ))
if torch.isnan(iou):
iou = 1.0
return iou
def nanFscore(pred,target, mask, test=False, device= 'cpu'):
#print(mask)
target = torch.nan_to_num(target[:,:1], nan=0.0)
pred = torch.nan_to_num(pred[:,:1], nan=0.0)
#print(target.shape)
if not test:
m = torch.nn.Sigmoid()
pred = m(pred)
fscore= binary_f1_score(pred[mask==0], target[mask==0], threshold=0.5).to(device)
return fscore
class ToTensor(nn.Module):
"""
In contrast to torchvision.transforms.ToTensor, this class doesn't permute the images' dimensions.
"""
def forward(self, array):
return torch.from_numpy(array)
class SelectChannels(nn.Module):
"""
Selects specified channels from an [B, C, H, W] tensor, dropping all other channels.
"""
def __init__(self, channels: np.array):
super().__init__()
self.channels = channels
def forward(self, tensor):
if tensor.ndim == 3:
return tensor[self.channels, :, :]
elif tensor.ndim == 4:
return tensor[:, self.channels, :, :]
raise ValueError(f'tensor.ndim should be in [3,4], got {tensor.ndim}')
def normalize_val(data):
return (data - np.min(data)) / (np.max(data) - np.min(data))
def main():
value = nanIoU(torch.tensor([[[0.0, 0.2, 0.5],[0.0, 0.2, 0.5],[0.0, 0.2, 0.5]]], requires_grad=True), torch.tensor([[[1.0, 0.0, 0.0], [1.0, 0.0, 0.0], [1.0, 0.0, 0.0]]], requires_grad=True), torch.tensor([[0.0,5.0,0.0], [0.0,5.0,0.0], [0.0,5.0,0.0]], requires_grad=True), 'cpu')
#print(value)
if __name__=='__main__':
main()