From fffa0ae4a8aa54d2ca0f6e57b7ed36e0e2ad99fe Mon Sep 17 00:00:00 2001 From: JudithBernett Date: Fri, 6 Mar 2026 14:54:50 +0100 Subject: [PATCH 1/4] First pitch of the PPI graph GNN. Also fixing a bug in drug gnn --- drevalpy/cli.py | 3 + .../datasets/featurizer/create_ppi_graphs.py | 158 ++++ drevalpy/models/DrugGNN/drug_gnn.py | 71 +- drevalpy/models/PPIGraphGNN/__init__.py | 5 + .../models/PPIGraphGNN/hyperparameters.yaml | 14 + drevalpy/models/PPIGraphGNN/ppi_graph_gnn.py | 704 ++++++++++++++++++ drevalpy/models/__init__.py | 3 + 7 files changed, 931 insertions(+), 27 deletions(-) create mode 100644 drevalpy/datasets/featurizer/create_ppi_graphs.py create mode 100644 drevalpy/models/PPIGraphGNN/__init__.py create mode 100644 drevalpy/models/PPIGraphGNN/hyperparameters.yaml create mode 100644 drevalpy/models/PPIGraphGNN/ppi_graph_gnn.py diff --git a/drevalpy/cli.py b/drevalpy/cli.py index 5602f3bd..db2bdf67 100644 --- a/drevalpy/cli.py +++ b/drevalpy/cli.py @@ -7,3 +7,6 @@ def cli_main(): """Command line interface entry point for the drug response evaluation pipeline.""" args = get_parser().parse_args() main(args) + +if __name__ == "__main__": + cli_main() \ No newline at end of file diff --git a/drevalpy/datasets/featurizer/create_ppi_graphs.py b/drevalpy/datasets/featurizer/create_ppi_graphs.py new file mode 100644 index 00000000..09e8ee80 --- /dev/null +++ b/drevalpy/datasets/featurizer/create_ppi_graphs.py @@ -0,0 +1,158 @@ +""" +Preprocesses PPI network CSV files into graph representations for PPIGraphGNN. + +This script takes a dataset name as input, reads the corresponding +PPI network CSV file, and converts it into a torch_geometric.data.Data object. +The PPI CSV should have columns: gene_id_1, gene_id_2, and optionally interaction_score. +""" + +import argparse +from pathlib import Path + +import numpy as np +import pandas as pd +import torch +from torch_geometric.data import Data + + +def _load_ppi_network(ppi_file: Path, gene_list_file: Path) -> Data: + """ + Load PPI network from CSV and create a PyTorch Geometric Data object. + + The gene order in the PPI graph will match the order in the gene list file. + This ensures consistency when gene expression features are set as node features. + + :param ppi_file: Path to the PPI network CSV file with columns [gene_id_1, gene_id_2, (optional) interaction_score] + :param gene_list_file: Path to the gene list CSV (e.g., landmark_genes_reduced.csv) that defines gene order + :return: A Data object representing the PPI network graph + """ + # Load the gene list to get the ordered list of genes (same as will be used for gene expression) + gene_list_df = pd.read_csv(gene_list_file) + if "Symbol" in gene_list_df.columns: + genes = gene_list_df["Symbol"].tolist() + elif "gene" in gene_list_df.columns: + genes = gene_list_df["gene"].tolist() + else: + # Gene expression file -> columns are genes + genes = gene_list_df.columns + genes = [g for g in genes if g not in ["cellosaurus_id", "cell_line_name"]] + + # Create a mapping from gene name to index + gene_to_idx = {gene: idx for idx, gene in enumerate(genes)} + + # Load PPI network + ppi_df = pd.read_csv(ppi_file) + + # Validate columns + required_cols = {"gene_id_1", "gene_id_2"} + if not required_cols.issubset(ppi_df.columns): + raise ValueError( + f"PPI CSV must contain columns 'gene_id_1' and 'gene_id_2'. Found: {ppi_df.columns.tolist()}" + ) + + # Build edge list (only include genes that exist in gene expression) + edge_list = [] + edge_weights = [] + + has_weights = "interaction_score" in ppi_df.columns + + for _, row in ppi_df.iterrows(): + gene1 = str(row["gene_id_1"]) + gene2 = str(row["gene_id_2"]) + + # Only add edge if both genes exist in gene expression data + if gene1 in gene_to_idx and gene2 in gene_to_idx: + idx1 = gene_to_idx[gene1] + idx2 = gene_to_idx[gene2] + + # Add both directions for undirected graph + edge_list.append([idx1, idx2]) + edge_list.append([idx2, idx1]) + + if has_weights: + weight = float(row["interaction_score"]) + edge_weights.extend([weight, weight]) + + if not edge_list: + raise ValueError("No valid edges found in PPI network (genes don't match gene expression)") + + # Convert to tensors + edge_index = torch.tensor(edge_list, dtype=torch.long).t().contiguous() + + # Create node feature placeholder (will be filled with gene expression at runtime) + num_nodes = len(genes) + x = torch.zeros((num_nodes, 1), dtype=torch.float) + + # Edge attributes + if has_weights: + edge_attr = torch.tensor(edge_weights, dtype=torch.float).view(-1, 1) + else: + edge_attr = None + + # Store gene names as metadata + graph = Data(x=x, edge_index=edge_index, edge_attr=edge_attr) + graph.gene_names = genes # Store for reference + + return graph + + +def main(): + """Main function to run the PPI preprocessing.""" + parser = argparse.ArgumentParser(description="Preprocess PPI network to graph.") + parser.add_argument("dataset_name", type=str, help="The name of the dataset to process.") + parser.add_argument("--path_data", type=str, default="data", help="Path to the data folder") + parser.add_argument( + "--ppi_file", + type=str, + default=None, + help="Path to PPI CSV file (default: {path_data}/{dataset_name}/ppi_network.csv)", + ) + parser.add_argument( + "--gene_list", + type=str, + default="gene_expression.csv", + help="Gene list name to use (default: gene_expression.csv; will take the columns)", + ) + args = parser.parse_args() + + dataset_name = args.dataset_name + data_dir = Path(args.path_data).resolve() + + # Determine PPI file path + if args.ppi_file: + ppi_file = Path(args.ppi_file) + else: + ppi_file = data_dir / dataset_name / "ppi_network.csv" + + # Gene list file + gene_list_file = data_dir / dataset_name / f"{args.gene_list}" + output_file = data_dir / dataset_name / "ppi_graph.pt" + + if not ppi_file.exists(): + print(f"Error: {ppi_file} not found.") + return + + if not gene_list_file.exists(): + print(f"Error: {gene_list_file} not found.") + print(f"Available gene lists should be in {data_dir / 'meta' / 'gene_lists'}/") + return + + print(f"Processing PPI network for dataset {dataset_name}...") + print(f"Using gene list: {args.gene_list}") + + try: + graph = _load_ppi_network(ppi_file, gene_list_file) + torch.save(graph, output_file) + print(f"PPI graph saved to {output_file}") + print(f" Nodes (genes): {graph.num_nodes}") + print(f" Edges (interactions): {graph.num_edges}") + if graph.edge_attr is not None: + print(f" Edge attributes: Yes") + print(f"\nGene order matches: {args.gene_list}") + print(f"First 5 genes: {graph.gene_names[:5]}") + except Exception as e: + print(f"Error processing PPI network: {e}") + + +if __name__ == "__main__": + main() diff --git a/drevalpy/models/DrugGNN/drug_gnn.py b/drevalpy/models/DrugGNN/drug_gnn.py index 89f0fb7d..55e06721 100644 --- a/drevalpy/models/DrugGNN/drug_gnn.py +++ b/drevalpy/models/DrugGNN/drug_gnn.py @@ -456,41 +456,58 @@ def load_drug_features(self, data_path: str, dataset_name: str) -> FeatureDatase return FeatureDataset(features=feature_dict) - def save_model(self, path: str | Path, drug_name=None): - """Save the model. + def save(self, directory: str) -> None: + """ + Save the trained model, hyperparameters, and gene expression scaler to the given directory. + + This enables full reconstruction of the model using `load`. + + Files saved: + - model.pt: PyTorch state_dict of the trained model + - hyperparameters.json: Dictionary containing all relevant model hyperparameters - :param path: The path to save the model to. - :param drug_name: The name of the drug. - :raises RuntimeError: If there is no model to save. + :param directory: Target directory to store all model artifacts """ - if self.model is None: - raise RuntimeError("No model to save.") - path = Path(path) + path = Path(directory) path.mkdir(parents=True, exist_ok=True) - trainer = pl.Trainer() - trainer.save_checkpoint(path / "model.ckpt", weights_only=True) + torch.save(self.model.state_dict(), path / "model.pt") # noqa: S614 + + with open(path / "hyperparameters.json", "w") as f: + json.dump(self.hyperparameters, f) - with open(path / "config.json", "w") as f: - json.dump(self.hyperparameters, f, indent=4) + @classmethod + def load(cls, directory: str) -> "DrugGNN": + """ + Load a trained DrugGNN model from the given directory. - def load_model(self, path: str | Path, drug_name=None): - """Load the model. + This includes: + - model.pt: PyTorch state_dict of the trained model + - hyperparameters.json: Dictionary containing all relevant model hyperparameters - :param path: The path to load the model from. - :param drug_name: The name of the drug. + :param directory: The path to load the model from. + :return: The loaded DrugGNN model. + :raises FileNotFoundError: If any of the required files are not found. """ - path = Path(path) + path = Path(directory) - config_path = path / "config.json" - with open(config_path) as f: - self.hyperparameters = json.load(f) + hpam_path = path / "hyperparameters.json" + model_path = path / "model.pt" + if not hpam_path.exists() or not model_path.exists(): + raise FileNotFoundError(f"Required files not found in {directory}.") - self.model = DrugGNNModule.load_from_checkpoint( - path / "model.ckpt", - num_node_features=self.hyperparameters["num_node_features"], - num_cell_features=self.hyperparameters["num_cell_features"], - hidden_dim=self.hyperparameters.get("hidden_dim", 64), - dropout=self.hyperparameters.get("dropout", 0.2), - learning_rate=self.hyperparameters.get("learning_rate", 0.001), + instance = cls() + + with open(hpam_path) as f: + instance.hyperparameters = json.load(f) + + instance.model = DrugGNNModule( + num_node_features=instance.hyperparameters["num_node_features"], + num_cell_features=instance.hyperparameters["num_cell_features"], + hidden_dim=instance.hyperparameters.get("hidden_dim", 64), + dropout=instance.hyperparameters.get("dropout", 0.2), + learning_rate=instance.hyperparameters.get("learning_rate", 0.001), ) + instance.model.load_state_dict(torch.load(model_path, weights_only=True)) + instance.model.eval() + return instance diff --git a/drevalpy/models/PPIGraphGNN/__init__.py b/drevalpy/models/PPIGraphGNN/__init__.py new file mode 100644 index 00000000..7a2de887 --- /dev/null +++ b/drevalpy/models/PPIGraphGNN/__init__.py @@ -0,0 +1,5 @@ +"""PPIGraphGNN model for drug response prediction using PPI networks.""" + +from .ppi_graph_gnn import PPIGraphGNN + +__all__ = ["PPIGraphGNN"] diff --git a/drevalpy/models/PPIGraphGNN/hyperparameters.yaml b/drevalpy/models/PPIGraphGNN/hyperparameters.yaml new file mode 100644 index 00000000..1ccef982 --- /dev/null +++ b/drevalpy/models/PPIGraphGNN/hyperparameters.yaml @@ -0,0 +1,14 @@ +PPIGraphGNN: + learning_rate: + - 0.001 + epochs: + - 10 + hidden_dim: + - 64 + num_gnn_layers: + - 2 + - 3 + dropout: + - 0.3 + batch_size: + - 32 diff --git a/drevalpy/models/PPIGraphGNN/ppi_graph_gnn.py b/drevalpy/models/PPIGraphGNN/ppi_graph_gnn.py new file mode 100644 index 00000000..5135d9be --- /dev/null +++ b/drevalpy/models/PPIGraphGNN/ppi_graph_gnn.py @@ -0,0 +1,704 @@ +"""PPIGraphGNN model for drug response prediction using PPI networks and GNNExplainer.""" + +import json +from pathlib import Path +from typing import Any + +import numpy as np +import pytorch_lightning as pl +import torch +import torch.nn as nn +from torch.optim import Adam +from torch.utils.data import Dataset as PytorchDataset +from torch_geometric.data import Data +from torch_geometric.explain import Explainer, GNNExplainer +from torch_geometric.loader import DataLoader +from torch_geometric.nn import GCNConv, global_mean_pool + +from ...datasets.dataset import DrugResponseDataset, FeatureDataset +from ..drp_model import DRPModel +from ..lightning_metrics_mixin import RegressionMetricsMixin +from ..utils import load_and_select_gene_features, load_drug_fingerprint_features + + +class PPIGraphNet(nn.Module): + """Graph Neural Network for processing PPI networks with gene expression and drug features.""" + + def __init__( + self, num_genes: int, num_drug_features: int, hidden_dim: int = 64, num_gnn_layers: int = 3, dropout: float = 0.2 + ): + """Initialize the network. + + :param num_genes: Number of genes (node features dimension). + :param num_drug_features: Number of drug features (e.g., fingerprint size). + :param hidden_dim: The hidden dimension size. + :param num_gnn_layers: Number of GNN layers. + :param dropout: The dropout rate. + """ + super().__init__() + self.dropout = dropout + self.num_gnn_layers = num_gnn_layers + + # GNN layers to process PPI graph with gene expression + self.gnn_layers = nn.ModuleList() + self.gnn_layers.append(GCNConv(1, hidden_dim)) + for _ in range(num_gnn_layers - 1): + self.gnn_layers.append(GCNConv(hidden_dim, hidden_dim)) + + # Drug encoder (MLP for drug fingerprints) + self.drug_fc1 = nn.Linear(num_drug_features, hidden_dim) + self.drug_fc2 = nn.Linear(hidden_dim, hidden_dim) + + # Combined prediction layers (PPI graph embedding + drug embedding) + self.combiner_fc1 = nn.Linear(hidden_dim * 2, hidden_dim) + self.combiner_fc2 = nn.Linear(hidden_dim, hidden_dim // 2) + self.output_fc = nn.Linear(hidden_dim // 2, 1) + + def forward(self, x, edge_index, batch, drug_features): + """Forward pass of the network. + + :param x: Node features (gene expression per node). + :param edge_index: Edge connectivity from PPI network. + :param batch: Batch assignment vector. + :param drug_features: Drug fingerprints or other drug features. + :return: Predicted drug response. + """ + # Process PPI graph through GNN layers + for i, gnn_layer in enumerate(self.gnn_layers): + x = gnn_layer(x, edge_index) + x = nn.functional.relu(x) + if i < len(self.gnn_layers) - 1: # Don't apply dropout after the last GNN layer + x = nn.functional.dropout(x, p=self.dropout, training=self.training) + + # Global pooling to get graph-level embedding + graph_embedding = global_mean_pool(x, batch) + + # Process drug features + drug_embedding = nn.functional.relu(self.drug_fc1(drug_features)) + drug_embedding = nn.functional.dropout(drug_embedding, p=self.dropout, training=self.training) + drug_embedding = self.drug_fc2(drug_embedding) + + # Combine graph embedding and drug embedding + combined = torch.cat([graph_embedding, drug_embedding], dim=1) + x = nn.functional.relu(self.combiner_fc1(combined)) + x = nn.functional.dropout(x, p=self.dropout, training=self.training) + x = nn.functional.relu(self.combiner_fc2(x)) + x = nn.functional.dropout(x, p=self.dropout, training=self.training) + out = self.output_fc(x) + return out.view(-1) + + +class PPIGraphGNNModule(RegressionMetricsMixin, pl.LightningModule): + """The LightningModule for the PPIGraphGNN model.""" + + def __init__( + self, + num_genes: int, + num_drug_features: int, + hidden_dim: int = 64, + num_gnn_layers: int = 3, + dropout: float = 0.2, + learning_rate: float = 0.001, + ): + """Initialize the LightningModule. + + :param num_genes: Number of genes in the gene expression data. + :param num_drug_features: Number of drug features. + :param hidden_dim: The hidden dimension size. + :param num_gnn_layers: Number of GNN layers. + :param dropout: The dropout rate. + :param learning_rate: The learning rate. + """ + super().__init__() + self.save_hyperparameters() + self.model = PPIGraphNet( + num_genes=self.hparams["num_genes"], + num_drug_features=self.hparams["num_drug_features"], + hidden_dim=self.hparams["hidden_dim"], + num_gnn_layers=self.hparams["num_gnn_layers"], + dropout=self.hparams["dropout"], + ) + self.criterion = nn.MSELoss() + + # Initialize metrics storage for epoch-end R^2 and PCC computation + self._init_metrics_storage() + + def forward(self, batch): + """Forward pass of the module. + + :param batch: The batch containing graph data, drug features, and responses. + :return: The output of the model. + """ + graph, drug_features, responses = batch + return self.model(graph.x, graph.edge_index, graph.batch, drug_features) + + def training_step(self, batch, batch_idx): + """A single training step. + + :param batch: The batch. + :param batch_idx: The batch index. + :return: The loss. + """ + graph, drug_features, responses = batch + outputs = self.model(graph.x, graph.edge_index, graph.batch, drug_features) + loss = self.criterion(outputs, responses) + self.log("train_loss", loss, on_step=False, on_epoch=True, batch_size=responses.size(0)) + + # Store predictions and targets for epoch-end metrics via mixin + self._store_predictions(outputs, responses, is_training=True) + + return loss + + def validation_step(self, batch, batch_idx): + """A single validation step. + + :param batch: The batch. + :param batch_idx: The batch index. + """ + graph, drug_features, responses = batch + outputs = self.model(graph.x, graph.edge_index, graph.batch, drug_features) + loss = self.criterion(outputs, responses) + self.log("val_loss", loss, on_step=False, on_epoch=True, batch_size=responses.size(0)) + + # Store predictions and targets for epoch-end metrics via mixin + self._store_predictions(outputs, responses, is_training=False) + + def predict_step(self, batch, batch_idx, dataloader_idx=0): + """A single prediction step. + + :param batch: The batch. + :param batch_idx: The batch index. + :param dataloader_idx: The dataloader index. + :return: The output of the model. + """ + return self.forward(batch) + + def configure_optimizers(self): + """Configure the optimizer. + + :return: The optimizer. + """ + return Adam(self.parameters(), lr=self.hparams.learning_rate) + + +class _PPIGraphDataset(PytorchDataset): + """A PyTorch Dataset to wrap PPI graphs with gene expression and drug features.""" + + def __init__( + self, + response: np.ndarray, + cell_line_ids: np.ndarray, + drug_ids: np.ndarray, + cell_line_features: FeatureDataset, + drug_features: FeatureDataset, + ppi_graph_template: Data, + ): + """Initialize the dataset. + + :param response: The drug response values. + :param cell_line_ids: The cell line IDs. + :param drug_ids: The drug IDs. + :param cell_line_features: A FeatureDataset object with cell line gene expression features. + :param drug_features: A FeatureDataset object with drug features (fingerprints). + :param ppi_graph_template: Template PPI graph (same structure for all samples). + """ + self.response = response + self.cell_line_ids = cell_line_ids + self.drug_ids = drug_ids + self.ppi_graph_template = ppi_graph_template + + # Preconvert gene expression to tensors + self.cell_features = { + cl_id: torch.tensor(features["gene_expression"], dtype=torch.float32) + for cl_id, features in cell_line_features.features.items() + } + + # Preconvert drug features to tensors + self.drug_features = { + drug_id: torch.tensor(features["fingerprints"], dtype=torch.float32) + for drug_id, features in drug_features.features.items() + } + + self.response_tensor = torch.tensor(self.response, dtype=torch.float32) + + def __len__(self): + return len(self.response) + + def __getitem__(self, idx): + cell_line_id = self.cell_line_ids[idx] + drug_id = self.drug_ids[idx] + + # Create a copy of the PPI graph and set node features to gene expression + graph = self.ppi_graph_template.clone() + gene_expr = self.cell_features[cell_line_id] + + # Set node features as gene expression values (expand dims to match expected shape) + graph.x = gene_expr.unsqueeze(1) + + # Get drug features + drug_feat = self.drug_features[drug_id] + + response = self.response_tensor[idx] + + return graph, drug_feat, response + + +class PPIGraphGNN(DRPModel): + """PPIGraphGNN model using PPI networks and gene expression with GNNExplainer support.""" + + def __init__(self): + """Initialize the PPIGraphGNN model.""" + super().__init__() + self.model: PPIGraphGNNModule | None = None + self.hyperparameters = {} + self.ppi_graph_template: Data | None = None + self.explainer: Explainer | None = None + + @classmethod + def get_model_name(cls) -> str: + """Return the name of the model. + + :return: The name of the model. + """ + return "PPIGraphGNN" + + @property + def cell_line_views(self) -> list[str]: + """Return the sources the model needs as input for describing the cell line. + + :return: The sources the model needs as input for describing the cell line. + """ + return ["gene_expression"] + + @property + def drug_views(self) -> list[str]: + """Return the sources the model needs as input for describing the drug. + + :return: The sources the model needs as input for describing the drug. + """ + return ["fingerprints"] + + def build_model(self, hyperparameters: dict[str, Any]) -> None: + """Build the model. + + :param hyperparameters: The hyperparameters. + """ + # Log hyperparameters to wandb if enabled + self.log_hyperparameters(hyperparameters) + + self.hyperparameters = hyperparameters + + def _validate_gene_order(self, cell_line_input: FeatureDataset) -> None: + """ + Validate that the gene order in the PPI graph matches the gene expression feature order. + + :param cell_line_input: FeatureDataset with gene expression features + :raises ValueError: If gene order doesn't match or validation fails + """ + if self.ppi_graph_template is None: + raise RuntimeError("PPI graph template not loaded") + + # Check if the PPI graph has gene_names attributes + if not hasattr(self.ppi_graph_template, "gene_names"): + raise ValueError( + "PPI graph doesn't contain gene_names metadata. " + "Please regenerate the PPI graph using the updated create_ppi_graphs.py script." + ) + + ppi_gene_names = self.ppi_graph_template.gene_names + + # Get gene names from cell_line_input meta_info + if "gene_expression" not in cell_line_input.meta_info: + raise ValueError("cell_line_input doesn't contain gene_expression meta_info") + + expr_gene_names = list(cell_line_input.meta_info["gene_expression"]) + + # Validate number of genes matches + if len(ppi_gene_names) != len(expr_gene_names): + raise ValueError( + f"Gene count mismatch: PPI graph has {len(ppi_gene_names)} genes, " + f"but gene expression has {len(expr_gene_names)} genes. " + f"Ensure both use the same gene list (e.g., landmark_genes_reduced)." + ) + + # Validate gene order matches + for i, (ppi_gene, expr_gene) in enumerate(zip(ppi_gene_names, expr_gene_names, strict=False)): + if ppi_gene != expr_gene: + raise ValueError( + f"Gene order mismatch at position {i}: " + f"PPI graph has '{ppi_gene}' but gene expression has '{expr_gene}'. " + f"Regenerate PPI graph using: python -m drevalpy.datasets.featurizer.create_ppi_graphs" + ) + + print(f"✓ Validated: PPI graph and gene expression have matching gene order ({len(ppi_gene_names)} genes)") + + def _loader_kwargs(self) -> dict[str, Any]: + num_workers = int(self.hyperparameters.get("num_workers", 4)) + kw = { + "num_workers": num_workers, + "pin_memory": True, + } + if num_workers > 0: + kw["persistent_workers"] = True + kw["prefetch_factor"] = int(self.hyperparameters.get("prefetch_factor", 2)) + return kw + + def train( + self, + output: DrugResponseDataset, + cell_line_input: FeatureDataset, + drug_input: FeatureDataset | None = None, + output_earlystopping: DrugResponseDataset | None = None, + **kwargs, + ): + """Train the model. + + :param output: The output dataset. + :param cell_line_input: The cell line input dataset. + :param drug_input: The drug input dataset (fingerprints). + :param output_earlystopping: The early stopping output dataset. + :param kwargs: Additional arguments. + :raises RuntimeError: If PPI graph template is not loaded. + :raises ValueError: If drug_input is not provided. + :raises ValueError: If gene order doesn't match between PPI graph and gene expression. + """ + if self.ppi_graph_template is None: + raise RuntimeError("PPI graph template not loaded. Call load_drug_features() first.") + + if drug_input is None: + raise ValueError("drug_input (fingerprints) is required for PPIGraphGNN.") + + # Validate gene order consistency + self._validate_gene_order(cell_line_input) + + # Determine feature sizes + num_genes = next(iter(cell_line_input.features.values()))["gene_expression"].shape[0] + num_drug_features = next(iter(drug_input.features.values()))["fingerprints"].shape[0] + + self.model = PPIGraphGNNModule( + num_genes=1, + num_drug_features=num_drug_features, + hidden_dim=self.hyperparameters.get("hidden_dim", 64), + num_gnn_layers=self.hyperparameters.get("num_gnn_layers", 3), + dropout=self.hyperparameters.get("dropout", 0.2), + learning_rate=self.hyperparameters.get("learning_rate", 0.001), + ) + + # Initialize GNNExplainer + self.explainer = Explainer( + model=self.model.model, + algorithm=GNNExplainer(epochs=200), + explanation_type="model", + node_mask_type="attributes", + edge_mask_type="object", + model_config=dict( + mode="regression", + task_level="graph", + return_type="raw", + ), + ) + + train_dataset = _PPIGraphDataset( + response=output.response, + cell_line_ids=output.cell_line_ids, + drug_ids=output.drug_ids, + cell_line_features=cell_line_input, + drug_features=drug_input, + ppi_graph_template=self.ppi_graph_template, + ) + train_loader = DataLoader( + train_dataset, + batch_size=self.hyperparameters.get("batch_size", 32), + shuffle=True, + **self._loader_kwargs(), + ) + + val_loader = None + if output_earlystopping is not None and len(output_earlystopping) > 0: + val_dataset = _PPIGraphDataset( + response=output_earlystopping.response, + cell_line_ids=output_earlystopping.cell_line_ids, + drug_ids=output_earlystopping.drug_ids, + cell_line_features=cell_line_input, + drug_features=drug_input, + ppi_graph_template=self.ppi_graph_template, + ) + val_loader = DataLoader( + val_dataset, + batch_size=self.hyperparameters.get("batch_size", 32), + **self._loader_kwargs(), + ) + + # Set up wandb logger if project is provided + loggers = [] + if self.wandb_project is not None: + from pytorch_lightning.loggers import WandbLogger + + logger = WandbLogger(project=self.wandb_project, log_model=False) + loggers.append(logger) + + trainer = pl.Trainer( + max_epochs=self.hyperparameters.get("epochs", 100), + accelerator="auto", + devices="auto", + callbacks=[pl.callbacks.EarlyStopping(monitor="val_loss", mode="min", patience=5)] if val_loader else None, + logger=loggers if loggers else True, + enable_progress_bar=True, + log_every_n_steps=int(self.hyperparameters.get("log_every_n_steps", 50)), + precision=self.hyperparameters.get("precision", 32), + ) + trainer.fit(self.model, train_dataloaders=train_loader, val_dataloaders=val_loader) + + def predict( + self, + cell_line_ids: np.ndarray, + drug_ids: np.ndarray, + cell_line_input: FeatureDataset, + drug_input: FeatureDataset | None = None, + ) -> np.ndarray: + """Predict drug response. + + :param cell_line_ids: The cell line IDs. + :param drug_ids: The drug IDs. + :param cell_line_input: The cell line input dataset. + :param drug_input: The drug input dataset (fingerprints). + :raises RuntimeError: If the model has not been trained yet. + :raises RuntimeError: If PPI graph template is not loaded. + :raises ValueError: If drug_input is not provided. + :return: The predicted drug response. + """ + if len(cell_line_ids) == 0: + print("PPIGraphGNN predict: No cell line IDs provided; returning empty array.") + return np.array([]) + if self.model is None: + raise RuntimeError("Model has not been trained yet.") + if self.ppi_graph_template is None: + raise RuntimeError("PPI graph template not loaded.") + if drug_input is None: + raise ValueError("drug_input (fingerprints) is required for PPIGraphGNN.") + + self.model.eval() + + predict_dataset = _PPIGraphDataset( + response=np.zeros(len(cell_line_ids)), + cell_line_ids=cell_line_ids, + drug_ids=drug_ids, + cell_line_features=cell_line_input, + drug_features=drug_input, + ppi_graph_template=self.ppi_graph_template, + ) + predict_loader = DataLoader( + predict_dataset, + batch_size=self.hyperparameters.get("batch_size", 32), + **self._loader_kwargs(), + ) + + trainer = pl.Trainer(accelerator="auto", devices="auto", enable_progress_bar=False) + predictions_list = trainer.predict(self.model, dataloaders=predict_loader) + + if not predictions_list: + print("PPIGraphGNN predict: No predictions were made; returning empty array.") + return np.array([]) + + predictions_flat = [ + item for sublist in predictions_list for item in (sublist if isinstance(sublist, list) else [sublist]) + ] + + predictions = torch.cat(predictions_flat).cpu().numpy() + return predictions + + def explain( + self, + cell_line_id: str, + drug_id: str, + cell_line_input: FeatureDataset, + drug_input: FeatureDataset, + top_k_edges: int = 20, + ) -> dict[str, Any]: + """ + Use GNNExplainer to extract important subnetwork for a specific cell line-drug pair. + + :param cell_line_id: The cell line ID to explain. + :param drug_id: The drug ID to explain. + :param cell_line_input: The cell line input dataset. + :param drug_input: The drug input dataset. + :param top_k_edges: Number of top important edges to return. + :raises RuntimeError: If model or explainer is not initialized. + :return: Dictionary containing explanation with important edges and nodes. + """ + if self.model is None: + raise RuntimeError("Model has not been trained yet.") + if self.explainer is None: + raise RuntimeError("Explainer not initialized. Train the model first.") + if self.ppi_graph_template is None: + raise RuntimeError("PPI graph template not loaded.") + + self.model.eval() + + # Create graph with gene expression for this cell line + graph = self.ppi_graph_template.clone() + gene_expr = torch.tensor( + cell_line_input.features[cell_line_id]["gene_expression"], dtype=torch.float32 + ).unsqueeze(1) + graph.x = gene_expr + + # Get drug features + drug_features = torch.tensor(drug_input.features[drug_id]["fingerprints"], dtype=torch.float32).unsqueeze(0) + + # Get explanation + with torch.no_grad(): + explanation = self.explainer( + x=graph.x, + edge_index=graph.edge_index, + batch=torch.zeros(graph.num_nodes, dtype=torch.long), + drug_features=drug_features, + ) + + # Extract important edges + edge_mask = explanation.edge_mask.cpu().numpy() + edge_index = graph.edge_index.cpu().numpy() + + # Get top-k edges + top_edge_indices = np.argsort(edge_mask)[::-1][:top_k_edges] + important_edges = [(int(edge_index[0, i]), int(edge_index[1, i])) for i in top_edge_indices] + edge_scores = [float(edge_mask[i]) for i in top_edge_indices] + + # Get gene names if available + gene_names = getattr(self.ppi_graph_template, "gene_names", None) + if gene_names is not None: + important_edges_with_names = [ + (gene_names[src], gene_names[dst], score) + for (src, dst), score in zip(important_edges, edge_scores, strict=True) + ] + else: + important_edges_with_names = [ + (src, dst, score) for (src, dst), score in zip(important_edges, edge_scores, strict=True) + ] + + return { + "cell_line_id": cell_line_id, + "drug_id": drug_id, + "important_edges": important_edges_with_names, + "edge_mask": edge_mask, + "explanation": explanation, + } + + def load_cell_line_features(self, data_path: str, dataset_name: str) -> FeatureDataset: + """Loads the cell line features. + + :param data_path: Path to the gene expression + :param dataset_name: name of the dataset + :return: FeatureDataset containing the cell line gene expression features. + """ + # Load PPI graph first + ppi_graph_path = Path(data_path) / dataset_name / "ppi_graph.pt" + if not ppi_graph_path.exists(): + raise FileNotFoundError( + f"PPI graph not found at {ppi_graph_path}. " + f"Please run 'python -m drevalpy.datasets.featurizer.create_ppi_graphs {dataset_name}' first." + ) + + self.ppi_graph_template = torch.load(ppi_graph_path, weights_only=False) # noqa: S614 + print( + f"Loaded PPI graph with {self.ppi_graph_template.num_nodes} nodes " + f"and {self.ppi_graph_template.num_edges} edges" + ) + + return load_and_select_gene_features( + feature_type="gene_expression", + gene_list=None, + data_path=data_path, + dataset_name=dataset_name, + ) + + def load_drug_features(self, data_path: str, dataset_name: str) -> FeatureDataset: + """Loads the drug features (fingerprints) and PPI graph. + + :param data_path: Path to the data directory. + :param dataset_name: Name of the dataset. + :return: FeatureDataset containing drug fingerprints. + """ + + # Load drug fingerprints + return load_drug_fingerprint_features(data_path, dataset_name, fill_na=True) + + def save(self, directory: str) -> None: + """ + Save the trained model, hyperparameters, and gene expression scaler to the given directory. + + This enables full reconstruction of the model using `load`. + + Files saved: + - model.pt: PyTorch state_dict of the trained model + - hyperparameters.json: Dictionary containing all relevant model hyperparameters + - ppi_graph.pt: PPI graph template + + :param directory: Target directory to store all model artifacts + """ + path = Path(directory) + path.mkdir(parents=True, exist_ok=True) + + torch.save(self.model.state_dict(), path / "model.pt") # noqa: S614 + + with open(path / "hyperparameters.json", "w") as f: + json.dump(self.hyperparameters, f) + + torch.save(self.ppi_graph_template, path / "ppi_graph.pt") + + @classmethod + def load(cls, directory: str) -> "PPIGraphGNN": + """ + Load a trained PPI Graph GNN model from the given directory. + + This includes: + - model.pt: PyTorch state_dict of the trained model + - hyperparameters.json: Dictionary containing all relevant model hyperparameters + - ppi_graph.pt: PPI graph template + + :param directory: The path to load the model from. + :return: The loaded PPIGraphGNN model. + :raises FileNotFoundError: If any of the required files are not found. + """ + path = Path(directory) + + hpam_path = path / "hyperparameters.json" + model_file = path / "model.pt" + ppi_graph_path = path / "ppi_graph.pt" + if not hpam_path.exists() or not model_file.exists() or not ppi_graph_path.exists(): + raise FileNotFoundError( + f"Missing required files in {directory}. " + f"Please make sure all files are present and try again." + ) + + instance = cls() + + with open(hpam_path) as f: + instance.hyperparameters = json.load(f) + + instance.ppi_graph_template = torch.load(ppi_graph_path, weights_only=False) # noqa: S614 + + instance.model = PPIGraphGNNModule( + num_genes=instance.hyperparameters["num_genes"], + num_drug_features=instance.hyperparameters["num_drug_features"], + hidden_dim=instance.hyperparameters.get("hidden_dim", 64), + num_gnn_layers=instance.hyperparameters.get("num_gnn_layers", 3), + dropout=instance.hyperparameters.get("dropout", 0.2), + learning_rate=instance.hyperparameters.get("learning_rate", 0.001), + ) + instance.model.load_state_dict(torch.load(model_file, weights_only=True)) + instance.model.eval() + + # Reinitialize explainer + instance.explainer = Explainer( + model=instance.model.model, + algorithm=GNNExplainer(epochs=200), + explanation_type="model", + node_mask_type="attributes", + edge_mask_type="object", + model_config=dict( + mode="regression", + task_level="graph", + return_type="raw", + ), + ) + return instance diff --git a/drevalpy/models/__init__.py b/drevalpy/models/__init__.py index 5ecf2e4f..015781bf 100644 --- a/drevalpy/models/__init__.py +++ b/drevalpy/models/__init__.py @@ -30,6 +30,7 @@ "DrugGNN", "ChemBERTaNeuralNetwork", "PharmaFormerModel", + "PPIGraphGNN", ] from .baselines.multi_omics_random_forest import MultiOmicsRandomForest @@ -56,6 +57,7 @@ from .DrugGNN import DrugGNN from .MOLIR.molir import MOLIR from .PharmaFormer.pharmaformer import PharmaFormerModel +from .PPIGraphGNN import PPIGraphGNN from .SimpleNeuralNetwork.multiomics_neural_network import MultiOmicsNeuralNetwork from .SimpleNeuralNetwork.simple_neural_network import ChemBERTaNeuralNetwork, SimpleNeuralNetwork from .SRMF.srmf import SRMF @@ -93,6 +95,7 @@ "DrugGNN": DrugGNN, "ChemBERTaNeuralNetwork": ChemBERTaNeuralNetwork, "PharmaFormer": PharmaFormerModel, + "PPIGraphGNN": PPIGraphGNN, } # MODEL_FACTORY is used in the pipeline! From 8097086eb5b553bc4f6495fd0f9d912de71585f1 Mon Sep 17 00:00:00 2001 From: JudithBernett Date: Fri, 6 Mar 2026 15:03:51 +0100 Subject: [PATCH 2/4] Fixing linting --- drevalpy/cli.py | 3 +- .../datasets/featurizer/create_ppi_graphs.py | 8 +- drevalpy/models/DrugGNN/drug_gnn.py | 2 +- drevalpy/models/PPIGraphGNN/README.md | 209 ++++++++++++++++++ drevalpy/models/PPIGraphGNN/ppi_graph_gnn.py | 14 +- 5 files changed, 224 insertions(+), 12 deletions(-) create mode 100644 drevalpy/models/PPIGraphGNN/README.md diff --git a/drevalpy/cli.py b/drevalpy/cli.py index db2bdf67..b5b286e7 100644 --- a/drevalpy/cli.py +++ b/drevalpy/cli.py @@ -8,5 +8,6 @@ def cli_main(): args = get_parser().parse_args() main(args) + if __name__ == "__main__": - cli_main() \ No newline at end of file + cli_main() diff --git a/drevalpy/datasets/featurizer/create_ppi_graphs.py b/drevalpy/datasets/featurizer/create_ppi_graphs.py index 09e8ee80..2799e03e 100644 --- a/drevalpy/datasets/featurizer/create_ppi_graphs.py +++ b/drevalpy/datasets/featurizer/create_ppi_graphs.py @@ -9,7 +9,6 @@ import argparse from pathlib import Path -import numpy as np import pandas as pd import torch from torch_geometric.data import Data @@ -24,6 +23,7 @@ def _load_ppi_network(ppi_file: Path, gene_list_file: Path) -> Data: :param ppi_file: Path to the PPI network CSV file with columns [gene_id_1, gene_id_2, (optional) interaction_score] :param gene_list_file: Path to the gene list CSV (e.g., landmark_genes_reduced.csv) that defines gene order + :raises ValueError: If the PPI CSV does not contain the required columns or if the gene list file is not found :return: A Data object representing the PPI network graph """ # Load the gene list to get the ordered list of genes (same as will be used for gene expression) @@ -46,9 +46,7 @@ def _load_ppi_network(ppi_file: Path, gene_list_file: Path) -> Data: # Validate columns required_cols = {"gene_id_1", "gene_id_2"} if not required_cols.issubset(ppi_df.columns): - raise ValueError( - f"PPI CSV must contain columns 'gene_id_1' and 'gene_id_2'. Found: {ppi_df.columns.tolist()}" - ) + raise ValueError(f"PPI CSV must contain columns 'gene_id_1' and 'gene_id_2'. Found: {ppi_df.columns.tolist()}") # Build edge list (only include genes that exist in gene expression) edge_list = [] @@ -147,7 +145,7 @@ def main(): print(f" Nodes (genes): {graph.num_nodes}") print(f" Edges (interactions): {graph.num_edges}") if graph.edge_attr is not None: - print(f" Edge attributes: Yes") + print(" Edge attributes: Yes") print(f"\nGene order matches: {args.gene_list}") print(f"First 5 genes: {graph.gene_names[:5]}") except Exception as e: diff --git a/drevalpy/models/DrugGNN/drug_gnn.py b/drevalpy/models/DrugGNN/drug_gnn.py index 55e06721..59a47194 100644 --- a/drevalpy/models/DrugGNN/drug_gnn.py +++ b/drevalpy/models/DrugGNN/drug_gnn.py @@ -456,7 +456,7 @@ def load_drug_features(self, data_path: str, dataset_name: str) -> FeatureDatase return FeatureDataset(features=feature_dict) - def save(self, directory: str) -> None: + def save(self, directory: str) -> None: """ Save the trained model, hyperparameters, and gene expression scaler to the given directory. diff --git a/drevalpy/models/PPIGraphGNN/README.md b/drevalpy/models/PPIGraphGNN/README.md new file mode 100644 index 00000000..cf40f50c --- /dev/null +++ b/drevalpy/models/PPIGraphGNN/README.md @@ -0,0 +1,209 @@ +# PPIGraphGNN Model + +A Graph Neural Network (GNN) model for drug response prediction that uses protein-protein interaction (PPI) networks, gene expression data, and drug fingerprints. The model includes GNNExplainer for extracting drug-cell-line-specific subnetworks for interpretability. + +## Overview + +**PPIGraphGNN** combines: + +- Gene expression vectors (cell line features) +- Drug fingerprints (drug features) +- PPI network structure (protein interactions) +- Graph Convolutional Networks (GCN) for learning from network topology +- GNNExplainer for interpretable predictions + +## Architecture + +1. **Input**: + + - Gene expression vector for each cell line + - Drug fingerprints for each drug + - PPI network as a graph (edges represent protein-protein interactions) + +2. **Model**: + + - **PPI Graph Encoder**: Node features are initialized with gene expression values, then multiple GCN layers propagate information through the PPI network + - **Drug Encoder**: MLP processes drug fingerprints + - **Combiner**: Concatenates PPI graph embedding and drug embedding + - **Predictor**: Fully connected layers predict drug response from combined features + +3. **Explainability**: + - GNNExplainer identifies important PPI subnetworks for each drug-cell line prediction + - Returns top-k edges and their importance scores specific to each drug-cell line pair + +## Usage + +### 1. Prepare PPI Network Data + +Create a CSV file with PPI network at `data/{dataset_name}/ppi_network.csv`: + +```csv +gene_id_1,gene_id_2,interaction_score +BRCA1,BRCA2,0.95 +TP53,MDM2,0.88 +EGFR,PIK3CA,0.72 +... +``` + +**Required columns:** + +- `gene_id_1`: First gene/protein identifier +- `gene_id_2`: Second gene/protein identifier +- `interaction_score` (optional): Confidence score for the interaction (0-1) + +**Important:** Gene IDs must match those in your gene expression data. + +### 2. Generate PPI Graph + +Run the preprocessing script to convert the PPI CSV to a PyTorch Geometric graph: + +```bash +python -m drevalpy.datasets.featurizer.create_ppi_graphs GDSC1 --data_path data --gene_list landmark_genes_reduced +``` + +This creates `data/GDSC1/ppi_graph.pt` containing the graph structure. + +**Important:** The `--gene_list` parameter must match the gene list used by the model (default: `landmark_genes_reduced`). This ensures the gene order in the PPI graph matches the gene expression feature order. + +### 3. Train the Model + +```python +from drevalpy.models import PPIGraphGNN +from drevalpy.datasets.dataset import DrugResponseDataset, FeatureDataset + +# Initialize model +model = PPIGraphGNN() + +# Build model with hyperparameters +model.build_model({ + "hidden_dim": 64, + "num_gnn_layers": 3, + "dropout": 0.2, + "learning_rate": 0.001, + "epochs": 100, + "batch_size": 32 +}) + +# Load features (also loads PPI graph automatically in load_drug_features) +cell_line_features = model.load_cell_line_features("data", "GDSC1") +drug_features = model.load_drug_features("data", "GDSC1") # Loads PPI graph + drug fingerprints + +# Train +model.train( + output=train_dataset, + cell_line_input=cell_line_features, + drug_input=drug_features, # Required for drug fingerprints + output_earlystopping=val_dataset +) + +# Predict +predictions = model.predict( + cell_line_ids=test_cell_line_ids, + drug_ids=test_drug_ids, + cell_line_input=cell_line_features, + drug_input=drug_features # Required for drug fingerprints +) +``` + +### 4. Extract Explanations + +Use GNNExplainer to get important subnetworks for specific drug-cell line pairs: + +```python +# Get explanation for a specific drug-cell line pair +explanation = model.explain( + cell_line_id="ACH-000001", + drug_id="123456", + cell_line_input=cell_line_features, + drug_input=drug_features, + top_k_edges=20 # Number of top edges to return +) + +# Access results +print(f"Cell line: {explanation['cell_line_id']}") +print(f"Drug: {explanation['drug_id']}") +print(f"Important PPI interactions for this drug-cell line pair:") +for gene1, gene2, score in explanation['important_edges']: + print(f" {gene1} <-> {gene2}: {score:.4f}") +``` + +## Hyperparameters + +Configurable in `hyperparameters.yaml`: + +- `learning_rate`: Learning rate for optimizer (default: 0.001) +- `epochs`: Number of training epochs (default: 100) +- `hidden_dim`: Hidden dimension size (default: 64) +- `num_gnn_layers`: Number of GCN layers (default: 3) +- `dropout`: Dropout probability (default: 0.2) +- `batch_size`: Batch size for training (default: 32) + +## Requirements + +The model requires: + +- `torch_geometric` with GNNExplainer +- Gene expression data with landmark genes +- Drug fingerprints +- PPI network in CSV format + +## Model Properties + +- **cell_line_views**: `["gene_expression"]` +- **drug_views**: `["fingerprints"]` +- **is_single_drug_model**: `False` +- **early_stopping**: Supported via validation dataset + +## Output + +The `explain()` method returns a dictionary with: + +- `cell_line_id`: The cell line being explained +- `drug_id`: The drug being explained +- `important_edges`: List of tuples `(gene1, gene2, score)` for top-k edges +- `edge_mask`: Full edge importance scores for all edges +- `explanation`: Raw GNNExplainer output + +## Example PPI Network Sources + +Common sources for PPI networks: + +- **STRING**: https://string-db.org/ (comprehensive, includes scores) +- **BioGRID**: https://thebiogrid.org/ (curated interactions) +- **IntAct**: https://www.ebi.ac.uk/intact/ (molecular interaction database) +- **HIPPIE**: http://cbdm-01.zdv.uni-mainz.de/~mschaefer/hippie/ (human integrated protein-protein interaction) + +## Notes + +1. **Gene Order Consistency**: The order of nodes in the PPI graph MUST match the order of genes in the gene expression features. The preprocessing script ensures this by using the same gene list file (e.g., `landmark_genes_reduced.csv`). The model validates this at training time and will raise an error if there's a mismatch. + +2. The PPI graph structure is shared across all samples; only node features (gene expression) vary per cell line + +3. Drug features (fingerprints) are used to distinguish between different drugs + +4. The model uses landmark genes by default - ensure your PPI network includes these genes + +5. GNNExplainer provides drug-cell-line-specific explanations by considering both the PPI network context and drug features + +6. GNNExplainer is computationally intensive; use `top_k_edges` parameter to limit output + +## How Gene Order is Maintained + +When you run: + +```python +graph.x = gene_expr.unsqueeze(1) +``` + +The gene expression vector is assigned to graph nodes. The model ensures correct mapping by: + +1. **PPI Graph Creation**: Uses the same gene list file (e.g., `landmark_genes_reduced.csv`) to define node order +2. **Gene Expression Loading**: `load_and_select_gene_features()` uses the same gene list to order features +3. **Runtime Validation**: The model validates that both orders match before training + +Example: + +- Gene list: `["TP53", "EGFR", "BRCA1", ...]` +- PPI graph nodes: `[0: TP53, 1: EGFR, 2: BRCA1, ...]` +- Gene expression: `[expr_TP53, expr_EGFR, expr_BRCA1, ...]` +- Assignment: Node 0 gets expr_TP53, Node 1 gets expr_EGFR, etc. diff --git a/drevalpy/models/PPIGraphGNN/ppi_graph_gnn.py b/drevalpy/models/PPIGraphGNN/ppi_graph_gnn.py index 5135d9be..07b5a400 100644 --- a/drevalpy/models/PPIGraphGNN/ppi_graph_gnn.py +++ b/drevalpy/models/PPIGraphGNN/ppi_graph_gnn.py @@ -25,7 +25,12 @@ class PPIGraphNet(nn.Module): """Graph Neural Network for processing PPI networks with gene expression and drug features.""" def __init__( - self, num_genes: int, num_drug_features: int, hidden_dim: int = 64, num_gnn_layers: int = 3, dropout: float = 0.2 + self, + num_genes: int, + num_drug_features: int, + hidden_dim: int = 64, + num_gnn_layers: int = 3, + dropout: float = 0.2, ): """Initialize the network. @@ -294,6 +299,7 @@ def _validate_gene_order(self, cell_line_input: FeatureDataset) -> None: :param cell_line_input: FeatureDataset with gene expression features :raises ValueError: If gene order doesn't match or validation fails + :raises RuntimeError: If PPI graph template is not loaded """ if self.ppi_graph_template is None: raise RuntimeError("PPI graph template not loaded") @@ -372,7 +378,6 @@ def train( self._validate_gene_order(cell_line_input) # Determine feature sizes - num_genes = next(iter(cell_line_input.features.values()))["gene_expression"].shape[0] num_drug_features = next(iter(drug_input.features.values()))["fingerprints"].shape[0] self.model = PPIGraphGNNModule( @@ -588,6 +593,7 @@ def load_cell_line_features(self, data_path: str, dataset_name: str) -> FeatureD :param data_path: Path to the gene expression :param dataset_name: name of the dataset + :raises FileNotFoundError: If PPI graph is not found at the specified path. :return: FeatureDataset containing the cell line gene expression features. """ # Load PPI graph first @@ -618,7 +624,6 @@ def load_drug_features(self, data_path: str, dataset_name: str) -> FeatureDatase :param dataset_name: Name of the dataset. :return: FeatureDataset containing drug fingerprints. """ - # Load drug fingerprints return load_drug_fingerprint_features(data_path, dataset_name, fill_na=True) @@ -666,8 +671,7 @@ def load(cls, directory: str) -> "PPIGraphGNN": ppi_graph_path = path / "ppi_graph.pt" if not hpam_path.exists() or not model_file.exists() or not ppi_graph_path.exists(): raise FileNotFoundError( - f"Missing required files in {directory}. " - f"Please make sure all files are present and try again." + f"Missing required files in {directory}. " f"Please make sure all files are present and try again." ) instance = cls() From 999c1be15a6d9d9c0c918faf7c7de2d357ee79ea Mon Sep 17 00:00:00 2001 From: JudithBernett Date: Fri, 6 Mar 2026 16:03:15 +0100 Subject: [PATCH 3/4] moved the explainer --- drevalpy/models/PPIGraphGNN/ppi_graph_gnn.py | 57 +++++++------------- 1 file changed, 20 insertions(+), 37 deletions(-) diff --git a/drevalpy/models/PPIGraphGNN/ppi_graph_gnn.py b/drevalpy/models/PPIGraphGNN/ppi_graph_gnn.py index 07b5a400..3f5b03bf 100644 --- a/drevalpy/models/PPIGraphGNN/ppi_graph_gnn.py +++ b/drevalpy/models/PPIGraphGNN/ppi_graph_gnn.py @@ -389,20 +389,6 @@ def train( learning_rate=self.hyperparameters.get("learning_rate", 0.001), ) - # Initialize GNNExplainer - self.explainer = Explainer( - model=self.model.model, - algorithm=GNNExplainer(epochs=200), - explanation_type="model", - node_mask_type="attributes", - edge_mask_type="object", - model_config=dict( - mode="regression", - task_level="graph", - return_type="raw", - ), - ) - train_dataset = _PPIGraphDataset( response=output.response, cell_line_ids=output.cell_line_ids, @@ -533,8 +519,6 @@ def explain( """ if self.model is None: raise RuntimeError("Model has not been trained yet.") - if self.explainer is None: - raise RuntimeError("Explainer not initialized. Train the model first.") if self.ppi_graph_template is None: raise RuntimeError("PPI graph template not loaded.") @@ -549,15 +533,27 @@ def explain( # Get drug features drug_features = torch.tensor(drug_input.features[drug_id]["fingerprints"], dtype=torch.float32).unsqueeze(0) + # Initialize GNNExplainer + self.explainer = Explainer( + model=self.model.model, + algorithm=GNNExplainer(epochs=200), + explanation_type="model", + node_mask_type="attributes", + edge_mask_type="object", + model_config=dict( + mode="regression", + task_level="graph", + return_type="raw", + ), + ) # Get explanation - with torch.no_grad(): - explanation = self.explainer( - x=graph.x, - edge_index=graph.edge_index, - batch=torch.zeros(graph.num_nodes, dtype=torch.long), - drug_features=drug_features, - ) + explanation = self.explainer( + x=graph.x, + edge_index=graph.edge_index, + batch=torch.zeros(graph.num_nodes, dtype=torch.long), + drug_features=drug_features, + ) # Extract important edges edge_mask = explanation.edge_mask.cpu().numpy() @@ -646,7 +642,7 @@ def save(self, directory: str) -> None: torch.save(self.model.state_dict(), path / "model.pt") # noqa: S614 with open(path / "hyperparameters.json", "w") as f: - json.dump(self.hyperparameters, f) + json.dump(self.model.hparams, f) torch.save(self.ppi_graph_template, path / "ppi_graph.pt") @@ -692,17 +688,4 @@ def load(cls, directory: str) -> "PPIGraphGNN": instance.model.load_state_dict(torch.load(model_file, weights_only=True)) instance.model.eval() - # Reinitialize explainer - instance.explainer = Explainer( - model=instance.model.model, - algorithm=GNNExplainer(epochs=200), - explanation_type="model", - node_mask_type="attributes", - edge_mask_type="object", - model_config=dict( - mode="regression", - task_level="graph", - return_type="raw", - ), - ) return instance From 6d3175cd671ff5756714c079dbc8102353a0ecd6 Mon Sep 17 00:00:00 2001 From: JudithBernett Date: Mon, 9 Mar 2026 11:23:33 +0100 Subject: [PATCH 4/4] mypy fixes --- drevalpy/datasets/curvecurator.py | 15 ++-- .../datasets/featurizer/create_ppi_graphs.py | 2 +- drevalpy/evaluation.py | 4 +- drevalpy/models/DIPK/data_utils.py | 6 +- drevalpy/models/DrugGNN/drug_gnn.py | 2 +- drevalpy/models/PPIGraphGNN/ppi_graph_gnn.py | 6 +- drevalpy/models/SRMF/srmf.py | 2 +- drevalpy/models/drp_model.py | 10 +-- .../visualization/critical_difference_plot.py | 4 +- drevalpy/visualization/utils.py | 74 ++++++++++++------- 10 files changed, 76 insertions(+), 49 deletions(-) diff --git a/drevalpy/datasets/curvecurator.py b/drevalpy/datasets/curvecurator.py index a3655b1f..87b957be 100644 --- a/drevalpy/datasets/curvecurator.py +++ b/drevalpy/datasets/curvecurator.py @@ -15,10 +15,12 @@ import subprocess import warnings from pathlib import Path +from typing import Union import numpy as np import pandas as pd import toml +from pandas.core.groupby import DataFrameGroupBy from drevalpy.datasets.utils import CELL_LINE_IDENTIFIER, DRUG_IDENTIFIER @@ -40,7 +42,7 @@ def _prepare_raw_data(curve_df: pd.DataFrame, output_dir: Path, prefix: str = "" UserWarning, stacklevel=1, ) - curve_df = curve_df.groupby(["sample", "drug", "dose", "replicate"], as_index=False)["response"].mean() + curve_df = curve_df.groupby(["sample", "drug", "dose", "replicate"], as_index=False)[["response"]].mean() df = curve_df.pivot(index=["sample", "drug"], columns=pivot_columns, values="response") @@ -177,7 +179,7 @@ def ic50(front, back, slope, pec50): back = model_params_df["Back"].values slope = model_params_df["Slope"].values # we need the pEC50 in uM; now it is in M: -log10(EC50[M] * 10^6) = -log10(EC50[M])-6 = pEC50 -6 - pec50 = model_params_df["pEC50_curvecurator"].values - 6 + pec50 = model_params_df["pEC50_curvecurator"].to_numpy(dtype=float) - 6 model_params_df["IC50_curvecurator"] = ic50(front, back, slope, pec50) model_params_df["LN_IC50_curvecurator"] = np.log(model_params_df["IC50_curvecurator"].values) @@ -233,6 +235,7 @@ def preprocess(input_file: str, output_dir: str, dataset_name: str, cores: int, if curve_df["nreplicates"].nunique() > 1: groupby.append("nreplicates") + drug_df_groups: Union[DataFrameGroupBy, list[tuple[str, pd.DataFrame]]] if len(groupby) > 0: drug_df_groups = curve_df.groupby(groupby) else: @@ -301,17 +304,17 @@ def postprocess(output_folder: str, dataset_name: str): with open(output_path / f"{dataset_name}.csv", "w") as f: first_file = True for output_file in curvecurator_output_files: - fitted_curve_data = pd.read_csv(output_file, sep="\t", usecols=required_columns).rename( + fitted_curve_data = pd.read_csv(output_file, sep="\t", usecols=list(required_columns.values())).rename( columns=required_columns ) - fitted_curve_data[[CELL_LINE_IDENTIFIER, DRUG_IDENTIFIER]] = fitted_curve_data.Name.str.split( + fitted_curve_data[[CELL_LINE_IDENTIFIER, DRUG_IDENTIFIER]] = fitted_curve_data["Name"].str.split( "|", expand=True ) fitted_curve_data["EC50_curvecurator"] = ( - np.power(10, -fitted_curve_data["pEC50_curvecurator"].values) * 10**6 + np.power(10, -fitted_curve_data["pEC50_curvecurator"].to_numpy(dtype=float)) * 10**6 ) # in CurveCurator 10^-pEC50 = EC50 _calc_ic50(fitted_curve_data) - fitted_curve_data.to_csv(f, index=None, header=first_file, mode="a") + fitted_curve_data.to_csv(f, index=False, header=first_file, mode="a") first_file = False f.close() diff --git a/drevalpy/datasets/featurizer/create_ppi_graphs.py b/drevalpy/datasets/featurizer/create_ppi_graphs.py index 2799e03e..fb41e1df 100644 --- a/drevalpy/datasets/featurizer/create_ppi_graphs.py +++ b/drevalpy/datasets/featurizer/create_ppi_graphs.py @@ -34,7 +34,7 @@ def _load_ppi_network(ppi_file: Path, gene_list_file: Path) -> Data: genes = gene_list_df["gene"].tolist() else: # Gene expression file -> columns are genes - genes = gene_list_df.columns + genes = list(gene_list_df.columns) genes = [g for g in genes if g not in ["cellosaurus_id", "cell_line_name"]] # Create a mapping from gene name to index diff --git a/drevalpy/evaluation.py b/drevalpy/evaluation.py index 8d239e5e..4dea98e4 100644 --- a/drevalpy/evaluation.py +++ b/drevalpy/evaluation.py @@ -52,7 +52,7 @@ def pearson(y_pred: np.ndarray, y_true: np.ndarray) -> float: if _check_constant_target_or_small_sample(y_true): return np.nan - return pearsonr(y_pred, y_true)[0] + return pearsonr(y_pred, y_true).statistic def spearman(y_pred: np.ndarray, y_true: np.ndarray) -> float: @@ -72,7 +72,7 @@ def spearman(y_pred: np.ndarray, y_true: np.ndarray) -> float: if _check_constant_target_or_small_sample(y_true): return np.nan - return spearmanr(y_pred, y_true)[0] + return spearmanr(y_pred, y_true).statistic def kendall(y_pred: np.ndarray, y_true: np.ndarray) -> float: diff --git a/drevalpy/models/DIPK/data_utils.py b/drevalpy/models/DIPK/data_utils.py index 11926e3f..c5e183f3 100644 --- a/drevalpy/models/DIPK/data_utils.py +++ b/drevalpy/models/DIPK/data_utils.py @@ -53,14 +53,16 @@ def load_bionic_features(data_path: str, dataset_name: str, gene_add_num: int = # Aggregate BIONIC features for selected genes selected_features = [bionic_gene_dict[gene] for gene in top_genes if gene in bionic_gene_dict] if selected_features: - aggregated_feature = np.mean(selected_features, axis=0) + aggregated_feature = np.mean(np.array(selected_features), axis=0) else: # Handle case where no features are found (padding with zeros) aggregated_feature = np.zeros(next(iter(bionic_gene_dict.values())).shape) bionic_feature_dict[cell_line] = aggregated_feature - feature_data = {cell_line: {"bionic_features": features} for cell_line, features in bionic_feature_dict.items()} + feature_data = { + str(cell_line): {"bionic_features": features} for cell_line, features in bionic_feature_dict.items() + } return FeatureDataset(features=feature_data) diff --git a/drevalpy/models/DrugGNN/drug_gnn.py b/drevalpy/models/DrugGNN/drug_gnn.py index 59a47194..54f94614 100644 --- a/drevalpy/models/DrugGNN/drug_gnn.py +++ b/drevalpy/models/DrugGNN/drug_gnn.py @@ -231,7 +231,7 @@ class DrugGNN(DRPModel): def __init__(self): """Initialize the DrugGNN model.""" super().__init__() - self.model: DrugGNNModule | None = None + self.model = None self.hyperparameters = {} @classmethod diff --git a/drevalpy/models/PPIGraphGNN/ppi_graph_gnn.py b/drevalpy/models/PPIGraphGNN/ppi_graph_gnn.py index 3f5b03bf..a71284ff 100644 --- a/drevalpy/models/PPIGraphGNN/ppi_graph_gnn.py +++ b/drevalpy/models/PPIGraphGNN/ppi_graph_gnn.py @@ -254,10 +254,10 @@ class PPIGraphGNN(DRPModel): def __init__(self): """Initialize the PPIGraphGNN model.""" super().__init__() - self.model: PPIGraphGNNModule | None = None + self.model = None self.hyperparameters = {} - self.ppi_graph_template: Data | None = None - self.explainer: Explainer | None = None + self.ppi_graph_template = None + self.explainer = None @classmethod def get_model_name(cls) -> str: diff --git a/drevalpy/models/SRMF/srmf.py b/drevalpy/models/SRMF/srmf.py index 8e1aaeb0..8123039c 100644 --- a/drevalpy/models/SRMF/srmf.py +++ b/drevalpy/models/SRMF/srmf.py @@ -135,7 +135,7 @@ def train( index=cell_lines, columns=drugs ) # missing rows and columns are filled with NaN - self.w = ~np.isnan(drug_response_matrix) + self.w = pd.DataFrame(~np.isnan(drug_response_matrix)) drug_response_matrix = drug_response_matrix.copy() drug_response_matrix[np.isnan(drug_response_matrix)] = 0 diff --git a/drevalpy/models/drp_model.py b/drevalpy/models/drp_model.py index 7599be2e..87048e4b 100644 --- a/drevalpy/models/drp_model.py +++ b/drevalpy/models/drp_model.py @@ -38,11 +38,11 @@ class DRPModel(ABC): def __init__(self): """Initialize the DRPModel instance.""" - self.wandb_project: str | None = None - self.wandb_run: Any = None - self.wandb_config: dict[str, Any] | None = None - self.hyperparameters: dict[str, Any] = {} - self._in_hyperparameter_tuning: bool = False # Flag to track if we're in hyperparameter tuning + self.wandb_project = None + self.wandb_run = None + self.wandb_config = None + self.hyperparameters = {} + self._in_hyperparameter_tuning = False # Flag to track if we're in hyperparameter tuning def init_wandb( self, diff --git a/drevalpy/visualization/critical_difference_plot.py b/drevalpy/visualization/critical_difference_plot.py index a7f3f7d0..01c41d61 100644 --- a/drevalpy/visualization/critical_difference_plot.py +++ b/drevalpy/visualization/critical_difference_plot.py @@ -297,7 +297,7 @@ def _critical_difference_diagram( ) # for each algorithm: get the set of algorithms that are not significantly different - crossbar_sets = dict() + crossbar_sets: dict[str, set[str]] = {} for alg, row in adj_matrix.iterrows(): not_different = adj_matrix.columns[row].tolist() crossbar_sets[alg] = set(not_different).union({alg}) @@ -307,7 +307,7 @@ def _critical_difference_diagram( ypos = -0.5 for alg in ranks.index: bar = crossbar_sets[alg] - not_different = crossbar_sets[alg] + not_different = list(crossbar_sets[alg]) if len(not_different) == 1: continue crossbar_levels.append([bar]) diff --git a/drevalpy/visualization/utils.py b/drevalpy/visualization/utils.py index ca4dcd64..be9d3705 100644 --- a/drevalpy/visualization/utils.py +++ b/drevalpy/visualization/utils.py @@ -64,7 +64,9 @@ def _parse_layout(f: TextIO, path_to_layout: str, test_mode: str) -> None: f.write("".join(layout)) -def parse_results(path_to_results: str, dataset: str) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame]: +def parse_results( + path_to_results: str, dataset: str +) -> tuple[pd.DataFrame, pd.DataFrame | None, pd.DataFrame | None, pd.DataFrame]: """ Parse the results from the given directory. @@ -88,10 +90,10 @@ def parse_results(path_to_results: str, dataset: str) -> tuple[pd.DataFrame, pd. result_files = [file for file in result_files if pattern.match(str(file).replace("\\", "/"))] # inititalize dictionaries to store the evaluation results - evaluation_results = None - evaluation_results_per_drug = None - evaluation_results_per_cell_line = None - true_vs_pred = None + evaluation_results_list: list[pd.DataFrame] = [] + evaluation_results_per_drug_list: list[pd.DataFrame] = [] + evaluation_results_per_cell_line_list: list[pd.DataFrame] = [] + true_vs_pred_list: list[pd.DataFrame] = [] # read every result file and compute the evaluation metrics for file in result_files: @@ -109,24 +111,23 @@ def parse_results(path_to_results: str, dataset: str) -> tuple[pd.DataFrame, pd. model_name, ) = evaluate_file(pred_file=file, test_mode=test_mode, model_name=algorithm) - evaluation_results = ( - overall_eval if evaluation_results is None else pd.concat([evaluation_results, overall_eval]) - ) - true_vs_pred = t_vs_p if true_vs_pred is None else pd.concat([true_vs_pred, t_vs_p]) + evaluation_results_list.append(overall_eval) + true_vs_pred_list.append(t_vs_p) if eval_results_per_drug is not None: - evaluation_results_per_drug = ( - eval_results_per_drug - if evaluation_results_per_drug is None - else pd.concat([evaluation_results_per_drug, eval_results_per_drug]) - ) + evaluation_results_per_drug_list.append(eval_results_per_drug) if eval_results_per_cl is not None: - evaluation_results_per_cell_line = ( - eval_results_per_cl - if evaluation_results_per_cell_line is None - else pd.concat([evaluation_results_per_cell_line, eval_results_per_cl]) - ) + evaluation_results_per_cell_line_list.append(eval_results_per_cl) + + evaluation_results = pd.concat(evaluation_results_list) + evaluation_results_per_drug = ( + pd.concat(evaluation_results_per_drug_list) if evaluation_results_per_drug_list else None + ) + evaluation_results_per_cell_line = ( + pd.concat(evaluation_results_per_cell_line_list) if evaluation_results_per_cell_line_list else None + ) + true_vs_pred = pd.concat(true_vs_pred_list) return ( evaluation_results, @@ -185,10 +186,10 @@ def evaluate_file( eval_results_per_group=evaluation_results_per_cl, model=model, ) - overall_eval = pd.DataFrame.from_dict(overall_eval, orient="index") + overall_eval_df = pd.DataFrame.from_dict(overall_eval, orient="index") return ( - overall_eval, + overall_eval_df, evaluation_results_per_drug, evaluation_results_per_cl, true_vs_pred, @@ -199,11 +200,11 @@ def evaluate_file( @pipeline_function def prep_results( eval_results: pd.DataFrame, - eval_results_per_drug: pd.DataFrame, - eval_results_per_cell_line: pd.DataFrame, + eval_results_per_drug: pd.DataFrame | None, + eval_results_per_cell_line: pd.DataFrame | None, t_vs_p: pd.DataFrame, path_data: pathlib.Path, -) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame]: +) -> tuple[pd.DataFrame, pd.DataFrame | None, pd.DataFrame | None, pd.DataFrame]: """ Prepare the results by introducing new columns for algorithm, randomization, test_mode, split, CV_split. @@ -450,8 +451,8 @@ def compute_evaluation(df: pd.DataFrame, return_df: pd.DataFrame | None, group_b def write_results( path_out: str, eval_results: pd.DataFrame, - eval_results_per_drug: pd.DataFrame, - eval_results_per_cl: pd.DataFrame, + eval_results_per_drug: pd.DataFrame | None, + eval_results_per_cl: pd.DataFrame | None, t_vs_p: pd.DataFrame, ) -> None: """ @@ -645,6 +646,11 @@ def draw_test_mode_plots( # per group plots if test_mode in ("LPO", "LDO"): + if ev_res_per_drug is None: + raise ValueError( + f"No evaluation results found for test_mode {test_mode} with drug information. " + "Please check if the evaluation was run correctly." + ) _draw_per_grouping_setting_plots( grouping="drug_name", ev_res_per_group=ev_res_per_drug, @@ -653,6 +659,11 @@ def draw_test_mode_plots( result_path=result_path, ) if test_mode in ("LPO", "LCO", "LTO"): + if ev_res_per_cell_line is None: + raise ValueError( + f"No evaluation results found for test_mode {test_mode} with cell line information. " + "Please check if the evaluation was run correctly." + ) _draw_per_grouping_setting_plots( grouping="cell_line_name", ev_res_per_group=ev_res_per_cell_line, @@ -717,6 +728,7 @@ def draw_algorithm_plots( :param test_mode: test_mode :param custom_id: run id passed via command line :param result_path: path to the results + :raises ValueError: if no group-wise evaluation results are found for the given test_mode and model """ eval_results_algorithm = ev_res[(ev_res["test_mode"] == test_mode) & (ev_res["algorithm"] == model)] for plt_type in ["violinplot", "heatmap"]: @@ -743,6 +755,11 @@ def draw_algorithm_plots( out_suffix=f"{model}_{test_mode}", ) if test_mode in ("LPO", "LDO"): + if ev_res_per_drug is None: + raise ValueError( + f"No drug evaluation results found for test_mode {test_mode} and model {model}. " + "Please check if the evaluation was run correctly." + ) _draw_per_grouping_algorithm_plots( grouping="drug_name", model=model, @@ -753,6 +770,11 @@ def draw_algorithm_plots( result_path=result_path, ) if test_mode in ("LPO", "LCO", "LTO"): + if ev_res_per_cell_line is None: + raise ValueError( + f"No cell line evaluation results found for test_mode {test_mode} and model {model}. " + "Please check if the evaluation was run correctly." + ) _draw_per_grouping_algorithm_plots( grouping="cell_line_name", model=model,