Skip to content

AmayaGS/ProtoPathway

Repository files navigation

ProtoPathway

Interpretable-by-design multimodal cancer survival prediction

ProtoPathway fuses whole slide image (WSI) morphology with bulk transcriptomics through semantically grounded representations. A bipartite graph neural network encodes gene expression over a Reactome and MSigDB Hallmark pathway graph, a prototype-based MIL encoder compresses gigapixel slides into a fixed set of learned morphological tokens, and asymmetric cross-attention lets the prototypes query the pathways. Every component is interpretable: the gene encoder exposes gene-pathway attention, the WSI encoder exposes patch-prototype assignments, and the fusion stage exposes a prototype-pathway attention matrix.

Paper Hugging Face License Python PyTorch

Highlights

  • Compact and fast. 480K parameters, 3.9G FLOPs, 13.6 ms per patient. Between 28 and 50 times faster than attention-based multimodal baselines (MCAT, SurvPath, MMP) thanks to the K=16 prototype bottleneck.
  • Strong performance. Competitive C-index of 0.670 across five TCGA cohorts, ahead of MCAT (0.662), SurvPath (0.660), and MMP (0.659).
  • Interpretable. Every stage exposes a structured attention signal: gene-pathway attention, pathway gates, patch-prototype assignments, and a prototype-pathway cross-modal matrix.
  • Validated on five TCGA cohorts. BRCA (N=714), BLCA (N=359), COADREAD (N=227), HNSC (N=392), and STAD (N=318), for a total of N=2,010 patients.

Architecture

ProtoPathway has three components:

Gene encoder. A bipartite graph over 662 pathways and 4,574 genes (17,275 edges) drawn from Reactome and MSigDB Hallmark gene sets. Early layers use GraphSAGE with mean aggregation for stability under noisy survival supervision. The final layer is GATv2, which yields interpretable gene-pathway attention weights.

WSI encoder. PrototypeMIL with K=16 learned morphological prototypes. Patch features are softly assigned to prototypes by cosine similarity at temperature τ=0.1, then aggregated into a fixed set of K token embeddings.

Fusion. Asymmetric cross-attention where the prototypes query the pathways (A ∈ R^{K×P}). A three-gate combination merges a pathway-only stream, a raw-prototype stream, and a cross-attended-prototype stream before the survival head.

ProtoPathway pipeline

Installation

First clone the repository to the desired location and enter the directory:

# clone project to desired location
git clone https://github.com/AmayaGS/ProtoPathway
cd ProtoPathway

Then create a virtual environment and install the requirements.txt

General Requirements

  • Python 3.11.7
  • PyTorch 2.5
  • NVIDIA GPU with CUDA 12.4
# Virtual Environment
python -m venv protopath
source protopath/bin/activate

# PyTorch with cuda capabilities
pip install torch==2.5.1 torchvision==0.20.1 --index-url https://download.pytorch.org/whl/cu124

pip install -r requirements.txt

Data

All preprocessed cohort data, trained checkpoints, the curated pathway graph, and the raw inputs used for preprocessing are mirrored on Hugging Face:

AmayaGS/ProtoPathway

pathways/pathways_base_*.pkl           curated Reactome + Hallmark pathway graph
raw_inputs/                            raw files for re-running preprocessing
    Reactome/                          Reactome hierarchy files
    Hallmark/                          MSigDB Hallmark gene sets
    {cohort}/                          rna_clean.csv, clinical CSV, SurvPath splits
cohorts/{cohort}/                      preprocessed data and trained models
    gene_expression.csv                preprocessed expression matrix
    bipartite_graph.pt                 cohort-specific gene-pathway graph
    labels.csv                         survival times, events, and discretized bins
    data_splits.pkl                    5-fold CV splits (SurvPath-compatible)
    checkpoints/best_fold_{0..4}.pt    trained model weights

Quick download (one cohort)

from huggingface_hub import snapshot_download
 
snapshot_download(
    repo_id="AmayaGS/ProtoPathway",
    local_dir="./assets",
    allow_patterns=["cohorts/TCGA-BLCA/*", "pathways/pathways_base_*"],
)

This is sufficient for evaluating BLCA out of the box. Substitute any cohort name (TCGA-BRCA, TCGA-BLCA, TCGA-COADREAD, TCGA-HNSC, TCGA-STAD).

WSI patch features

UNI2-h patch features are the only assets not redistributed here. Obtain them directly from the Mahmood Lab and extract per-WSI features for the TCGA slides in your chosen cohorts. Slide-level features (one .pt per WSI) are referenced via the manifest path in configs/data.yaml.

Optional: re-running preprocessing from scratch

If you want different gene-set filters, alternative pathway sources, or custom splits, you can rebuild everything from the raw inputs. All raw files used for the published checkpoints are mirrored on the Hub under raw_inputs/, in the layout the preprocessing configs expect:

huggingface-cli download AmayaGS/ProtoPathway \
    --include "raw_inputs/Reactome/*" "raw_inputs/Hallmark/*" "raw_inputs/TCGA-BLCA/*" \
    --local-dir ./data

