This project aims to mitigate demographic bias (e.g., race, gender) in medical image diagnosis using causal modeling and adversarial training.
It supports datasets including:
- MIMIC-CXR
- CheXpert
- TCGA-LUAD
adv.py: Main script for training and testing fairness-aware models with adversarial perturbationlib/: Contains model architectures (ResNet, ViT, Generator, Discriminator)import_datasets.py: Data loading utilities for supported datasetsgrad_rollout.py: GradCAM-based visualization and mask generationimages/: Saves masked and perturbed image visualizations
conda create -n fair-diagnosis python=3.8
conda activate fair-diagnosisMake sure you also have pytorch, torchvision, and pytorch-grad-cam installed.
Place your datasets in the datasets/ directory with the following structure:
datasets/
├── mimic-cxr/
│ └── metadata.csv
├── CheXpert-v1.0-small/
│ └── train.csv / valid.csv
├── TCGA-LUAD/
│ └── metadata.csv
Example: Train with adversarial fairness on MIMIC-CXR
python adv.py --dataset mimic-cxr --model resnet18 --feature mimic_exp \
--alpha 0.8 --beta 1.0 --noise_strength 0.1 \
--num_epochs 10 --visualizeOptional flags:
--debug: Use a small subset of data--no_adversarial: Disable adversarial training--no_cam: Disable GradCAM mask guidance--visualize: Save masked output images toimages/
The framework logs both performance and fairness metrics:
- Accuracy, AUC, F1 Score
- DP (Demographic Parity) and EO (Equalized Odds)
- ADF (Approximate Diagnosis Fairness)
Adversarial masks are guided using GradCAM to focus on zones important for:
- Diagnosis prediction (
Ŷ) - Sensitive attribute prediction (
Ŝ)
The resulting heatmaps help interpret which regions are being masked or preserved.