David Dong, Phuc Duong, Steven Zhou
CPSC 4710/5710: Trustworthy Deep Learning
Yale University, Department of Computer Science
The increasing realism of AI generated photo content has become a major challenge in modern day media, and it has becoming increasingly difficult for companies to verify the validity of content passing through their systems. As such, there have been recent developments to create models to detect this AI generated content in an effort to distinguish between real and fake images. However, these methods often lack explainability and thus aren't usable in real world situations. We propose using Grad-CAM as a feature gate. Specifically, we apply Grad-CAM to our base detection model (SuSy) to generate heatmaps highlighting synthetic artifacts. We then concatenate this heatmap with the original image patch and train a lightweight convolutional network to make the final prediction. Our results demonstrated improved performance while adding a layer of explainability.
Main Dependencies
- torch 2.9.1
- torchvision 0.24.1
- datasets 4.4.1
- sckit-learn 1.7.2
- scipy 1.16.3
- pillow 12.0.0
- tensorboard 2.20.0
A full list of the dependencies can be found in
requirements.txt.
To install the dependencies, please do the following command in the root directory g-cage.
pip install -e .
pip install -r requirements.txtWe used the WildFake dataset and ImageNet. Scripts to organize and format can be found in scripts/data_loading. Our post-processed dataset can be found in here.
Relevant Files
- utils/gcage_dataset.py: Contain the DataLoader for
GCageTrainDatasetandGCageEvalDataset. GCageTrainDataset has an extra columnheatmap_pathfor the heatmap. - utils/susy_dataset.py: Contain the DataLoader for evaluating the baseline Susy model.
For our baseline we used the SuSy. We loaded the SuSy model from a JIT checkpoint provided by the authors into a regular nn.Module (required for GradCAM hooks). The implementation can be found in synthethic_detectors/susy.
For our GCage model we implemented a CNN with 4 conv layers, a global average pooling layer, and a fully connected layer as described in the paper. The implementation can be found in synthethic_detectors/cnn.
Our SynthethicDetector class is a wrapper around the baseline and GCage models for easy evaluation and training. Relevant files are in synthethic_detectors/synthethic_detector.
GradCAM implementation can be found in gradcam/.
Visualizing the heatmap overlay for a single image
python gradcam/visualize.py \
--image_path <path_to_image> \
--model_path models/baseline_susy.pt \
--output <path_to_output>Baseline model from SuSy is saved in models/baseline_susy.pt.
Running the baseline evaluation with the SuSy model
python eval/baseline_eval.py \
--model_path models/baseline_susy.pt \
--data_path <path_to_data> \
--output_file <path_to_output_file>Running our hyperparameter search for top_k patches
python eval/hyperparam_search.py \
--model_path models/baseline_susy.pt \
--data_path <path_to_data> \
--output_file <path_to_output_file>To retreive the heatmaps from GradCAM for all the images in the dataset
python gradcam/generate_gradcam_maps.py \
--model_path models/baseline_susy.pt \
--data_dir <path_to_data_dir> \
--output_dir <path_to_output_dir>Each image will be divided up into patches of 50x50 and we will generate the heatmap for each patch.
Training the GCage model
# ensure .pt at the end of the save path
python scripts/train_gcage.py \
--data_path <path_to_data> \
--val_path <path_to_val_data> \
--save_path <path_to_save_path>.pt \
--log_dir <path_to_log_dir> \
--epochs <num_epochs> \
--batch_size <batch_size> \
--lr <learning_rate>Log_dir will have the TensorBoard logs for the training process.
Evaluating the GCage model
python eval/gcage_eval.py \
--model_path <path_to_model> \
--data_path <path_to_data> \
--output_file <path_to_output_file>All results discussed in the paper can be found in the results directory. The model we trained on is saved in models/gcage_final.pt.
