Real-time background removal for video calls and images using deep learning. This project provides fast and accurate alpha matte extraction with U-Net and MobileNet-based architectures.
The aim was for the network to be lightweight enough to run in real-time on most mobile phones.
This project uses uv for fast dependency management. Python 3.13+ is required.
uv sync --all-groupsBackgroundSegmentationUNet/
├── pyproject.toml
├── src/
│ ├── main.ipynb # Training pipeline
│ ├── inference.ipynb # ONNX inference & live webcam demo
│ ├── loss/
│ │ └── loss.py # CombinedLoss (BCE + IoU + L1)
│ └── models/
│ ├── unet.py # Vanilla U-Net (31M params)
│ ├── mobilenetv2.py # MobileNetV2 encoder + decoder (3.5M params)
│ ├── mobilenetv4.py # MobileNetV4 encoder + decoder
│ └── autoencoder.py # Restricted encoder-decoder (ONNX-friendly)
└── input/ # Dataset root (not tracked)
flowchart LR
A[RGB Image] --> B[Augmentation<br/>Albumentations]
B --> C[Encoder]
C --> D[Decoder]
D --> E[Sigmoid]
E --> F[Alpha Matte]
F --> G[Composite<br/>over background]
style A fill:#4a90d9,color:#fff
style F fill:#7cb342,color:#fff
style G fill:#7cb342,color:#fff
Classic encoder-decoder with skip connections. Each encoder stage halves spatial resolution while doubling channels; the decoder reverses this while fusing skip features.
flowchart TD
Input["Input (3 x 256 x 256)"] --> Inc["DoubleConv → 64"]
Inc --> D1["Down1: MaxPool → 128"]
D1 --> D2["Down2: MaxPool → 256"]
D2 --> D3["Down3: MaxPool → 512"]
D3 --> D4["Down4: MaxPool → 512<br/>(bottleneck)"]
D4 --> U1["Up1: Upsample → 256"]
U1 --> U2["Up2: Upsample → 128"]
U2 --> U3["Up3: Upsample → 64"]
U3 --> U4["Up4: Upsample → 64"]
U4 --> Out["Conv 1x1 → 1 channel<br/>(logits)"]
D3 -- "skip" --> U1
D2 -- "skip" --> U2
D1 -- "skip" --> U3
Inc -- "skip" --> U4
style D4 fill:#e57373,color:#fff
style Out fill:#7cb342,color:#fff
DoubleConv block — the fundamental building unit, used in every stage:
Conv2d(3x3) → BatchNorm → ReLU → Conv2d(3x3) → BatchNorm → ReLU
Key design choices:
- Bilinear upsampling (default) instead of transposed convolutions — avoids checkerboard artifacts
- Skip connections concatenate encoder features to preserve fine spatial detail
- No bias in convolutions (BatchNorm absorbs the bias term)
Replaces the U-Net encoder with a pretrained MobileNetV2 backbone. The decoder is a lightweight stack of transposed convolutions.
flowchart TD
Input["Input (3 x 256 x 256)"] --> Enc
subgraph Enc["MobileNetV2 Encoder (pretrained, ImageNet)"]
direction TB
E1["Depthwise Separable Convs"] --> E2["Inverted Residual Blocks"]
E2 --> E3["Output: 1280 x 8 x 8"]
end
Enc --> Dec
subgraph Dec["Custom Decoder"]
direction TB
U1["ConvTranspose 1280 → 96"] --> U2["ConvTranspose 96 → 32"]
U2 --> U3["ConvTranspose 32 → 24"]
U3 --> U4["ConvTranspose 24 → 16"]
end
Dec --> Final["Conv 1x1 → 1 channel"]
Final --> Interp["Bilinear interpolate<br/>to input size"]
style Enc fill:#4a90d9,color:#fff
style Dec fill:#ff9800,color:#fff
style Interp fill:#7cb342,color:#fff
Each decoder block: ConvTranspose2d(2x) → BN → ReLU → Conv2d → BN → ReLU
Same decoder design as V2, but uses MobileNetV4 Conv Small from timm (pretrained on ImageNet). Encoder outputs 960 channels instead of 1280.
flowchart LR
subgraph Encoder
A["MobileNetV4<br/>Conv Small<br/>(timm)"] --> B["960 x 8 x 8"]
end
subgraph Decoder
B --> C["960 → 96"]
C --> D["96 → 32"]
D --> E["32 → 24"]
E --> F["24 → 16"]
end
F --> G["Conv 1x1 → 1"]
G --> H["Bilinear interp"]
style Encoder fill:#4a90d9,color:#fff
style Decoder fill:#ff9800,color:#fff
A pure encoder-decoder with no skip connections and no residuals. Designed for maximum ONNX minimal operator set with intend for running it in own custom inference engine
flowchart TD
Input["Input (3 x 256 x 256)"] --> EB1
subgraph Encoder
EB1["Conv Block: 3 → 32, MaxPool"] --> EB2["Conv Block: 32 → 64, MaxPool"]
EB2 --> EB3["Conv Block: 64 → 128, MaxPool"]
EB3 --> EB4["Conv Block: 128 → 256, MaxPool"]
end
EB4 --> BN["Bottleneck<br/>256 → 512 → 256"]
BN --> DB4
subgraph Decoder
DB4["ConvTranspose 256 → 128"] --> DB3["ConvTranspose 128 → 64"]
DB3 --> DB2["ConvTranspose 64 → 32"]
DB2 --> DB1["ConvTranspose 32 → 16"]
end
DB1 --> Out["Conv 1x1 → 1"]
style BN fill:#e57373,color:#fff
style Out fill:#7cb342,color:#fff
Weighted combination of three complementary losses:
flowchart LR
Logits["Raw Logits"] --> BCE["BCE Loss<br/>weight: 1.0"]
Logits --> Sig["Sigmoid"]
Sig --> IoU["IoU Loss<br/>weight: 0.5"]
Sig --> L1["L1 Loss<br/>weight: 0.5"]
BCE --> Sum(("+"))
IoU --> Sum
L1 --> Sum
Sum --> Total["Total Loss"]
style Total fill:#e57373,color:#fff
| Component | Operates On | Purpose |
|---|---|---|
| BCE | Raw logits | Pixel-wise classification accuracy (numerically stable) |
| IoU | Sigmoid probs | Penalises poor overlap — sharpens boundaries |
| L1 | Sigmoid probs | Encourages smooth, artifact-free gradients |
Loss = 1.0 * BCE + 0.5 * IoU + 0.5 * L1
flowchart TD
Data["Dataset<br/>clip_img + matting"] --> Split["90/10 Train/Val Split<br/>(seed=42)"]
Split --> TL["Train Loader<br/>batch=16, workers=4"]
Split --> VL["Val Loader<br/>batch=16, workers=4"]
TL --> Aug["Augmentation<br/>(Albumentations)"]
Aug --> FP["Forward Pass<br/>(AMP mixed precision)"]
FP --> Loss["CombinedLoss"]
Loss --> BP["Backward + GradScaler"]
BP --> Opt["AdamW<br/>lr=1e-3, wd=1e-5"]
VL --> VFP["Val Forward Pass"]
VFP --> VLoss["Val Loss"]
VLoss --> Sched["ReduceLROnPlateau<br/>patience=15, factor=0.2"]
VLoss --> Check{"val_loss<br/>improved?"}
Check -- "Yes" --> Save["Save checkpoint<br/>+ ONNX export"]
Check -- "No" --> Next["Next epoch"]
Save --> Next
Loss --> WB["W&B Logging<br/>(bce, iou, l1, total)"]
style Data fill:#4a90d9,color:#fff
style Save fill:#7cb342,color:#fff
style WB fill:#ff9800,color:#fff
When using pretrained backbones (MobileNetV2/V4), the encoder is frozen for the first N epochs (freeze_epochs=10 by default). This lets the decoder warm up before fine-tuning the encoder, preventing the pretrained features from being destroyed early in training.
Training uses the Albumentations library with synchronized image+mask transforms:
| Transform | Parameters |
|---|---|
| ShiftScaleRotate | shift=0.1, scale=0.15, rotate=10° |
| HorizontalFlip | p=0.5 |
| RandomBrightnessContrast | brightness=0.2, contrast=0.2 |
| SmallestMaxSize + RandomCrop | Aspect-preserving resize to 256x256 |
| Normalize | Scale to [0, 1] |
input/
clip_img/{session_id}/clip_{xxxxx}/*.jpg # RGB frames
matting/{session_id}/matting_{xxxxx}/*.png # RGBA (alpha channel = GT mask)
- Frames are paired by filename stem (e.g.
1803151818-00000134) - Alpha channel:
1.0= foreground (person),0.0= background - Format is compatible with the Background Matting dataset (Sengupta et al., 2020)
Inference runs via ONNX Runtime for maximum portability:
flowchart LR
Cam["Webcam /<br/>Image"] --> Crop["Center Crop<br/>(square)"]
Crop --> Resize["Resize<br/>256x256"]
Resize --> ONNX["ONNX Runtime<br/>(CPU)"]
ONNX --> Sigmoid["Sigmoid"]
Sigmoid --> Matte["Alpha Matte"]
Matte --> Comp["Composite over<br/>background color"]
Comp --> Out["Display<br/>with FPS"]
style ONNX fill:#4a90d9,color:#fff
style Out fill:#7cb342,color:#fff
The live demo (src/inference.ipynb) supports:
- Real-time webcam feed with FPS overlay
- Background color cycling (press
b) - FPS benchmarking (avg / min / max)
| Category | Tools |
|---|---|
| Framework | PyTorch 2.9+, torchvision, timm |
| Augmentation | Albumentations |
| Experiment tracking | Weights & Biases |
| Inference export | ONNX, ONNX Runtime |
| Dev tooling | uv, ruff |
| Visualization | matplotlib, OpenCV |