diff --git a/.gitignore b/.gitignore index 0de5ab96..5e873b1f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,21 +1,12 @@ # Data -data/cell_line_input -data/response_output -data/mapping -data/GDSC1 -data/GDSC2 -data/CCLE -data/TOYv1 -data/TOYv2 -data/CTRPv1 -data/CTRPv2 -data/meta -data/BeatAML2 -data/PDX_Bruna +data/ # Results directory is created when running the demo notebook results/ +# Wandb directory is created when running the benchmark with wandb +wandb/ + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/drevalpy/experiment.py b/drevalpy/experiment.py index b2b15969..635c1ea3 100644 --- a/drevalpy/experiment.py +++ b/drevalpy/experiment.py @@ -13,8 +13,13 @@ import torch from sklearn.base import TransformerMixin +try: + import wandb +except ImportError: + wandb = None # type: ignore[assignment] + from .datasets.dataset import DrugResponseDataset, FeatureDataset, split_early_stopping_data -from .evaluation import evaluate, get_mode +from .evaluation import get_mode from .models import MODEL_FACTORY, MULTI_DRUG_MODEL_FACTORY, SINGLE_DRUG_MODEL_FACTORY from .models.drp_model import DRPModel from .pipeline_function import pipeline_function @@ -45,6 +50,7 @@ def drug_response_experiment( model_checkpoint_dir: str = "TEMPORARY", hyperparameter_tuning=True, final_model_on_full_data: bool = False, + wandb_project: str | None = None, ) -> None: """ Run the drug response prediction experiment. Save results to disc. @@ -97,6 +103,8 @@ def drug_response_experiment( :param final_model_on_full_data: if True, a final/production model is saved in the results directory. If hyperparameter_tuning is true, the final model is produced according to the hyperparameter tuning procedure which was evaluated in the nested cross validation. + :param wandb_project: if provided, enables wandb logging for all DRPModel instances throughout training. + All hyperparameters and metrics will be logged to the specified wandb project. :raises ValueError: if no cv splits are found """ # Default baseline model, needed for normalization @@ -137,6 +145,7 @@ def drug_response_experiment( ) response_data.save_splits(path=split_path) + # Build the list of models to run (done regardless of whether splits were newly created or loaded) model_list = make_model_list(models + baselines, response_data) for model_name in model_list.keys(): print(f"Running {model_name}") @@ -189,6 +198,16 @@ def drug_response_experiment( ) = get_datasets_from_cv_split(split, model_class, model_name, drug_id) model = model_class() + # Base wandb configuration for this split (used when training actually happens) + base_wandb_config = { + "model_name": model_name, + "drug_id": drug_id, + "split_index": split_index, + "test_mode": test_mode, + "dataset": response_data.dataset_name, + "n_cv_splits": n_cv_splits, + "hyperparameter_tuning": hyperparameter_tuning, + } if not os.path.isfile( prediction_file @@ -205,6 +224,12 @@ def drug_response_experiment( "model_checkpoint_dir": model_checkpoint_dir, } + # During hyperparameter tuning, create separate wandb runs per trial if enabled + if wandb_project is not None: + tuning_inputs["wandb_project"] = wandb_project + tuning_inputs["split_index"] = split_index + tuning_inputs["wandb_base_config"] = base_wandb_config + if multiprocessing: tuning_inputs["ray_path"] = os.path.abspath(os.path.join(result_path, "raytune")) best_hpams = hpam_tune_raytune(**tuning_inputs) @@ -213,6 +238,9 @@ def drug_response_experiment( print(f"Best hyperparameters: {best_hpams}") print("Training model on full train and validation set to predict test set") + + # Log best hyperparameters to wandb (they will be logged when build_model is called) + # The best hyperparameters will be logged via build_model -> log_hyperparameters # save best hyperparameters as json with open( hpam_save_path, @@ -224,6 +252,25 @@ def drug_response_experiment( train_dataset.add_rows(validation_dataset) # use full train val set data for final training train_dataset.shuffle(random_state=42) + # Initialize wandb for the final training on the full train+validation set + # This happens regardless of whether hyperparameter tuning was performed + if wandb_project is not None: + final_run_name = f"{model_name}" + if drug_id is not None: + final_run_name += f"_{drug_id}" + final_run_name += f"_split_{split_index}_final" + + final_config = { + **base_wandb_config, + "phase": "final_training", + } + model.init_wandb( + project=wandb_project, + config=final_config, + name=final_run_name, + tags=[model_name, test_mode, response_data.dataset_name or "unknown", "final"], + ) + test_dataset = train_and_predict( model=model, hpams=best_hpams, @@ -235,6 +282,24 @@ def drug_response_experiment( model_checkpoint_dir=model_checkpoint_dir, ) + # Log final metrics on test set for all models + # Metrics will be logged as test_RMSE, test_R^2, test_Pearson, etc. + # This happens regardless of whether hyperparameter tuning was performed + if ( + wandb_project is not None + and wandb is not None + and len(test_dataset) > 0 + and test_dataset.predictions is not None + and len(test_dataset.predictions) > 0 + ): + # Ensure wandb run is active before logging metrics + if wandb.run is not None: + model.compute_and_log_final_metrics( + test_dataset, + additional_metrics=[hpam_optimization_metric], + prefix="test_", + ) + for cross_study_dataset in cross_study_datasets: print(f"Cross study prediction on {cross_study_dataset.dataset_name}") cross_study_dataset.remove_nan_responses() @@ -259,6 +324,10 @@ def drug_response_experiment( encoding="utf-8", ) as f: best_hpams = json.load(f) + + # Finish wandb run for this split + if wandb_project is not None: + model.finish_wandb() if not is_baseline: if randomization_mode is not None: print(f"Randomization tests for {model_class.get_model_name()}") @@ -1063,7 +1132,20 @@ def train_and_evaluate( response_transformation=response_transformation, model_checkpoint_dir=model_checkpoint_dir, ) - return evaluate(validation_dataset, metric=[metric]) + + # Compute final metrics using DRPModel helper (always includes R^2 and PCC) + # Add primary metric if it's not already included + additional_metrics = None + if metric not in ["R^2", "Pearson"]: + additional_metrics = [metric] + # Use "val_" prefix to clearly denote validation metrics (val_RMSE, val_R^2, val_Pearson) + results = model.compute_and_log_final_metrics( + validation_dataset, + additional_metrics=additional_metrics, + prefix="val_", + ) + + return results def hpam_tune( @@ -1076,6 +1158,10 @@ def hpam_tune( metric: str = "RMSE", path_data: str = "data", model_checkpoint_dir: str = "TEMPORARY", + *, + split_index: int | None = None, + wandb_project: str | None = None, + wandb_base_config: dict[str, Any] | None = None, ) -> dict: """ Tune the hyperparameters for the given model in an iterative manner. @@ -1089,6 +1175,9 @@ def hpam_tune( :param metric: metric to evaluate which model is the best :param path_data: path to the data directory, e.g., data/ :param model_checkpoint_dir: directory to save model checkpoints + :param split_index: optional CV split index, used for naming wandb runs + :param wandb_project: optional wandb project name; if provided, enables per-trial wandb runs + :param wandb_base_config: optional base config dict to include in each wandb run :returns: best hyperparameters :raises AssertionError: if hpam_set is empty """ @@ -1097,11 +1186,44 @@ def hpam_tune( if len(hpam_set) == 1: return hpam_set[0] + # Mark that we're in hyperparameter tuning phase + # This prevents updating wandb.config during tuning - we'll only log final best hyperparameters + model._in_hyperparameter_tuning = True + best_hyperparameters = None mode = get_mode(metric) best_score = float("inf") if mode == "min" else float("-inf") - for hyperparameter in hpam_set: + for trial_idx, hyperparameter in enumerate(hpam_set): print(f"Training model with hyperparameters: {hyperparameter}") + + # Create a separate wandb run for each hyperparameter trial if enabled + if wandb_project is not None: + trial_run_name = model.get_model_name() + if split_index is not None: + trial_run_name += f"_split_{split_index}" + trial_run_name += f"_trial_{trial_idx}" + + trial_config: dict[str, Any] = {} + if wandb_base_config is not None: + trial_config.update(wandb_base_config) + trial_config.update( + { + "phase": "hyperparameter_tuning", + "trial_index": trial_idx, + "hyperparameters": hyperparameter, + } + ) + + model.init_wandb( + project=wandb_project, + config=trial_config, + name=trial_run_name, + tags=[model.get_model_name(), "hpam_tuning"], + finish_previous=True, + ) + + # During hyperparameter tuning, don't update wandb config via log_hyperparameters + # Trial hyperparameters are stored in wandb.config for each run score = train_and_evaluate( model=model, hpams=hyperparameter, @@ -1114,7 +1236,12 @@ def hpam_tune( model_checkpoint_dir=model_checkpoint_dir, )[metric] + # Note: train_and_evaluate() already logs val_* metrics once via + # DRPModel.compute_and_log_final_metrics(..., prefix="val_"). + # Avoid logging val_{metric} again here (it would create duplicate points). if np.isnan(score): + if model.is_wandb_enabled(): + model.finish_wandb() continue if (mode == "min" and score < best_score) or (mode == "max" and score > best_score): @@ -1122,6 +1249,10 @@ def hpam_tune( best_score = score best_hyperparameters = hyperparameter + # Close this trial's run after all logging is done + if model.is_wandb_enabled(): + model.finish_wandb() + if best_hyperparameters is None: warnings.warn("all hpams lead to NaN respone. using last hpam combination.", stacklevel=2) best_hyperparameters = hyperparameter diff --git a/drevalpy/models/DrugGNN/drug_gnn.py b/drevalpy/models/DrugGNN/drug_gnn.py index ca103ab4..89f0fb7d 100644 --- a/drevalpy/models/DrugGNN/drug_gnn.py +++ b/drevalpy/models/DrugGNN/drug_gnn.py @@ -15,6 +15,7 @@ from ...datasets.dataset import DrugResponseDataset, FeatureDataset from ..drp_model import DRPModel +from ..lightning_metrics_mixin import RegressionMetricsMixin from ..utils import load_and_select_gene_features @@ -86,7 +87,7 @@ def forward(self, drug_graph, cell_features): return out.view(-1) -class DrugGNNModule(pl.LightningModule): +class DrugGNNModule(RegressionMetricsMixin, pl.LightningModule): """The LightningModule for the DrugGNN model.""" def __init__( @@ -115,6 +116,9 @@ def __init__( ) self.criterion = nn.MSELoss() + # Initialize metrics storage for epoch-end R^2 and PCC computation + self._init_metrics_storage() + def forward(self, batch): """Forward pass of the module. @@ -135,6 +139,10 @@ def training_step(self, batch, batch_idx): outputs = self.model(drug_graph, cell_features) loss = self.criterion(outputs, responses) self.log("train_loss", loss, on_step=False, on_epoch=True, batch_size=responses.size(0)) + + # Store predictions and targets for epoch-end metrics via mixin + self._store_predictions(outputs, responses, is_training=True) + return loss def validation_step(self, batch, batch_idx): @@ -148,6 +156,9 @@ def validation_step(self, batch, batch_idx): loss = self.criterion(outputs, responses) self.log("val_loss", loss, on_step=False, on_epoch=True, batch_size=responses.size(0)) + # Store predictions and targets for epoch-end metrics via mixin + self._store_predictions(outputs, responses, is_training=False) + def predict_step(self, batch, batch_idx, dataloader_idx=0): """A single prediction step. @@ -252,6 +263,9 @@ def build_model(self, hyperparameters: dict[str, Any]) -> None: :param hyperparameters: The hyperparameters. """ + # Log hyperparameters to wandb if enabled + self.log_hyperparameters(hyperparameters) + self.hyperparameters = hyperparameters def _loader_kwargs(self) -> dict[str, Any]: @@ -326,11 +340,20 @@ def train( **self._loader_kwargs(), ) + # Set up wandb logger if project is provided + loggers = [] + if self.wandb_project is not None: + from pytorch_lightning.loggers import WandbLogger + + logger = WandbLogger(project=self.wandb_project, log_model=False) + loggers.append(logger) + trainer = pl.Trainer( max_epochs=self.hyperparameters.get("epochs", 100), accelerator="auto", devices="auto", callbacks=[pl.callbacks.EarlyStopping(monitor="val_loss", mode="min", patience=5)] if val_loader else None, + logger=loggers if loggers else True, # Use default logger if no wandb enable_progress_bar=True, log_every_n_steps=int(self.hyperparameters.get("log_every_n_steps", 50)), precision=self.hyperparameters.get("precision", 32), diff --git a/drevalpy/models/MOLIR/molir.py b/drevalpy/models/MOLIR/molir.py index 844bcb23..62e76d87 100644 --- a/drevalpy/models/MOLIR/molir.py +++ b/drevalpy/models/MOLIR/molir.py @@ -64,6 +64,9 @@ def build_model(self, hyperparameters: dict[str, Any]) -> None: :param hyperparameters: Custom hyperparameters for the model, includes mini_batch, layer dimensions (h_dim1, h_dim2, h_dim3), learning_rate, dropout_rate, weight_decay, gamma, epochs, and margin. """ + # Log hyperparameters to wandb if enabled + self.log_hyperparameters(hyperparameters) + self.hyperparameters = hyperparameters self.selector = VarianceFeatureSelector( view="gene_expression", k=hyperparameters.get("n_gene_expression_features", 1000) @@ -125,6 +128,7 @@ def train( cell_line_input=cell_line_input, output_earlystopping=output_earlystopping, model_checkpoint_dir=model_checkpoint_dir, + wandb_project=self.wandb_project, ) else: print(f"Not enough training data provided ({len(output)}), will predict on randomly initialized model.") diff --git a/drevalpy/models/MOLIR/utils.py b/drevalpy/models/MOLIR/utils.py index f83eb5e2..1f4401f2 100644 --- a/drevalpy/models/MOLIR/utils.py +++ b/drevalpy/models/MOLIR/utils.py @@ -18,7 +18,9 @@ from torch.utils.data import DataLoader, Dataset from drevalpy.datasets.dataset import DrugResponseDataset, FeatureDataset -from drevalpy.models.drp_model import DRPModel + +from ..drp_model import DRPModel +from ..lightning_metrics_mixin import RegressionMetricsMixin class RegressionDataset(Dataset): @@ -313,7 +315,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.regressor(x) -class MOLIModel(pl.LightningModule): +class MOLIModel(RegressionMetricsMixin, pl.LightningModule): """ PyTorch Lightning module for the MOLIR model. @@ -363,6 +365,9 @@ def __init__( self.cna_encoder = MOLIEncoder(input_dim_cnv, self.h_dim3, self.dropout_rate) self.regressor = MOLIRegressor(self.h_dim1 + self.h_dim2 + self.h_dim3, self.dropout_rate) + # Initialize metrics storage for epoch-end R^2 and PCC computation + self._init_metrics_storage() + def fit( self, output_train: DrugResponseDataset, @@ -370,6 +375,7 @@ def fit( output_earlystopping: DrugResponseDataset | None = None, patience: int = 5, model_checkpoint_dir: str = "checkpoints", + wandb_project: str | None = None, ) -> None: """ Trains the MOLIR model. @@ -377,11 +383,14 @@ def fit( First, the ranges for the triplet loss are determined using the standard deviation of the training responses. Then, the training and validation data loaders are created. The model is trained using the Lightning Trainer with an early stopping callback and patience of 5. + :param output_train: training dataset containing the response output :param cell_line_input: feature dataset containing the omics data of the cell lines :param output_earlystopping: early stopping dataset :param patience: for early stopping :param model_checkpoint_dir: directory to save the model checkpoints + :param wandb_project: optional wandb project name for logging. If provided, uses WandbLogger + for PyTorch Lightning training. """ self.positive_range, self.negative_range = make_ranges(output_train) @@ -407,9 +416,18 @@ def fit( save_weights_only=True, ) + # Set up wandb logger if project is provided + loggers = [] + if wandb_project is not None: + from pytorch_lightning.loggers import WandbLogger + + logger = WandbLogger(project=wandb_project, log_model=False) + loggers.append(logger) + # Initialize the Lightning trainer trainer = pl.Trainer( max_epochs=self.epochs, + logger=loggers if loggers else True, # Use default logger if no wandb callbacks=[ early_stop_callback, self.checkpoint_callback, @@ -521,6 +539,10 @@ def training_step(self, batch: list[torch.Tensor], batch_idx: int) -> torch.Tens # Compute loss loss = self._compute_loss(z, preds, response) self.log("train_loss", loss, on_step=False, on_epoch=True, prog_bar=True) + + # Store predictions and targets for epoch-end metrics via mixin + self._store_predictions(preds, response, is_training=True) + return loss def validation_step(self, batch: list[torch.Tensor], batch_idx: int) -> torch.Tensor: @@ -542,6 +564,10 @@ def validation_step(self, batch: list[torch.Tensor], batch_idx: int) -> torch.Te # Compute loss val_loss = self._compute_loss(z, preds, response) self.log("val_loss", val_loss, on_step=False, on_epoch=True, prog_bar=True) + + # Store predictions and targets for epoch-end metrics via mixin + self._store_predictions(preds, response, is_training=False) + return val_loss def configure_optimizers(self) -> torch.optim.Optimizer: diff --git a/drevalpy/models/PharmaFormer/pharmaformer.py b/drevalpy/models/PharmaFormer/pharmaformer.py index b6eaf647..d1561d2d 100644 --- a/drevalpy/models/PharmaFormer/pharmaformer.py +++ b/drevalpy/models/PharmaFormer/pharmaformer.py @@ -112,6 +112,9 @@ def build_model(self, hyperparameters: dict[str, Any]) -> None: :param hyperparameters: Model hyperparameters including gene_hidden_size, drug_hidden_size, feature_dim, nhead, num_layers, dim_feedforward, dropout, batch_size, lr, epochs, patience """ + # Log hyperparameters to wandb if enabled + self.log_hyperparameters(hyperparameters) + self.hyperparameters = hyperparameters # Model will be built in train() when we know the input dimensions @@ -217,6 +220,8 @@ def train( self.model.train() epoch_loss = 0.0 batch_count = 0 + train_predictions = [] + train_targets = [] # Training phase for gene_inputs, smiles_inputs, targets in train_loader: @@ -236,13 +241,31 @@ def train( epoch_loss += loss.detach().item() batch_count += 1 + # Store predictions and targets for R^2 and PCC computation + train_predictions.append(outputs.squeeze().detach().cpu().numpy()) + train_targets.append(targets.detach().cpu().numpy()) + epoch_loss /= batch_count print(f"PharmaFormer: Epoch [{epoch + 1}/{self.hyperparameters['epochs']}] Training Loss: {epoch_loss:.4f}") + # Compute and log training R^2 and PCC using DRPModel helper + train_metrics = {"train_loss": epoch_loss} + if len(train_predictions) > 0: + all_train_preds = np.concatenate(train_predictions) + all_train_targets = np.concatenate(train_targets) + perf_metrics = self.compute_performance_metrics(all_train_preds, all_train_targets, prefix="train_") + train_metrics.update(perf_metrics) + + # Log training metrics to wandb if enabled + if self.is_wandb_enabled(): + self.log_metrics(train_metrics, step=epoch) + # Validation phase for early stopping self.model.eval() val_loss = 0.0 val_batch_count = 0 + val_predictions = [] + val_targets = [] with torch.no_grad(): for gene_inputs, smiles_inputs, targets in early_stopping_loader: gene_inputs = gene_inputs.to(self.DEVICE) @@ -255,9 +278,25 @@ def train( val_loss += loss.item() val_batch_count += 1 + # Store predictions and targets for R^2 and PCC computation + val_predictions.append(outputs.squeeze().detach().cpu().numpy()) + val_targets.append(targets.detach().cpu().numpy()) + val_loss /= val_batch_count print(f"PharmaFormer: Epoch [{epoch + 1}/{self.hyperparameters['epochs']}] Validation Loss: {val_loss:.4f}") + # Compute and log validation R^2 and PCC using DRPModel helper + val_metrics = {"val_loss": val_loss} + if len(val_predictions) > 0: + all_val_preds = np.concatenate(val_predictions) + all_val_targets = np.concatenate(val_targets) + perf_metrics = self.compute_performance_metrics(all_val_preds, all_val_targets, prefix="val_") + val_metrics.update(perf_metrics) + + # Log validation metrics to wandb if enabled + if self.is_wandb_enabled(): + self.log_metrics(val_metrics, step=epoch) + # Checkpointing: Save the best model if val_loss < best_val_loss: best_val_loss = val_loss diff --git a/drevalpy/models/SimpleNeuralNetwork/multiomics_neural_network.py b/drevalpy/models/SimpleNeuralNetwork/multiomics_neural_network.py index ae58b172..39b778a1 100644 --- a/drevalpy/models/SimpleNeuralNetwork/multiomics_neural_network.py +++ b/drevalpy/models/SimpleNeuralNetwork/multiomics_neural_network.py @@ -37,7 +37,6 @@ def __init__(self): """ super().__init__() self.model = None - self.hyperparameters = None self.methylation_scaler = StandardScaler() self.methylation_pca = None self.pca_ncomp = 100 @@ -62,6 +61,9 @@ def build_model(self, hyperparameters: dict): :param hyperparameters: dictionary containing the hyperparameters units_per_layer, dropout_prob, and methylation_pca_components. """ + # Log hyperparameters to wandb if enabled + self.log_hyperparameters(hyperparameters) + self.hyperparameters = hyperparameters self.pca_ncomp = hyperparameters["methylation_pca_components"] @@ -139,6 +141,7 @@ def train( patience=5, num_workers=1, model_checkpoint_dir=model_checkpoint_dir, + wandb_project=self.wandb_project, ) def predict( diff --git a/drevalpy/models/SimpleNeuralNetwork/simple_neural_network.py b/drevalpy/models/SimpleNeuralNetwork/simple_neural_network.py index 43f2eb49..077e69ff 100644 --- a/drevalpy/models/SimpleNeuralNetwork/simple_neural_network.py +++ b/drevalpy/models/SimpleNeuralNetwork/simple_neural_network.py @@ -33,7 +33,6 @@ def __init__(self): """ super().__init__() self.model = None - self.hyperparameters = None self.gene_expression_scaler = StandardScaler() @classmethod @@ -51,6 +50,9 @@ def build_model(self, hyperparameters: dict): :param hyperparameters: includes units_per_layer and dropout_prob. """ + # Log hyperparameters to wandb if enabled + self.log_hyperparameters(hyperparameters) + self.hyperparameters = hyperparameters self.hyperparameters.setdefault("input_dim_gex", None) self.hyperparameters.setdefault("input_dim_fp", None) @@ -128,6 +130,7 @@ def train( patience=5, num_workers=1 if platform.system() == "Windows" else 8, model_checkpoint_dir=model_checkpoint_dir, + wandb_project=self.wandb_project, ) def predict( diff --git a/drevalpy/models/SimpleNeuralNetwork/utils.py b/drevalpy/models/SimpleNeuralNetwork/utils.py index 26b2790b..b388bcab 100644 --- a/drevalpy/models/SimpleNeuralNetwork/utils.py +++ b/drevalpy/models/SimpleNeuralNetwork/utils.py @@ -7,11 +7,14 @@ import pytorch_lightning as pl import torch from pytorch_lightning.callbacks import EarlyStopping, TQDMProgressBar +from pytorch_lightning.loggers import WandbLogger from torch import nn from torch.utils.data import DataLoader, Dataset from drevalpy.datasets.dataset import DrugResponseDataset, FeatureDataset +from ..lightning_metrics_mixin import RegressionMetricsMixin + class RegressionDataset(Dataset): """Dataset for regression tasks for the data loader.""" @@ -92,7 +95,7 @@ def __len__(self): return len(self.output.response) -class FeedForwardNetwork(pl.LightningModule): +class FeedForwardNetwork(RegressionMetricsMixin, pl.LightningModule): """Feed forward neural network for regression tasks with basic architecture.""" def __init__(self, hyperparameters: dict[str, int | float | list[int]], input_dim: int) -> None: @@ -137,6 +140,9 @@ def __init__(self, hyperparameters: dict[str, int | float | list[int]], input_di if self.dropout_prob is not None: self.dropout_layer = nn.Dropout(p=self.dropout_prob) + # Initialize metrics storage for epoch-end R^2 and PCC computation + self._init_metrics_storage() + def fit( self, output_train: DrugResponseDataset, @@ -150,11 +156,13 @@ def fit( patience=5, num_workers: int = 2, model_checkpoint_dir: str = "checkpoints", + wandb_project: str | None = None, ) -> None: """ Fits the model. First, the data is loaded using a DataLoader. Then, the model is trained using the Lightning Trainer. + :param output_train: Response values for training :param cell_line_input: Cell line features :param drug_input: Drug features @@ -166,6 +174,8 @@ def fit( :param patience: patience for early stopping, default is 5 :param num_workers: number of workers for the DataLoader, default is 2 :param model_checkpoint_dir: directory to save the model checkpoints + :param wandb_project: optional wandb project name for logging. If provided, uses WandbLogger + for PyTorch Lightning training. :raises ValueError: if drug_input is missing """ if trainer_params is None: @@ -233,6 +243,14 @@ def fit( trainer_params_copy = trainer_params.copy() del trainer_params_copy["progress_bar_refresh_rate"] + # Set up wandb logger if project is provided + # Note: This method receives wandb_project as parameter, but the model instance + # should have wandb already initialized via DRPModel.init_wandb() + loggers = [] + if wandb_project is not None: + logger = WandbLogger(project=wandb_project, log_model=False) + loggers.append(logger) + # Initialize the Lightning trainer trainer = pl.Trainer( callbacks=[ @@ -240,6 +258,7 @@ def fit( self.checkpoint_callback, progress_bar, ], + logger=loggers if loggers else True, # Use default logger if no wandb default_root_dir=model_checkpoint_dir, devices=1, **trainer_params_copy, @@ -287,6 +306,10 @@ def _forward_loss_and_log(self, x, y, log_as: str): y_pred = self.forward(x) result = self.loss(y_pred, y) self.log(log_as, result, on_step=True, on_epoch=True, prog_bar=True) + + # Store predictions and targets for epoch-end metrics via mixin + self._store_predictions(y_pred, y, is_training=(log_as == "train_loss")) + return result def training_step(self, batch): diff --git a/drevalpy/models/SuperFELTR/superfeltr.py b/drevalpy/models/SuperFELTR/superfeltr.py index 43268f4f..59da1e1d 100644 --- a/drevalpy/models/SuperFELTR/superfeltr.py +++ b/drevalpy/models/SuperFELTR/superfeltr.py @@ -78,6 +78,9 @@ def build_model(self, hyperparameters) -> None: dropout_rate, weight_decay, out_dim_expr_encoder, out_dim_mutation_encoder, out_dim_cnv_encoder, epochs, variance thresholds for gene expression, mutation, and copy number variation, margin, and learning rate. """ + # Log hyperparameters to wandb if enabled + self.log_hyperparameters(hyperparameters) + self.hyperparameters = hyperparameters n_features = hyperparameters.get("n_features_per_view", 1000) @@ -133,6 +136,7 @@ def train( output_earlystopping=output_earlystopping, patience=5, model_checkpoint_dir=model_checkpoint_dir, + wandb_project=self.wandb_project, ) encoders[omic_type] = SuperFELTEncoder.load_from_checkpoint(best_checkpoint.best_model_path) else: @@ -165,6 +169,7 @@ def train( output_earlystopping=output_earlystopping, patience=5, model_checkpoint_dir=model_checkpoint_dir, + wandb_project=self.wandb_project, ) else: print("Not enough training data provided for SuperFELTR Regressor. Using random initialization.") diff --git a/drevalpy/models/SuperFELTR/utils.py b/drevalpy/models/SuperFELTR/utils.py index 4816302d..de507331 100644 --- a/drevalpy/models/SuperFELTR/utils.py +++ b/drevalpy/models/SuperFELTR/utils.py @@ -10,6 +10,7 @@ from torch import nn from ...datasets.dataset import DrugResponseDataset, FeatureDataset +from ..lightning_metrics_mixin import RegressionMetricsMixin from ..MOLIR.utils import create_dataset_and_loaders, generate_triplets_indices @@ -164,7 +165,7 @@ def validation_step(self, batch: list[torch.Tensor], batch_idx: int) -> torch.Te return triplet_loss -class SuperFELTRegressor(pl.LightningModule): +class SuperFELTRegressor(RegressionMetricsMixin, pl.LightningModule): """ SuperFELT regressor definition. @@ -204,6 +205,9 @@ def __init__( encoder.eval() self.regression_loss = nn.MSELoss() + # Initialize metrics storage for epoch-end R^2 and PCC computation + self._init_metrics_storage() + def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass of the SuperFELTRegressor. @@ -268,6 +272,10 @@ def training_step(self, batch: list[torch.Tensor], batch_idx: int) -> torch.Tens pred = self.regressor(encoded) loss = self.regression_loss(pred.squeeze(), response) self.log("train_loss", loss, on_step=False, on_epoch=True, prog_bar=True) + + # Store predictions and targets for epoch-end metrics via mixin + self._store_predictions(pred.squeeze(), response, is_training=True) + return loss def validation_step(self, batch: list[torch.Tensor], batch_idx: int) -> torch.Tensor: @@ -283,6 +291,10 @@ def validation_step(self, batch: list[torch.Tensor], batch_idx: int) -> torch.Te pred = self.regressor(encoded) loss = self.regression_loss(pred.squeeze(), response) self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True) + + # Store predictions and targets for epoch-end metrics via mixin + self._store_predictions(pred.squeeze(), response, is_training=False) + return loss @@ -294,11 +306,13 @@ def train_superfeltr_model( output_earlystopping: DrugResponseDataset | None = None, patience: int = 5, model_checkpoint_dir: str = "superfeltr_checkpoints", + wandb_project: str | None = None, ) -> pl.callbacks.ModelCheckpoint: """ Trains one encoder or the regressor. First, the dataset and loaders are created. Then, the model is trained with the Lightning trainer. + :param model: either one of the encoders or the regressor :param hpams: hyperparameters for the model :param output_train: response data for training @@ -306,6 +320,8 @@ def train_superfeltr_model( :param output_earlystopping: response data for early stopping :param patience: for early stopping, defaults to 5 :param model_checkpoint_dir: directory to save the model checkpoints + :param wandb_project: optional wandb project name for logging. If provided, uses WandbLogger + for PyTorch Lightning training. :returns: checkpoint callback with the best model :raises ValueError: if the epochs and mini_batch are not integers """ @@ -329,9 +345,18 @@ def train_superfeltr_model( mode="min", save_top_k=1, ) + # Set up wandb logger if project is provided + loggers = [] + if wandb_project is not None: + from pytorch_lightning.loggers import WandbLogger + + logger = WandbLogger(project=wandb_project, log_model=False) + loggers.append(logger) + # Initialize the Lightning trainer trainer = pl.Trainer( max_epochs=hpams["epochs"], + logger=loggers if loggers else True, # Use default logger if no wandb callbacks=[ early_stop_callback, checkpoint_callback, diff --git a/drevalpy/models/drp_model.py b/drevalpy/models/drp_model.py index 8c60d39e..7599be2e 100644 --- a/drevalpy/models/drp_model.py +++ b/drevalpy/models/drp_model.py @@ -9,13 +9,16 @@ import inspect import os from abc import ABC, abstractmethod +from contextlib import suppress from typing import Any import numpy as np +import wandb import yaml from sklearn.model_selection import ParameterGrid from ..datasets.dataset import DrugResponseDataset, FeatureDataset +from ..evaluation import AVAILABLE_METRICS, evaluate from ..pipeline_function import pipeline_function @@ -33,6 +36,220 @@ class DRPModel(ABC): # Then, the model is trained per drug is_single_drug_model = False + def __init__(self): + """Initialize the DRPModel instance.""" + self.wandb_project: str | None = None + self.wandb_run: Any = None + self.wandb_config: dict[str, Any] | None = None + self.hyperparameters: dict[str, Any] = {} + self._in_hyperparameter_tuning: bool = False # Flag to track if we're in hyperparameter tuning + + def init_wandb( + self, + project: str, + config: dict[str, Any] | None = None, + name: str | None = None, + tags: list[str] | None = None, + finish_previous: bool = True, + ) -> None: + """ + Initialize wandb logging for this model instance. + + :param project: wandb project name + :param config: dictionary of configuration to log (e.g., hyperparameters, dataset info) + :param name: run name (defaults to model name) + :param tags: list of tags for the run + :param finish_previous: whether to finish any existing wandb run before starting a new one + """ + self.wandb_project = project + self.wandb_config = config or {} + + if finish_previous: + wandb.finish() + + run_name = name or self.get_model_name() + wandb.init( + project=project, + config=self.wandb_config, + name=run_name, + tags=tags, + ) + self.wandb_run = wandb.run + + # Define common metric summaries so final/best values are tracked automatically + with suppress(Exception): # pragma: no cover - wandb may not support define_metric in all contexts + wandb.define_metric("epoch", summary="max") + wandb.define_metric("train_loss", summary="min") + wandb.define_metric("val_loss", summary="min") + wandb.define_metric("train_R^2", summary="max") + wandb.define_metric("val_R^2", summary="max") + wandb.define_metric("train_Pearson", summary="max") + wandb.define_metric("val_Pearson", summary="max") + + def log_hyperparameters(self, hyperparameters: dict[str, Any]) -> None: + """ + Log hyperparameters to wandb. + + This method is called automatically by build_model when wandb is enabled. + Subclasses can override this to add additional hyperparameter logging. + + During hyperparameter tuning, config updates are skipped to avoid overwriting. + Only the final best hyperparameters are logged to wandb.config. + + :param hyperparameters: dictionary of hyperparameters to log + """ + if not self.is_wandb_enabled(): + return + + self.hyperparameters = hyperparameters + # Only update wandb.config if we're not in hyperparameter tuning phase + # During tuning, trial hyperparameters are stored in config.hyperparameters + # Nest hyperparameters under a single key to prevent them from appearing as separate table columns + if not self._in_hyperparameter_tuning: + wandb.config.update({"hyperparameters": hyperparameters}) + + def is_wandb_enabled(self) -> bool: + """ + Check if wandb logging is enabled for this model instance. + + :returns: True if wandb is initialized and active, False otherwise + """ + # Check both self.wandb_run and wandb.run to handle cases where + # PyTorch Lightning's WandbLogger might have affected the run state + return self.wandb_project is not None and (self.wandb_run is not None or wandb.run is not None) + + def get_wandb_logger(self) -> Any | None: + """ + Get a WandbLogger for PyTorch Lightning integration. + + This method creates a WandbLogger that uses the existing wandb run. + Returns None if wandb is not enabled. + + :returns: WandbLogger instance or None + """ + if not self.is_wandb_enabled() or self.wandb_project is None: + return None + + from pytorch_lightning.loggers import WandbLogger + + return WandbLogger(project=self.wandb_project, log_model=False) + + def log_metrics(self, metrics: dict[str, float], step: int | None = None) -> None: + """ + Log metrics to wandb. + + Subclasses can call this method to log custom metrics during training. + + :param metrics: dictionary of metric names to values + :param step: optional step number for the metrics + """ + if not self.is_wandb_enabled(): + return + + if step is not None: + wandb.log(metrics, step=step) + else: + wandb.log(metrics) + + def compute_performance_metrics( + self, predictions: np.ndarray, targets: np.ndarray, prefix: str = "" + ) -> dict[str, float]: + """ + Compute R^2 and PCC metrics from predictions and targets. + + This is a convenience method for computing performance metrics consistently + across all models. It always computes R^2 and PCC in addition to any other + metrics that may be needed. + + :param predictions: model predictions array + :param targets: ground truth targets array + :param prefix: optional prefix for metric keys (e.g., ``val_``, ``train_``) + :returns: dictionary of computed metrics with optional prefix + """ + try: + # Always compute R^2 and PCC + metrics = { + "R^2": AVAILABLE_METRICS["R^2"](y_pred=predictions, y_true=targets), + "Pearson": AVAILABLE_METRICS["Pearson"](y_pred=predictions, y_true=targets), + } + + # Add prefix if provided + if prefix: + metrics = {f"{prefix}{k}": v for k, v in metrics.items()} + + return metrics + except Exception: + # Return empty dict if computation fails + return {} + + def compute_and_log_final_metrics( + self, + dataset: DrugResponseDataset, + additional_metrics: list[str] | None = None, + prefix: str = "val_", + ) -> dict[str, float]: + r""" + Compute final performance metrics from a dataset and log them to wandb. + + This method computes R^2 and PCC (always), plus any additional metrics specified. + The metrics are both logged to wandb history and stored in the run summary. + + :param dataset: DrugResponseDataset with predictions and response + :param additional_metrics: optional list of additional metrics to compute (e.g., ["RMSE", "MAE"]) + :param prefix: metric name prefix indicating which split the metrics belong to + (for example, use ``"val"`` for validation and ``"test"`` for test metrics) + :returns: dictionary of computed metrics + """ + if dataset.predictions is None: + return {} + + # Always compute R^2 and PCC + metrics_to_compute = ["R^2", "Pearson"] + if additional_metrics: + metrics_to_compute.extend(additional_metrics) + + results = evaluate(dataset, metric=metrics_to_compute) + + # Log to wandb if enabled + # Check both is_wandb_enabled() and wandb.run to ensure the run is active + if self.is_wandb_enabled() and wandb.run is not None: + # Prefix indicates which split the metrics belong to (e.g. \"val\" or \"test\") + wandb_metrics = {f"{prefix}{k}": v for k, v in results.items()} + # Log to summary only (not history) since these are final metrics logged once + self.log_final_metrics(wandb_metrics) + + return results + + def log_final_metrics(self, metrics: dict[str, float]) -> None: + """ + Store final metrics in the wandb run summary. + + This method is used to record final metrics (e.g., after validation + or after a hyperparameter trial). Metrics are stored with their original + names (e.g., val_RMSE, test_RMSE) without additional prefixes. + + :param metrics: dictionary of metric names to values + """ + if not self.is_wandb_enabled(): + return + + # Ensure wandb.run is active before logging + if wandb.run is None: + return + + for key, value in metrics.items(): + # Store metrics directly without adding "final_" prefix + # The prefix (val_ or test_) already indicates the split + wandb.run.summary[key] = value + + def finish_wandb(self) -> None: + """Finish the wandb run. Call this when training is complete.""" + if not self.is_wandb_enabled(): + return + + wandb.finish() + self.wandb_run = None + @classmethod @abstractmethod @pipeline_function @@ -97,11 +314,16 @@ def build_model(self, hyperparameters: dict[str, Any]) -> None: """ Builds the model, for models that use hyperparameters. + Subclasses should call self.log_hyperparameters(hyperparameters) at the beginning + of this method to ensure hyperparameters are logged to wandb if enabled. + :param hyperparameters: hyperparameters for the model Example:: - self.model = ElasticNet(alpha=hyperparameters["alpha"], l1_ratio=hyperparameters["l1_ratio"]) + def build_model(self, hyperparameters: dict[str, Any]) -> None: + self.log_hyperparameters(hyperparameters) # Log to wandb + self.model = ElasticNet(alpha=hyperparameters["alpha"], l1_ratio=hyperparameters["l1_ratio"]) """ @pipeline_function diff --git a/drevalpy/models/lightning_metrics_mixin.py b/drevalpy/models/lightning_metrics_mixin.py new file mode 100644 index 00000000..6139de75 --- /dev/null +++ b/drevalpy/models/lightning_metrics_mixin.py @@ -0,0 +1,112 @@ +"""Mixin class for PyTorch Lightning modules to add R^2 and PCC metrics logging.""" + +import torch + +from ..evaluation import AVAILABLE_METRICS + + +class RegressionMetricsMixin: + """ + Mixin class for PyTorch Lightning modules to automatically compute and log R^2 and PCC metrics. + + This mixin provides: + - Storage for predictions and targets during training/validation steps + - Automatic computation of R^2 and PCC at epoch end + - Consistent logging to wandb via PyTorch Lightning's logging system + + Usage: + class MyModel(RegressionMetricsMixin, pl.LightningModule): + def __init__(self, ...): + super().__init__() + # Initialize your model... + self._init_metrics_storage() # Call this in __init__ + + def training_step(self, batch, batch_idx): + # ... your training logic ... + predictions = self.forward(...) + loss = self.criterion(predictions, targets) + self.log("train_loss", loss, ...) + self._store_predictions(predictions, targets, is_training=True) + return loss + + def validation_step(self, batch, batch_idx): + # ... your validation logic ... + predictions = self.forward(...) + loss = self.criterion(predictions, targets) + self.log("val_loss", loss, ...) + self._store_predictions(predictions, targets, is_training=False) + return loss + """ + + def _init_metrics_storage(self) -> None: + """Initialize storage for predictions and targets.""" + self.train_predictions: list[torch.Tensor] = [] + self.train_targets: list[torch.Tensor] = [] + self.val_predictions: list[torch.Tensor] = [] + self.val_targets: list[torch.Tensor] = [] + + def _store_predictions(self, predictions: torch.Tensor, targets: torch.Tensor, is_training: bool = True) -> None: + """ + Store predictions and targets for epoch-end metric computation. + + :param predictions: model predictions tensor + :param targets: ground truth targets tensor + :param is_training: whether this is from training (True) or validation (False) + """ + # Ensure tensors are detached and on CPU for numpy conversion + preds_cpu = predictions.detach().cpu() + targets_cpu = targets.detach().cpu() + + if is_training: + self.train_predictions.append(preds_cpu) + self.train_targets.append(targets_cpu) + else: + self.val_predictions.append(preds_cpu) + self.val_targets.append(targets_cpu) + + def _compute_epoch_metrics(self, predictions: list[torch.Tensor], targets: list[torch.Tensor]) -> dict[str, float]: + """ + Compute R^2 and PCC metrics from stored predictions and targets. + + :param predictions: list of prediction tensors from the epoch + :param targets: list of target tensors from the epoch + :returns: dictionary with "R^2" and "Pearson" keys, or empty dict if computation fails + """ + if len(predictions) == 0: + return {} + + try: + # Concatenate all predictions and targets from the epoch + all_preds = torch.cat(predictions).numpy() + all_targets = torch.cat(targets).numpy() + + # Compute metrics + r2 = AVAILABLE_METRICS["R^2"](y_pred=all_preds, y_true=all_targets) + pcc = AVAILABLE_METRICS["Pearson"](y_pred=all_preds, y_true=all_targets) + + return {"R^2": r2, "Pearson": pcc} + except Exception: + # If computation fails (e.g., NaN values, insufficient data), return empty dict + return {} + + def on_train_epoch_end(self) -> None: + """ + Epoch-end hook for training. + + Intentionally does NOT log R^2/Pearson per epoch anymore. We only keep + these buffers to allow optional debugging or future extensions. + """ + # Clear stored predictions/targets for next epoch + self.train_predictions.clear() + self.train_targets.clear() + + def on_validation_epoch_end(self) -> None: + """ + Epoch-end hook for validation. + + Intentionally does NOT log R^2/Pearson per epoch anymore. Final metrics + are logged once at the end via DRPModel.compute_and_log_final_metrics(). + """ + # Clear stored predictions/targets for next epoch + self.val_predictions.clear() + self.val_targets.clear() diff --git a/drevalpy/utils.py b/drevalpy/utils.py index ab89b49c..46e3d9d8 100644 --- a/drevalpy/utils.py +++ b/drevalpy/utils.py @@ -163,6 +163,14 @@ def get_parser() -> argparse.ArgumentParser: default="RMSE", help=f"Metric for hyperparameter tuning choose from {list(AVAILABLE_METRICS.keys())} " f"Default is RMSE.", ) + parser.add_argument( + "--wandb_project", + type=str, + default=None, + help=( + "Optional Weights & Biases project name. " "If provided, enables wandb logging for all DRPModel instances." + ), + ) parser.add_argument( "--n_cv_splits", type=int, @@ -348,6 +356,7 @@ def main(args) -> None: model_checkpoint_dir=args.model_checkpoint_dir, hyperparameter_tuning=not args.no_hyperparameter_tuning, final_model_on_full_data=args.final_model_on_full_data, + wandb_project=args.wandb_project, ) diff --git a/poetry.lock b/poetry.lock index e2292254..6d0fed1c 100644 --- a/poetry.lock +++ b/poetry.lock @@ -189,7 +189,6 @@ description = "Reusable constraint types to use with typing.Annotated" optional = false python-versions = ">=3.8" groups = ["main"] -markers = "extra == \"multiprocessing\"" files = [ {file = "annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53"}, {file = "annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89"}, @@ -688,7 +687,6 @@ files = [ {file = "click-8.3.1-py3-none-any.whl", hash = "sha256:981153a64e25f12d547d3426c367a4857371575ee7ad18df2a6183ab0545b2a6"}, {file = "click-8.3.1.tar.gz", hash = "sha256:12ff4785d337a1bb490bb7e9c2b1ee5da3112e94a8622f26a6c77f5d2fc6842a"}, ] -markers = {main = "extra == \"multiprocessing\""} [package.dependencies] colorama = {version = "*", markers = "platform_system == \"Windows\""} @@ -1452,6 +1450,40 @@ test-downstream = ["aiobotocore (>=2.5.4,<3.0.0)", "dask[dataframe,test]", "moto test-full = ["adlfs", "aiohttp (!=4.0.0a0,!=4.0.0a1)", "backports-zstd ; python_version < \"3.14\"", "cloudpickle", "dask", "distributed", "dropbox", "dropboxdrivefs", "fastparquet", "fusepy", "gcsfs", "jinja2", "kerchunk", "libarchive-c", "lz4", "notebook", "numpy", "ocifs", "pandas", "panel", "paramiko", "pyarrow", "pyarrow (>=1)", "pyftpdlib", "pygit2", "pytest", "pytest-asyncio (!=0.22.0)", "pytest-benchmark", "pytest-cov", "pytest-mock", "pytest-recording", "pytest-rerunfailures", "python-snappy", "requests", "smbprotocol", "tqdm", "urllib3", "zarr"] tqdm = ["tqdm"] +[[package]] +name = "gitdb" +version = "4.0.12" +description = "Git Object Database" +optional = false +python-versions = ">=3.7" +groups = ["main"] +files = [ + {file = "gitdb-4.0.12-py3-none-any.whl", hash = "sha256:67073e15955400952c6565cc3e707c554a4eea2e428946f7a4c162fab9bd9bcf"}, + {file = "gitdb-4.0.12.tar.gz", hash = "sha256:5ef71f855d191a3326fcfbc0d5da835f26b13fbcba60c32c21091c349ffdb571"}, +] + +[package.dependencies] +smmap = ">=3.0.1,<6" + +[[package]] +name = "gitpython" +version = "3.1.46" +description = "GitPython is a Python library used to interact with Git repositories" +optional = false +python-versions = ">=3.7" +groups = ["main"] +files = [ + {file = "gitpython-3.1.46-py3-none-any.whl", hash = "sha256:79812ed143d9d25b6d176a10bb511de0f9c67b1fa641d82097b0ab90398a2058"}, + {file = "gitpython-3.1.46.tar.gz", hash = "sha256:400124c7d0ef4ea03f7310ac2fbf7151e09ff97f2a3288d64a440c584a29c37f"}, +] + +[package.dependencies] +gitdb = ">=4.0.1,<5" + +[package.extras] +doc = ["sphinx (>=7.1.2,<7.2)", "sphinx-autodoc-typehints", "sphinx_rtd_theme"] +test = ["coverage[toml]", "ddt (>=1.1.1,!=1.4.3)", "mock ; python_version < \"3.8\"", "mypy (==1.18.2) ; python_version >= \"3.9\"", "pre-commit", "pytest (>=7.3.1)", "pytest-cov", "pytest-instafail", "pytest-mock", "pytest-sugar", "typing-extensions ; python_version < \"3.11\""] + [[package]] name = "h11" version = "0.16.0" @@ -3471,7 +3503,6 @@ description = "" optional = false python-versions = ">=3.9" groups = ["main"] -markers = "extra == \"multiprocessing\"" files = [ {file = "protobuf-6.33.4-cp310-abi3-win32.whl", hash = "sha256:918966612c8232fc6c24c78e1cd89784307f5814ad7506c308ee3cf86662850d"}, {file = "protobuf-6.33.4-cp310-abi3-win_amd64.whl", hash = "sha256:8f11ffae31ec67fc2554c2ef891dcb561dae9a2a3ed941f9e134c2db06657dbc"}, @@ -3613,7 +3644,6 @@ description = "Data validation using Python type hints" optional = false python-versions = ">=3.9" groups = ["main"] -markers = "extra == \"multiprocessing\"" files = [ {file = "pydantic-2.12.5-py3-none-any.whl", hash = "sha256:e561593fccf61e8a20fc46dfc2dfe075b8be7d0188df33f221ad1f0139180f9d"}, {file = "pydantic-2.12.5.tar.gz", hash = "sha256:4d351024c75c0f085a9febbb665ce8c0c6ec5d30e903bdb6394b7ede26aebb49"}, @@ -3636,7 +3666,6 @@ description = "Core functionality for Pydantic validation and serialization" optional = false python-versions = ">=3.9" groups = ["main"] -markers = "extra == \"multiprocessing\"" files = [ {file = "pydantic_core-2.41.5-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:77b63866ca88d804225eaa4af3e664c5faf3568cea95360d21f4725ab6e07146"}, {file = "pydantic_core-2.41.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:dfa8a0c812ac681395907e71e1274819dec685fec28273a28905df579ef137e2"}, @@ -4660,6 +4689,69 @@ files = [ cryptography = ">=2.0" jeepney = ">=0.6" +[[package]] +name = "sentry-sdk" +version = "2.49.0" +description = "Python client for Sentry (https://sentry.io)" +optional = false +python-versions = ">=3.6" +groups = ["main"] +files = [ + {file = "sentry_sdk-2.49.0-py2.py3-none-any.whl", hash = "sha256:6ea78499133874445a20fe9c826c9e960070abeb7ae0cdf930314ab16bb97aa0"}, + {file = "sentry_sdk-2.49.0.tar.gz", hash = "sha256:c1878599cde410d481c04ef50ee3aedd4f600e4d0d253f4763041e468b332c30"}, +] + +[package.dependencies] +certifi = "*" +urllib3 = ">=1.26.11" + +[package.extras] +aiohttp = ["aiohttp (>=3.5)"] +anthropic = ["anthropic (>=0.16)"] +arq = ["arq (>=0.23)"] +asyncpg = ["asyncpg (>=0.23)"] +beam = ["apache-beam (>=2.12)"] +bottle = ["bottle (>=0.12.13)"] +celery = ["celery (>=3)"] +celery-redbeat = ["celery-redbeat (>=2)"] +chalice = ["chalice (>=1.16.0)"] +clickhouse-driver = ["clickhouse-driver (>=0.2.0)"] +django = ["django (>=1.8)"] +falcon = ["falcon (>=1.4)"] +fastapi = ["fastapi (>=0.79.0)"] +flask = ["blinker (>=1.1)", "flask (>=0.11)", "markupsafe"] +google-genai = ["google-genai (>=1.29.0)"] +grpcio = ["grpcio (>=1.21.1)", "protobuf (>=3.8.0)"] +http2 = ["httpcore[http2] (==1.*)"] +httpx = ["httpx (>=0.16.0)"] +huey = ["huey (>=2)"] +huggingface-hub = ["huggingface_hub (>=0.22)"] +langchain = ["langchain (>=0.0.210)"] +langgraph = ["langgraph (>=0.6.6)"] +launchdarkly = ["launchdarkly-server-sdk (>=9.8.0)"] +litellm = ["litellm (>=1.77.5)"] +litestar = ["litestar (>=2.0.0)"] +loguru = ["loguru (>=0.5)"] +mcp = ["mcp (>=1.15.0)"] +openai = ["openai (>=1.0.0)", "tiktoken (>=0.3.0)"] +openfeature = ["openfeature-sdk (>=0.7.1)"] +opentelemetry = ["opentelemetry-distro (>=0.35b0)"] +opentelemetry-experimental = ["opentelemetry-distro"] +opentelemetry-otlp = ["opentelemetry-distro[otlp] (>=0.35b0)"] +pure-eval = ["asttokens", "executing", "pure_eval"] +pydantic-ai = ["pydantic-ai (>=1.0.0)"] +pymongo = ["pymongo (>=3.1)"] +pyspark = ["pyspark (>=2.4.4)"] +quart = ["blinker (>=1.1)", "quart (>=0.16.1)"] +rq = ["rq (>=0.6)"] +sanic = ["sanic (>=0.8)"] +sqlalchemy = ["sqlalchemy (>=1.2)"] +starlette = ["starlette (>=0.19.1)"] +starlite = ["starlite (>=1.48)"] +statsig = ["statsig (>=0.55.3)"] +tornado = ["tornado (>=6)"] +unleash = ["UnleashClient (>=6.0.1)"] + [[package]] name = "setuptools" version = "80.9.0" @@ -4706,6 +4798,18 @@ files = [ {file = "six-1.17.0.tar.gz", hash = "sha256:ff70335d468e7eb6ec65b95b99d3a2836546063f63acc5171de367e834932a81"}, ] +[[package]] +name = "smmap" +version = "5.0.2" +description = "A pure Python implementation of a sliding window memory map manager" +optional = false +python-versions = ">=3.7" +groups = ["main"] +files = [ + {file = "smmap-5.0.2-py3-none-any.whl", hash = "sha256:b30115f0def7d7531d22a0fb6502488d879e75b260a9db4d0819cfb25403af5e"}, + {file = "smmap-5.0.2.tar.gz", hash = "sha256:26ea65a03958fa0c8a1c7e8c7a58fdc77221b8910f6be2131affade476898ad5"}, +] + [[package]] name = "snowballstemmer" version = "3.0.1" @@ -5379,7 +5483,6 @@ description = "Runtime typing introspection tools" optional = false python-versions = ">=3.9" groups = ["main"] -markers = "extra == \"multiprocessing\"" files = [ {file = "typing_inspection-0.4.2-py3-none-any.whl", hash = "sha256:4ed1cacbdc298c220f1bd249ed5287caa16f34d44ef4e9c3d0cbad5b521545e7"}, {file = "typing_inspection-0.4.2.tar.gz", hash = "sha256:ba561c48a67c5958007083d386c3295464928b01faa735ab8547c5692e87f464"}, @@ -5458,6 +5561,51 @@ platformdirs = ">=3.9.1,<5" docs = ["furo (>=2023.7.26)", "proselint (>=0.13)", "sphinx (>=7.1.2,!=7.3)", "sphinx-argparse (>=0.4)", "sphinxcontrib-towncrier (>=0.2.1a0)", "towncrier (>=23.6)"] test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess (>=1)", "flaky (>=3.7)", "packaging (>=23.1)", "pytest (>=7.4)", "pytest-env (>=0.8.2)", "pytest-freezer (>=0.4.8) ; platform_python_implementation == \"PyPy\" or platform_python_implementation == \"GraalVM\" or platform_python_implementation == \"CPython\" and sys_platform == \"win32\" and python_version >= \"3.13\"", "pytest-mock (>=3.11.1)", "pytest-randomly (>=3.12)", "pytest-timeout (>=2.1)", "setuptools (>=68)", "time-machine (>=2.10) ; platform_python_implementation == \"CPython\""] +[[package]] +name = "wandb" +version = "0.24.0" +description = "A CLI and library for interacting with the Weights & Biases API." +optional = false +python-versions = ">=3.8" +groups = ["main"] +files = [ + {file = "wandb-0.24.0-py3-none-macosx_12_0_arm64.whl", hash = "sha256:aa9777398ff4b0f04c41359f7d1b95b5d656cb12c37c63903666799212e50299"}, + {file = "wandb-0.24.0-py3-none-macosx_12_0_x86_64.whl", hash = "sha256:0423fbd58c3926949724feae8aab89d20c68846f9f4f596b80f9ffe1fc298130"}, + {file = "wandb-0.24.0-py3-none-manylinux_2_28_aarch64.whl", hash = "sha256:2b25fc0c123daac97ed32912ac55642c65013cc6e3a898e88ca2d917fc8eadc0"}, + {file = "wandb-0.24.0-py3-none-manylinux_2_28_x86_64.whl", hash = "sha256:9485344b4667944b5b77294185bae8469cfa4074869bec0e74f54f8492234cc2"}, + {file = "wandb-0.24.0-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:51b2b9a9d7d6b35640f12a46a48814fd4516807ad44f586b819ed6560f8de1fd"}, + {file = "wandb-0.24.0-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:11f7e7841f31eff82c82a677988889ad3aa684c6de61ff82145333b5214ec860"}, + {file = "wandb-0.24.0-py3-none-win32.whl", hash = "sha256:42af348998b00d4309ae790c5374040ac6cc353ab21567f4e29c98c9376dee8e"}, + {file = "wandb-0.24.0-py3-none-win_amd64.whl", hash = "sha256:32604eddcd362e1ed4a2e2ce5f3a239369c4a193af223f3e66603481ac91f336"}, + {file = "wandb-0.24.0-py3-none-win_arm64.whl", hash = "sha256:e0f2367552abfca21b0f3a03405fbf48f1e14de9846e70f73c6af5da57afd8ef"}, + {file = "wandb-0.24.0.tar.gz", hash = "sha256:4715a243b3d460b6434b9562e935dfd9dfdf5d6e428cfb4c3e7ce4fd44460ab3"}, +] + +[package.dependencies] +click = ">=8.0.1" +gitpython = ">=1.0.0,<3.1.29 || >3.1.29" +packaging = "*" +platformdirs = "*" +protobuf = {version = ">=3.19.0,<4.21.0 || >4.21.0,<5.28.0 || >5.28.0,<7", markers = "python_version > \"3.9\" or sys_platform != \"linux\""} +pydantic = "<3" +pyyaml = "*" +requests = ">=2.0.0,<3" +sentry-sdk = ">=2.0.0" +typing-extensions = ">=4.8,<5" + +[package.extras] +aws = ["boto3", "botocore (>=1.5.76)"] +azure = ["azure-identity", "azure-storage-blob"] +gcp = ["google-cloud-storage"] +importers = ["filelock", "mlflow", "polars (<=1.2.1)", "rich", "tenacity"] +kubeflow = ["google-cloud-storage", "kubernetes", "minio", "sh"] +launch = ["awscli", "azure-containerregistry", "azure-identity", "azure-storage-blob", "boto3", "botocore (>=1.5.76)", "chardet", "google-auth", "google-cloud-aiplatform", "google-cloud-artifact-registry", "google-cloud-compute", "google-cloud-storage", "iso8601", "jsonschema", "kubernetes", "kubernetes-asyncio", "nbconvert", "nbformat", "optuna", "pydantic", "pyyaml (>=6.0.0)", "tomli", "tornado (>=6.5.0) ; python_version >= \"3.9\"", "typing-extensions"] +media = ["bokeh", "imageio (>=2.28.1)", "moviepy (>=1.0.0)", "numpy", "pillow", "plotly (>=5.18.0)", "rdkit", "soundfile"] +models = ["cloudpickle"] +perf = ["orjson"] +sweeps = ["sweeps (>=0.2.0)"] +workspaces = ["wandb-workspaces"] + [[package]] name = "watchfiles" version = "1.1.1" @@ -6164,4 +6312,4 @@ multiprocessing = ["pydantic", "ray"] [metadata] lock-version = "2.1" python-versions = ">=3.11,<3.14" -content-hash = "cb7bf71d59b895c0cfbe611d73c357cf78478175d923b41ae5a3565aaf47437e" +content-hash = "b48d33e2c6e66c3fa8aa5b42db423e8577a44e19090b766f65051f4b9587dde4" diff --git a/pyproject.toml b/pyproject.toml index 02886951..17a62ab4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,6 +49,7 @@ toml = {version = "^0.10.2"} poetry = "^2.0.1" starlette = ">=0.49.1" pydantic = { version = ">=2.5", optional = true } +wandb = "^0.24.0" [tool.poetry.requires-plugins] poetry-plugin-export = ">=1.8" diff --git a/tests/test_hpam_tune.py b/tests/test_hpam_tune.py index 38fcd98b..ccbcd8e0 100644 --- a/tests/test_hpam_tune.py +++ b/tests/test_hpam_tune.py @@ -63,6 +63,9 @@ def test_hpam_tune(tmp_path): metric="RMSE", path_data="../data", model_checkpoint_dir="TEMPORARY", + split_index=None, + wandb_project=None, + wandb_base_config=None, ) assert best in hpam_set diff --git a/tests/test_main.py b/tests/test_main.py index a86d7ed1..cadc9312 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -41,6 +41,7 @@ "model_checkpoint_dir": "TEMPORARY", "no_hyperparameter_tuning": True, "final_model_on_full_data": True, + "wandb_project": None, } ], )