From b49305a8785d55831787c8dac2e446ccb6ce4fad Mon Sep 17 00:00:00 2001 From: PooriyaSanaie <88898162+ELECEngineer@users.noreply.github.com> Date: Tue, 20 Jun 2023 23:51:15 +0330 Subject: [PATCH] A3_PooriyaSanaie_Fatemeh Aghabarari_Seyed Javad Hosseini --- src/adverse_weather_classification/train.py | 96 ++++++++------------- 1 file changed, 38 insertions(+), 58 deletions(-) diff --git a/src/adverse_weather_classification/train.py b/src/adverse_weather_classification/train.py index 8822b6b..26660ba 100644 --- a/src/adverse_weather_classification/train.py +++ b/src/adverse_weather_classification/train.py @@ -10,7 +10,6 @@ from mock import Mock import matplotlib.pyplot as plt - class TrainHyperParameters: def __init__(self, input_shape: Tuple[int, int, int] = (256, 256, 3), number_of_classes: int = 2, learning_rate: float = 0.001, batch_size: int = 32, number_of_epochs: int = 3) -> None: @@ -27,7 +26,6 @@ def __init__(self, data_dir: str, checkpoint_dir: str = 'output/checkpoints') -> super().__init__() self.model = None - # Set the seed for reproducibility np.random.seed(42) tf.random.set_seed(42) @@ -38,7 +36,6 @@ def form_data_generator(self) -> Tuple[ImageDataGenerator, ImageDataGenerator]: train_dir = os.path.join(self.data_dir, 'train') test_dir = os.path.join(self.data_dir, 'test') - # Define the data generators for training and validation train_datagen = ImageDataGenerator(rescale=1. / 255) test_datagen = ImageDataGenerator(rescale=1. / 255) @@ -71,63 +68,46 @@ def model_builder(self): ]) def train(self, train_generator, test_generator): - # Define the optimizer and loss function optimizer = keras.optimizers.Adam(lr=self.hyperparameters.learning_rate) loss_fn = keras.losses.CategoricalCrossentropy() - - # Compile the model if self.model is not None: self.model.compile(optimizer=optimizer, loss=loss_fn, metrics=['accuracy']) else: - raise ValueError('Model is not defined. Please call model_builder() first.') - - # Set up a checkpoint to save the best model weights - if os.path.exists(self.checkpoint_dir) is False: - os.makedirs(self.checkpoint_dir) - checkpoint_path = os.path.join(self.checkpoint_dir, 'best_model.h5') - - checkpoint_cb = ModelCheckpoint(checkpoint_path, monitor='val_accuracy', verbose=1, save_best_only=True, - mode='max') - - # Train the model - history = self.model.fit(train_generator, - epochs=self.hyperparameters.number_of_epochs, - validation_data=test_generator, - callbacks=[checkpoint_cb]) - - # Save the model architecture - model_dir = os.path.join(self.checkpoint_dir, 'model') - if os.path.exists(model_dir) is False: - os.makedirs(model_dir) - model_path = os.path.join(model_dir, 'model.json') - model_json = self.model.to_json() - with open(model_path, 'w') as json_file: - json_file.write(model_json) - # plot loss and accuracy on train and validation set - self.plot_history(history) - - def plot_history(self, history): - matplotlib.use('Agg') - plt.figure(figsize=(10, 5)) - plt.subplot(1, 2, 1) - plt.plot(history.history['loss'], label='train') - plt.plot(history.history['val_loss'], label='test') - plt.title('Loss') - plt.legend() - plt.subplot(1, 2, 2) - plt.plot(history.history['accuracy'], label='train') - plt.plot(history.history['val_accuracy'], label='test') - plt.title('Accuracy') - plt.legend() - plt.savefig(os.path.join(self.checkpoint_dir, 'loss_accuracy.png'), dpi=300) - - def exec(self): - train_generator, test_generator = self.form_data_generator() - self.model_builder() - self.train(train_generator, test_generator) - - -if __name__ == '__main__': - data_dir_ = '/home/ahv/PycharmProjects/Visual-Inertial-Odometry/simulation/CARLA/output/root_dir' - train_custom_cnn = TrainCustomCNN(data_dir_) - train_custom_cnn.exec() + raise ValueError("Model is not built.") + + checkpoint_callback = ModelCheckpoint( + filepath=os.path.join(self.checkpoint_dir, "model_checkpoint"), + monitor='val_loss', + save_best_only=True, + save_weights_only=False, + verbose=1) + history = self.model.fit( + train_generator, + epochs=self.hyperparameters.number_of_epochs, + validation_data=test_generator, + callbacks=[checkpoint_callback]) + + return history + + +if __name__ == "__main__": + + data_directory = '/home/ahv/PycharmProjects/Visual-Inertial-Odometry/simulation/CARLA/output/data_directory' + checkpoint_directory = '/home/ahv/PycharmProjects/Visual-Inertial-Odometry/simulation/CARLA/output/checkpoint_directory" + + trainer = TrainCustomCNN(data_dir=data_directory, checkpoint_dir=checkpoint_directory) + + trainer.model_builder() + + train_generator, test_generator = trainer.form_data_generator() + + history = trainer.train(train_generator, test_generator) + + # Plot the training and validation accuracy over epochs + plt.plot(history.history['accuracy']) + plt.plot(history.history['val_accuracy']) + plt.title('Model Accuracy') + plt.xlabel('Epoch') + plt.ylabel('Accuracy') + plt.legend(['Train', 'Validation'], loc='upper left') + plt.show()