-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_TMRec_multiGPU.py
More file actions
441 lines (391 loc) · 15.8 KB
/
train_TMRec_multiGPU.py
File metadata and controls
441 lines (391 loc) · 15.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
# This file is based on Marigold training script (https://github.com/prs-eth/marigold)
# with modifications by Tianmouc, 2025.
# These modifications are part of the work "Diffusion-Based Extreme High-speed Scenes Reconstruction
# with the Complementary Vision Sensor" published in ICCV 2025.
# Project repository: https://github.com/Tianmouc/GenRec
#
# Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# --------------------------------------------------------------------------
import argparse
import logging
import os
import shutil
from datetime import datetime, timedelta
from typing import List
import torch
from omegaconf import OmegaConf
from torch.utils.data import ConcatDataset, DataLoader
from tqdm import tqdm
from diffusers import DDPMScheduler, DDIMScheduler
from CBRDM import get_pipeline_cls
from CBRDM.unet2d_btchw import UNet2DModelBTCHW # used in stage1
from diffusers import UNet2DModel, DiffusionPipeline # used in stage2
from src.trainer import get_trainer_cls
from src.dataset import get_rec_dataset
from src.dataset.mixed_sampler import MixedBatchSampler
from src.util.config_util import (
find_value_in_omegaconf,
recursive_load_config,
)
from src.util.logging_util import (
config_logging,
init_wandb,
load_wandb_job_id,
log_slurm_job_id,
save_wandb_job_id,
tb_logger,
)
from src.util.slurm_util import get_local_scratch_dir, is_on_slurm
import requests
import torch.multiprocessing as mp
import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler
def setup(rank, world_size):
"""Setup distributed process group."""
dist.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank)
torch.cuda.set_device(rank)
def cleanup():
"""Cleanup distributed process group."""
dist.destroy_process_group()
def main_worker(rank, world_size, args):
if world_size > 1:
setup(rank, world_size)
resume_run = args.resume_run
output_dir = args.output_dir
base_ckpt_dir = (
args.base_ckpt_dir
if args.base_ckpt_dir is not None
else os.environ["BASE_CKPT_DIR"]
)
# -------------------- Initialization --------------------
# Resume previous run
if resume_run is not None:
print(f"Resume run: {resume_run}")
out_dir_run = os.path.dirname(os.path.dirname(resume_run))
job_name = os.path.basename(out_dir_run)
# Resume config file
cfg = OmegaConf.load(os.path.join(out_dir_run, "config.yaml"))
else:
# Run from start
cfg = recursive_load_config(args.config)
# Full job name
pure_job_name = os.path.basename(args.config).split(".")[0]
# Add time prefix
if args.add_datetime_prefix:
job_name = f"{t_start.strftime('%y_%m_%d-%H_%M_%S')}-{pure_job_name}"
else:
job_name = pure_job_name
# Output dir
if output_dir is not None:
out_dir_run = os.path.join(output_dir, job_name)
else:
out_dir_run = os.path.join("./output", job_name)
# if rank == 0:
os.makedirs(out_dir_run, exist_ok=True)
cfg_data = cfg.dataset
# Other directories
out_dir_ckpt = os.path.join(out_dir_run, "checkpoint")
out_dir_tb = os.path.join(out_dir_run, "tensorboard")
out_dir_eval = os.path.join(out_dir_run, "evaluation")
out_dir_vis = os.path.join(out_dir_run, "visualization")
# if rank == 0:
if not os.path.exists(out_dir_ckpt):
os.makedirs(out_dir_ckpt, exist_ok=True)
if not os.path.exists(out_dir_tb):
os.makedirs(out_dir_tb, exist_ok=True)
if not os.path.exists(out_dir_eval):
os.makedirs(out_dir_eval, exist_ok=True)
if not os.path.exists(out_dir_vis):
os.makedirs(out_dir_vis, exist_ok=True)
# -------------------- Logging settings --------------------
config_logging(cfg.logging, out_dir=out_dir_run)
logging.debug(f"config: {cfg}")
# Initialize wandb
if not args.no_wandb:
if resume_run is not None:
wandb_id = load_wandb_job_id(out_dir_run)
wandb_cfg_dic = {
"id": wandb_id,
"resume": "must",
**cfg.wandb,
}
else:
wandb_cfg_dic = {
"config": dict(cfg),
"name": job_name,
"mode": "online",
**cfg.wandb,
}
wandb_cfg_dic.update({"dir": out_dir_run})
wandb_run = init_wandb(enable=True, **wandb_cfg_dic)
save_wandb_job_id(wandb_run, out_dir_run)
else:
init_wandb(enable=False)
# Tensorboard (should be initialized after wandb)
tb_logger.set_dir(out_dir_tb)
log_slurm_job_id(step=0)
# -------------------- Device --------------------
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
logging.info(f"Rank {rank} is using device {device}")
# -------------------- Snapshot of code and config --------------------
if resume_run is None:
_output_path = os.path.join(out_dir_run, "config.yaml")
with open(_output_path, "w+") as f:
OmegaConf.save(config=cfg, f=f)
logging.info(f"Config saved to {_output_path}")
# -------------------- Gradient accumulation steps --------------------
eff_bs = cfg.dataloader.effective_batch_size
accumulation_steps = eff_bs / cfg.dataloader.max_train_batch_size / world_size
assert int(accumulation_steps) == accumulation_steps, f"eff_bs must be divided by (max_train_batch_size * world_size) \n but Effective batch size: {eff_bs}, world_size: {world_size}"
accumulation_steps = int(accumulation_steps)
logging.info(
f"Effective batch size: {eff_bs}, accumulation steps: {accumulation_steps}, world_size: {world_size}"
)
# -------------------- Data --------------------
loader_seed = cfg.dataloader.seed
if loader_seed is None:
loader_generator = None
else:
loader_generator = torch.Generator().manual_seed(loader_seed)
# Training dataset
train_dataset = get_rec_dataset(
cfg_data.train,
)
if "mixed" == cfg_data.train.name:
dataset_ls = train_dataset
assert len(cfg_data.train.prob_ls) == len(
dataset_ls
), "Lengths don't match: `prob_ls` and `dataset_list`"
concat_dataset = ConcatDataset(dataset_ls)
mixed_sampler = MixedBatchSampler(
src_dataset_ls=dataset_ls,
batch_size_per_gpu=cfg.dataloader.max_train_batch_size,
drop_last=True,
prob=cfg_data.train.prob_ls,
shuffle=True,
generator=loader_generator,
)
train_loader = DataLoader(
concat_dataset,
batch_sampler=mixed_sampler,
num_workers=cfg.dataloader.num_workers,
)
else:
train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank)
train_loader = DataLoader(
dataset=train_dataset,
batch_size=cfg.dataloader.max_train_batch_size,
sampler=train_sampler,
num_workers=cfg.dataloader.num_workers,
shuffle=False, # DistributedSampler already process shuffle
generator=loader_generator,
)
if rank == 0:
# Validation dataset
val_loaders: List[DataLoader] = []
for _val_dic in cfg_data.val:
_val_dataset = get_rec_dataset(
_val_dic,
)
_val_loader = DataLoader(
dataset=_val_dataset,
batch_size=1,
shuffle=False,
num_workers=cfg.dataloader.num_workers,
)
val_loaders.append(_val_loader)
# Visualization dataset
vis_loaders: List[DataLoader] = []
if hasattr(cfg_data, "vis"):
for _vis_dic in cfg_data.vis:
_vis_dataset = get_rec_dataset(
_vis_dic,
)
_vis_loader = DataLoader(
dataset=_vis_dataset,
batch_size=8,
shuffle=False,
num_workers=cfg.dataloader.num_workers,
)
vis_loaders.append(_vis_loader)
else:
vis_loaders = None
else:
val_loaders = None
vis_loaders = None
# -------------------- Model --------------------
_pipeline_kwargs = cfg.pipeline.kwargs if hasattr(cfg.pipeline, "kwargs") and cfg.pipeline.kwargs is not None else {}
pipeline_cls: DiffusionPipeline = get_pipeline_cls(cfg.pipeline.name)
if cfg.pipeline.name == "TianmoucSingleStageReconstructionPipeline":
_pipeline_component = {}
logging.info(f"Initializing {cfg.pipeline.name} from {cfg.model.pretrained_path}")
scheduler = DDIMScheduler.from_pretrained(os.path.join(base_ckpt_dir, cfg.model.pretrained_path, "scheduler"))
# Select UNet class
if cfg.model.name == "TianmoucRec_BRDM": # multi frame bi-directional recurrent reconstruction
cls = UNet2DModelBTCHW
elif cfg.model.name == "TianmoucRec_Base" or cfg.model.name == "TianmoucRec_SR": # single Frame reconstruction
cls = UNet2DModel
else:
raise NotImplementedError(f"Model {cfg.model.name} not implemented")
# Initialize UNet
try:
unet = cls.from_pretrained(os.path.join(base_ckpt_dir, cfg.model.pretrained_path, "unet"))
except:
logging.info(f"Initializing {cfg.pipeline.name} from pretrained {cfg.model.pretrained_path} error, only use config and init weights!!!")
unet = cls.from_config(os.path.join(base_ckpt_dir, cfg.model.pretrained_path, "unet"))
# Maybe has pretrained weights
if hasattr(cfg.trainer, "unet_pretrained_path"):
logging.info(f"Load net parameters from {cfg.trainer.unet_pretrained_path}")
try:
# Load the pretrained state dict
state_dict = torch.load(
cfg.trainer.unet_pretrained_path,
map_location=device
)
# Define the keys to remove (if necessary)
keys_to_remove = {} # "conv_in.weight" Can add other keys here if needed
state_dict = {k: v for k, v in state_dict.items() if k not in keys_to_remove}
# Load the model's current state dict
model_state_dict = unet.state_dict()
# Initialize the lists to track mismatched keys
mismatched_keys = []
valid_state_dict = {}
# Iterate through the state_dict and check the size of each tensor
for k, v in state_dict.items():
if k in model_state_dict:
if v.shape == model_state_dict[k].shape:
valid_state_dict[k] = v
else:
mismatched_keys.append(k)
else:
mismatched_keys.append(k)
# Log any mismatched keys
if mismatched_keys:
logging.info(f"Mismatched keys (size mismatch): {mismatched_keys}")
# Load the valid state dict into the model
missing_keys, unexpected_keys = unet.load_state_dict(valid_state_dict, strict=False)
# Log missing and unexpected keys
if missing_keys:
logging.info(f"Missing keys: {missing_keys}")
if unexpected_keys:
logging.info(f"Unexpected keys: {unexpected_keys}")
except Exception as e:
logging.info(f"Error loading state_dict: {e}")
# Initialize pipeline
model = pipeline_cls(
unet, scheduler, **_pipeline_kwargs,
)
else:
raise NotImplementedError
# -------------------- Trainer --------------------
# Exit time
if args.exit_after > 0:
t_end = t_start + timedelta(minutes=args.exit_after)
logging.info(f"Will exit at {t_end}")
else:
t_end = None
_trainer_kwargs = cfg.trainer.kwargs if hasattr(cfg.trainer, "kwargs") and cfg.trainer.kwargs is not None else {}
trainer_cls = get_trainer_cls(cfg.trainer.name)
logging.debug(f"Trainer: {trainer_cls}")
logging.debug(f"_trainer_kwargs: {_trainer_kwargs}")
trainer = trainer_cls(
cfg=cfg,
model=model,
train_dataloader=train_loader,
device=device,
base_ckpt_dir=base_ckpt_dir,
out_dir_ckpt=out_dir_ckpt,
out_dir_eval=out_dir_eval,
out_dir_vis=out_dir_vis,
accumulation_steps=accumulation_steps,
val_dataloaders=val_loaders,
vis_dataloaders=vis_loaders,
rank = rank,
world_size = world_size,
**_trainer_kwargs
)
# make DDP
if world_size > 1:
trainer.makeDDPmodel()
# -------------------- Checkpoint --------------------
if resume_run is not None:
try:
pass
trainer.load_checkpoint(
resume_run, load_trainer_state=True, resume_lr_scheduler=True
)
except:
logging.info(f"WARNING optimizer state mismatch, try to use `load_trainer_state=False` ")
trainer.load_checkpoint(
resume_run, load_trainer_state=False, resume_lr_scheduler=True
)
# -------------------- Training & Evaluation Loop --------------------
try:
trainer.train(t_end=t_end)
except BaseException as e:
logging.exception(e)
if "__main__" == __name__:
t_start = datetime.now()
print(f"start at {t_start}")
# -------------------- Arguments --------------------
parser = argparse.ArgumentParser(description="Train your cute model!")
parser.add_argument(
"--config",
type=str,
default=None,
help="Path to config file.",
)
parser.add_argument(
"--resume_run",
action="store",
default=None,
help="Path of checkpoint to be resumed. If given, will ignore --config, and checkpoint in the config",
)
parser.add_argument(
"--output_dir", type=str, default=None, help="directory to save checkpoints"
)
parser.add_argument("--no_cuda", action="store_true", help="Do not use cuda.")
parser.add_argument(
"--exit_after",
type=int,
default=-1,
help="Save checkpoint and exit after X minutes.",
)
parser.add_argument("--no_wandb", action="store_true", help="run without wandb")
parser.add_argument(
"--base_ckpt_dir",
type=str,
default=None,
help="directory of pretrained checkpoint, if None, use os.environ[`BASE_CKPT_DIR`]",
)
parser.add_argument(
"--add_datetime_prefix",
action="store_true",
help="Add datetime to the output folder name, can reuse the same config file and avoid dictionary existed",
)
args = parser.parse_args()
world_size = torch.cuda.device_count()
if world_size > 1:
logging.info(f"Total GPUs available: {world_size}")
else:
logging.info(f"Single GPU training")
if world_size > 1:
try:
mp.spawn(main_worker, args=(world_size, args, ), nprocs=world_size)
except:
cleanup()
else:
main_worker(0, 1, args)