-
Notifications
You must be signed in to change notification settings - Fork 17
Expand file tree
/
Copy pathsample.py
More file actions
executable file
·111 lines (96 loc) · 3.93 KB
/
sample.py
File metadata and controls
executable file
·111 lines (96 loc) · 3.93 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import os
import hydra
import torch
import random
import numpy as np
from omegaconf import DictConfig, OmegaConf
from loguru import logger
from termcolor import cprint
from utils.misc import timestamp_str, compute_model_dim
from utils.io import mkdir_if_not_exists
from datasets.base import create_dataset
from datasets.misc import collate_fn_general, collate_fn_squeeze_pcd_batch
from models.base import create_model
from models.visualizer import create_visualizer
def load_ckpt(model: torch.nn.Module, path: str) -> None:
""" load ckpt for current model
Args:
model: current model
path: save path
"""
assert os.path.exists(path), 'Can\'t find provided ckpt.'
saved_state_dict = torch.load(path)['model']
model_state_dict = model.state_dict()
for key in model_state_dict:
if key in saved_state_dict:
model_state_dict[key] = saved_state_dict[key]
## model is trained with ddm
if 'module.'+key in saved_state_dict:
model_state_dict[key] = saved_state_dict['module.'+key]
model.load_state_dict(model_state_dict)
@hydra.main(version_base=None, config_path="./configs", config_name="default")
def main(cfg: DictConfig) -> None:
## compute modeling dimension according to task
cfg.model.d_x = compute_model_dim(cfg.task)
if os.environ.get('SLURM') is not None:
cfg.slurm = True # update slurm config
## set output dir
eval_dir = os.path.join(cfg.exp_dir, 'eval')
mkdir_if_not_exists(eval_dir)
vis_dir = os.path.join(eval_dir,
'series' if cfg.task.visualizer.vis_denoising else 'final', timestamp_str())
logger.add(vis_dir + '/sample.log') # set logger file
logger.info('Configuration: \n' + OmegaConf.to_yaml(cfg)) # record configuration
if cfg.gpu is not None:
device = f'cuda:{cfg.gpu}'
else:
device = 'cpu'
## prepare dataset for visual evaluation
## only load scene
datasets = {
'test': create_dataset(cfg.task.dataset, 'test', cfg.slurm, case_only=True),
}
for subset, dataset in datasets.items():
logger.info(f'Load {subset} dataset size: {len(dataset)}')
if cfg.model.scene_model.name == 'PointTransformer':
collate_fn = collate_fn_squeeze_pcd_batch
else:
collate_fn = collate_fn_general
dataloaders = {
'test': datasets['test'].get_dataloader(
batch_size=cfg.task.test.batch_size,
collate_fn=collate_fn,
num_workers=cfg.task.test.num_workers,
pin_memory=True,
shuffle=True,
)
}
## create model and load ckpt
model = create_model(cfg, slurm=cfg.slurm, device=device)
model.to(device=device)
## if your models are seperately saved in each epoch, you need to change the model path manually
ckpt_path = os.path.join(cfg.ckpt_dir, 'model.pth')
if not os.path.exists(ckpt_path):
checkpoint_files = [f for f in os.listdir(cfg.ckpt_dir) if f.startswith('model_') and f.endswith('.pth')]
latest_ckpt = max(checkpoint_files, key=lambda f: int(f.split('_')[1].replace('.pth', '')))
ckpt_path = os.path.join(cfg.ckpt_dir, latest_ckpt)
logger.info(f"Using the latest checkpoint: {ckpt_path}")
load_ckpt(model, path=ckpt_path)
if cfg.diffuser.sample.use_dpmsolver:
cprint("\033[1;35m[INFO] Using DPMSolver++ for sampling.\033[0m", "magenta")
else:
cprint("\033[1;35m[INFO] Using DDPM for sampling.\033[0m", "magenta")
## create visualizer and visualize
visualizer = create_visualizer(cfg.task.visualizer)
visualizer.visualize(model, dataloaders['test'], vis_dir)
if __name__ == '__main__':
## set random seed
seed = 0
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
main()