-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtool.py
More file actions
48 lines (37 loc) · 1.41 KB
/
tool.py
File metadata and controls
48 lines (37 loc) · 1.41 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import os
import torch
def float_or_string(input):
try:
return float(input)
except:
if isinstance(input, str):
return input
else:
raise TypeError(f"Type for temperature \
must be either string or float, but got {type(input)} instead.")
def save_checkpoint(model, cfg, epoch, optimizer, lr_scheduler, logging_dir, logger, best=False):
state_dict = model.module.get_state_dict_to_save()
saving_dict = {
"state_dict": state_dict,
"cfg": cfg,
"epoch": epoch,
"optimizer": optimizer.state_dict(),
"lr_scheduler": lr_scheduler.state_dict()
}
filename = "best_ckpt.pt" if best else f"{epoch}_epoch.pt"
save_path = os.path.join(logging_dir, filename)
torch.save(saving_dict, save_path)
logger.info(f"{filename} saved at {save_path}")
def save_idx_checkpoint(model, cfg, epoch, optimizer, lr_scheduler, logging_dir, logger, idx):
state_dict = model.module.get_state_dict_to_save()
saving_dict = {
"state_dict": state_dict,
"cfg": cfg,
"epoch": epoch,
"optimizer": optimizer.state_dict(),
"lr_scheduler": lr_scheduler.state_dict()
}
filename = f"{epoch}_{idx}.pt"
save_path = os.path.join(logging_dir, filename)
torch.save(saving_dict, save_path)
logger.info(f"{filename} saved at {save_path}")