Skip to content

lursz/BackgroundSegmentationUNet

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

15 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

BackgroundSegmentationUNet

Real-time background removal for video calls and images using deep learning. This project provides fast and accurate alpha matte extraction with U-Net and MobileNet-based architectures.

The aim was for the network to be lightweight enough to run in real-time on most mobile phones.


Quick Install

This project uses uv for fast dependency management. Python 3.13+ is required.

uv sync --all-groups

Project Structure

BackgroundSegmentationUNet/
├── pyproject.toml
├── src/
│   ├── main.ipynb           # Training pipeline
│   ├── inference.ipynb      # ONNX inference & live webcam demo
│   ├── loss/
│   │   └── loss.py          # CombinedLoss (BCE + IoU + L1)
│   └── models/
│       ├── unet.py          # Vanilla U-Net (31M params)
│       ├── mobilenetv2.py   # MobileNetV2 encoder + decoder (3.5M params)
│       ├── mobilenetv4.py   # MobileNetV4 encoder + decoder
│       └── autoencoder.py   # Restricted encoder-decoder (ONNX-friendly)
└── input/                   # Dataset root (not tracked)

High-Level Pipeline

flowchart LR
    A[RGB Image] --> B[Augmentation<br/>Albumentations]
    B --> C[Encoder]
    C --> D[Decoder]
    D --> E[Sigmoid]
    E --> F[Alpha Matte]
    F --> G[Composite<br/>over background]

    style A fill:#4a90d9,color:#fff
    style F fill:#7cb342,color:#fff
    style G fill:#7cb342,color:#fff
Loading

Model Architectures

U-Net (31M params) — Best Accuracy

Classic encoder-decoder with skip connections. Each encoder stage halves spatial resolution while doubling channels; the decoder reverses this while fusing skip features.

flowchart TD
    Input["Input (3 x 256 x 256)"] --> Inc["DoubleConv → 64"]

    Inc --> D1["Down1: MaxPool → 128"]
    D1 --> D2["Down2: MaxPool → 256"]
    D2 --> D3["Down3: MaxPool → 512"]
    D3 --> D4["Down4: MaxPool → 512<br/>(bottleneck)"]

    D4 --> U1["Up1: Upsample → 256"]
    U1 --> U2["Up2: Upsample → 128"]
    U2 --> U3["Up3: Upsample → 64"]
    U3 --> U4["Up4: Upsample → 64"]
    U4 --> Out["Conv 1x1 → 1 channel<br/>(logits)"]

    D3 -- "skip" --> U1
    D2 -- "skip" --> U2
    D1 -- "skip" --> U3
    Inc -- "skip" --> U4

    style D4 fill:#e57373,color:#fff
    style Out fill:#7cb342,color:#fff
Loading

DoubleConv block — the fundamental building unit, used in every stage:

Conv2d(3x3) → BatchNorm → ReLU → Conv2d(3x3) → BatchNorm → ReLU

Key design choices:

  • Bilinear upsampling (default) instead of transposed convolutions — avoids checkerboard artifacts
  • Skip connections concatenate encoder features to preserve fine spatial detail
  • No bias in convolutions (BatchNorm absorbs the bias term)

MobileNetV2-UNet (3.5M params) — Best Speed

Replaces the U-Net encoder with a pretrained MobileNetV2 backbone. The decoder is a lightweight stack of transposed convolutions.

flowchart TD
    Input["Input (3 x 256 x 256)"] --> Enc

    subgraph Enc["MobileNetV2 Encoder (pretrained, ImageNet)"]
        direction TB
        E1["Depthwise Separable Convs"] --> E2["Inverted Residual Blocks"]
        E2 --> E3["Output: 1280 x 8 x 8"]
    end

    Enc --> Dec

    subgraph Dec["Custom Decoder"]
        direction TB
        U1["ConvTranspose 1280 → 96"] --> U2["ConvTranspose 96 → 32"]
        U2 --> U3["ConvTranspose 32 → 24"]
        U3 --> U4["ConvTranspose 24 → 16"]
    end

    Dec --> Final["Conv 1x1 → 1 channel"]
    Final --> Interp["Bilinear interpolate<br/>to input size"]

    style Enc fill:#4a90d9,color:#fff
    style Dec fill:#ff9800,color:#fff
    style Interp fill:#7cb342,color:#fff
Loading

Each decoder block: ConvTranspose2d(2x) → BN → ReLU → Conv2d → BN → ReLU


MobileNetV4-UNet — Cutting Edge Encoder

Same decoder design as V2, but uses MobileNetV4 Conv Small from timm (pretrained on ImageNet). Encoder outputs 960 channels instead of 1280.

flowchart LR
    subgraph Encoder
        A["MobileNetV4<br/>Conv Small<br/>(timm)"] --> B["960 x 8 x 8"]
    end

    subgraph Decoder
        B --> C["960 → 96"]
        C --> D["96 → 32"]
        D --> E["32 → 24"]
        E --> F["24 → 16"]
    end

    F --> G["Conv 1x1 → 1"]
    G --> H["Bilinear interp"]

    style Encoder fill:#4a90d9,color:#fff
    style Decoder fill:#ff9800,color:#fff
Loading

Restricted Autoencoder — ONNX Friendly

A pure encoder-decoder with no skip connections and no residuals. Designed for maximum ONNX minimal operator set with intend for running it in own custom inference engine

