Skip to content
/ gcage Public

G-Cage: GradCAM-Augmented Generation Examiner

Notifications You must be signed in to change notification settings

phucd5/gcage

Repository files navigation

GradCAM-Augmented Generation Examiner (G-CAGE)

Author

David Dong, Phuc Duong, Steven Zhou

CPSC 4710/5710: Trustworthy Deep Learning

Yale University, Department of Computer Science

Overview

The increasing realism of AI generated photo content has become a major challenge in modern day media, and it has becoming increasingly difficult for companies to verify the validity of content passing through their systems. As such, there have been recent developments to create models to detect this AI generated content in an effort to distinguish between real and fake images. However, these methods often lack explainability and thus aren't usable in real world situations. We propose using Grad-CAM as a feature gate. Specifically, we apply Grad-CAM to our base detection model (SuSy) to generate heatmaps highlighting synthetic artifacts. We then concatenate this heatmap with the original image patch and train a lightweight convolutional network to make the final prediction. Our results demonstrated improved performance while adding a layer of explainability.

architecture

Setup

Main Dependencies

  • torch 2.9.1
  • torchvision 0.24.1
  • datasets 4.4.1
  • sckit-learn 1.7.2
  • scipy 1.16.3
  • pillow 12.0.0
  • tensorboard 2.20.0

A full list of the dependencies can be found in requirements.txt.

To install the dependencies, please do the following command in the root directory g-cage.

pip install -e .
pip install -r requirements.txt

Datasets

We used the WildFake dataset and ImageNet. Scripts to organize and format can be found in scripts/data_loading. Our post-processed dataset can be found in here.

Relevant Files

  • utils/gcage_dataset.py: Contain the DataLoader for GCageTrainDataset and GCageEvalDataset. GCageTrainDataset has an extra column heatmap_path for the heatmap.
  • utils/susy_dataset.py: Contain the DataLoader for evaluating the baseline Susy model.

Model

For our baseline we used the SuSy. We loaded the SuSy model from a JIT checkpoint provided by the authors into a regular nn.Module (required for GradCAM hooks). The implementation can be found in synthethic_detectors/susy.

For our GCage model we implemented a CNN with 4 conv layers, a global average pooling layer, and a fully connected layer as described in the paper. The implementation can be found in synthethic_detectors/cnn.

Our SynthethicDetector class is a wrapper around the baseline and GCage models for easy evaluation and training. Relevant files are in synthethic_detectors/synthethic_detector.

GradCAM

GradCAM implementation can be found in gradcam/.

Visualizing the heatmap overlay for a single image

python gradcam/visualize.py \
    --image_path <path_to_image> \
    --model_path models/baseline_susy.pt \
    --output <path_to_output>

Baseline

Baseline model from SuSy is saved in models/baseline_susy.pt.

Running the baseline evaluation with the SuSy model

python eval/baseline_eval.py \
    --model_path models/baseline_susy.pt \
    --data_path <path_to_data> \
    --output_file <path_to_output_file>

Running our hyperparameter search for top_k patches

python eval/hyperparam_search.py \
    --model_path models/baseline_susy.pt \
    --data_path <path_to_data> \
    --output_file <path_to_output_file>

GCage

Retreiving Heatmaps

To retreive the heatmaps from GradCAM for all the images in the dataset

python gradcam/generate_gradcam_maps.py \
    --model_path models/baseline_susy.pt \
    --data_dir <path_to_data_dir> \
    --output_dir <path_to_output_dir>

Each image will be divided up into patches of 50x50 and we will generate the heatmap for each patch.

Training

Training the GCage model

# ensure .pt at the end of the save path
python scripts/train_gcage.py \
  --data_path <path_to_data> \
  --val_path <path_to_val_data> \
  --save_path <path_to_save_path>.pt \
  --log_dir <path_to_log_dir> \
  --epochs <num_epochs> \
  --batch_size <batch_size> \
  --lr <learning_rate>

Log_dir will have the TensorBoard logs for the training process.

Evaluation

Evaluating the GCage model

python eval/gcage_eval.py \
    --model_path <path_to_model> \
    --data_path <path_to_data> \
    --output_file <path_to_output_file>

Results

All results discussed in the paper can be found in the results directory. The model we trained on is saved in models/gcage_final.pt.

About

G-Cage: GradCAM-Augmented Generation Examiner

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages