-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathtrain.py
More file actions
31 lines (23 loc) · 853 Bytes
/
train.py
File metadata and controls
31 lines (23 loc) · 853 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from omegaconf import OmegaConf
from src.trainer import VideoTokenizerTrainer
from argparse import ArgumentParser
def get_config_cli():
cli_conf = OmegaConf.from_cli()
yaml_conf = OmegaConf.load(cli_conf.config)
conf = OmegaConf.merge(yaml_conf, cli_conf)
return conf
def get_config(config_path):
conf = OmegaConf.load(config_path)
return conf
def parse_args():
parser = ArgumentParser()
parser.add_argument("--model_config_path", type=str, required=True)
parser.add_argument("--trainer_config_path", type=str, required=True)
args = parser.parse_args()
return args
def main():
args = parse_args()
trainer = VideoTokenizerTrainer(model_config=get_config(args.model_config_path), trainer_config=get_config(args.trainer_config_path))
trainer.train()
if __name__ == "__main__":
main()