flowchart TD
    Input["Input (3 x 256 x 256)"] --> EB1

    subgraph Encoder
        EB1["Conv Block: 3 → 32, MaxPool"] --> EB2["Conv Block: 32 → 64, MaxPool"]
        EB2 --> EB3["Conv Block: 64 → 128, MaxPool"]
        EB3 --> EB4["Conv Block: 128 → 256, MaxPool"]
    end

    EB4 --> BN["Bottleneck<br/>256 → 512 → 256"]

    BN --> DB4

    subgraph Decoder
        DB4["ConvTranspose 256 → 128"] --> DB3["ConvTranspose 128 → 64"]
        DB3 --> DB2["ConvTranspose 64 → 32"]
        DB2 --> DB1["ConvTranspose 32 → 16"]
    end

    DB1 --> Out["Conv 1x1 → 1"]

    style BN fill:#e57373,color:#fff
    style Out fill:#7cb342,color:#fff
Loading

Loss Function

Weighted combination of three complementary losses:

flowchart LR
    Logits["Raw Logits"] --> BCE["BCE Loss<br/>weight: 1.0"]
    Logits --> Sig["Sigmoid"]
    Sig --> IoU["IoU Loss<br/>weight: 0.5"]
    Sig --> L1["L1 Loss<br/>weight: 0.5"]

    BCE --> Sum(("+"))
    IoU --> Sum
    L1 --> Sum
    Sum --> Total["Total Loss"]

    style Total fill:#e57373,color:#fff
Loading
Component Operates On Purpose
BCE Raw logits Pixel-wise classification accuracy (numerically stable)
IoU Sigmoid probs Penalises poor overlap — sharpens boundaries
L1 Sigmoid probs Encourages smooth, artifact-free gradients
Loss = 1.0 * BCE + 0.5 * IoU + 0.5 * L1

Training Pipeline

flowchart TD
    Data["Dataset<br/>clip_img + matting"] --> Split["90/10 Train/Val Split<br/>(seed=42)"]
    Split --> TL["Train Loader<br/>batch=16, workers=4"]
    Split --> VL["Val Loader<br/>batch=16, workers=4"]

    TL --> Aug["Augmentation<br/>(Albumentations)"]
    Aug --> FP["Forward Pass<br/>(AMP mixed precision)"]
    FP --> Loss["CombinedLoss"]
    Loss --> BP["Backward + GradScaler"]
    BP --> Opt["AdamW<br/>lr=1e-3, wd=1e-5"]

    VL --> VFP["Val Forward Pass"]
    VFP --> VLoss["Val Loss"]
    VLoss --> Sched["ReduceLROnPlateau<br/>patience=15, factor=0.2"]
    VLoss --> Check{"val_loss<br/>improved?"}
    Check -- "Yes" --> Save["Save checkpoint<br/>+ ONNX export"]
    Check -- "No" --> Next["Next epoch"]
    Save --> Next

    Loss --> WB["W&B Logging<br/>(bce, iou, l1, total)"]

    style Data fill:#4a90d9,color:#fff
    style Save fill:#7cb342,color:#fff
    style WB fill:#ff9800,color:#fff
Loading

Encoder Freezing

When using pretrained backbones (MobileNetV2/V4), the encoder is frozen for the first N epochs (freeze_epochs=10 by default). This lets the decoder warm up before fine-tuning the encoder, preventing the pretrained features from being destroyed early in training.

Data Augmentation

Training uses the Albumentations library with synchronized image+mask transforms:

Transform Parameters
ShiftScaleRotate shift=0.1, scale=0.15, rotate=10°
HorizontalFlip p=0.5
RandomBrightnessContrast brightness=0.2, contrast=0.2
SmallestMaxSize + RandomCrop Aspect-preserving resize to 256x256
Normalize Scale to [0, 1]

Dataset Format

input/
  clip_img/{session_id}/clip_{xxxxx}/*.jpg    # RGB frames
  matting/{session_id}/matting_{xxxxx}/*.png   # RGBA (alpha channel = GT mask)
  • Frames are paired by filename stem (e.g. 1803151818-00000134)
  • Alpha channel: 1.0 = foreground (person), 0.0 = background
  • Format is compatible with the Background Matting dataset (Sengupta et al., 2020)

Inference

Inference runs via ONNX Runtime for maximum portability:

flowchart LR
    Cam["Webcam /<br/>Image"] --> Crop["Center Crop<br/>(square)"]
    Crop --> Resize["Resize<br/>256x256"]
    Resize --> ONNX["ONNX Runtime<br/>(CPU)"]
    ONNX --> Sigmoid["Sigmoid"]
    Sigmoid --> Matte["Alpha Matte"]
    Matte --> Comp["Composite over<br/>background color"]
    Comp --> Out["Display<br/>with FPS"]

    style ONNX fill:#4a90d9,color:#fff
    style Out fill:#7cb342,color:#fff
Loading

The live demo (src/inference.ipynb) supports:

  • Real-time webcam feed with FPS overlay
  • Background color cycling (press b)
  • FPS benchmarking (avg / min / max)

Tech Stack

Category Tools
Framework PyTorch 2.9+, torchvision, timm
Augmentation Albumentations
Experiment tracking Weights & Biases
Inference export ONNX, ONNX Runtime
Dev tooling uv, ruff
Visualization matplotlib, OpenCV

About

Collection of .onnx neural networks used for real-time background segmentation

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors