-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
119 lines (101 loc) · 3.41 KB
/
train.py
File metadata and controls
119 lines (101 loc) · 3.41 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import argparse
import json
import os
from ast import literal_eval
from collections import defaultdict
from keras.callbacks import TensorBoard
from keras.callbacks import ModelCheckpoint
from sklearn.model_selection import train_test_split
import model as SegModel
from generator import CarlaBatchGenerator
from utils import load_dataset
from utils import BASE_DIR
def train(**kwargs):
model_name = kwargs['model_name']
model_dir = kwargs.get('model_dir', os.path.join(BASE_DIR, 'models'))
input_images = kwargs['input_images']
target_images = kwargs['target_images']
architecture = kwargs.get('architecture', 'fpn')
backbone = kwargs.get('backbone', 'vgg16')
freeze_encoder = kwargs.get('freeze_encoder', False)
input_shape = (
kwargs['input_shape']['height'],
kwargs['input_shape']['width'],
kwargs['input_shape']['channels']
)
learning_rate = kwargs.get('learning_rate', 1e-5)
weights = kwargs.get('weights', None)
if weights:
weights = os.path.join(model_dir, weights)
train_val_ratio = kwargs.get('train_val_ratio', 0.1)
batch_size = kwargs.get('batch_size', 16)
epochs = kwargs.get('epochs', 10)
num_classes = kwargs['num_classes']
image_size = (kwargs['image_size']['width'], kwargs['image_size']['height'])
encoding = kwargs['encoding']
workers = kwargs.get('workers', 1)
multiprocessing = kwargs.get('multiprocessing', False)
dataset = load_dataset(input_images)
train_set, val_set = train_test_split(dataset, test_size=train_val_ratio, random_state=99)
train_gen = CarlaBatchGenerator(
dataset=train_set,
input_dir=input_images,
target_dir=target_images,
batch_size=batch_size,
image_size=image_size,
encoding=encoding
)
val_gen = CarlaBatchGenerator(
dataset=val_set,
input_dir=input_images,
target_dir=target_images,
batch_size=batch_size,
image_size=image_size,
encoding=encoding
)
model = SegModel.build(
architecture=architecture,
backbone=backbone,
weights=weights,
freeze_encoder=freeze_encoder,
input_shape=input_shape,
num_classes=num_classes,
learning_rate=learning_rate
)
filepath = os.path.join(model_dir, f"{model_name}.h5")
model.fit_generator(
generator=train_gen,
validation_data=val_gen,
epochs=epochs,
workers=workers,
use_multiprocessing=multiprocessing,
shuffle=True,
verbose=1,
callbacks=[
TensorBoard(
batch_size=batch_size,
update_freq=20 * batch_size
),
ModelCheckpoint(
filepath=filepath,
save_best_only=True,
)
]
)
if __name__ == "__main__":
arg_parser = argparse.ArgumentParser(
description='Road segmentation module training routine')
arg_parser.add_argument(
'-c',
'--conf',
help='Path to the configuration file',
default='config.json',
)
args = arg_parser.parse_args()
with open(args.conf, 'r') as f:
config = json.load(f)
CARLA_ENCODING = defaultdict(lambda: [1, 0, 0])
for key, val in config['common']['carla_encoding'].items():
CARLA_ENCODING[literal_eval(key)] = val
config['training']['encoding'] = CARLA_ENCODING
train(**config['training'])