From a6fe5c6b03a50127961d793321aff62ab854ffe1 Mon Sep 17 00:00:00 2001 From: nictru Date: Wed, 14 Jan 2026 13:57:04 +0000 Subject: [PATCH 01/22] Add wandb logging --- drevalpy/experiment.py | 62 ++++++- drevalpy/models/DrugGNN/drug_gnn.py | 21 +++ drevalpy/models/MOLIR/molir.py | 4 + drevalpy/models/MOLIR/utils.py | 21 +++ drevalpy/models/PharmaFormer/pharmaformer.py | 11 ++ .../multiomics_neural_network.py | 7 + .../simple_neural_network.py | 7 + drevalpy/models/SimpleNeuralNetwork/utils.py | 22 +++ drevalpy/models/SuperFELTR/superfeltr.py | 5 + drevalpy/models/SuperFELTR/utils.py | 21 +++ drevalpy/models/drp_model.py | 93 +++++++++- poetry.lock | 162 +++++++++++++++++- pyproject.toml | 1 + 13 files changed, 427 insertions(+), 10 deletions(-) diff --git a/drevalpy/experiment.py b/drevalpy/experiment.py index b55681a2..9c830f4b 100644 --- a/drevalpy/experiment.py +++ b/drevalpy/experiment.py @@ -45,6 +45,7 @@ def drug_response_experiment( model_checkpoint_dir: str = "TEMPORARY", hyperparameter_tuning=True, final_model_on_full_data: bool = False, + wandb_project: str | None = None, ) -> None: """ Run the drug response prediction experiment. Save results to disc. @@ -97,6 +98,8 @@ def drug_response_experiment( :param final_model_on_full_data: if True, a final/production model is saved in the results directory. If hyperparameter_tuning is true, the final model is produced according to the hyperparameter tuning procedure which was evaluated in the nested cross validation. + :param wandb_project: if provided, enables wandb logging for all DRPModel instances throughout training. + All hyperparameters and metrics will be logged to the specified wandb project. :raises ValueError: if no cv splits are found """ # Default baseline model, needed for normalization @@ -190,6 +193,29 @@ def drug_response_experiment( model = model_class() + # Initialize wandb if project is provided (before hyperparameter tuning) + if wandb_project is not None: + run_name = f"{model_name}" + if drug_id is not None: + run_name += f"_{drug_id}" + run_name += f"_split_{split_index}" + + config = { + "model_name": model_name, + "drug_id": drug_id, + "split_index": split_index, + "test_mode": test_mode, + "dataset": response_data.dataset_name, + "n_cv_splits": n_cv_splits, + "hyperparameter_tuning": hyperparameter_tuning, + } + model.init_wandb( + project=wandb_project, + config=config, + name=run_name, + tags=[model_name, test_mode, response_data.dataset_name or "unknown"], + ) + if not os.path.isfile( prediction_file ): # if this split has not been run yet (or for a single drug model, this drug_id) @@ -213,6 +239,9 @@ def drug_response_experiment( print(f"Best hyperparameters: {best_hpams}") print("Training model on full train and validation set to predict test set") + + # Log best hyperparameters to wandb (they will be logged when build_model is called) + # The best hyperparameters will be logged via build_model -> log_hyperparameters # save best hyperparameters as json with open( hpam_save_path, @@ -259,6 +288,10 @@ def drug_response_experiment( encoding="utf-8", ) as f: best_hpams = json.load(f) + + # Finish wandb run for this split + if wandb_project is not None: + model.finish_wandb() if not is_baseline: if randomization_mode is not None: print(f"Randomization tests for {model_class.get_model_name()}") @@ -1057,7 +1090,15 @@ def train_and_evaluate( response_transformation=response_transformation, model_checkpoint_dir=model_checkpoint_dir, ) - return evaluate(validation_dataset, metric=[metric]) + results = evaluate(validation_dataset, metric=[metric]) + + # Log validation metrics to wandb if enabled + if hasattr(model, "wandb_run") and model.wandb_run is not None: + # Prefix metrics with "val_" to distinguish from training metrics + wandb_metrics = {f"val_{k}": v for k, v in results.items()} + model.log_metrics(wandb_metrics) + + return results def hpam_tune( @@ -1091,11 +1132,18 @@ def hpam_tune( if len(hpam_set) == 1: return hpam_set[0] + # Mark that we're in hyperparameter tuning phase + # This prevents updating wandb.config during tuning - we'll only log final best hyperparameters + model._in_hyperparameter_tuning = True + best_hyperparameters = None mode = get_mode(metric) best_score = float("inf") if mode == "min" else float("-inf") - for hyperparameter in hpam_set: + for trial_idx, hyperparameter in enumerate(hpam_set): print(f"Training model with hyperparameters: {hyperparameter}") + + # During hyperparameter tuning, don't update wandb config for each trial + # Instead, we'll log trial hyperparameters as metrics score = train_and_evaluate( model=model, hpams=hyperparameter, @@ -1111,11 +1159,21 @@ def hpam_tune( if np.isnan(score): continue + # Log trial hyperparameters and result to wandb if enabled + if hasattr(model, "wandb_run") and model.wandb_run is not None: + trial_metrics = {f"trial_{trial_idx}_{k}": v for k, v in hyperparameter.items()} + trial_metrics[f"trial_{trial_idx}_{metric}"] = score + model.log_metrics(trial_metrics) + if (mode == "min" and score < best_score) or (mode == "max" and score > best_score): print(f"current best {metric} score: {np.round(score, 3)}") best_score = score best_hyperparameters = hyperparameter + # Log best score so far to wandb if enabled + if hasattr(model, "wandb_run") and model.wandb_run is not None: + model.log_metrics({f"best_{metric}": best_score}) + if best_hyperparameters is None: warnings.warn("all hpams lead to NaN respone. using last hpam combination.", stacklevel=2) best_hyperparameters = hyperparameter diff --git a/drevalpy/models/DrugGNN/drug_gnn.py b/drevalpy/models/DrugGNN/drug_gnn.py index ca103ab4..20f45cc8 100644 --- a/drevalpy/models/DrugGNN/drug_gnn.py +++ b/drevalpy/models/DrugGNN/drug_gnn.py @@ -252,6 +252,9 @@ def build_model(self, hyperparameters: dict[str, Any]) -> None: :param hyperparameters: The hyperparameters. """ + # Log hyperparameters to wandb if enabled + self.log_hyperparameters(hyperparameters) + self.hyperparameters = hyperparameters def _loader_kwargs(self) -> dict[str, Any]: @@ -326,11 +329,29 @@ def train( **self._loader_kwargs(), ) + # Set up wandb logger if project is provided + loggers = [] + wandb_project = getattr(self, "wandb_project", None) + if wandb_project is not None: + from pytorch_lightning.loggers import WandbLogger + + import wandb + + if wandb.run is not None: + # Use existing wandb run + logger = WandbLogger(project=wandb_project, log_model=False) + loggers.append(logger) + else: + # If wandb is not initialized, create a new logger + logger = WandbLogger(project=wandb_project, log_model=False) + loggers.append(logger) + trainer = pl.Trainer( max_epochs=self.hyperparameters.get("epochs", 100), accelerator="auto", devices="auto", callbacks=[pl.callbacks.EarlyStopping(monitor="val_loss", mode="min", patience=5)] if val_loader else None, + logger=loggers if loggers else True, # Use default logger if no wandb enable_progress_bar=True, log_every_n_steps=int(self.hyperparameters.get("log_every_n_steps", 50)), precision=self.hyperparameters.get("precision", 32), diff --git a/drevalpy/models/MOLIR/molir.py b/drevalpy/models/MOLIR/molir.py index 844bcb23..5ccc450d 100644 --- a/drevalpy/models/MOLIR/molir.py +++ b/drevalpy/models/MOLIR/molir.py @@ -64,6 +64,9 @@ def build_model(self, hyperparameters: dict[str, Any]) -> None: :param hyperparameters: Custom hyperparameters for the model, includes mini_batch, layer dimensions (h_dim1, h_dim2, h_dim3), learning_rate, dropout_rate, weight_decay, gamma, epochs, and margin. """ + # Log hyperparameters to wandb if enabled + self.log_hyperparameters(hyperparameters) + self.hyperparameters = hyperparameters self.selector = VarianceFeatureSelector( view="gene_expression", k=hyperparameters.get("n_gene_expression_features", 1000) @@ -125,6 +128,7 @@ def train( cell_line_input=cell_line_input, output_earlystopping=output_earlystopping, model_checkpoint_dir=model_checkpoint_dir, + wandb_project=getattr(self, "wandb_project", None), ) else: print(f"Not enough training data provided ({len(output)}), will predict on randomly initialized model.") diff --git a/drevalpy/models/MOLIR/utils.py b/drevalpy/models/MOLIR/utils.py index f83eb5e2..cefc804f 100644 --- a/drevalpy/models/MOLIR/utils.py +++ b/drevalpy/models/MOLIR/utils.py @@ -370,6 +370,7 @@ def fit( output_earlystopping: DrugResponseDataset | None = None, patience: int = 5, model_checkpoint_dir: str = "checkpoints", + wandb_project: str | None = None, ) -> None: """ Trains the MOLIR model. @@ -377,11 +378,14 @@ def fit( First, the ranges for the triplet loss are determined using the standard deviation of the training responses. Then, the training and validation data loaders are created. The model is trained using the Lightning Trainer with an early stopping callback and patience of 5. + :param output_train: training dataset containing the response output :param cell_line_input: feature dataset containing the omics data of the cell lines :param output_earlystopping: early stopping dataset :param patience: for early stopping :param model_checkpoint_dir: directory to save the model checkpoints + :param wandb_project: optional wandb project name for logging. If provided, uses WandbLogger + for PyTorch Lightning training. """ self.positive_range, self.negative_range = make_ranges(output_train) @@ -407,9 +411,26 @@ def fit( save_weights_only=True, ) + # Set up wandb logger if project is provided + loggers = [] + if wandb_project is not None: + from pytorch_lightning.loggers import WandbLogger + + import wandb + + if wandb.run is not None: + # Use existing wandb run + logger = WandbLogger(project=wandb_project, log_model=False) + loggers.append(logger) + else: + # If wandb is not initialized, create a new logger + logger = WandbLogger(project=wandb_project, log_model=False) + loggers.append(logger) + # Initialize the Lightning trainer trainer = pl.Trainer( max_epochs=self.epochs, + logger=loggers if loggers else True, # Use default logger if no wandb callbacks=[ early_stop_callback, self.checkpoint_callback, diff --git a/drevalpy/models/PharmaFormer/pharmaformer.py b/drevalpy/models/PharmaFormer/pharmaformer.py index b6eaf647..439d6bd6 100644 --- a/drevalpy/models/PharmaFormer/pharmaformer.py +++ b/drevalpy/models/PharmaFormer/pharmaformer.py @@ -112,6 +112,9 @@ def build_model(self, hyperparameters: dict[str, Any]) -> None: :param hyperparameters: Model hyperparameters including gene_hidden_size, drug_hidden_size, feature_dim, nhead, num_layers, dim_feedforward, dropout, batch_size, lr, epochs, patience """ + # Log hyperparameters to wandb if enabled + self.log_hyperparameters(hyperparameters) + self.hyperparameters = hyperparameters # Model will be built in train() when we know the input dimensions @@ -239,6 +242,10 @@ def train( epoch_loss /= batch_count print(f"PharmaFormer: Epoch [{epoch + 1}/{self.hyperparameters['epochs']}] Training Loss: {epoch_loss:.4f}") + # Log training loss to wandb if enabled + if hasattr(self, "wandb_run") and self.wandb_run is not None: + self.log_metrics({"train_loss": epoch_loss}, step=epoch) + # Validation phase for early stopping self.model.eval() val_loss = 0.0 @@ -258,6 +265,10 @@ def train( val_loss /= val_batch_count print(f"PharmaFormer: Epoch [{epoch + 1}/{self.hyperparameters['epochs']}] Validation Loss: {val_loss:.4f}") + # Log validation loss to wandb if enabled + if hasattr(self, "wandb_run") and self.wandb_run is not None: + self.log_metrics({"val_loss": val_loss}, step=epoch) + # Checkpointing: Save the best model if val_loss < best_val_loss: best_val_loss = val_loss diff --git a/drevalpy/models/SimpleNeuralNetwork/multiomics_neural_network.py b/drevalpy/models/SimpleNeuralNetwork/multiomics_neural_network.py index ae58b172..59372c60 100644 --- a/drevalpy/models/SimpleNeuralNetwork/multiomics_neural_network.py +++ b/drevalpy/models/SimpleNeuralNetwork/multiomics_neural_network.py @@ -62,6 +62,9 @@ def build_model(self, hyperparameters: dict): :param hyperparameters: dictionary containing the hyperparameters units_per_layer, dropout_prob, and methylation_pca_components. """ + # Log hyperparameters to wandb if enabled + self.log_hyperparameters(hyperparameters) + self.hyperparameters = hyperparameters self.pca_ncomp = hyperparameters["methylation_pca_components"] @@ -124,6 +127,9 @@ def train( "ignore", message=".*does not have many workers which may be a bottleneck.*", ) + # Get wandb project from parent model if available + wandb_project = getattr(self, "wandb_project", None) + self.model.fit( output_train=output, cell_line_input=cell_line_input, @@ -139,6 +145,7 @@ def train( patience=5, num_workers=1, model_checkpoint_dir=model_checkpoint_dir, + wandb_project=wandb_project, ) def predict( diff --git a/drevalpy/models/SimpleNeuralNetwork/simple_neural_network.py b/drevalpy/models/SimpleNeuralNetwork/simple_neural_network.py index 43f2eb49..18d86fb5 100644 --- a/drevalpy/models/SimpleNeuralNetwork/simple_neural_network.py +++ b/drevalpy/models/SimpleNeuralNetwork/simple_neural_network.py @@ -51,6 +51,9 @@ def build_model(self, hyperparameters: dict): :param hyperparameters: includes units_per_layer and dropout_prob. """ + # Log hyperparameters to wandb if enabled + self.log_hyperparameters(hyperparameters) + self.hyperparameters = hyperparameters self.hyperparameters.setdefault("input_dim_gex", None) self.hyperparameters.setdefault("input_dim_fp", None) @@ -113,6 +116,9 @@ def train( print("Probably, your training dataset is small.") + # Get wandb project from parent model if available + wandb_project = getattr(self, "wandb_project", None) + self.model.fit( output_train=output, cell_line_input=cell_line_input, @@ -128,6 +134,7 @@ def train( patience=5, num_workers=1 if platform.system() == "Windows" else 8, model_checkpoint_dir=model_checkpoint_dir, + wandb_project=wandb_project, ) def predict( diff --git a/drevalpy/models/SimpleNeuralNetwork/utils.py b/drevalpy/models/SimpleNeuralNetwork/utils.py index 26b2790b..0b2c31d3 100644 --- a/drevalpy/models/SimpleNeuralNetwork/utils.py +++ b/drevalpy/models/SimpleNeuralNetwork/utils.py @@ -7,6 +7,7 @@ import pytorch_lightning as pl import torch from pytorch_lightning.callbacks import EarlyStopping, TQDMProgressBar +from pytorch_lightning.loggers import WandbLogger from torch import nn from torch.utils.data import DataLoader, Dataset @@ -150,11 +151,13 @@ def fit( patience=5, num_workers: int = 2, model_checkpoint_dir: str = "checkpoints", + wandb_project: str | None = None, ) -> None: """ Fits the model. First, the data is loaded using a DataLoader. Then, the model is trained using the Lightning Trainer. + :param output_train: Response values for training :param cell_line_input: Cell line features :param drug_input: Drug features @@ -166,6 +169,8 @@ def fit( :param patience: patience for early stopping, default is 5 :param num_workers: number of workers for the DataLoader, default is 2 :param model_checkpoint_dir: directory to save the model checkpoints + :param wandb_project: optional wandb project name for logging. If provided, uses WandbLogger + for PyTorch Lightning training. :raises ValueError: if drug_input is missing """ if trainer_params is None: @@ -233,6 +238,22 @@ def fit( trainer_params_copy = trainer_params.copy() del trainer_params_copy["progress_bar_refresh_rate"] + # Set up wandb logger if project is provided + loggers = [] + if wandb_project is not None: + # Use existing wandb run if available, otherwise create new logger + # The wandb run should already be initialized by DRPModel.init_wandb() + import wandb + + if wandb.run is not None: + # Use existing wandb run + logger = WandbLogger(project=wandb_project, log_model=False) + loggers.append(logger) + else: + # If wandb is not initialized, create a new logger + logger = WandbLogger(project=wandb_project, log_model=False) + loggers.append(logger) + # Initialize the Lightning trainer trainer = pl.Trainer( callbacks=[ @@ -240,6 +261,7 @@ def fit( self.checkpoint_callback, progress_bar, ], + logger=loggers if loggers else True, # Use default logger if no wandb default_root_dir=model_checkpoint_dir, devices=1, **trainer_params_copy, diff --git a/drevalpy/models/SuperFELTR/superfeltr.py b/drevalpy/models/SuperFELTR/superfeltr.py index 43268f4f..4e92866b 100644 --- a/drevalpy/models/SuperFELTR/superfeltr.py +++ b/drevalpy/models/SuperFELTR/superfeltr.py @@ -78,6 +78,9 @@ def build_model(self, hyperparameters) -> None: dropout_rate, weight_decay, out_dim_expr_encoder, out_dim_mutation_encoder, out_dim_cnv_encoder, epochs, variance thresholds for gene expression, mutation, and copy number variation, margin, and learning rate. """ + # Log hyperparameters to wandb if enabled + self.log_hyperparameters(hyperparameters) + self.hyperparameters = hyperparameters n_features = hyperparameters.get("n_features_per_view", 1000) @@ -133,6 +136,7 @@ def train( output_earlystopping=output_earlystopping, patience=5, model_checkpoint_dir=model_checkpoint_dir, + wandb_project=getattr(self, "wandb_project", None), ) encoders[omic_type] = SuperFELTEncoder.load_from_checkpoint(best_checkpoint.best_model_path) else: @@ -165,6 +169,7 @@ def train( output_earlystopping=output_earlystopping, patience=5, model_checkpoint_dir=model_checkpoint_dir, + wandb_project=getattr(self, "wandb_project", None), ) else: print("Not enough training data provided for SuperFELTR Regressor. Using random initialization.") diff --git a/drevalpy/models/SuperFELTR/utils.py b/drevalpy/models/SuperFELTR/utils.py index 4816302d..da232af1 100644 --- a/drevalpy/models/SuperFELTR/utils.py +++ b/drevalpy/models/SuperFELTR/utils.py @@ -294,11 +294,13 @@ def train_superfeltr_model( output_earlystopping: DrugResponseDataset | None = None, patience: int = 5, model_checkpoint_dir: str = "superfeltr_checkpoints", + wandb_project: str | None = None, ) -> pl.callbacks.ModelCheckpoint: """ Trains one encoder or the regressor. First, the dataset and loaders are created. Then, the model is trained with the Lightning trainer. + :param model: either one of the encoders or the regressor :param hpams: hyperparameters for the model :param output_train: response data for training @@ -306,6 +308,8 @@ def train_superfeltr_model( :param output_earlystopping: response data for early stopping :param patience: for early stopping, defaults to 5 :param model_checkpoint_dir: directory to save the model checkpoints + :param wandb_project: optional wandb project name for logging. If provided, uses WandbLogger + for PyTorch Lightning training. :returns: checkpoint callback with the best model :raises ValueError: if the epochs and mini_batch are not integers """ @@ -329,9 +333,26 @@ def train_superfeltr_model( mode="min", save_top_k=1, ) + # Set up wandb logger if project is provided + loggers = [] + if wandb_project is not None: + from pytorch_lightning.loggers import WandbLogger + + import wandb + + if wandb.run is not None: + # Use existing wandb run + logger = WandbLogger(project=wandb_project, log_model=False) + loggers.append(logger) + else: + # If wandb is not initialized, create a new logger + logger = WandbLogger(project=wandb_project, log_model=False) + loggers.append(logger) + # Initialize the Lightning trainer trainer = pl.Trainer( max_epochs=hpams["epochs"], + logger=loggers if loggers else True, # Use default logger if no wandb callbacks=[ early_stop_callback, checkpoint_callback, diff --git a/drevalpy/models/drp_model.py b/drevalpy/models/drp_model.py index 8c60d39e..3b5e1460 100644 --- a/drevalpy/models/drp_model.py +++ b/drevalpy/models/drp_model.py @@ -15,6 +15,8 @@ import yaml from sklearn.model_selection import ParameterGrid +import wandb + from ..datasets.dataset import DrugResponseDataset, FeatureDataset from ..pipeline_function import pipeline_function @@ -33,6 +35,90 @@ class DRPModel(ABC): # Then, the model is trained per drug is_single_drug_model = False + def __init__(self): + """Initialize the DRPModel instance.""" + self.wandb_project: str | None = None + self.wandb_run: Any = None + self.wandb_config: dict[str, Any] | None = None + self.hyperparameters: dict[str, Any] | None = None + self._in_hyperparameter_tuning: bool = False # Flag to track if we're in hyperparameter tuning + + def init_wandb( + self, + project: str, + config: dict[str, Any] | None = None, + name: str | None = None, + tags: list[str] | None = None, + reinit: bool = True, + ) -> None: + """ + Initialize wandb logging for this model instance. + + :param project: wandb project name + :param config: dictionary of configuration to log (e.g., hyperparameters, dataset info) + :param name: run name (defaults to model name) + :param tags: list of tags for the run + :param reinit: whether to reinitialize wandb if already initialized + """ + self.wandb_project = project + self.wandb_config = config or {} + + run_name = name or self.get_model_name() + wandb.init( + project=project, + config=self.wandb_config, + name=run_name, + tags=tags, + reinit=reinit, + ) + self.wandb_run = wandb.run + + def log_hyperparameters(self, hyperparameters: dict[str, Any]) -> None: + """ + Log hyperparameters to wandb. + + This method is called automatically by build_model when wandb is enabled. + Subclasses can override this to add additional hyperparameter logging. + + During hyperparameter tuning, config updates are skipped to avoid overwriting. + Only the final best hyperparameters are logged to wandb.config. + + :param hyperparameters: dictionary of hyperparameters to log + """ + if self.wandb_run is None: + return + + self.hyperparameters = hyperparameters + # Only update wandb.config if we're not in hyperparameter tuning phase + # During tuning, trial hyperparameters are logged as metrics instead + if not getattr(self, "_in_hyperparameter_tuning", False): + wandb.config.update(hyperparameters) + + def log_metrics(self, metrics: dict[str, float], step: int | None = None) -> None: + """ + Log metrics to wandb. + + Subclasses can call this method to log custom metrics during training. + + :param metrics: dictionary of metric names to values + :param step: optional step number for the metrics + """ + if self.wandb_run is None: + return + + if step is not None: + wandb.log(metrics, step=step) + else: + wandb.log(metrics) + + def finish_wandb(self) -> None: + """Finish the wandb run. Call this when training is complete.""" + if self.wandb_run is None: + return + + wandb.finish() + self.wandb_run = None + @classmethod @abstractmethod @pipeline_function @@ -97,11 +183,16 @@ def build_model(self, hyperparameters: dict[str, Any]) -> None: """ Builds the model, for models that use hyperparameters. + Subclasses should call self.log_hyperparameters(hyperparameters) at the beginning + of this method to ensure hyperparameters are logged to wandb if enabled. + :param hyperparameters: hyperparameters for the model Example:: - self.model = ElasticNet(alpha=hyperparameters["alpha"], l1_ratio=hyperparameters["l1_ratio"]) + def build_model(self, hyperparameters: dict[str, Any]) -> None: + self.log_hyperparameters(hyperparameters) # Log to wandb + self.model = ElasticNet(alpha=hyperparameters["alpha"], l1_ratio=hyperparameters["l1_ratio"]) """ @pipeline_function diff --git a/poetry.lock b/poetry.lock index e2292254..6d0fed1c 100644 --- a/poetry.lock +++ b/poetry.lock @@ -189,7 +189,6 @@ description = "Reusable constraint types to use with typing.Annotated" optional = false python-versions = ">=3.8" groups = ["main"] -markers = "extra == \"multiprocessing\"" files = [ {file = "annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53"}, {file = "annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89"}, @@ -688,7 +687,6 @@ files = [ {file = "click-8.3.1-py3-none-any.whl", hash = "sha256:981153a64e25f12d547d3426c367a4857371575ee7ad18df2a6183ab0545b2a6"}, {file = "click-8.3.1.tar.gz", hash = "sha256:12ff4785d337a1bb490bb7e9c2b1ee5da3112e94a8622f26a6c77f5d2fc6842a"}, ] -markers = {main = "extra == \"multiprocessing\""} [package.dependencies] colorama = {version = "*", markers = "platform_system == \"Windows\""} @@ -1452,6 +1450,40 @@ test-downstream = ["aiobotocore (>=2.5.4,<3.0.0)", "dask[dataframe,test]", "moto test-full = ["adlfs", "aiohttp (!=4.0.0a0,!=4.0.0a1)", "backports-zstd ; python_version < \"3.14\"", "cloudpickle", "dask", "distributed", "dropbox", "dropboxdrivefs", "fastparquet", "fusepy", "gcsfs", "jinja2", "kerchunk", "libarchive-c", "lz4", "notebook", "numpy", "ocifs", "pandas", "panel", "paramiko", "pyarrow", "pyarrow (>=1)", "pyftpdlib", "pygit2", "pytest", "pytest-asyncio (!=0.22.0)", "pytest-benchmark", "pytest-cov", "pytest-mock", "pytest-recording", "pytest-rerunfailures", "python-snappy", "requests", "smbprotocol", "tqdm", "urllib3", "zarr"] tqdm = ["tqdm"] +[[package]] +name = "gitdb" +version = "4.0.12" +description = "Git Object Database" +optional = false +python-versions = ">=3.7" +groups = ["main"] +files = [ + {file = "gitdb-4.0.12-py3-none-any.whl", hash = "sha256:67073e15955400952c6565cc3e707c554a4eea2e428946f7a4c162fab9bd9bcf"}, + {file = "gitdb-4.0.12.tar.gz", hash = "sha256:5ef71f855d191a3326fcfbc0d5da835f26b13fbcba60c32c21091c349ffdb571"}, +] + +[package.dependencies] +smmap = ">=3.0.1,<6" + +[[package]] +name = "gitpython" +version = "3.1.46" +description = "GitPython is a Python library used to interact with Git repositories" +optional = false +python-versions = ">=3.7" +groups = ["main"] +files = [ + {file = "gitpython-3.1.46-py3-none-any.whl", hash = "sha256:79812ed143d9d25b6d176a10bb511de0f9c67b1fa641d82097b0ab90398a2058"}, + {file = "gitpython-3.1.46.tar.gz", hash = "sha256:400124c7d0ef4ea03f7310ac2fbf7151e09ff97f2a3288d64a440c584a29c37f"}, +] + +[package.dependencies] +gitdb = ">=4.0.1,<5" + +[package.extras] +doc = ["sphinx (>=7.1.2,<7.2)", "sphinx-autodoc-typehints", "sphinx_rtd_theme"] +test = ["coverage[toml]", "ddt (>=1.1.1,!=1.4.3)", "mock ; python_version < \"3.8\"", "mypy (==1.18.2) ; python_version >= \"3.9\"", "pre-commit", "pytest (>=7.3.1)", "pytest-cov", "pytest-instafail", "pytest-mock", "pytest-sugar", "typing-extensions ; python_version < \"3.11\""] + [[package]] name = "h11" version = "0.16.0" @@ -3471,7 +3503,6 @@ description = "" optional = false python-versions = ">=3.9" groups = ["main"] -markers = "extra == \"multiprocessing\"" files = [ {file = "protobuf-6.33.4-cp310-abi3-win32.whl", hash = "sha256:918966612c8232fc6c24c78e1cd89784307f5814ad7506c308ee3cf86662850d"}, {file = "protobuf-6.33.4-cp310-abi3-win_amd64.whl", hash = "sha256:8f11ffae31ec67fc2554c2ef891dcb561dae9a2a3ed941f9e134c2db06657dbc"}, @@ -3613,7 +3644,6 @@ description = "Data validation using Python type hints" optional = false python-versions = ">=3.9" groups = ["main"] -markers = "extra == \"multiprocessing\"" files = [ {file = "pydantic-2.12.5-py3-none-any.whl", hash = "sha256:e561593fccf61e8a20fc46dfc2dfe075b8be7d0188df33f221ad1f0139180f9d"}, {file = "pydantic-2.12.5.tar.gz", hash = "sha256:4d351024c75c0f085a9febbb665ce8c0c6ec5d30e903bdb6394b7ede26aebb49"}, @@ -3636,7 +3666,6 @@ description = "Core functionality for Pydantic validation and serialization" optional = false python-versions = ">=3.9" groups = ["main"] -markers = "extra == \"multiprocessing\"" files = [ {file = "pydantic_core-2.41.5-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:77b63866ca88d804225eaa4af3e664c5faf3568cea95360d21f4725ab6e07146"}, {file = "pydantic_core-2.41.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:dfa8a0c812ac681395907e71e1274819dec685fec28273a28905df579ef137e2"}, @@ -4660,6 +4689,69 @@ files = [ cryptography = ">=2.0" jeepney = ">=0.6" +[[package]] +name = "sentry-sdk" +version = "2.49.0" +description = "Python client for Sentry (https://sentry.io)" +optional = false +python-versions = ">=3.6" +groups = ["main"] +files = [ + {file = "sentry_sdk-2.49.0-py2.py3-none-any.whl", hash = "sha256:6ea78499133874445a20fe9c826c9e960070abeb7ae0cdf930314ab16bb97aa0"}, + {file = "sentry_sdk-2.49.0.tar.gz", hash = "sha256:c1878599cde410d481c04ef50ee3aedd4f600e4d0d253f4763041e468b332c30"}, +] + +[package.dependencies] +certifi = "*" +urllib3 = ">=1.26.11" + +[package.extras] +aiohttp = ["aiohttp (>=3.5)"] +anthropic = ["anthropic (>=0.16)"] +arq = ["arq (>=0.23)"] +asyncpg = ["asyncpg (>=0.23)"] +beam = ["apache-beam (>=2.12)"] +bottle = ["bottle (>=0.12.13)"] +celery = ["celery (>=3)"] +celery-redbeat = ["celery-redbeat (>=2)"] +chalice = ["chalice (>=1.16.0)"] +clickhouse-driver = ["clickhouse-driver (>=0.2.0)"] +django = ["django (>=1.8)"] +falcon = ["falcon (>=1.4)"] +fastapi = ["fastapi (>=0.79.0)"] +flask = ["blinker (>=1.1)", "flask (>=0.11)", "markupsafe"] +google-genai = ["google-genai (>=1.29.0)"] +grpcio = ["grpcio (>=1.21.1)", "protobuf (>=3.8.0)"] +http2 = ["httpcore[http2] (==1.*)"] +httpx = ["httpx (>=0.16.0)"] +huey = ["huey (>=2)"] +huggingface-hub = ["huggingface_hub (>=0.22)"] +langchain = ["langchain (>=0.0.210)"] +langgraph = ["langgraph (>=0.6.6)"] +launchdarkly = ["launchdarkly-server-sdk (>=9.8.0)"] +litellm = ["litellm (>=1.77.5)"] +litestar = ["litestar (>=2.0.0)"] +loguru = ["loguru (>=0.5)"] +mcp = ["mcp (>=1.15.0)"] +openai = ["openai (>=1.0.0)", "tiktoken (>=0.3.0)"] +openfeature = ["openfeature-sdk (>=0.7.1)"] +opentelemetry = ["opentelemetry-distro (>=0.35b0)"] +opentelemetry-experimental = ["opentelemetry-distro"] +opentelemetry-otlp = ["opentelemetry-distro[otlp] (>=0.35b0)"] +pure-eval = ["asttokens", "executing", "pure_eval"] +pydantic-ai = ["pydantic-ai (>=1.0.0)"] +pymongo = ["pymongo (>=3.1)"] +pyspark = ["pyspark (>=2.4.4)"] +quart = ["blinker (>=1.1)", "quart (>=0.16.1)"] +rq = ["rq (>=0.6)"] +sanic = ["sanic (>=0.8)"] +sqlalchemy = ["sqlalchemy (>=1.2)"] +starlette = ["starlette (>=0.19.1)"] +starlite = ["starlite (>=1.48)"] +statsig = ["statsig (>=0.55.3)"] +tornado = ["tornado (>=6)"] +unleash = ["UnleashClient (>=6.0.1)"] + [[package]] name = "setuptools" version = "80.9.0" @@ -4706,6 +4798,18 @@ files = [ {file = "six-1.17.0.tar.gz", hash = "sha256:ff70335d468e7eb6ec65b95b99d3a2836546063f63acc5171de367e834932a81"}, ] +[[package]] +name = "smmap" +version = "5.0.2" +description = "A pure Python implementation of a sliding window memory map manager" +optional = false +python-versions = ">=3.7" +groups = ["main"] +files = [ + {file = "smmap-5.0.2-py3-none-any.whl", hash = "sha256:b30115f0def7d7531d22a0fb6502488d879e75b260a9db4d0819cfb25403af5e"}, + {file = "smmap-5.0.2.tar.gz", hash = "sha256:26ea65a03958fa0c8a1c7e8c7a58fdc77221b8910f6be2131affade476898ad5"}, +] + [[package]] name = "snowballstemmer" version = "3.0.1" @@ -5379,7 +5483,6 @@ description = "Runtime typing introspection tools" optional = false python-versions = ">=3.9" groups = ["main"] -markers = "extra == \"multiprocessing\"" files = [ {file = "typing_inspection-0.4.2-py3-none-any.whl", hash = "sha256:4ed1cacbdc298c220f1bd249ed5287caa16f34d44ef4e9c3d0cbad5b521545e7"}, {file = "typing_inspection-0.4.2.tar.gz", hash = "sha256:ba561c48a67c5958007083d386c3295464928b01faa735ab8547c5692e87f464"}, @@ -5458,6 +5561,51 @@ platformdirs = ">=3.9.1,<5" docs = ["furo (>=2023.7.26)", "proselint (>=0.13)", "sphinx (>=7.1.2,!=7.3)", "sphinx-argparse (>=0.4)", "sphinxcontrib-towncrier (>=0.2.1a0)", "towncrier (>=23.6)"] test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess (>=1)", "flaky (>=3.7)", "packaging (>=23.1)", "pytest (>=7.4)", "pytest-env (>=0.8.2)", "pytest-freezer (>=0.4.8) ; platform_python_implementation == \"PyPy\" or platform_python_implementation == \"GraalVM\" or platform_python_implementation == \"CPython\" and sys_platform == \"win32\" and python_version >= \"3.13\"", "pytest-mock (>=3.11.1)", "pytest-randomly (>=3.12)", "pytest-timeout (>=2.1)", "setuptools (>=68)", "time-machine (>=2.10) ; platform_python_implementation == \"CPython\""] +[[package]] +name = "wandb" +version = "0.24.0" +description = "A CLI and library for interacting with the Weights & Biases API." +optional = false +python-versions = ">=3.8" +groups = ["main"] +files = [ + {file = "wandb-0.24.0-py3-none-macosx_12_0_arm64.whl", hash = "sha256:aa9777398ff4b0f04c41359f7d1b95b5d656cb12c37c63903666799212e50299"}, + {file = "wandb-0.24.0-py3-none-macosx_12_0_x86_64.whl", hash = "sha256:0423fbd58c3926949724feae8aab89d20c68846f9f4f596b80f9ffe1fc298130"}, + {file = "wandb-0.24.0-py3-none-manylinux_2_28_aarch64.whl", hash = "sha256:2b25fc0c123daac97ed32912ac55642c65013cc6e3a898e88ca2d917fc8eadc0"}, + {file = "wandb-0.24.0-py3-none-manylinux_2_28_x86_64.whl", hash = "sha256:9485344b4667944b5b77294185bae8469cfa4074869bec0e74f54f8492234cc2"}, + {file = "wandb-0.24.0-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:51b2b9a9d7d6b35640f12a46a48814fd4516807ad44f586b819ed6560f8de1fd"}, + {file = "wandb-0.24.0-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:11f7e7841f31eff82c82a677988889ad3aa684c6de61ff82145333b5214ec860"}, + {file = "wandb-0.24.0-py3-none-win32.whl", hash = "sha256:42af348998b00d4309ae790c5374040ac6cc353ab21567f4e29c98c9376dee8e"}, + {file = "wandb-0.24.0-py3-none-win_amd64.whl", hash = "sha256:32604eddcd362e1ed4a2e2ce5f3a239369c4a193af223f3e66603481ac91f336"}, + {file = "wandb-0.24.0-py3-none-win_arm64.whl", hash = "sha256:e0f2367552abfca21b0f3a03405fbf48f1e14de9846e70f73c6af5da57afd8ef"}, + {file = "wandb-0.24.0.tar.gz", hash = "sha256:4715a243b3d460b6434b9562e935dfd9dfdf5d6e428cfb4c3e7ce4fd44460ab3"}, +] + +[package.dependencies] +click = ">=8.0.1" +gitpython = ">=1.0.0,<3.1.29 || >3.1.29" +packaging = "*" +platformdirs = "*" +protobuf = {version = ">=3.19.0,<4.21.0 || >4.21.0,<5.28.0 || >5.28.0,<7", markers = "python_version > \"3.9\" or sys_platform != \"linux\""} +pydantic = "<3" +pyyaml = "*" +requests = ">=2.0.0,<3" +sentry-sdk = ">=2.0.0" +typing-extensions = ">=4.8,<5" + +[package.extras] +aws = ["boto3", "botocore (>=1.5.76)"] +azure = ["azure-identity", "azure-storage-blob"] +gcp = ["google-cloud-storage"] +importers = ["filelock", "mlflow", "polars (<=1.2.1)", "rich", "tenacity"] +kubeflow = ["google-cloud-storage", "kubernetes", "minio", "sh"] +launch = ["awscli", "azure-containerregistry", "azure-identity", "azure-storage-blob", "boto3", "botocore (>=1.5.76)", "chardet", "google-auth", "google-cloud-aiplatform", "google-cloud-artifact-registry", "google-cloud-compute", "google-cloud-storage", "iso8601", "jsonschema", "kubernetes", "kubernetes-asyncio", "nbconvert", "nbformat", "optuna", "pydantic", "pyyaml (>=6.0.0)", "tomli", "tornado (>=6.5.0) ; python_version >= \"3.9\"", "typing-extensions"] +media = ["bokeh", "imageio (>=2.28.1)", "moviepy (>=1.0.0)", "numpy", "pillow", "plotly (>=5.18.0)", "rdkit", "soundfile"] +models = ["cloudpickle"] +perf = ["orjson"] +sweeps = ["sweeps (>=0.2.0)"] +workspaces = ["wandb-workspaces"] + [[package]] name = "watchfiles" version = "1.1.1" @@ -6164,4 +6312,4 @@ multiprocessing = ["pydantic", "ray"] [metadata] lock-version = "2.1" python-versions = ">=3.11,<3.14" -content-hash = "cb7bf71d59b895c0cfbe611d73c357cf78478175d923b41ae5a3565aaf47437e" +content-hash = "b48d33e2c6e66c3fa8aa5b42db423e8577a44e19090b766f65051f4b9587dde4" diff --git a/pyproject.toml b/pyproject.toml index 02886951..17a62ab4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,6 +49,7 @@ toml = {version = "^0.10.2"} poetry = "^2.0.1" starlette = ">=0.49.1" pydantic = { version = ">=2.5", optional = true } +wandb = "^0.24.0" [tool.poetry.requires-plugins] poetry-plugin-export = ">=1.8" From 249da90859c719c0a438732a8e57b708d4ddd6b9 Mon Sep 17 00:00:00 2001 From: nictru Date: Wed, 14 Jan 2026 14:03:17 +0000 Subject: [PATCH 02/22] Update deprecated re-init option --- drevalpy/models/drp_model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/drevalpy/models/drp_model.py b/drevalpy/models/drp_model.py index 3b5e1460..54e12408 100644 --- a/drevalpy/models/drp_model.py +++ b/drevalpy/models/drp_model.py @@ -49,7 +49,7 @@ def init_wandb( config: dict[str, Any] | None = None, name: str | None = None, tags: list[str] | None = None, - reinit: bool = True, + finish_previous: bool = True, ) -> None: """ Initialize wandb logging for this model instance. @@ -58,7 +58,7 @@ def init_wandb( :param config: dictionary of configuration to log (e.g., hyperparameters, dataset info) :param name: run name (defaults to model name) :param tags: list of tags for the run - :param reinit: whether to reinitialize wandb if already initialized + :param finish_previous: whether to finish any existing wandb run before starting a new one """ self.wandb_project = project self.wandb_config = config or {} @@ -69,7 +69,7 @@ def init_wandb( config=self.wandb_config, name=run_name, tags=tags, - reinit=reinit, + finish_previous=finish_previous, ) self.wandb_run = wandb.run From 9c3b6aea592343387433916164060a3444b31b82 Mon Sep 17 00:00:00 2001 From: nictru Date: Wed, 14 Jan 2026 14:04:10 +0000 Subject: [PATCH 03/22] Improve gitignore --- .gitignore | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/.gitignore b/.gitignore index 0de5ab96..5e873b1f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,21 +1,12 @@ # Data -data/cell_line_input -data/response_output -data/mapping -data/GDSC1 -data/GDSC2 -data/CCLE -data/TOYv1 -data/TOYv2 -data/CTRPv1 -data/CTRPv2 -data/meta -data/BeatAML2 -data/PDX_Bruna +data/ # Results directory is created when running the demo notebook results/ +# Wandb directory is created when running the benchmark with wandb +wandb/ + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] From d74815025347947692b8329f3bca70a70981543c Mon Sep 17 00:00:00 2001 From: nictru Date: Wed, 14 Jan 2026 14:13:05 +0000 Subject: [PATCH 04/22] Simplify wandb implementation --- drevalpy/experiment.py | 6 ++-- drevalpy/models/DrugGNN/drug_gnn.py | 15 ++------- drevalpy/models/MOLIR/molir.py | 2 +- drevalpy/models/MOLIR/utils.py | 12 ++----- drevalpy/models/PharmaFormer/pharmaformer.py | 4 +-- .../multiomics_neural_network.py | 2 +- .../simple_neural_network.py | 2 +- drevalpy/models/SimpleNeuralNetwork/utils.py | 16 +++------- drevalpy/models/SuperFELTR/superfeltr.py | 4 +-- drevalpy/models/SuperFELTR/utils.py | 12 ++----- drevalpy/models/drp_model.py | 32 ++++++++++++++++--- 11 files changed, 49 insertions(+), 58 deletions(-) diff --git a/drevalpy/experiment.py b/drevalpy/experiment.py index 9c830f4b..48e4a975 100644 --- a/drevalpy/experiment.py +++ b/drevalpy/experiment.py @@ -1093,7 +1093,7 @@ def train_and_evaluate( results = evaluate(validation_dataset, metric=[metric]) # Log validation metrics to wandb if enabled - if hasattr(model, "wandb_run") and model.wandb_run is not None: + if model.is_wandb_enabled(): # Prefix metrics with "val_" to distinguish from training metrics wandb_metrics = {f"val_{k}": v for k, v in results.items()} model.log_metrics(wandb_metrics) @@ -1160,7 +1160,7 @@ def hpam_tune( continue # Log trial hyperparameters and result to wandb if enabled - if hasattr(model, "wandb_run") and model.wandb_run is not None: + if model.is_wandb_enabled(): trial_metrics = {f"trial_{trial_idx}_{k}": v for k, v in hyperparameter.items()} trial_metrics[f"trial_{trial_idx}_{metric}"] = score model.log_metrics(trial_metrics) @@ -1171,7 +1171,7 @@ def hpam_tune( best_hyperparameters = hyperparameter # Log best score so far to wandb if enabled - if hasattr(model, "wandb_run") and model.wandb_run is not None: + if model.is_wandb_enabled(): model.log_metrics({f"best_{metric}": best_score}) if best_hyperparameters is None: diff --git a/drevalpy/models/DrugGNN/drug_gnn.py b/drevalpy/models/DrugGNN/drug_gnn.py index 20f45cc8..d8e2bc39 100644 --- a/drevalpy/models/DrugGNN/drug_gnn.py +++ b/drevalpy/models/DrugGNN/drug_gnn.py @@ -331,20 +331,11 @@ def train( # Set up wandb logger if project is provided loggers = [] - wandb_project = getattr(self, "wandb_project", None) - if wandb_project is not None: + if self.wandb_project is not None: from pytorch_lightning.loggers import WandbLogger - import wandb - - if wandb.run is not None: - # Use existing wandb run - logger = WandbLogger(project=wandb_project, log_model=False) - loggers.append(logger) - else: - # If wandb is not initialized, create a new logger - logger = WandbLogger(project=wandb_project, log_model=False) - loggers.append(logger) + logger = WandbLogger(project=self.wandb_project, log_model=False) + loggers.append(logger) trainer = pl.Trainer( max_epochs=self.hyperparameters.get("epochs", 100), diff --git a/drevalpy/models/MOLIR/molir.py b/drevalpy/models/MOLIR/molir.py index 5ccc450d..62e76d87 100644 --- a/drevalpy/models/MOLIR/molir.py +++ b/drevalpy/models/MOLIR/molir.py @@ -128,7 +128,7 @@ def train( cell_line_input=cell_line_input, output_earlystopping=output_earlystopping, model_checkpoint_dir=model_checkpoint_dir, - wandb_project=getattr(self, "wandb_project", None), + wandb_project=self.wandb_project, ) else: print(f"Not enough training data provided ({len(output)}), will predict on randomly initialized model.") diff --git a/drevalpy/models/MOLIR/utils.py b/drevalpy/models/MOLIR/utils.py index cefc804f..d404e27f 100644 --- a/drevalpy/models/MOLIR/utils.py +++ b/drevalpy/models/MOLIR/utils.py @@ -416,16 +416,8 @@ def fit( if wandb_project is not None: from pytorch_lightning.loggers import WandbLogger - import wandb - - if wandb.run is not None: - # Use existing wandb run - logger = WandbLogger(project=wandb_project, log_model=False) - loggers.append(logger) - else: - # If wandb is not initialized, create a new logger - logger = WandbLogger(project=wandb_project, log_model=False) - loggers.append(logger) + logger = WandbLogger(project=wandb_project, log_model=False) + loggers.append(logger) # Initialize the Lightning trainer trainer = pl.Trainer( diff --git a/drevalpy/models/PharmaFormer/pharmaformer.py b/drevalpy/models/PharmaFormer/pharmaformer.py index 439d6bd6..39268989 100644 --- a/drevalpy/models/PharmaFormer/pharmaformer.py +++ b/drevalpy/models/PharmaFormer/pharmaformer.py @@ -243,7 +243,7 @@ def train( print(f"PharmaFormer: Epoch [{epoch + 1}/{self.hyperparameters['epochs']}] Training Loss: {epoch_loss:.4f}") # Log training loss to wandb if enabled - if hasattr(self, "wandb_run") and self.wandb_run is not None: + if self.is_wandb_enabled(): self.log_metrics({"train_loss": epoch_loss}, step=epoch) # Validation phase for early stopping @@ -266,7 +266,7 @@ def train( print(f"PharmaFormer: Epoch [{epoch + 1}/{self.hyperparameters['epochs']}] Validation Loss: {val_loss:.4f}") # Log validation loss to wandb if enabled - if hasattr(self, "wandb_run") and self.wandb_run is not None: + if self.is_wandb_enabled(): self.log_metrics({"val_loss": val_loss}, step=epoch) # Checkpointing: Save the best model diff --git a/drevalpy/models/SimpleNeuralNetwork/multiomics_neural_network.py b/drevalpy/models/SimpleNeuralNetwork/multiomics_neural_network.py index 59372c60..7c6865b3 100644 --- a/drevalpy/models/SimpleNeuralNetwork/multiomics_neural_network.py +++ b/drevalpy/models/SimpleNeuralNetwork/multiomics_neural_network.py @@ -128,7 +128,7 @@ def train( message=".*does not have many workers which may be a bottleneck.*", ) # Get wandb project from parent model if available - wandb_project = getattr(self, "wandb_project", None) + wandb_project = self.wandb_project self.model.fit( output_train=output, diff --git a/drevalpy/models/SimpleNeuralNetwork/simple_neural_network.py b/drevalpy/models/SimpleNeuralNetwork/simple_neural_network.py index 18d86fb5..5e0e054b 100644 --- a/drevalpy/models/SimpleNeuralNetwork/simple_neural_network.py +++ b/drevalpy/models/SimpleNeuralNetwork/simple_neural_network.py @@ -117,7 +117,7 @@ def train( print("Probably, your training dataset is small.") # Get wandb project from parent model if available - wandb_project = getattr(self, "wandb_project", None) + wandb_project = self.wandb_project self.model.fit( output_train=output, diff --git a/drevalpy/models/SimpleNeuralNetwork/utils.py b/drevalpy/models/SimpleNeuralNetwork/utils.py index 0b2c31d3..92bb03e6 100644 --- a/drevalpy/models/SimpleNeuralNetwork/utils.py +++ b/drevalpy/models/SimpleNeuralNetwork/utils.py @@ -239,20 +239,12 @@ def fit( del trainer_params_copy["progress_bar_refresh_rate"] # Set up wandb logger if project is provided + # Note: This method receives wandb_project as parameter, but the model instance + # should have wandb already initialized via DRPModel.init_wandb() loggers = [] if wandb_project is not None: - # Use existing wandb run if available, otherwise create new logger - # The wandb run should already be initialized by DRPModel.init_wandb() - import wandb - - if wandb.run is not None: - # Use existing wandb run - logger = WandbLogger(project=wandb_project, log_model=False) - loggers.append(logger) - else: - # If wandb is not initialized, create a new logger - logger = WandbLogger(project=wandb_project, log_model=False) - loggers.append(logger) + logger = WandbLogger(project=wandb_project, log_model=False) + loggers.append(logger) # Initialize the Lightning trainer trainer = pl.Trainer( diff --git a/drevalpy/models/SuperFELTR/superfeltr.py b/drevalpy/models/SuperFELTR/superfeltr.py index 4e92866b..59da1e1d 100644 --- a/drevalpy/models/SuperFELTR/superfeltr.py +++ b/drevalpy/models/SuperFELTR/superfeltr.py @@ -136,7 +136,7 @@ def train( output_earlystopping=output_earlystopping, patience=5, model_checkpoint_dir=model_checkpoint_dir, - wandb_project=getattr(self, "wandb_project", None), + wandb_project=self.wandb_project, ) encoders[omic_type] = SuperFELTEncoder.load_from_checkpoint(best_checkpoint.best_model_path) else: @@ -169,7 +169,7 @@ def train( output_earlystopping=output_earlystopping, patience=5, model_checkpoint_dir=model_checkpoint_dir, - wandb_project=getattr(self, "wandb_project", None), + wandb_project=self.wandb_project, ) else: print("Not enough training data provided for SuperFELTR Regressor. Using random initialization.") diff --git a/drevalpy/models/SuperFELTR/utils.py b/drevalpy/models/SuperFELTR/utils.py index da232af1..5d5aee93 100644 --- a/drevalpy/models/SuperFELTR/utils.py +++ b/drevalpy/models/SuperFELTR/utils.py @@ -338,16 +338,8 @@ def train_superfeltr_model( if wandb_project is not None: from pytorch_lightning.loggers import WandbLogger - import wandb - - if wandb.run is not None: - # Use existing wandb run - logger = WandbLogger(project=wandb_project, log_model=False) - loggers.append(logger) - else: - # If wandb is not initialized, create a new logger - logger = WandbLogger(project=wandb_project, log_model=False) - loggers.append(logger) + logger = WandbLogger(project=wandb_project, log_model=False) + loggers.append(logger) # Initialize the Lightning trainer trainer = pl.Trainer( diff --git a/drevalpy/models/drp_model.py b/drevalpy/models/drp_model.py index 54e12408..99b574e8 100644 --- a/drevalpy/models/drp_model.py +++ b/drevalpy/models/drp_model.py @@ -85,15 +85,39 @@ def log_hyperparameters(self, hyperparameters: dict[str, Any]) -> None: :param hyperparameters: dictionary of hyperparameters to log """ - if self.wandb_run is None: + if not self.is_wandb_enabled(): return self.hyperparameters = hyperparameters # Only update wandb.config if we're not in hyperparameter tuning phase # During tuning, trial hyperparameters are logged as metrics instead - if not getattr(self, "_in_hyperparameter_tuning", False): + if not self._in_hyperparameter_tuning: wandb.config.update(hyperparameters) + def is_wandb_enabled(self) -> bool: + """ + Check if wandb logging is enabled for this model instance. + + :returns: True if wandb is initialized and active, False otherwise + """ + return self.wandb_run is not None + + def get_wandb_logger(self) -> Any | None: + """ + Get a WandbLogger for PyTorch Lightning integration. + + This method creates a WandbLogger that uses the existing wandb run. + Returns None if wandb is not enabled. + + :returns: WandbLogger instance or None + """ + if not self.is_wandb_enabled() or self.wandb_project is None: + return None + + from pytorch_lightning.loggers import WandbLogger + + return WandbLogger(project=self.wandb_project, log_model=False) + def log_metrics(self, metrics: dict[str, float], step: int | None = None) -> None: """ Log metrics to wandb. @@ -103,7 +127,7 @@ def log_metrics(self, metrics: dict[str, float], step: int | None = None) -> Non :param metrics: dictionary of metric names to values :param step: optional step number for the metrics """ - if self.wandb_run is None: + if not self.is_wandb_enabled(): return if step is not None: @@ -113,7 +137,7 @@ def log_metrics(self, metrics: dict[str, float], step: int | None = None) -> Non def finish_wandb(self) -> None: """Finish the wandb run. Call this when training is complete.""" - if self.wandb_run is None: + if not self.is_wandb_enabled(): return wandb.finish() From 1aa852fb5a9982fa86f5f7820b2d1d370dcbe57b Mon Sep 17 00:00:00 2001 From: nictru Date: Wed, 14 Jan 2026 14:15:41 +0000 Subject: [PATCH 05/22] Simplify wandb implementation --- .../models/SimpleNeuralNetwork/multiomics_neural_network.py | 5 +---- drevalpy/models/SimpleNeuralNetwork/simple_neural_network.py | 5 +---- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/drevalpy/models/SimpleNeuralNetwork/multiomics_neural_network.py b/drevalpy/models/SimpleNeuralNetwork/multiomics_neural_network.py index 7c6865b3..dc6a7b5b 100644 --- a/drevalpy/models/SimpleNeuralNetwork/multiomics_neural_network.py +++ b/drevalpy/models/SimpleNeuralNetwork/multiomics_neural_network.py @@ -127,9 +127,6 @@ def train( "ignore", message=".*does not have many workers which may be a bottleneck.*", ) - # Get wandb project from parent model if available - wandb_project = self.wandb_project - self.model.fit( output_train=output, cell_line_input=cell_line_input, @@ -145,7 +142,7 @@ def train( patience=5, num_workers=1, model_checkpoint_dir=model_checkpoint_dir, - wandb_project=wandb_project, + wandb_project=self.wandb_project, ) def predict( diff --git a/drevalpy/models/SimpleNeuralNetwork/simple_neural_network.py b/drevalpy/models/SimpleNeuralNetwork/simple_neural_network.py index 5e0e054b..7ab82bbf 100644 --- a/drevalpy/models/SimpleNeuralNetwork/simple_neural_network.py +++ b/drevalpy/models/SimpleNeuralNetwork/simple_neural_network.py @@ -116,9 +116,6 @@ def train( print("Probably, your training dataset is small.") - # Get wandb project from parent model if available - wandb_project = self.wandb_project - self.model.fit( output_train=output, cell_line_input=cell_line_input, @@ -134,7 +131,7 @@ def train( patience=5, num_workers=1 if platform.system() == "Windows" else 8, model_checkpoint_dir=model_checkpoint_dir, - wandb_project=wandb_project, + wandb_project=self.wandb_project, ) def predict( From 1a4462bd8dd8ed7da53ff154cd6dc3d6284cdb38 Mon Sep 17 00:00:00 2001 From: Nico Trummer Date: Thu, 15 Jan 2026 09:31:58 +0100 Subject: [PATCH 06/22] Test Judith's proposed fix Co-authored-by: Judith Bernett <38618495+JudithBernett@users.noreply.github.com> --- drevalpy/models/drp_model.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/drevalpy/models/drp_model.py b/drevalpy/models/drp_model.py index 99b574e8..1df8b6b6 100644 --- a/drevalpy/models/drp_model.py +++ b/drevalpy/models/drp_model.py @@ -16,6 +16,10 @@ from sklearn.model_selection import ParameterGrid import wandb +import numpy as np +import wandb +import yaml +from sklearn.model_selection import ParameterGrid from ..datasets.dataset import DrugResponseDataset, FeatureDataset from ..pipeline_function import pipeline_function From f444b597a7d5a25558e11abd0bd517c29691d5c8 Mon Sep 17 00:00:00 2001 From: nictru Date: Thu, 15 Jan 2026 08:39:07 +0000 Subject: [PATCH 07/22] Update CI isort version --- .pre-commit-config.yaml | 2 +- drevalpy/models/drp_model.py | 4 ---- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index fdb212c9..f4ed4412 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -49,7 +49,7 @@ repos: hooks: - id: prettier - repo: https://github.com/pycqa/isort - rev: 6.0.1 + rev: 7.0.0 hooks: - id: isort name: isort (python) diff --git a/drevalpy/models/drp_model.py b/drevalpy/models/drp_model.py index 1df8b6b6..99b574e8 100644 --- a/drevalpy/models/drp_model.py +++ b/drevalpy/models/drp_model.py @@ -16,10 +16,6 @@ from sklearn.model_selection import ParameterGrid import wandb -import numpy as np -import wandb -import yaml -from sklearn.model_selection import ParameterGrid from ..datasets.dataset import DrugResponseDataset, FeatureDataset from ..pipeline_function import pipeline_function From 7c53a18c5c67f11fb2db7fb097c082f786a4d971 Mon Sep 17 00:00:00 2001 From: nictru Date: Thu, 15 Jan 2026 08:44:06 +0000 Subject: [PATCH 08/22] Attempt changing pre-commit isort config --- .pre-commit-config.yaml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f4ed4412..8f723834 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -53,12 +53,15 @@ repos: hooks: - id: isort name: isort (python) + args: ["--settings-path=pyproject.toml"] - id: isort name: isort (cython) types: [cython] + args: ["--settings-path=pyproject.toml"] - id: isort name: isort (pyi) types: [pyi] + args: ["--settings-path=pyproject.toml"] - repo: https://github.com/pre-commit/pre-commit-hooks rev: v5.0.0 hooks: From feee8b3473ae833f835d932f77196e9a362fabfc Mon Sep 17 00:00:00 2001 From: nictru Date: Thu, 15 Jan 2026 08:48:30 +0000 Subject: [PATCH 09/22] Try making isort diff visible in CI --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8f723834..7788c995 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -53,7 +53,7 @@ repos: hooks: - id: isort name: isort (python) - args: ["--settings-path=pyproject.toml"] + args: ["--settings-path=pyproject.toml", "--diff"] - id: isort name: isort (cython) types: [cython] From 2b247fc2c19842a9562ca77a6f202c94b2e06019 Mon Sep 17 00:00:00 2001 From: nictru Date: Thu, 15 Jan 2026 08:52:49 +0000 Subject: [PATCH 10/22] Revert pre-commit config changes --- .pre-commit-config.yaml | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7788c995..1c3d4be0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -49,19 +49,17 @@ repos: hooks: - id: prettier - repo: https://github.com/pycqa/isort - rev: 7.0.0 + rev: 6.0.1 hooks: - id: isort name: isort (python) - args: ["--settings-path=pyproject.toml", "--diff"] + args: ["--diff"] - id: isort name: isort (cython) types: [cython] - args: ["--settings-path=pyproject.toml"] - id: isort name: isort (pyi) types: [pyi] - args: ["--settings-path=pyproject.toml"] - repo: https://github.com/pre-commit/pre-commit-hooks rev: v5.0.0 hooks: From df9817341f418cf1d249ec976da57e9dce2c6ba2 Mon Sep 17 00:00:00 2001 From: nictru Date: Thu, 15 Jan 2026 08:54:50 +0000 Subject: [PATCH 11/22] Remove isort diff arg --- .pre-commit-config.yaml | 1 - 1 file changed, 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1c3d4be0..fdb212c9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -53,7 +53,6 @@ repos: hooks: - id: isort name: isort (python) - args: ["--diff"] - id: isort name: isort (cython) types: [cython] From fa20fd1765ca9b468a86ade6db803d0519e48e64 Mon Sep 17 00:00:00 2001 From: nictru Date: Thu, 15 Jan 2026 09:00:18 +0000 Subject: [PATCH 12/22] Add back isort diff arg --- .pre-commit-config.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index fdb212c9..1c3d4be0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -53,6 +53,7 @@ repos: hooks: - id: isort name: isort (python) + args: ["--diff"] - id: isort name: isort (cython) types: [cython] From 0fa25e1c6e0cf7a7495071e3897ddbf3f8e1c09a Mon Sep 17 00:00:00 2001 From: nictru Date: Thu, 15 Jan 2026 09:03:33 +0000 Subject: [PATCH 13/22] Also add check arg --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1c3d4be0..180ddd74 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -53,7 +53,7 @@ repos: hooks: - id: isort name: isort (python) - args: ["--diff"] + args: ["--diff", "--check"] - id: isort name: isort (cython) types: [cython] From 583fdfaa3b3af5757bf7c76e335b0787a3163a29 Mon Sep 17 00:00:00 2001 From: nictru Date: Thu, 15 Jan 2026 09:07:04 +0000 Subject: [PATCH 14/22] Attempt fixing isort issue --- .pre-commit-config.yaml | 1 - drevalpy/models/drp_model.py | 3 +-- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 180ddd74..fdb212c9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -53,7 +53,6 @@ repos: hooks: - id: isort name: isort (python) - args: ["--diff", "--check"] - id: isort name: isort (cython) types: [cython] diff --git a/drevalpy/models/drp_model.py b/drevalpy/models/drp_model.py index 99b574e8..0bb2cf24 100644 --- a/drevalpy/models/drp_model.py +++ b/drevalpy/models/drp_model.py @@ -12,11 +12,10 @@ from typing import Any import numpy as np +import wandb import yaml from sklearn.model_selection import ParameterGrid -import wandb - from ..datasets.dataset import DrugResponseDataset, FeatureDataset from ..pipeline_function import pipeline_function From b02bb63d63bf4814cf50b0a2d8cd3204db7a380b Mon Sep 17 00:00:00 2001 From: nictru Date: Thu, 15 Jan 2026 09:15:49 +0000 Subject: [PATCH 15/22] Fix mypy --- drevalpy/models/DrugGNN/drug_gnn.py | 8 ++++++++ .../SimpleNeuralNetwork/multiomics_neural_network.py | 1 + .../models/SimpleNeuralNetwork/simple_neural_network.py | 4 ++++ drevalpy/models/drp_model.py | 4 +++- 4 files changed, 16 insertions(+), 1 deletion(-) diff --git a/drevalpy/models/DrugGNN/drug_gnn.py b/drevalpy/models/DrugGNN/drug_gnn.py index d8e2bc39..2d746446 100644 --- a/drevalpy/models/DrugGNN/drug_gnn.py +++ b/drevalpy/models/DrugGNN/drug_gnn.py @@ -258,6 +258,9 @@ def build_model(self, hyperparameters: dict[str, Any]) -> None: self.hyperparameters = hyperparameters def _loader_kwargs(self) -> dict[str, Any]: + assert ( + self.hyperparameters is not None + ), "hyperparameters must be set via build_model() before calling this method" num_workers = int(self.hyperparameters.get("num_workers", 4)) kw = { "num_workers": num_workers, @@ -289,6 +292,7 @@ def train( raise ValueError("Drug input is required for DrugGNN") # Determine feature sizes + assert self.hyperparameters is not None, "hyperparameters must be set via build_model() before calling train()" num_node_features = next(iter(drug_input.features.values()))["drug_graph"].num_node_features num_cell_features = next(iter(cell_line_input.features.values()))["gene_expression"].shape[0] @@ -337,6 +341,7 @@ def train( logger = WandbLogger(project=self.wandb_project, log_model=False) loggers.append(logger) + assert self.hyperparameters is not None, "hyperparameters must be set via build_model() before calling train()" trainer = pl.Trainer( max_epochs=self.hyperparameters.get("epochs", 100), accelerator="auto", @@ -383,6 +388,9 @@ def predict( cell_line_features=cell_line_input, drug_features=drug_input, ) + assert ( + self.hyperparameters is not None + ), "hyperparameters must be set via build_model() before calling predict()" predict_loader = DataLoader( predict_dataset, batch_size=self.hyperparameters.get("batch_size", 32), diff --git a/drevalpy/models/SimpleNeuralNetwork/multiomics_neural_network.py b/drevalpy/models/SimpleNeuralNetwork/multiomics_neural_network.py index dc6a7b5b..0bdf9e1f 100644 --- a/drevalpy/models/SimpleNeuralNetwork/multiomics_neural_network.py +++ b/drevalpy/models/SimpleNeuralNetwork/multiomics_neural_network.py @@ -117,6 +117,7 @@ def train( self.dim_cnv = dim_cnv self.dim_fp = dim_fingerprint + assert self.hyperparameters is not None, "hyperparameters must be set via build_model() before calling train()" self.model = FeedForwardNetwork( hyperparameters=self.hyperparameters, input_dim=dim_gex + dim_met + dim_mut + dim_cnv + dim_fingerprint, diff --git a/drevalpy/models/SimpleNeuralNetwork/simple_neural_network.py b/drevalpy/models/SimpleNeuralNetwork/simple_neural_network.py index 7ab82bbf..69d010db 100644 --- a/drevalpy/models/SimpleNeuralNetwork/simple_neural_network.py +++ b/drevalpy/models/SimpleNeuralNetwork/simple_neural_network.py @@ -91,6 +91,7 @@ def train( gene_expression_scaler=self.gene_expression_scaler, ) + assert self.hyperparameters is not None, "hyperparameters must be set via build_model() before calling train()" dim_gex = next(iter(cell_line_input.features.values()))["gene_expression"].shape[0] dim_fingerprint = next(iter(drug_input.features.values()))[self.drug_views[0]].shape[0] self.hyperparameters["input_dim_gex"] = dim_gex @@ -116,6 +117,9 @@ def train( print("Probably, your training dataset is small.") + assert ( + self.hyperparameters is not None + ), "hyperparameters must be set via build_model() before calling train()" self.model.fit( output_train=output, cell_line_input=cell_line_input, diff --git a/drevalpy/models/drp_model.py b/drevalpy/models/drp_model.py index 0bb2cf24..6c32dc3e 100644 --- a/drevalpy/models/drp_model.py +++ b/drevalpy/models/drp_model.py @@ -62,13 +62,15 @@ def init_wandb( self.wandb_project = project self.wandb_config = config or {} + if finish_previous: + wandb.finish() + run_name = name or self.get_model_name() wandb.init( project=project, config=self.wandb_config, name=run_name, tags=tags, - finish_previous=finish_previous, ) self.wandb_run = wandb.run From 30af5144e31a917f8014d7fa8be920baa8358f1d Mon Sep 17 00:00:00 2001 From: nictru Date: Thu, 15 Jan 2026 09:31:32 +0000 Subject: [PATCH 16/22] Make hyperparameters non-optional --- drevalpy/models/DrugGNN/drug_gnn.py | 8 -------- .../SimpleNeuralNetwork/multiomics_neural_network.py | 2 -- .../models/SimpleNeuralNetwork/simple_neural_network.py | 5 ----- drevalpy/models/drp_model.py | 2 +- 4 files changed, 1 insertion(+), 16 deletions(-) diff --git a/drevalpy/models/DrugGNN/drug_gnn.py b/drevalpy/models/DrugGNN/drug_gnn.py index 2d746446..d8e2bc39 100644 --- a/drevalpy/models/DrugGNN/drug_gnn.py +++ b/drevalpy/models/DrugGNN/drug_gnn.py @@ -258,9 +258,6 @@ def build_model(self, hyperparameters: dict[str, Any]) -> None: self.hyperparameters = hyperparameters def _loader_kwargs(self) -> dict[str, Any]: - assert ( - self.hyperparameters is not None - ), "hyperparameters must be set via build_model() before calling this method" num_workers = int(self.hyperparameters.get("num_workers", 4)) kw = { "num_workers": num_workers, @@ -292,7 +289,6 @@ def train( raise ValueError("Drug input is required for DrugGNN") # Determine feature sizes - assert self.hyperparameters is not None, "hyperparameters must be set via build_model() before calling train()" num_node_features = next(iter(drug_input.features.values()))["drug_graph"].num_node_features num_cell_features = next(iter(cell_line_input.features.values()))["gene_expression"].shape[0] @@ -341,7 +337,6 @@ def train( logger = WandbLogger(project=self.wandb_project, log_model=False) loggers.append(logger) - assert self.hyperparameters is not None, "hyperparameters must be set via build_model() before calling train()" trainer = pl.Trainer( max_epochs=self.hyperparameters.get("epochs", 100), accelerator="auto", @@ -388,9 +383,6 @@ def predict( cell_line_features=cell_line_input, drug_features=drug_input, ) - assert ( - self.hyperparameters is not None - ), "hyperparameters must be set via build_model() before calling predict()" predict_loader = DataLoader( predict_dataset, batch_size=self.hyperparameters.get("batch_size", 32), diff --git a/drevalpy/models/SimpleNeuralNetwork/multiomics_neural_network.py b/drevalpy/models/SimpleNeuralNetwork/multiomics_neural_network.py index 0bdf9e1f..39b778a1 100644 --- a/drevalpy/models/SimpleNeuralNetwork/multiomics_neural_network.py +++ b/drevalpy/models/SimpleNeuralNetwork/multiomics_neural_network.py @@ -37,7 +37,6 @@ def __init__(self): """ super().__init__() self.model = None - self.hyperparameters = None self.methylation_scaler = StandardScaler() self.methylation_pca = None self.pca_ncomp = 100 @@ -117,7 +116,6 @@ def train( self.dim_cnv = dim_cnv self.dim_fp = dim_fingerprint - assert self.hyperparameters is not None, "hyperparameters must be set via build_model() before calling train()" self.model = FeedForwardNetwork( hyperparameters=self.hyperparameters, input_dim=dim_gex + dim_met + dim_mut + dim_cnv + dim_fingerprint, diff --git a/drevalpy/models/SimpleNeuralNetwork/simple_neural_network.py b/drevalpy/models/SimpleNeuralNetwork/simple_neural_network.py index 69d010db..077e69ff 100644 --- a/drevalpy/models/SimpleNeuralNetwork/simple_neural_network.py +++ b/drevalpy/models/SimpleNeuralNetwork/simple_neural_network.py @@ -33,7 +33,6 @@ def __init__(self): """ super().__init__() self.model = None - self.hyperparameters = None self.gene_expression_scaler = StandardScaler() @classmethod @@ -91,7 +90,6 @@ def train( gene_expression_scaler=self.gene_expression_scaler, ) - assert self.hyperparameters is not None, "hyperparameters must be set via build_model() before calling train()" dim_gex = next(iter(cell_line_input.features.values()))["gene_expression"].shape[0] dim_fingerprint = next(iter(drug_input.features.values()))[self.drug_views[0]].shape[0] self.hyperparameters["input_dim_gex"] = dim_gex @@ -117,9 +115,6 @@ def train( print("Probably, your training dataset is small.") - assert ( - self.hyperparameters is not None - ), "hyperparameters must be set via build_model() before calling train()" self.model.fit( output_train=output, cell_line_input=cell_line_input, diff --git a/drevalpy/models/drp_model.py b/drevalpy/models/drp_model.py index 6c32dc3e..96da076c 100644 --- a/drevalpy/models/drp_model.py +++ b/drevalpy/models/drp_model.py @@ -39,7 +39,7 @@ def __init__(self): self.wandb_project: str | None = None self.wandb_run: Any = None self.wandb_config: dict[str, Any] | None = None - self.hyperparameters: dict[str, Any] | None = None + self.hyperparameters: dict[str, Any] = {} self._in_hyperparameter_tuning: bool = False # Flag to track if we're in hyperparameter tuning def init_wandb( From 0d362c3e414fb7f778d53940467193f897aa71e4 Mon Sep 17 00:00:00 2001 From: nictru Date: Mon, 19 Jan 2026 15:19:20 +0000 Subject: [PATCH 17/22] Clean wandb implementation --- drevalpy/experiment.py | 148 ++++++++++++++----- drevalpy/models/DrugGNN/drug_gnn.py | 13 +- drevalpy/models/MOLIR/utils.py | 17 ++- drevalpy/models/PharmaFormer/pharmaformer.py | 36 ++++- drevalpy/models/SimpleNeuralNetwork/utils.py | 11 +- drevalpy/models/SuperFELTR/utils.py | 14 +- drevalpy/models/drp_model.py | 100 ++++++++++++- drevalpy/models/lightning_metrics_mixin.py | 112 ++++++++++++++ drevalpy/utils.py | 9 ++ tests/test_hpam_tune.py | 3 + 10 files changed, 415 insertions(+), 48 deletions(-) create mode 100644 drevalpy/models/lightning_metrics_mixin.py diff --git a/drevalpy/experiment.py b/drevalpy/experiment.py index 48e4a975..13b2e1f2 100644 --- a/drevalpy/experiment.py +++ b/drevalpy/experiment.py @@ -14,7 +14,7 @@ from sklearn.base import TransformerMixin from .datasets.dataset import DrugResponseDataset, FeatureDataset, split_early_stopping_data -from .evaluation import evaluate, get_mode +from .evaluation import get_mode from .models import MODEL_FACTORY, MULTI_DRUG_MODEL_FACTORY, SINGLE_DRUG_MODEL_FACTORY from .models.drp_model import DRPModel from .pipeline_function import pipeline_function @@ -140,6 +140,7 @@ def drug_response_experiment( ) response_data.save_splits(path=split_path) + # Build the list of models to run (done regardless of whether splits were newly created or loaded) model_list = make_model_list(models + baselines, response_data) for model_name in model_list.keys(): print(f"Running {model_name}") @@ -192,29 +193,16 @@ def drug_response_experiment( ) = get_datasets_from_cv_split(split, model_class, model_name, drug_id) model = model_class() - - # Initialize wandb if project is provided (before hyperparameter tuning) - if wandb_project is not None: - run_name = f"{model_name}" - if drug_id is not None: - run_name += f"_{drug_id}" - run_name += f"_split_{split_index}" - - config = { - "model_name": model_name, - "drug_id": drug_id, - "split_index": split_index, - "test_mode": test_mode, - "dataset": response_data.dataset_name, - "n_cv_splits": n_cv_splits, - "hyperparameter_tuning": hyperparameter_tuning, - } - model.init_wandb( - project=wandb_project, - config=config, - name=run_name, - tags=[model_name, test_mode, response_data.dataset_name or "unknown"], - ) + # Base wandb configuration for this split (used when training actually happens) + base_wandb_config = { + "model_name": model_name, + "drug_id": drug_id, + "split_index": split_index, + "test_mode": test_mode, + "dataset": response_data.dataset_name, + "n_cv_splits": n_cv_splits, + "hyperparameter_tuning": hyperparameter_tuning, + } if not os.path.isfile( prediction_file @@ -231,6 +219,12 @@ def drug_response_experiment( "model_checkpoint_dir": model_checkpoint_dir, } + # During hyperparameter tuning, create separate wandb runs per trial if enabled + if wandb_project is not None: + tuning_inputs["wandb_project"] = wandb_project + tuning_inputs["split_index"] = split_index + tuning_inputs["wandb_base_config"] = base_wandb_config + if multiprocessing: tuning_inputs["ray_path"] = os.path.abspath(os.path.join(result_path, "raytune")) best_hpams = hpam_tune_raytune(**tuning_inputs) @@ -253,6 +247,25 @@ def drug_response_experiment( train_dataset.add_rows(validation_dataset) # use full train val set data for final training train_dataset.shuffle(random_state=42) + # Initialize wandb for the final training on the full train+validation set + if wandb_project is not None: + final_run_name = f"{model_name}" + if drug_id is not None: + final_run_name += f"_{drug_id}" + final_run_name += f"_split_{split_index}_final" + + final_config = { + **base_wandb_config, + "phase": "final_training", + "best_hyperparameters": best_hpams, + } + model.init_wandb( + project=wandb_project, + config=final_config, + name=final_run_name, + tags=[model_name, test_mode, response_data.dataset_name or "unknown", "final"], + ) + test_dataset = train_and_predict( model=model, hpams=best_hpams, @@ -264,6 +277,20 @@ def drug_response_experiment( model_checkpoint_dir=model_checkpoint_dir, ) + # Log final metrics on test set for all models + # Metrics will be logged as test_RMSE, test_R^2, test_Pearson, etc. + if ( + model.is_wandb_enabled() + and len(test_dataset) > 0 + and test_dataset.predictions is not None + and len(test_dataset.predictions) > 0 + ): + model.compute_and_log_final_metrics( + test_dataset, + additional_metrics=[hpam_optimization_metric], + prefix="test_", + ) + for cross_study_dataset in cross_study_datasets: print(f"Cross study prediction on {cross_study_dataset.dataset_name}") cross_study_dataset.remove_nan_responses() @@ -1090,13 +1117,18 @@ def train_and_evaluate( response_transformation=response_transformation, model_checkpoint_dir=model_checkpoint_dir, ) - results = evaluate(validation_dataset, metric=[metric]) - # Log validation metrics to wandb if enabled - if model.is_wandb_enabled(): - # Prefix metrics with "val_" to distinguish from training metrics - wandb_metrics = {f"val_{k}": v for k, v in results.items()} - model.log_metrics(wandb_metrics) + # Compute final metrics using DRPModel helper (always includes R^2 and PCC) + # Add primary metric if it's not already included + additional_metrics = None + if metric not in ["R^2", "Pearson"]: + additional_metrics = [metric] + # Use "val_" prefix to clearly denote validation metrics (val_RMSE, val_R^2, val_Pearson) + results = model.compute_and_log_final_metrics( + validation_dataset, + additional_metrics=additional_metrics, + prefix="val_", + ) return results @@ -1111,6 +1143,10 @@ def hpam_tune( metric: str = "RMSE", path_data: str = "data", model_checkpoint_dir: str = "TEMPORARY", + *, + split_index: int | None = None, + wandb_project: str | None = None, + wandb_base_config: dict[str, Any] | None = None, ) -> dict: """ Tune the hyperparameters for the given model in an iterative manner. @@ -1124,6 +1160,9 @@ def hpam_tune( :param metric: metric to evaluate which model is the best :param path_data: path to the data directory, e.g., data/ :param model_checkpoint_dir: directory to save model checkpoints + :param split_index: optional CV split index, used for naming wandb runs + :param wandb_project: optional wandb project name; if provided, enables per-trial wandb runs + :param wandb_base_config: optional base config dict to include in each wandb run :returns: best hyperparameters :raises AssertionError: if hpam_set is empty """ @@ -1142,8 +1181,34 @@ def hpam_tune( for trial_idx, hyperparameter in enumerate(hpam_set): print(f"Training model with hyperparameters: {hyperparameter}") - # During hyperparameter tuning, don't update wandb config for each trial - # Instead, we'll log trial hyperparameters as metrics + # Create a separate wandb run for each hyperparameter trial if enabled + if wandb_project is not None: + trial_run_name = model.get_model_name() + if split_index is not None: + trial_run_name += f"_split_{split_index}" + trial_run_name += f"_trial_{trial_idx}" + + trial_config: dict[str, Any] = {} + if wandb_base_config is not None: + trial_config.update(wandb_base_config) + trial_config.update( + { + "phase": "hyperparameter_tuning", + "trial_index": trial_idx, + "hyperparameters": hyperparameter, + } + ) + + model.init_wandb( + project=wandb_project, + config=trial_config, + name=trial_run_name, + tags=[model.get_model_name(), "hpam_tuning"], + finish_previous=True, + ) + + # During hyperparameter tuning, don't update wandb config via log_hyperparameters + # Trial hyperparameters are stored in wandb.config for each run score = train_and_evaluate( model=model, hpams=hyperparameter, @@ -1157,22 +1222,29 @@ def hpam_tune( )[metric] if np.isnan(score): + # Finish the wandb run for this trial if it exists, even when score is NaN + if model.is_wandb_enabled(): + # Log NaN metric with validation prefix for clarity (e.g., val_RMSE) + model.log_metrics({f"val_{metric}": score}) + model.log_final_metrics({f"val_{metric}": score}) + model.finish_wandb() continue - # Log trial hyperparameters and result to wandb if enabled + # Log trial result to wandb if enabled if model.is_wandb_enabled(): - trial_metrics = {f"trial_{trial_idx}_{k}": v for k, v in hyperparameter.items()} - trial_metrics[f"trial_{trial_idx}_{metric}"] = score - model.log_metrics(trial_metrics) + # Log using validation-prefixed metric name (e.g., val_RMSE) + model.log_metrics({f"val_{metric}": score}) + model.log_final_metrics({f"val_{metric}": score}) + model.finish_wandb() if (mode == "min" and score < best_score) or (mode == "max" and score > best_score): print(f"current best {metric} score: {np.round(score, 3)}") best_score = score best_hyperparameters = hyperparameter - # Log best score so far to wandb if enabled + # Log best score so far to wandb if enabled, using a clear name if model.is_wandb_enabled(): - model.log_metrics({f"best_{metric}": best_score}) + model.log_metrics({f"best_val_{metric}": best_score}) if best_hyperparameters is None: warnings.warn("all hpams lead to NaN respone. using last hpam combination.", stacklevel=2) diff --git a/drevalpy/models/DrugGNN/drug_gnn.py b/drevalpy/models/DrugGNN/drug_gnn.py index d8e2bc39..89f0fb7d 100644 --- a/drevalpy/models/DrugGNN/drug_gnn.py +++ b/drevalpy/models/DrugGNN/drug_gnn.py @@ -15,6 +15,7 @@ from ...datasets.dataset import DrugResponseDataset, FeatureDataset from ..drp_model import DRPModel +from ..lightning_metrics_mixin import RegressionMetricsMixin from ..utils import load_and_select_gene_features @@ -86,7 +87,7 @@ def forward(self, drug_graph, cell_features): return out.view(-1) -class DrugGNNModule(pl.LightningModule): +class DrugGNNModule(RegressionMetricsMixin, pl.LightningModule): """The LightningModule for the DrugGNN model.""" def __init__( @@ -115,6 +116,9 @@ def __init__( ) self.criterion = nn.MSELoss() + # Initialize metrics storage for epoch-end R^2 and PCC computation + self._init_metrics_storage() + def forward(self, batch): """Forward pass of the module. @@ -135,6 +139,10 @@ def training_step(self, batch, batch_idx): outputs = self.model(drug_graph, cell_features) loss = self.criterion(outputs, responses) self.log("train_loss", loss, on_step=False, on_epoch=True, batch_size=responses.size(0)) + + # Store predictions and targets for epoch-end metrics via mixin + self._store_predictions(outputs, responses, is_training=True) + return loss def validation_step(self, batch, batch_idx): @@ -148,6 +156,9 @@ def validation_step(self, batch, batch_idx): loss = self.criterion(outputs, responses) self.log("val_loss", loss, on_step=False, on_epoch=True, batch_size=responses.size(0)) + # Store predictions and targets for epoch-end metrics via mixin + self._store_predictions(outputs, responses, is_training=False) + def predict_step(self, batch, batch_idx, dataloader_idx=0): """A single prediction step. diff --git a/drevalpy/models/MOLIR/utils.py b/drevalpy/models/MOLIR/utils.py index d404e27f..1f4401f2 100644 --- a/drevalpy/models/MOLIR/utils.py +++ b/drevalpy/models/MOLIR/utils.py @@ -18,7 +18,9 @@ from torch.utils.data import DataLoader, Dataset from drevalpy.datasets.dataset import DrugResponseDataset, FeatureDataset -from drevalpy.models.drp_model import DRPModel + +from ..drp_model import DRPModel +from ..lightning_metrics_mixin import RegressionMetricsMixin class RegressionDataset(Dataset): @@ -313,7 +315,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.regressor(x) -class MOLIModel(pl.LightningModule): +class MOLIModel(RegressionMetricsMixin, pl.LightningModule): """ PyTorch Lightning module for the MOLIR model. @@ -363,6 +365,9 @@ def __init__( self.cna_encoder = MOLIEncoder(input_dim_cnv, self.h_dim3, self.dropout_rate) self.regressor = MOLIRegressor(self.h_dim1 + self.h_dim2 + self.h_dim3, self.dropout_rate) + # Initialize metrics storage for epoch-end R^2 and PCC computation + self._init_metrics_storage() + def fit( self, output_train: DrugResponseDataset, @@ -534,6 +539,10 @@ def training_step(self, batch: list[torch.Tensor], batch_idx: int) -> torch.Tens # Compute loss loss = self._compute_loss(z, preds, response) self.log("train_loss", loss, on_step=False, on_epoch=True, prog_bar=True) + + # Store predictions and targets for epoch-end metrics via mixin + self._store_predictions(preds, response, is_training=True) + return loss def validation_step(self, batch: list[torch.Tensor], batch_idx: int) -> torch.Tensor: @@ -555,6 +564,10 @@ def validation_step(self, batch: list[torch.Tensor], batch_idx: int) -> torch.Te # Compute loss val_loss = self._compute_loss(z, preds, response) self.log("val_loss", val_loss, on_step=False, on_epoch=True, prog_bar=True) + + # Store predictions and targets for epoch-end metrics via mixin + self._store_predictions(preds, response, is_training=False) + return val_loss def configure_optimizers(self) -> torch.optim.Optimizer: diff --git a/drevalpy/models/PharmaFormer/pharmaformer.py b/drevalpy/models/PharmaFormer/pharmaformer.py index 39268989..d1561d2d 100644 --- a/drevalpy/models/PharmaFormer/pharmaformer.py +++ b/drevalpy/models/PharmaFormer/pharmaformer.py @@ -220,6 +220,8 @@ def train( self.model.train() epoch_loss = 0.0 batch_count = 0 + train_predictions = [] + train_targets = [] # Training phase for gene_inputs, smiles_inputs, targets in train_loader: @@ -239,17 +241,31 @@ def train( epoch_loss += loss.detach().item() batch_count += 1 + # Store predictions and targets for R^2 and PCC computation + train_predictions.append(outputs.squeeze().detach().cpu().numpy()) + train_targets.append(targets.detach().cpu().numpy()) + epoch_loss /= batch_count print(f"PharmaFormer: Epoch [{epoch + 1}/{self.hyperparameters['epochs']}] Training Loss: {epoch_loss:.4f}") - # Log training loss to wandb if enabled + # Compute and log training R^2 and PCC using DRPModel helper + train_metrics = {"train_loss": epoch_loss} + if len(train_predictions) > 0: + all_train_preds = np.concatenate(train_predictions) + all_train_targets = np.concatenate(train_targets) + perf_metrics = self.compute_performance_metrics(all_train_preds, all_train_targets, prefix="train_") + train_metrics.update(perf_metrics) + + # Log training metrics to wandb if enabled if self.is_wandb_enabled(): - self.log_metrics({"train_loss": epoch_loss}, step=epoch) + self.log_metrics(train_metrics, step=epoch) # Validation phase for early stopping self.model.eval() val_loss = 0.0 val_batch_count = 0 + val_predictions = [] + val_targets = [] with torch.no_grad(): for gene_inputs, smiles_inputs, targets in early_stopping_loader: gene_inputs = gene_inputs.to(self.DEVICE) @@ -262,12 +278,24 @@ def train( val_loss += loss.item() val_batch_count += 1 + # Store predictions and targets for R^2 and PCC computation + val_predictions.append(outputs.squeeze().detach().cpu().numpy()) + val_targets.append(targets.detach().cpu().numpy()) + val_loss /= val_batch_count print(f"PharmaFormer: Epoch [{epoch + 1}/{self.hyperparameters['epochs']}] Validation Loss: {val_loss:.4f}") - # Log validation loss to wandb if enabled + # Compute and log validation R^2 and PCC using DRPModel helper + val_metrics = {"val_loss": val_loss} + if len(val_predictions) > 0: + all_val_preds = np.concatenate(val_predictions) + all_val_targets = np.concatenate(val_targets) + perf_metrics = self.compute_performance_metrics(all_val_preds, all_val_targets, prefix="val_") + val_metrics.update(perf_metrics) + + # Log validation metrics to wandb if enabled if self.is_wandb_enabled(): - self.log_metrics({"val_loss": val_loss}, step=epoch) + self.log_metrics(val_metrics, step=epoch) # Checkpointing: Save the best model if val_loss < best_val_loss: diff --git a/drevalpy/models/SimpleNeuralNetwork/utils.py b/drevalpy/models/SimpleNeuralNetwork/utils.py index 92bb03e6..b388bcab 100644 --- a/drevalpy/models/SimpleNeuralNetwork/utils.py +++ b/drevalpy/models/SimpleNeuralNetwork/utils.py @@ -13,6 +13,8 @@ from drevalpy.datasets.dataset import DrugResponseDataset, FeatureDataset +from ..lightning_metrics_mixin import RegressionMetricsMixin + class RegressionDataset(Dataset): """Dataset for regression tasks for the data loader.""" @@ -93,7 +95,7 @@ def __len__(self): return len(self.output.response) -class FeedForwardNetwork(pl.LightningModule): +class FeedForwardNetwork(RegressionMetricsMixin, pl.LightningModule): """Feed forward neural network for regression tasks with basic architecture.""" def __init__(self, hyperparameters: dict[str, int | float | list[int]], input_dim: int) -> None: @@ -138,6 +140,9 @@ def __init__(self, hyperparameters: dict[str, int | float | list[int]], input_di if self.dropout_prob is not None: self.dropout_layer = nn.Dropout(p=self.dropout_prob) + # Initialize metrics storage for epoch-end R^2 and PCC computation + self._init_metrics_storage() + def fit( self, output_train: DrugResponseDataset, @@ -301,6 +306,10 @@ def _forward_loss_and_log(self, x, y, log_as: str): y_pred = self.forward(x) result = self.loss(y_pred, y) self.log(log_as, result, on_step=True, on_epoch=True, prog_bar=True) + + # Store predictions and targets for epoch-end metrics via mixin + self._store_predictions(y_pred, y, is_training=(log_as == "train_loss")) + return result def training_step(self, batch): diff --git a/drevalpy/models/SuperFELTR/utils.py b/drevalpy/models/SuperFELTR/utils.py index 5d5aee93..de507331 100644 --- a/drevalpy/models/SuperFELTR/utils.py +++ b/drevalpy/models/SuperFELTR/utils.py @@ -10,6 +10,7 @@ from torch import nn from ...datasets.dataset import DrugResponseDataset, FeatureDataset +from ..lightning_metrics_mixin import RegressionMetricsMixin from ..MOLIR.utils import create_dataset_and_loaders, generate_triplets_indices @@ -164,7 +165,7 @@ def validation_step(self, batch: list[torch.Tensor], batch_idx: int) -> torch.Te return triplet_loss -class SuperFELTRegressor(pl.LightningModule): +class SuperFELTRegressor(RegressionMetricsMixin, pl.LightningModule): """ SuperFELT regressor definition. @@ -204,6 +205,9 @@ def __init__( encoder.eval() self.regression_loss = nn.MSELoss() + # Initialize metrics storage for epoch-end R^2 and PCC computation + self._init_metrics_storage() + def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass of the SuperFELTRegressor. @@ -268,6 +272,10 @@ def training_step(self, batch: list[torch.Tensor], batch_idx: int) -> torch.Tens pred = self.regressor(encoded) loss = self.regression_loss(pred.squeeze(), response) self.log("train_loss", loss, on_step=False, on_epoch=True, prog_bar=True) + + # Store predictions and targets for epoch-end metrics via mixin + self._store_predictions(pred.squeeze(), response, is_training=True) + return loss def validation_step(self, batch: list[torch.Tensor], batch_idx: int) -> torch.Tensor: @@ -283,6 +291,10 @@ def validation_step(self, batch: list[torch.Tensor], batch_idx: int) -> torch.Te pred = self.regressor(encoded) loss = self.regression_loss(pred.squeeze(), response) self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True) + + # Store predictions and targets for epoch-end metrics via mixin + self._store_predictions(pred.squeeze(), response, is_training=False) + return loss diff --git a/drevalpy/models/drp_model.py b/drevalpy/models/drp_model.py index 96da076c..bda22c30 100644 --- a/drevalpy/models/drp_model.py +++ b/drevalpy/models/drp_model.py @@ -9,14 +9,17 @@ import inspect import os from abc import ABC, abstractmethod +from contextlib import suppress from typing import Any import numpy as np -import wandb import yaml from sklearn.model_selection import ParameterGrid +import wandb + from ..datasets.dataset import DrugResponseDataset, FeatureDataset +from ..evaluation import AVAILABLE_METRICS, evaluate from ..pipeline_function import pipeline_function @@ -74,6 +77,16 @@ def init_wandb( ) self.wandb_run = wandb.run + # Define common metric summaries so final/best values are tracked automatically + with suppress(Exception): # pragma: no cover - wandb may not support define_metric in all contexts + wandb.define_metric("epoch", summary="max") + wandb.define_metric("train_loss", summary="min") + wandb.define_metric("val_loss", summary="min") + wandb.define_metric("train_R^2", summary="max") + wandb.define_metric("val_R^2", summary="max") + wandb.define_metric("train_Pearson", summary="max") + wandb.define_metric("val_Pearson", summary="max") + def log_hyperparameters(self, hyperparameters: dict[str, Any]) -> None: """ Log hyperparameters to wandb. @@ -136,6 +149,91 @@ def log_metrics(self, metrics: dict[str, float], step: int | None = None) -> Non else: wandb.log(metrics) + def compute_performance_metrics( + self, predictions: np.ndarray, targets: np.ndarray, prefix: str = "" + ) -> dict[str, float]: + """ + Compute R^2 and PCC metrics from predictions and targets. + + This is a convenience method for computing performance metrics consistently + across all models. It always computes R^2 and PCC in addition to any other + metrics that may be needed. + + :param predictions: model predictions array + :param targets: ground truth targets array + :param prefix: optional prefix for metric keys (e.g., ``val_``, ``train_``) + :returns: dictionary of computed metrics with optional prefix + """ + try: + # Always compute R^2 and PCC + metrics = { + "R^2": AVAILABLE_METRICS["R^2"](y_pred=predictions, y_true=targets), + "Pearson": AVAILABLE_METRICS["Pearson"](y_pred=predictions, y_true=targets), + } + + # Add prefix if provided + if prefix: + metrics = {f"{prefix}{k}": v for k, v in metrics.items()} + + return metrics + except Exception: + # Return empty dict if computation fails + return {} + + def compute_and_log_final_metrics( + self, + dataset: DrugResponseDataset, + additional_metrics: list[str] | None = None, + prefix: str = "val_", + ) -> dict[str, float]: + r""" + Compute final performance metrics from a dataset and log them to wandb. + + This method computes R^2 and PCC (always), plus any additional metrics specified. + The metrics are both logged to wandb history and stored in the run summary. + + :param dataset: DrugResponseDataset with predictions and response + :param additional_metrics: optional list of additional metrics to compute (e.g., ["RMSE", "MAE"]) + :param prefix: metric name prefix indicating which split the metrics belong to + (for example, use ``"val"`` for validation and ``"test"`` for test metrics) + :returns: dictionary of computed metrics + """ + if dataset.predictions is None: + return {} + + # Always compute R^2 and PCC + metrics_to_compute = ["R^2", "Pearson"] + if additional_metrics: + metrics_to_compute.extend(additional_metrics) + + results = evaluate(dataset, metric=metrics_to_compute) + + # Log to wandb if enabled + if self.is_wandb_enabled(): + # Prefix indicates which split the metrics belong to (e.g. \"val\" or \"test\") + wandb_metrics = {f"{prefix}{k}": v for k, v in results.items()} + self.log_metrics(wandb_metrics) + self.log_final_metrics(wandb_metrics) + + return results + + def log_final_metrics(self, metrics: dict[str, float]) -> None: + """ + Store final metrics in the wandb run summary. + + This method is used to record final or best metrics (e.g., after validation + or after a hyperparameter trial) separate from the per-step logs. + + :param metrics: dictionary of metric names to values + """ + if not self.is_wandb_enabled(): + return + + for key, value in metrics.items(): + # Prefix with "final_" to distinguish from history metrics if not already prefixed + summary_key = key if key.startswith("final_") else f"final_{key}" + wandb.run.summary[summary_key] = value + def finish_wandb(self) -> None: """Finish the wandb run. Call this when training is complete.""" if not self.is_wandb_enabled(): diff --git a/drevalpy/models/lightning_metrics_mixin.py b/drevalpy/models/lightning_metrics_mixin.py new file mode 100644 index 00000000..6139de75 --- /dev/null +++ b/drevalpy/models/lightning_metrics_mixin.py @@ -0,0 +1,112 @@ +"""Mixin class for PyTorch Lightning modules to add R^2 and PCC metrics logging.""" + +import torch + +from ..evaluation import AVAILABLE_METRICS + + +class RegressionMetricsMixin: + """ + Mixin class for PyTorch Lightning modules to automatically compute and log R^2 and PCC metrics. + + This mixin provides: + - Storage for predictions and targets during training/validation steps + - Automatic computation of R^2 and PCC at epoch end + - Consistent logging to wandb via PyTorch Lightning's logging system + + Usage: + class MyModel(RegressionMetricsMixin, pl.LightningModule): + def __init__(self, ...): + super().__init__() + # Initialize your model... + self._init_metrics_storage() # Call this in __init__ + + def training_step(self, batch, batch_idx): + # ... your training logic ... + predictions = self.forward(...) + loss = self.criterion(predictions, targets) + self.log("train_loss", loss, ...) + self._store_predictions(predictions, targets, is_training=True) + return loss + + def validation_step(self, batch, batch_idx): + # ... your validation logic ... + predictions = self.forward(...) + loss = self.criterion(predictions, targets) + self.log("val_loss", loss, ...) + self._store_predictions(predictions, targets, is_training=False) + return loss + """ + + def _init_metrics_storage(self) -> None: + """Initialize storage for predictions and targets.""" + self.train_predictions: list[torch.Tensor] = [] + self.train_targets: list[torch.Tensor] = [] + self.val_predictions: list[torch.Tensor] = [] + self.val_targets: list[torch.Tensor] = [] + + def _store_predictions(self, predictions: torch.Tensor, targets: torch.Tensor, is_training: bool = True) -> None: + """ + Store predictions and targets for epoch-end metric computation. + + :param predictions: model predictions tensor + :param targets: ground truth targets tensor + :param is_training: whether this is from training (True) or validation (False) + """ + # Ensure tensors are detached and on CPU for numpy conversion + preds_cpu = predictions.detach().cpu() + targets_cpu = targets.detach().cpu() + + if is_training: + self.train_predictions.append(preds_cpu) + self.train_targets.append(targets_cpu) + else: + self.val_predictions.append(preds_cpu) + self.val_targets.append(targets_cpu) + + def _compute_epoch_metrics(self, predictions: list[torch.Tensor], targets: list[torch.Tensor]) -> dict[str, float]: + """ + Compute R^2 and PCC metrics from stored predictions and targets. + + :param predictions: list of prediction tensors from the epoch + :param targets: list of target tensors from the epoch + :returns: dictionary with "R^2" and "Pearson" keys, or empty dict if computation fails + """ + if len(predictions) == 0: + return {} + + try: + # Concatenate all predictions and targets from the epoch + all_preds = torch.cat(predictions).numpy() + all_targets = torch.cat(targets).numpy() + + # Compute metrics + r2 = AVAILABLE_METRICS["R^2"](y_pred=all_preds, y_true=all_targets) + pcc = AVAILABLE_METRICS["Pearson"](y_pred=all_preds, y_true=all_targets) + + return {"R^2": r2, "Pearson": pcc} + except Exception: + # If computation fails (e.g., NaN values, insufficient data), return empty dict + return {} + + def on_train_epoch_end(self) -> None: + """ + Epoch-end hook for training. + + Intentionally does NOT log R^2/Pearson per epoch anymore. We only keep + these buffers to allow optional debugging or future extensions. + """ + # Clear stored predictions/targets for next epoch + self.train_predictions.clear() + self.train_targets.clear() + + def on_validation_epoch_end(self) -> None: + """ + Epoch-end hook for validation. + + Intentionally does NOT log R^2/Pearson per epoch anymore. Final metrics + are logged once at the end via DRPModel.compute_and_log_final_metrics(). + """ + # Clear stored predictions/targets for next epoch + self.val_predictions.clear() + self.val_targets.clear() diff --git a/drevalpy/utils.py b/drevalpy/utils.py index ab89b49c..46e3d9d8 100644 --- a/drevalpy/utils.py +++ b/drevalpy/utils.py @@ -163,6 +163,14 @@ def get_parser() -> argparse.ArgumentParser: default="RMSE", help=f"Metric for hyperparameter tuning choose from {list(AVAILABLE_METRICS.keys())} " f"Default is RMSE.", ) + parser.add_argument( + "--wandb_project", + type=str, + default=None, + help=( + "Optional Weights & Biases project name. " "If provided, enables wandb logging for all DRPModel instances." + ), + ) parser.add_argument( "--n_cv_splits", type=int, @@ -348,6 +356,7 @@ def main(args) -> None: model_checkpoint_dir=args.model_checkpoint_dir, hyperparameter_tuning=not args.no_hyperparameter_tuning, final_model_on_full_data=args.final_model_on_full_data, + wandb_project=args.wandb_project, ) diff --git a/tests/test_hpam_tune.py b/tests/test_hpam_tune.py index 38fcd98b..ccbcd8e0 100644 --- a/tests/test_hpam_tune.py +++ b/tests/test_hpam_tune.py @@ -63,6 +63,9 @@ def test_hpam_tune(tmp_path): metric="RMSE", path_data="../data", model_checkpoint_dir="TEMPORARY", + split_index=None, + wandb_project=None, + wandb_base_config=None, ) assert best in hpam_set From 77f185d71d8ae5e25cbe9883d8651e838e492efd Mon Sep 17 00:00:00 2001 From: nictru Date: Mon, 19 Jan 2026 15:34:52 +0000 Subject: [PATCH 18/22] Prevent duplicate RMSE logging --- drevalpy/experiment.py | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/drevalpy/experiment.py b/drevalpy/experiment.py index 13b2e1f2..931525a3 100644 --- a/drevalpy/experiment.py +++ b/drevalpy/experiment.py @@ -1221,31 +1221,28 @@ def hpam_tune( model_checkpoint_dir=model_checkpoint_dir, )[metric] + # Note: train_and_evaluate() already logs val_* metrics once via + # DRPModel.compute_and_log_final_metrics(..., prefix="val_"). + # Avoid logging val_{metric} again here (it would create duplicate points). if np.isnan(score): - # Finish the wandb run for this trial if it exists, even when score is NaN if model.is_wandb_enabled(): - # Log NaN metric with validation prefix for clarity (e.g., val_RMSE) - model.log_metrics({f"val_{metric}": score}) - model.log_final_metrics({f"val_{metric}": score}) model.finish_wandb() continue - # Log trial result to wandb if enabled - if model.is_wandb_enabled(): - # Log using validation-prefixed metric name (e.g., val_RMSE) - model.log_metrics({f"val_{metric}": score}) - model.log_final_metrics({f"val_{metric}": score}) - model.finish_wandb() - if (mode == "min" and score < best_score) or (mode == "max" and score > best_score): print(f"current best {metric} score: {np.round(score, 3)}") best_score = score best_hyperparameters = hyperparameter - # Log best score so far to wandb if enabled, using a clear name + # Log best score so far to wandb if enabled, using a clear name. + # This is separate from val_{metric} and won't duplicate the val metric series. if model.is_wandb_enabled(): model.log_metrics({f"best_val_{metric}": best_score}) + # Close this trial's run after all logging is done + if model.is_wandb_enabled(): + model.finish_wandb() + if best_hyperparameters is None: warnings.warn("all hpams lead to NaN respone. using last hpam combination.", stacklevel=2) best_hyperparameters = hyperparameter From 146520ad1a565058a780624ffe4f36c2218a1cdc Mon Sep 17 00:00:00 2001 From: nictru Date: Mon, 19 Jan 2026 17:06:43 +0000 Subject: [PATCH 19/22] Make sure test losses are logged properly --- drevalpy/experiment.py | 6 ------ drevalpy/models/drp_model.py | 29 +++++++++++++++++++---------- 2 files changed, 19 insertions(+), 16 deletions(-) diff --git a/drevalpy/experiment.py b/drevalpy/experiment.py index 931525a3..4b4dd02e 100644 --- a/drevalpy/experiment.py +++ b/drevalpy/experiment.py @@ -257,7 +257,6 @@ def drug_response_experiment( final_config = { **base_wandb_config, "phase": "final_training", - "best_hyperparameters": best_hpams, } model.init_wandb( project=wandb_project, @@ -1234,11 +1233,6 @@ def hpam_tune( best_score = score best_hyperparameters = hyperparameter - # Log best score so far to wandb if enabled, using a clear name. - # This is separate from val_{metric} and won't duplicate the val metric series. - if model.is_wandb_enabled(): - model.log_metrics({f"best_val_{metric}": best_score}) - # Close this trial's run after all logging is done if model.is_wandb_enabled(): model.finish_wandb() diff --git a/drevalpy/models/drp_model.py b/drevalpy/models/drp_model.py index bda22c30..3b2a5c29 100644 --- a/drevalpy/models/drp_model.py +++ b/drevalpy/models/drp_model.py @@ -104,9 +104,10 @@ def log_hyperparameters(self, hyperparameters: dict[str, Any]) -> None: self.hyperparameters = hyperparameters # Only update wandb.config if we're not in hyperparameter tuning phase - # During tuning, trial hyperparameters are logged as metrics instead + # During tuning, trial hyperparameters are stored in config.hyperparameters + # Nest hyperparameters under a single key to prevent them from appearing as separate table columns if not self._in_hyperparameter_tuning: - wandb.config.update(hyperparameters) + wandb.config.update({"hyperparameters": hyperparameters}) def is_wandb_enabled(self) -> bool: """ @@ -114,7 +115,9 @@ def is_wandb_enabled(self) -> bool: :returns: True if wandb is initialized and active, False otherwise """ - return self.wandb_run is not None + # Check both self.wandb_run and wandb.run to handle cases where + # PyTorch Lightning's WandbLogger might have affected the run state + return self.wandb_project is not None and (self.wandb_run is not None or wandb.run is not None) def get_wandb_logger(self) -> Any | None: """ @@ -209,10 +212,11 @@ def compute_and_log_final_metrics( results = evaluate(dataset, metric=metrics_to_compute) # Log to wandb if enabled - if self.is_wandb_enabled(): + # Check both is_wandb_enabled() and wandb.run to ensure the run is active + if self.is_wandb_enabled() and wandb.run is not None: # Prefix indicates which split the metrics belong to (e.g. \"val\" or \"test\") wandb_metrics = {f"{prefix}{k}": v for k, v in results.items()} - self.log_metrics(wandb_metrics) + # Log to summary only (not history) since these are final metrics logged once self.log_final_metrics(wandb_metrics) return results @@ -221,18 +225,23 @@ def log_final_metrics(self, metrics: dict[str, float]) -> None: """ Store final metrics in the wandb run summary. - This method is used to record final or best metrics (e.g., after validation - or after a hyperparameter trial) separate from the per-step logs. + This method is used to record final metrics (e.g., after validation + or after a hyperparameter trial). Metrics are stored with their original + names (e.g., val_RMSE, test_RMSE) without additional prefixes. :param metrics: dictionary of metric names to values """ if not self.is_wandb_enabled(): return + # Ensure wandb.run is active before logging + if wandb.run is None: + return + for key, value in metrics.items(): - # Prefix with "final_" to distinguish from history metrics if not already prefixed - summary_key = key if key.startswith("final_") else f"final_{key}" - wandb.run.summary[summary_key] = value + # Store metrics directly without adding "final_" prefix + # The prefix (val_ or test_) already indicates the split + wandb.run.summary[key] = value def finish_wandb(self) -> None: """Finish the wandb run. Call this when training is complete.""" From 94c2c24523b3b41ed2daa71cac6579be34df6492 Mon Sep 17 00:00:00 2001 From: nictru Date: Tue, 20 Jan 2026 08:11:46 +0000 Subject: [PATCH 20/22] Fix non-hyperparam optimization wandb behavior --- drevalpy/experiment.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/drevalpy/experiment.py b/drevalpy/experiment.py index 4b4dd02e..a82c3e72 100644 --- a/drevalpy/experiment.py +++ b/drevalpy/experiment.py @@ -13,6 +13,11 @@ import torch from sklearn.base import TransformerMixin +try: + import wandb +except ImportError: + wandb = None # type: ignore[assignment] + from .datasets.dataset import DrugResponseDataset, FeatureDataset, split_early_stopping_data from .evaluation import get_mode from .models import MODEL_FACTORY, MULTI_DRUG_MODEL_FACTORY, SINGLE_DRUG_MODEL_FACTORY @@ -248,6 +253,7 @@ def drug_response_experiment( train_dataset.shuffle(random_state=42) # Initialize wandb for the final training on the full train+validation set + # This happens regardless of whether hyperparameter tuning was performed if wandb_project is not None: final_run_name = f"{model_name}" if drug_id is not None: @@ -278,17 +284,21 @@ def drug_response_experiment( # Log final metrics on test set for all models # Metrics will be logged as test_RMSE, test_R^2, test_Pearson, etc. + # This happens regardless of whether hyperparameter tuning was performed if ( - model.is_wandb_enabled() + wandb_project is not None + and wandb is not None and len(test_dataset) > 0 and test_dataset.predictions is not None and len(test_dataset.predictions) > 0 ): - model.compute_and_log_final_metrics( - test_dataset, - additional_metrics=[hpam_optimization_metric], - prefix="test_", - ) + # Ensure wandb run is active before logging metrics + if wandb.run is not None: + model.compute_and_log_final_metrics( + test_dataset, + additional_metrics=[hpam_optimization_metric], + prefix="test_", + ) for cross_study_dataset in cross_study_datasets: print(f"Cross study prediction on {cross_study_dataset.dataset_name}") From 2bed7fdeba744468f25fbb997ea6b7e11beb1db5 Mon Sep 17 00:00:00 2001 From: nictru Date: Tue, 20 Jan 2026 08:13:55 +0000 Subject: [PATCH 21/22] Fix drp_model for CI --- drevalpy/models/drp_model.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/drevalpy/models/drp_model.py b/drevalpy/models/drp_model.py index 3b2a5c29..7599be2e 100644 --- a/drevalpy/models/drp_model.py +++ b/drevalpy/models/drp_model.py @@ -13,11 +13,10 @@ from typing import Any import numpy as np +import wandb import yaml from sklearn.model_selection import ParameterGrid -import wandb - from ..datasets.dataset import DrugResponseDataset, FeatureDataset from ..evaluation import AVAILABLE_METRICS, evaluate from ..pipeline_function import pipeline_function From c7100c711edc16ded125ffcea9d7eab63a8cfd81 Mon Sep 17 00:00:00 2001 From: nictru Date: Tue, 20 Jan 2026 08:20:45 +0000 Subject: [PATCH 22/22] FixFix wandb_project test error --- tests/test_main.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_main.py b/tests/test_main.py index a86d7ed1..cadc9312 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -41,6 +41,7 @@ "model_checkpoint_dir": "TEMPORARY", "no_hyperparameter_tuning": True, "final_model_on_full_data": True, + "wandb_project": None, } ], )