diff --git a/src/adverse_weather_classification/weather_classification.py b/src/adverse_weather_classification/weather_classification.py index 777ddfe..9db16c9 100644 --- a/src/adverse_weather_classification/weather_classification.py +++ b/src/adverse_weather_classification/weather_classification.py @@ -8,50 +8,48 @@ class AdverseWeatherClassifier: - def __init__(self, model_path, model_input_size: Tuple[int, int] = (256, 256)) -> None: + def __init__(self, model_path: str, model_input_size: Tuple[int, int] = (256, 256)): self.model = None self.model_path = model_path self.model_input_size = model_input_size self.class_labels = ['day', 'night'] - def load(self): + def load_model(self): start_time = time.time() self.model = keras.models.load_model(self.model_path) - print("Model loaded in {} seconds.".format(time.time() - start_time)) + print("Model loaded in {:.2f} seconds.".format(time.time() - start_time)) def __preprocess_frame__(self, frame: np.ndarray) -> np.ndarray: - frame = cv2.resize(frame, self.model_input_size) - frame = frame.astype('float32') / 255.0 + frame = cv2.resize(frame, self.model_input_size).astype('float32') / 255.0 frame = np.expand_dims(frame, axis=0) return frame def predict(self, preprocessed_frame: np.ndarray) -> str: - if self.model is not None: - predictions = self.model.predict(preprocessed_frame) - else: - raise ValueError("Model is not loaded. Please call load() method first.") + if self.model is None: + raise ValueError("Model is not loaded. Please call load_model() method first.") + predictions = self.model.predict(preprocessed_frame) predicted_class = self.class_labels[np.argmax(predictions)] return predicted_class - def exec(self, frame: np.ndarray) -> str: - preprocessed_frame = self.__preprocess_frame__(frame) + def execute(self, frame: np.ndarray) -> str: + preprocessed_frame = self.preprocess_frame(frame) predicted_class = self.predict(preprocessed_frame) return predicted_class if __name__ == "__main__": - img_dir = "/home/ahv/PycharmProjects/Visual-Inertial-Odometry/simulation/CARLA/output/root_dir/testing_imgs" - model_path_ = "/src/adverse_weather_classification/output/checkpoints/best_model.h5" - adverse_weather_classifier = AdverseWeatherClassifier(model_path_) - adverse_weather_classifier.load() + img_dir = "/home/ahv/PycharmProjects/Visual-Inertial-Odometry/simulation/CARLA/output/testing_imgs" + model_path = "/home/ahv/PycharmProjects/Visual-Inertial-Odometry/simulation/CARLA/output/checkpoint_directory" + adverse_weather_classifier = AdverseWeatherClassifier(model_path) + adverse_weather_classifier.load_model() for root, dirs, files in os.walk(img_dir): - # shuffle the files + # Shuffle the files np.random.shuffle(files) for file in files: img = cv2.imread(os.path.join(root, file)) visualization_img = img.copy() - predicted_class_ = adverse_weather_classifier.exec(img) - cv2.putText(visualization_img, predicted_class_, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2) + predicted_class = adverse_weather_classifier.execute(img) + cv2.putText(visualization_img, predicted_class, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2) cv2.imshow("Image", visualization_img) cv2.waitKey(0) cv2.destroyAllWindows()