Skip to content

Latest commit

 

History

History
64 lines (49 loc) · 4.51 KB

File metadata and controls

64 lines (49 loc) · 4.51 KB

MSENSIS MLE Task - Solution Report

1. Executive Summary

This project implements a full-stack Machine Learning application for classifying images of cats and dogs. The solution leverages state-of-the-art architectures (Vision Transformers) for high-performance inference and a custom Convolutional Neural Network (CNN) trained with PyTorch Lightning to demonstrate deep learning engineering capabilities. The system is containerized with Docker and exposes a FastAPI backend consumed by a user-friendly Streamlit frontend.

2. Technical Architecture

2.1 Model Development

Two distinct modeling approaches were implemented to satisfy project requirements:

  • Pre-trained Model (Vision Transformer):

    • Architecture: google/vit-base-patch16-224 from Hugging Face.
    • Rationale: ViTs capture global context better than traditional CNNs using self-attention mechanisms. Using a pre-trained model ensures production-grade accuracy (likely >99% on this task) out of the box with zero training cost.
    • Implementation: Wrapped in HuggingFaceViT for easy inference.
  • Custom Model (CNN):

    • Framework: PyTorch & PyTorch Lightning.
    • Architecture: A ResNet-like architecture (SimpleCNN) defined in cnn_model.py. It uses 3 convolutional blocks with MaxPool and ReLU, followed by a fully connected classifier.
    • Training Loop: Managed by CatDogClassifier (LightningModule). This handles the training steps, validation logic, logging (loss/accuracy), and optimization (AdamW + Cosine Annealing).
    • Data Pipeline: A custom CatDogDataModule handles data loading. It processes a Pandas DataFrame containing image paths and labels, performs intelligent splitting (Train/Val/Test), and applies preprocessing transforms (Resize to $256 \times 256$, ImageNet Normalization).

2.2 Backend Engineering (FastAPI)

The API is designed for scalability and modularity:

  • Framework: FastAPI was chosen for its high performance (Starlette) and automatic documentation (Swagger UI).
  • Endpoints:
    • /predict/pretrained: routes request to the ViT model.
    • /predict/custom: routes request to the custom CNN.
  • Design Choice: Models are loaded into the app state on startup (lifespan handler) to prevent reloading latency on every request.

2.3 Frontend (Streamlit)

A functional UI allows users to upload images and select their preferred model. This decouples the ML logic from user interaction, enabling easy testing and usage.

2.4 DevOps & Deployment

  • Dockerisation: A multi-stage Dockerfile is provided. It installs valid dependencies via pyproject.toml using uv (a fast pip replacement) or standard pip, ensuring a reproducible environment.
  • Project Structure: Refactored from a complex template to a clean, domain-centric structure:
    src/msensis_mle/
    ├── data/       # DataModules
    ├── models/     # Torch & Lightning definitions
    deployment/
    ├── api/        # FastAPI app
    ├── docker/     # Container config
    notebooks/      # EDA & Training experiments
    

3. Methodological Notes

3.1 Data Normalization

We implemented standard Normalization ($mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]$). Decision rationale:

  • Standardization (Z-score) is generally preferred over MinMax for Deep Learning because it centers input data around zero, preventing vanishing/exploding gradients in initial layers and speeding up convergence.
  • ImageNet Stats: Since we are using a ViT (pre-trained on ImageNet) and a CNN (which often benefits from these stats even if trained from scratch due to transfer learning potential), sticking to industry-standard ImageNet statistics is the robust choice.

3.2 Code Quality & Refactoring

  • Type Hinting: All logic uses Python strict type hints (str, pd.DataFrame, etc.) for better maintainability.
  • Modular Design: The separation of cnn_model.py (pure PyTorch) and lightning_module.py (Wrapper) allows the model to be used independently of the training framework if needed.

4. Future Improvements

  • Data Augmentation: Currently, only resizing is applied. Adding RandomHorizontalFlip or ColorJitter would improve the Custom CNN's robustness.
  • Hyperparameter Tuning: We use a static Learning Rate ($1e-3$). Integrating a scheduler could auto-optimize this.
  • TensorRT/ONNX: For a production deployment, exporting the models to ONNX Runtime or TensorRT would significantly reduce inference latency.