Automated wafer-map defect pattern classification using a pragmatic data pipeline + a compact DenseNet variant (“WaferNet”) with channel attention (ECA) and mask-aware dual pooling (GAP ∥ GMP).
This repo is designed to be reproducible and deployment-friendly: fixed input size, explicit wafer geometry masking, class-imbalance handling, and leakage-aware evaluation.
Given a wafer map (die-grid) from the WM-811K / LSWMD dataset, we:
-
Clean/standardize labels into 9 classes:
Center, Donut, Edge-Loc, Edge-Ring, Loc, Random, Scratch, Near-Full, None -
Convert each sample into a 2-channel tensor of shape [2, 96, 96]:
-
Channel 0 (wafer map): values
{0,1,2} → {0.0, 0.5, 1.0}0= background / no-die1= pass2= fail
-
Channel 1 (mask): binary wafer geometry mask
{0,1}(valid die sites vs background)
-
-
Train WaferNet to output a 9-logit vector per wafer; apply softmax for probabilities.
-
Optionally generate Grad-CAM explanations constrained to the wafer region.
Wafer maps have:
- variable shapes/sizes (different die grids),
- lots of “empty” background,
- defects whose geometry matters.
Instead of warping wafers via resizing/cropping, we standardize to 96×96 and provide an explicit mask so the model can ignore padded/background regions during pooling and explanation. This is a central design choice in WaferNet.
Backbone: DenseNet-121 (trained from scratch; first conv modified to accept 2 channels). Attention: Efficient Channel Attention (ECA) after the final DenseNet feature map BN+ReLU. Head: Mask-aware pooling on downsampled mask:
- masked Global Average Pooling (GAP)
- masked Global Max Pooling (GMP)
- concatenate → 2048-d, then Linear → 9 logits
This keeps the model lightweight (~7M params) and robust to background/padding.
Typical training settings used in the project:
- Optimizer: AdamW (lr
3e-4, weight decay1e-4) - Batch size:
64 - Epochs:
30 - Loss: Logit-Adjusted Cross Entropy (helps long-tail class imbalance)
Hardware support:
- CUDA (NVIDIA) or Apple Silicon MPS, with CPU fallback.
Only use labeled wafers (unlabeled removed).
- If
H ≤ 96andW ≤ 96: zero-pad to 96×96 (no warping). - If larger than 96 in either dimension: exclude (to avoid resizing artifacts).
Wafers are associated with manufacturing lots; random splitting can leak lot-specific artifacts.
- Prefer lot-disjoint splits (no lot appears in both train and test), while keeping class ratios.
- If lot metadata is missing, fall back to standard stratified split.
-
Cap None to ~25,000 samples
-
Up-sample minority classes to ~3,000 using label-preserving transforms
- 90° rotations and flips
Validation/test are left untouched to avoid “cheating by augmentation”.
If your current code structure differs, adapt these names—this layout is the clean “canonical” version.
.
├── data/
│ ├── raw/ # LSWMD.pkl (or extracted files)
│ ├── processed/ # cached tensors, manifests, splits
│ └── manifests/ # split indices, lot ids, seeds
├── configs/
│ └── wafernet.yaml
├── src/
│ ├── data/
│ │ ├── make_splits.py
│ │ ├── preprocess.py
│ │ └── dataset.py
│ ├── models/
│ │ ├── wafernet.py
│ │ └── eca.py
│ ├── train.py
│ ├── eval.py
│ └── explain.py # Grad-CAM utilities
├── scripts/
│ ├── download_data.md
│ └── run_experiments.sh
├── outputs/
│ ├── checkpoints/
│ ├── metrics/
│ └── gradcam/
├── requirements.txt
└── README.md
Use Python 3.10+ (recommended).
Example (venv):
python -m venv .venv
source .venv/bin/activate
pip install -U pip
pip install -r requirements.txt- torch, torchvision
- numpy, pandas
- scikit-learn (metrics/splits)
- matplotlib (plots)
- opencv-python or pillow (image ops)
- tqdm, pyyaml
- (optional) streamlit/gradio for a demo UI
A common distribution is the Kaggle dataset containing LSWMD.pkl.
Place it here:
data/raw/LSWMD.pkl
python -m src.data.preprocess \
--input data/raw/LSWMD.pkl \
--output data/processed \
--size 96python -m src.data.make_splits \
--processed data/processed \
--out data/manifests \
--split 0.70 0.15 0.15 \
--seed 42 \
--lot_disjoint truepython -m src.train \
--config configs/wafernet.yaml \
--manifests data/manifests \
--out outputs \
--seed 42python -m src.eval \
--ckpt outputs/checkpoints/best.pt \
--manifests data/manifests \
--out outputs/metricspython -m src.explain \
--ckpt outputs/checkpoints/best.pt \
--sample_id <ID> \
--out outputs/gradcamThis project’s philosophy: if you can’t reproduce it, it didn’t happen.
Minimum requirements:
-
Fix seeds across
random,numpy, andtorch -
Use deterministic settings where possible
-
Save:
- exact train/val/test indices
- lot IDs (if available)
- augmentation recipe
- class caps / target counts
- library versions + git commit hash.