Set paths.base_data_dir in configs/preprocessing/*.yaml to ./data/raw_inputs, then run:

python main.py preprocess all

Quick start

All commands are run through the main.py entry point.

Use pretrained checkpoints

To evaluate or visualize the published model without re-running preprocessing or training:

  1. Download the cohort and checkpoints from the Hub:
   from huggingface_hub import snapshot_download

   snapshot_download(
       repo_id="AmayaGS/ProtoPathway",
       local_dir="./assets",
       allow_patterns=["cohorts/TCGA-BLCA/*", "pathways/pathways_base_*"],
   )
  1. Extract UNI2-h patch features for the TCGA-BLCA slides (see WSI patch features above). This is the only step that cannot be skipped, since the patch features are not redistributed.

  2. Point configs/experiments/experiment.yaml at your local copies of ./assets/cohorts/TCGA-BLCA/, ./assets/pathways/, and the UNI2-h features.

  3. Evaluate and visualize:

   python main.py evaluate --checkpoint-dir ./assets/cohorts/TCGA-BLCA/checkpoints
   python main.py visualize --eval-dir ./assets/cohorts/TCGA-BLCA/checkpoints/evaluation

The numbered steps below cover the full pipeline if you want to retrain or change the preprocessing.

1. Edit the configs

The paths block at the top of each YAML file in configs/ controls where data is read from and where outputs are written:

paths:
  base_data_dir: /path/to/data       # raw or preprocessed cohort data
  output_dir:    /path/to/results    # training and evaluation outputs

The preprocessing configs (configs/preprocessing/*.yaml) read raw inputs from base_data_dir. The experiment config (configs/experiments/experiment.yaml) reads the preprocessed inputs and writes checkpoints, predictions, and attention exports under output_dir/{cohort}/{experiment_name}/.

2. Preprocess

# One-time: curate Reactome + Hallmark pathways
python main.py preprocess pathways --config configs/preprocessing/preprocess_pathways.yaml
 
# Per cohort: gene expression, WSI features, and splits
python main.py preprocess genes  --config configs/preprocessing/preprocess_genes.yaml  dataset=TCGA-BLCA
python main.py preprocess wsi    --config configs/preprocessing/preprocess_wsi.yaml    dataset=TCGA-BLCA
python main.py preprocess splits --config configs/preprocessing/create_splits.yaml     dataset=TCGA-BLCA

Or run all four steps at once with the default config locations:

python main.py preprocess all

3. Train

python main.py train --config configs/experiments/experiment.yaml

CLI overrides use OmegaConf dot syntax:

python main.py train --config configs/experiments/experiment.yaml \
    dataset=TCGA-BRCA \
    model.fusion.type=cross_attention \
    training.learning_rate=1e-5

Baselines share the same entry point:

# Unimodal WSI
python main.py train --config configs/experiments/experiment.yaml model.name=abmil
python main.py train --config configs/experiments/experiment.yaml model.name=transmil
 
# Unimodal gene expression
python main.py train --config configs/experiments/experiment.yaml model.name=snn
 
# Multimodal baselines
python main.py train --config configs/experiments/experiment.yaml model.name=mcat
python main.py train --config configs/experiments/experiment.yaml model.name=survpath
python main.py train --config configs/experiments/experiment.yaml model.name=mmp

4. Evaluate

python main.py evaluate --checkpoint-dir results/TCGA-BLCA/<experiment_name>

This loads every fold checkpoint, computes the C-index, saves patient-level predictions, and exports attention weights for interpretability.

5. Visualize

python main.py visualize --eval-dir results/TCGA-BLCA/<experiment_name>/evaluation

To include spatial overlays on WSIs:

python main.py visualize \
    --eval-dir results/TCGA-BLCA/<experiment_name>/evaluation \
    --wsi-features-dir processed/TCGA-BLCA/wsi_features_per_patient \
    --wsi-dir /path/to/svs_files \
    --fold 1

For a single patient:

python main.py visualize \
    --eval-dir results/TCGA-BLCA/<experiment_name>/evaluation \
    --wsi-features-dir processed/TCGA-BLCA/wsi_features_per_patient \
    --patient TCGA-FD-A3B4 \
    --fold 1

6. Profile efficiency

python main.py profile --config configs/experiments/experiment.yaml --num-patients 30

Reports parameter count, FLOPs, peak VRAM, and training and inference time per patient.


Project structure

ProtoPathway/
├── configs/
│   ├── preprocessing/         pathway, gene, WSI, and split preprocessing
│   └── experiments/           training and evaluation
├── preprocessing/             data preparation pipelines
├── models/
│   ├── protopath.py           the main model
│   ├── protopath_components/  gene encoder, WSI encoder, fusion
│   ├── baselines/             ABMIL, TransMIL, SNN, MCAT, MMP, PIBD, ...
│   └── factory.py             unified model builder
├── experiments/
│   ├── train.py
│   ├── evaluate.py
│   └── visualize.py           interpretability pipeline
├── utils/
│   ├── analysis/              cross-fold pooling, rank-based statistics
│   └── visualization/         KM curves, heatmaps, spatial overlays
│   └── upload_to_hf.py        push checkpoints and assets to the Hub
└── main.py                    single CLI entry point

Citation

If you use this code in your research, please cite:

@article{protopathway2026,
  title   = {ProtoPathway: Biologically Structured Prototype-Pathway Fusion for Multimodal Cancer Survival Prediction},
  author  = {Amaya Gallagher-Syed, Costantino Pitzalis, Myles J. Lewis, Michael R. 
  Barnes, Gregory Slabaugh},
  journal = {arXiv preprint arXiv:2605.21454},
  year    = {2026},
}

Acknowledgments

License

This project is released under the MIT License. See LICENSE for details.

About

Interpretable multimodal cancer survival prediction from whole slide images and bulk transcriptomics, with prototype-pathway cross-modal attention.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages