-
Notifications
You must be signed in to change notification settings - Fork 11
Expand file tree
/
Copy pathmain.py
More file actions
126 lines (111 loc) · 4.26 KB
/
main.py
File metadata and controls
126 lines (111 loc) · 4.26 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
"""Main script for training and testing."""
import argparse
import os
from pathlib import Path
import sys
import torch
from datasets import fetch_dataset_class
from modeling.policy import fetch_model_class
from utils.common_utils import str2bool, str_none
from utils.trainers import fetch_train_tester
def parse_arguments():
parser = argparse.ArgumentParser("Parse arguments for main.py")
# Tuples: (name, type, default)
arguments = [
# Dataset/loader arguments
('train_data_dir', Path, ''),
('eval_data_dir', Path, ''),
('train_instructions', Path, ''),
('val_instructions', Path, ''),
('dataset', str, "Peract"),
('num_workers', int, 4),
('batch_size', int, 64),
('batch_size_val', int, 64),
('chunk_size', int, 1),
('memory_limit', float, 8), # cache limit in GB
# Logging arguments
('base_log_dir', Path, Path(__file__).parent / "train_logs"),
('exp_log_dir', Path, "exp"),
('run_log_dir', Path, "run"),
# Training and testing arguments
('checkpoint', str_none, None),
('val_freq', int, 4000),
('interm_ckpt_freq', int, 1000000),
('eval_only', str2bool, False),
('lr', float, 1e-4),
('backbone_lr', float, 1e-4),
('lr_scheduler', str, "constant"),
('wd', float, 5e-3),
('train_iters', int, 600000),
('use_compile', str2bool, False),
('use_ema', str2bool, False),
('lv2_batch_size', int, 1),
# Model arguments: general policy type
('model_type', str, 'denoise3d'),
('bimanual', str2bool, False),
('keypose_only', str2bool, True),
('pre_tokenize', str2bool, True),
('custom_img_size', int, None),
('workspace_normalizer_buffer', float, 0.04),
# Model arguments: encoder
('backbone', str, "clip"),
('finetune_backbone', str2bool, False),
('finetune_text_encoder', str2bool, False),
('fps_subsampling_factor', int, 5),
# Model arguments: encoder and head
('embedding_dim', int, 120), # divisible by num_attn_heads
('num_attn_heads', int, 8),
('num_vis_instr_attn_layers', int, 3),
('num_history', int, 1),
# Model arguments: head
('num_shared_attn_layers', int, 4),
('relative_action', str2bool, False),
('rotation_format', str, 'quat_xyzw'),
('denoise_timesteps', int, 10),
('denoise_model', str, "rectified_flow")
]
for arg in arguments:
parser.add_argument(f'--{arg[0]}', type=arg[1], default=arg[2])
return parser.parse_args()
def suppress_output_on_non_main():
if int(os.environ.get("RANK", 0)) != 0:
sys.stdout = open(os.devnull, "w")
sys.stderr = open(os.devnull, "w")
if __name__ == '__main__':
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
# Arguments
args = parse_arguments()
print("Arguments:")
print(args)
print("-" * 100)
log_dir = args.base_log_dir / args.exp_log_dir / args.run_log_dir
args.log_dir = log_dir
log_dir.mkdir(exist_ok=True, parents=True)
print("Logging:", log_dir)
print(
"Available devices (CUDA_VISIBLE_DEVICES):",
os.environ.get("CUDA_VISIBLE_DEVICES")
)
print("Device count:", torch.cuda.device_count())
args.local_rank = int(os.environ["LOCAL_RANK"])
suppress_output_on_non_main()
# DDP initialization
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(backend='nccl', init_method='env://')
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# Select dataset and model classes
dataset_class = fetch_dataset_class(args.dataset)
model_class = fetch_model_class(args.model_type)
# Run
TrainTester = fetch_train_tester(args.dataset)
train_tester = TrainTester(args, dataset_class, model_class)
train_tester.main()
# Safe program termination
if torch.distributed.is_initialized():
torch.cuda.empty_cache()
torch.distributed.destroy_process_group()