Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 38 additions & 58 deletions src/adverse_weather_classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from mock import Mock
import matplotlib.pyplot as plt


class TrainHyperParameters:
def __init__(self, input_shape: Tuple[int, int, int] = (256, 256, 3), number_of_classes: int = 2,
learning_rate: float = 0.001, batch_size: int = 32, number_of_epochs: int = 3) -> None:
Expand All @@ -27,7 +26,6 @@ def __init__(self, data_dir: str, checkpoint_dir: str = 'output/checkpoints') ->
super().__init__()
self.model = None

# Set the seed for reproducibility
np.random.seed(42)
tf.random.set_seed(42)

Expand All @@ -38,7 +36,6 @@ def form_data_generator(self) -> Tuple[ImageDataGenerator, ImageDataGenerator]:
train_dir = os.path.join(self.data_dir, 'train')
test_dir = os.path.join(self.data_dir, 'test')

# Define the data generators for training and validation
train_datagen = ImageDataGenerator(rescale=1. / 255)
test_datagen = ImageDataGenerator(rescale=1. / 255)

Expand Down Expand Up @@ -71,63 +68,46 @@ def model_builder(self):
])

def train(self, train_generator, test_generator):
# Define the optimizer and loss function
optimizer = keras.optimizers.Adam(lr=self.hyperparameters.learning_rate)
loss_fn = keras.losses.CategoricalCrossentropy()

# Compile the model
if self.model is not None:
self.model.compile(optimizer=optimizer, loss=loss_fn, metrics=['accuracy'])
else:
raise ValueError('Model is not defined. Please call model_builder() first.')

# Set up a checkpoint to save the best model weights
if os.path.exists(self.checkpoint_dir) is False:
os.makedirs(self.checkpoint_dir)
checkpoint_path = os.path.join(self.checkpoint_dir, 'best_model.h5')

checkpoint_cb = ModelCheckpoint(checkpoint_path, monitor='val_accuracy', verbose=1, save_best_only=True,
mode='max')

# Train the model
history = self.model.fit(train_generator,
epochs=self.hyperparameters.number_of_epochs,
validation_data=test_generator,
callbacks=[checkpoint_cb])

# Save the model architecture
model_dir = os.path.join(self.checkpoint_dir, 'model')
if os.path.exists(model_dir) is False:
os.makedirs(model_dir)
model_path = os.path.join(model_dir, 'model.json')
model_json = self.model.to_json()
with open(model_path, 'w') as json_file:
json_file.write(model_json)
# plot loss and accuracy on train and validation set
self.plot_history(history)

def plot_history(self, history):
matplotlib.use('Agg')
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.plot(history.history['loss'], label='train')
plt.plot(history.history['val_loss'], label='test')
plt.title('Loss')
plt.legend()
plt.subplot(1, 2, 2)
plt.plot(history.history['accuracy'], label='train')
plt.plot(history.history['val_accuracy'], label='test')
plt.title('Accuracy')
plt.legend()
plt.savefig(os.path.join(self.checkpoint_dir, 'loss_accuracy.png'), dpi=300)

def exec(self):
train_generator, test_generator = self.form_data_generator()
self.model_builder()
self.train(train_generator, test_generator)


if __name__ == '__main__':
data_dir_ = '/home/ahv/PycharmProjects/Visual-Inertial-Odometry/simulation/CARLA/output/root_dir'
train_custom_cnn = TrainCustomCNN(data_dir_)
train_custom_cnn.exec()
raise ValueError("Model is not built.")

checkpoint_callback = ModelCheckpoint(
filepath=os.path.join(self.checkpoint_dir, "model_checkpoint"),
monitor='val_loss',
save_best_only=True,
save_weights_only=False,
verbose=1)
history = self.model.fit(
train_generator,
epochs=self.hyperparameters.number_of_epochs,
validation_data=test_generator,
callbacks=[checkpoint_callback])

return history


if __name__ == "__main__":

data_directory = '/home/ahv/PycharmProjects/Visual-Inertial-Odometry/simulation/CARLA/output/data_directory'
checkpoint_directory = '/home/ahv/PycharmProjects/Visual-Inertial-Odometry/simulation/CARLA/output/checkpoint_directory"

trainer = TrainCustomCNN(data_dir=data_directory, checkpoint_dir=checkpoint_directory)

trainer.model_builder()

train_generator, test_generator = trainer.form_data_generator()

history = trainer.train(train_generator, test_generator)

# Plot the training and validation accuracy over epochs
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('Model Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend(['Train', 'Validation'], loc='upper left')
plt.show()