Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 4 additions & 13 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,21 +1,12 @@
# Data
data/cell_line_input
data/response_output
data/mapping
data/GDSC1
data/GDSC2
data/CCLE
data/TOYv1
data/TOYv2
data/CTRPv1
data/CTRPv2
data/meta
data/BeatAML2
data/PDX_Bruna
data/

# Results directory is created when running the demo notebook
results/

# Wandb directory is created when running the benchmark with wandb
wandb/

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
137 changes: 134 additions & 3 deletions drevalpy/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,13 @@
import torch
from sklearn.base import TransformerMixin

try:
import wandb
except ImportError:
wandb = None # type: ignore[assignment]

from .datasets.dataset import DrugResponseDataset, FeatureDataset, split_early_stopping_data
from .evaluation import evaluate, get_mode
from .evaluation import get_mode
from .models import MODEL_FACTORY, MULTI_DRUG_MODEL_FACTORY, SINGLE_DRUG_MODEL_FACTORY
from .models.drp_model import DRPModel
from .pipeline_function import pipeline_function
Expand Down Expand Up @@ -45,6 +50,7 @@ def drug_response_experiment(
model_checkpoint_dir: str = "TEMPORARY",
hyperparameter_tuning=True,
final_model_on_full_data: bool = False,
wandb_project: str | None = None,
) -> None:
"""
Run the drug response prediction experiment. Save results to disc.
Expand Down Expand Up @@ -97,6 +103,8 @@ def drug_response_experiment(
:param final_model_on_full_data: if True, a final/production model is saved in the results directory.
If hyperparameter_tuning is true, the final model is produced according to the hyperparameter tuning procedure
which was evaluated in the nested cross validation.
:param wandb_project: if provided, enables wandb logging for all DRPModel instances throughout training.
All hyperparameters and metrics will be logged to the specified wandb project.
:raises ValueError: if no cv splits are found
"""
# Default baseline model, needed for normalization
Expand Down Expand Up @@ -137,6 +145,7 @@ def drug_response_experiment(
)
response_data.save_splits(path=split_path)

# Build the list of models to run (done regardless of whether splits were newly created or loaded)
model_list = make_model_list(models + baselines, response_data)
for model_name in model_list.keys():
print(f"Running {model_name}")
Expand Down Expand Up @@ -189,6 +198,16 @@ def drug_response_experiment(
) = get_datasets_from_cv_split(split, model_class, model_name, drug_id)

model = model_class()
# Base wandb configuration for this split (used when training actually happens)
base_wandb_config = {
"model_name": model_name,
"drug_id": drug_id,
"split_index": split_index,
"test_mode": test_mode,
"dataset": response_data.dataset_name,
"n_cv_splits": n_cv_splits,
"hyperparameter_tuning": hyperparameter_tuning,
}

