This project implements a full-stack Machine Learning application for classifying images of cats and dogs. The solution leverages state-of-the-art architectures (Vision Transformers) for high-performance inference and a custom Convolutional Neural Network (CNN) trained with PyTorch Lightning to demonstrate deep learning engineering capabilities. The system is containerized with Docker and exposes a FastAPI backend consumed by a user-friendly Streamlit frontend.
Two distinct modeling approaches were implemented to satisfy project requirements:
-
Pre-trained Model (Vision Transformer):
-
Architecture:
google/vit-base-patch16-224from Hugging Face. - Rationale: ViTs capture global context better than traditional CNNs using self-attention mechanisms. Using a pre-trained model ensures production-grade accuracy (likely >99% on this task) out of the box with zero training cost.
-
Implementation: Wrapped in
HuggingFaceViTfor easy inference.
-
Architecture:
-
Custom Model (CNN):
- Framework: PyTorch & PyTorch Lightning.
-
Architecture: A ResNet-like architecture (
SimpleCNN) defined incnn_model.py. It uses 3 convolutional blocks with MaxPool and ReLU, followed by a fully connected classifier. -
Training Loop: Managed by
CatDogClassifier(LightningModule). This handles the training steps, validation logic, logging (loss/accuracy), and optimization (AdamW + Cosine Annealing). -
Data Pipeline: A custom
CatDogDataModulehandles data loading. It processes a Pandas DataFrame containing image paths and labels, performs intelligent splitting (Train/Val/Test), and applies preprocessing transforms (Resize to$256 \times 256$ , ImageNet Normalization).
The API is designed for scalability and modularity:
- Framework: FastAPI was chosen for its high performance (Starlette) and automatic documentation (Swagger UI).
- Endpoints:
/predict/pretrained: routes request to the ViT model./predict/custom: routes request to the custom CNN.
- Design Choice: Models are loaded into the app state on startup (
lifespanhandler) to prevent reloading latency on every request.
A functional UI allows users to upload images and select their preferred model. This decouples the ML logic from user interaction, enabling easy testing and usage.
- Dockerisation: A multi-stage
Dockerfileis provided. It installs valid dependencies viapyproject.tomlusinguv(a fast pip replacement) or standard pip, ensuring a reproducible environment. - Project Structure: Refactored from a complex template to a clean, domain-centric structure:
src/msensis_mle/ ├── data/ # DataModules ├── models/ # Torch & Lightning definitions deployment/ ├── api/ # FastAPI app ├── docker/ # Container config notebooks/ # EDA & Training experiments
We implemented standard Normalization (
- Standardization (Z-score) is generally preferred over MinMax for Deep Learning because it centers input data around zero, preventing vanishing/exploding gradients in initial layers and speeding up convergence.
- ImageNet Stats: Since we are using a ViT (pre-trained on ImageNet) and a CNN (which often benefits from these stats even if trained from scratch due to transfer learning potential), sticking to industry-standard ImageNet statistics is the robust choice.
- Type Hinting: All logic uses Python strict type hints (
str,pd.DataFrame, etc.) for better maintainability. - Modular Design: The separation of
cnn_model.py(pure PyTorch) andlightning_module.py(Wrapper) allows the model to be used independently of the training framework if needed.
-
Data Augmentation: Currently, only resizing is applied. Adding
RandomHorizontalFliporColorJitterwould improve the Custom CNN's robustness. -
Hyperparameter Tuning: We use a static Learning Rate (
$1e-3$ ). Integrating a scheduler could auto-optimize this. - TensorRT/ONNX: For a production deployment, exporting the models to ONNX Runtime or TensorRT would significantly reduce inference latency.