if not os.path.isfile(
prediction_file
Expand All @@ -205,6 +224,12 @@ def drug_response_experiment(
"model_checkpoint_dir": model_checkpoint_dir,
}

# During hyperparameter tuning, create separate wandb runs per trial if enabled
if wandb_project is not None:
tuning_inputs["wandb_project"] = wandb_project
tuning_inputs["split_index"] = split_index
tuning_inputs["wandb_base_config"] = base_wandb_config

if multiprocessing:
tuning_inputs["ray_path"] = os.path.abspath(os.path.join(result_path, "raytune"))
best_hpams = hpam_tune_raytune(**tuning_inputs)
Expand All @@ -213,6 +238,9 @@ def drug_response_experiment(

print(f"Best hyperparameters: {best_hpams}")
print("Training model on full train and validation set to predict test set")

# Log best hyperparameters to wandb (they will be logged when build_model is called)
# The best hyperparameters will be logged via build_model -> log_hyperparameters
# save best hyperparameters as json
with open(
hpam_save_path,
Expand All @@ -224,6 +252,25 @@ def drug_response_experiment(
train_dataset.add_rows(validation_dataset) # use full train val set data for final training
train_dataset.shuffle(random_state=42)

# Initialize wandb for the final training on the full train+validation set
# This happens regardless of whether hyperparameter tuning was performed
if wandb_project is not None:
final_run_name = f"{model_name}"
if drug_id is not None:
final_run_name += f"_{drug_id}"
final_run_name += f"_split_{split_index}_final"

final_config = {
**base_wandb_config,
"phase": "final_training",
}
model.init_wandb(
project=wandb_project,
config=final_config,
name=final_run_name,
tags=[model_name, test_mode, response_data.dataset_name or "unknown", "final"],
)

test_dataset = train_and_predict(
model=model,
hpams=best_hpams,
Expand All @@ -235,6 +282,24 @@ def drug_response_experiment(
model_checkpoint_dir=model_checkpoint_dir,
)

# Log final metrics on test set for all models
# Metrics will be logged as test_RMSE, test_R^2, test_Pearson, etc.
# This happens regardless of whether hyperparameter tuning was performed
if (
wandb_project is not None
and wandb is not None
and len(test_dataset) > 0
and test_dataset.predictions is not None
and len(test_dataset.predictions) > 0
):
# Ensure wandb run is active before logging metrics
if wandb.run is not None:
model.compute_and_log_final_metrics(
test_dataset,
additional_metrics=[hpam_optimization_metric],
prefix="test_",
)

for cross_study_dataset in cross_study_datasets:
print(f"Cross study prediction on {cross_study_dataset.dataset_name}")
cross_study_dataset.remove_nan_responses()
Expand All @@ -259,6 +324,10 @@ def drug_response_experiment(
encoding="utf-8",
) as f:
best_hpams = json.load(f)

# Finish wandb run for this split
if wandb_project is not None:
model.finish_wandb()
if not is_baseline:
if randomization_mode is not None:
print(f"Randomization tests for {model_class.get_model_name()}")
Expand Down Expand Up @@ -1063,7 +1132,20 @@ def train_and_evaluate(
response_transformation=response_transformation,
model_checkpoint_dir=model_checkpoint_dir,
)
return evaluate(validation_dataset, metric=[metric])

# Compute final metrics using DRPModel helper (always includes R^2 and PCC)
# Add primary metric if it's not already included
additional_metrics = None
if metric not in ["R^2", "Pearson"]:
additional_metrics = [metric]
# Use "val_" prefix to clearly denote validation metrics (val_RMSE, val_R^2, val_Pearson)
results = model.compute_and_log_final_metrics(
validation_dataset,
additional_metrics=additional_metrics,
prefix="val_",
)

return results


def hpam_tune(
Expand All @@ -1076,6 +1158,10 @@ def hpam_tune(
metric: str = "RMSE",
path_data: str = "data",
model_checkpoint_dir: str = "TEMPORARY",
*,
split_index: int | None = None,
wandb_project: str | None = None,
wandb_base_config: dict[str, Any] | None = None,
) -> dict:
"""
Tune the hyperparameters for the given model in an iterative manner.
Expand All @@ -1089,6 +1175,9 @@ def hpam_tune(
:param metric: metric to evaluate which model is the best
:param path_data: path to the data directory, e.g., data/
:param model_checkpoint_dir: directory to save model checkpoints
:param split_index: optional CV split index, used for naming wandb runs
:param wandb_project: optional wandb project name; if provided, enables per-trial wandb runs
:param wandb_base_config: optional base config dict to include in each wandb run
:returns: best hyperparameters
:raises AssertionError: if hpam_set is empty
"""
Expand All @@ -1097,11 +1186,44 @@ def hpam_tune(
if len(hpam_set) == 1:
return hpam_set[0]

# Mark that we're in hyperparameter tuning phase
# This prevents updating wandb.config during tuning - we'll only log final best hyperparameters
model._in_hyperparameter_tuning = True

best_hyperparameters = None
mode = get_mode(metric)
best_score = float("inf") if mode == "min" else float("-inf")
for hyperparameter in hpam_set:
for trial_idx, hyperparameter in enumerate(hpam_set):
print(f"Training model with hyperparameters: {hyperparameter}")

# Create a separate wandb run for each hyperparameter trial if enabled
if wandb_project is not None:
trial_run_name = model.get_model_name()
if split_index is not None:
trial_run_name += f"_split_{split_index}"
trial_run_name += f"_trial_{trial_idx}"

trial_config: dict[str, Any] = {}
if wandb_base_config is not None:
trial_config.update(wandb_base_config)
trial_config.update(
{
"phase": "hyperparameter_tuning",
"trial_index": trial_idx,
"hyperparameters": hyperparameter,
}
)

model.init_wandb(
project=wandb_project,
config=trial_config,
name=trial_run_name,
tags=[model.get_model_name(), "hpam_tuning"],
finish_previous=True,
)

# During hyperparameter tuning, don't update wandb config via log_hyperparameters
# Trial hyperparameters are stored in wandb.config for each run
score = train_and_evaluate(
model=model,
hpams=hyperparameter,
Expand All @@ -1114,14 +1236,23 @@ def hpam_tune(
model_checkpoint_dir=model_checkpoint_dir,
)[metric]

# Note: train_and_evaluate() already logs val_* metrics once via
# DRPModel.compute_and_log_final_metrics(..., prefix="val_").
# Avoid logging val_{metric} again here (it would create duplicate points).
if np.isnan(score):
if model.is_wandb_enabled():
model.finish_wandb()
continue

if (mode == "min" and score < best_score) or (mode == "max" and score > best_score):
print(f"current best {metric} score: {np.round(score, 3)}")
best_score = score
best_hyperparameters = hyperparameter

# Close this trial's run after all logging is done
if model.is_wandb_enabled():
model.finish_wandb()

if best_hyperparameters is None:
warnings.warn("all hpams lead to NaN respone. using last hpam combination.", stacklevel=2)
best_hyperparameters = hyperparameter
Expand Down
25 changes: 24 additions & 1 deletion drevalpy/models/DrugGNN/drug_gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from ...datasets.dataset import DrugResponseDataset, FeatureDataset
from ..drp_model import DRPModel
from ..lightning_metrics_mixin import RegressionMetricsMixin
from ..utils import load_and_select_gene_features


Expand Down Expand Up @@ -86,7 +87,7 @@ def forward(self, drug_graph, cell_features):
return out.view(-1)


class DrugGNNModule(pl.LightningModule):
class DrugGNNModule(RegressionMetricsMixin, pl.LightningModule):
"""The LightningModule for the DrugGNN model."""

def __init__(
Expand Down Expand Up @@ -115,6 +116,9 @@ def __init__(
)
self.criterion = nn.MSELoss()

# Initialize metrics storage for epoch-end R^2 and PCC computation
self._init_metrics_storage()

def forward(self, batch):
"""Forward pass of the module.

Expand All @@ -135,6 +139,10 @@ def training_step(self, batch, batch_idx):
outputs = self.model(drug_graph, cell_features)
loss = self.criterion(outputs, responses)
self.log("train_loss", loss, on_step=False, on_epoch=True, batch_size=responses.size(0))

# Store predictions and targets for epoch-end metrics via mixin
self._store_predictions(outputs, responses, is_training=True)

return loss

def validation_step(self, batch, batch_idx):
Expand All @@ -148,6 +156,9 @@ def validation_step(self, batch, batch_idx):
loss = self.criterion(outputs, responses)
self.log("val_loss", loss, on_step=False, on_epoch=True, batch_size=responses.size(0))

# Store predictions and targets for epoch-end metrics via mixin
self._store_predictions(outputs, responses, is_training=False)

def predict_step(self, batch, batch_idx, dataloader_idx=0):
"""A single prediction step.

Expand Down Expand Up @@ -252,6 +263,9 @@ def build_model(self, hyperparameters: dict[str, Any]) -> None:

:param hyperparameters: The hyperparameters.
"""
# Log hyperparameters to wandb if enabled
self.log_hyperparameters(hyperparameters)

self.hyperparameters = hyperparameters

def _loader_kwargs(self) -> dict[str, Any]:
Expand Down Expand Up @@ -326,11 +340,20 @@ def train(
**self._loader_kwargs(),
)

# Set up wandb logger if project is provided
loggers = []
if self.wandb_project is not None:
from pytorch_lightning.loggers import WandbLogger

logger = WandbLogger(project=self.wandb_project, log_model=False)
loggers.append(logger)

trainer = pl.Trainer(
max_epochs=self.hyperparameters.get("epochs", 100),
accelerator="auto",
devices="auto",
callbacks=[pl.callbacks.EarlyStopping(monitor="val_loss", mode="min", patience=5)] if val_loader else None,
logger=loggers if loggers else True, # Use default logger if no wandb
enable_progress_bar=True,
log_every_n_steps=int(self.hyperparameters.get("log_every_n_steps", 50)),
precision=self.hyperparameters.get("precision", 32),
Expand Down
4 changes: 4 additions & 0 deletions drevalpy/models/MOLIR/molir.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ def build_model(self, hyperparameters: dict[str, Any]) -> None:
:param hyperparameters: Custom hyperparameters for the model, includes mini_batch, layer dimensions (h_dim1,
h_dim2, h_dim3), learning_rate, dropout_rate, weight_decay, gamma, epochs, and margin.
"""
# Log hyperparameters to wandb if enabled
self.log_hyperparameters(hyperparameters)

self.hyperparameters = hyperparameters
self.selector = VarianceFeatureSelector(
view="gene_expression", k=hyperparameters.get("n_gene_expression_features", 1000)
Expand Down Expand Up @@ -125,6 +128,7 @@ def train(
cell_line_input=cell_line_input,
output_earlystopping=output_earlystopping,
model_checkpoint_dir=model_checkpoint_dir,
wandb_project=self.wandb_project,
)
else:
print(f"Not enough training data provided ({len(output)}), will predict on randomly initialized model.")
Expand Down
Loading
Loading