diff --git a/config/iwae_test_configs/conv15_sigma_0.01_kl_0.01_lr_0.1_n_10.yaml b/config/iwae_test_configs/conv15_sigma_0.01_kl_0.01_lr_0.1_n_10.yaml new file mode 100644 index 0000000..89ad595 --- /dev/null +++ b/config/iwae_test_configs/conv15_sigma_0.01_kl_0.01_lr_0.1_n_10.yaml @@ -0,0 +1,62 @@ +bound: + delta: 0.025 + delta_test: 0.01 +dist_init: + seed: 110 +factory: + bounds: + - kl + - mcallester + data_loader: + name: cifar10 + params: + dataset_path: ./data/cifar10 + losses: + - nll_loss + - scaled_nll_loss + - 01_loss + metrics: + - accuracy_micro_metric + - accuracy_macro_metric + - f1_micro_metric + - f1_macro_metric + model: + name: conv15 + params: + dataset: cifar10 + in_channels: 3 + posterior_objective: + name: bbb + params: + kl_penalty: 1.0 + prior_objective: + name: iwae + params: + n: 10 + kl_penalty: 0.01 +log_wandb: true +mcsamples: 3000 +pmin: 5.0e-05 +posterior: + training: + epochs: 1 + lr: 0.0001 + momentum: 0.9 + seed: 1135 +prior: + training: + epochs: 100 + lr: 0.0001 + momentum: 0.95 + seed: 1135 +sigma: 0.005 +split_config: + batch_size: 250 + dataset_loader_seed: 112 + seed: 111 +split_strategy: + prior_percent: 0.7 + prior_type: learnt + self_certified: true + train_percent: 1.0 + val_percent: 0.0 \ No newline at end of file diff --git a/config/iwae_test_configs/conv15_sigma_0.01_kl_0.01_lr_0.1_n_20.yaml b/config/iwae_test_configs/conv15_sigma_0.01_kl_0.01_lr_0.1_n_20.yaml new file mode 100644 index 0000000..5fc4dc6 --- /dev/null +++ b/config/iwae_test_configs/conv15_sigma_0.01_kl_0.01_lr_0.1_n_20.yaml @@ -0,0 +1,62 @@ +bound: + delta: 0.025 + delta_test: 0.01 +dist_init: + seed: 110 +factory: + bounds: + - kl + - mcallester + data_loader: + name: cifar10 + params: + dataset_path: ./data/cifar10 + losses: + - nll_loss + - scaled_nll_loss + - 01_loss + metrics: + - accuracy_micro_metric + - accuracy_macro_metric + - f1_micro_metric + - f1_macro_metric + model: + name: conv15 + params: + dataset: cifar10 + in_channels: 3 + posterior_objective: + name: bbb + params: + kl_penalty: 1.0 + prior_objective: + name: iwae + params: + n: 20 + kl_penalty: 0.01 +log_wandb: true +mcsamples: 3000 +pmin: 5.0e-05 +posterior: + training: + epochs: 1 + lr: 0.0001 + momentum: 0.9 + seed: 1135 +prior: + training: + epochs: 100 + lr: 0.0001 + momentum: 0.95 + seed: 1135 +sigma: 0.005 +split_config: + batch_size: 250 + dataset_loader_seed: 112 + seed: 111 +split_strategy: + prior_percent: 0.7 + prior_type: learnt + self_certified: true + train_percent: 1.0 + val_percent: 0.0 \ No newline at end of file diff --git a/core/layer/AbstractProbLayer.py b/core/layer/AbstractProbLayer.py index fbc9bd0..407c6f2 100644 --- a/core/layer/AbstractProbLayer.py +++ b/core/layer/AbstractProbLayer.py @@ -19,6 +19,8 @@ class AbstractProbLayer(nn.Module, ABC): _bias_dist: AbstractVariable _prior_weight_dist: AbstractVariable _prior_bias_dist: AbstractVariable + _sampled_weight: Tensor + _sampled_bias: Tensor def probabilistic(self, mode: bool = True): """ @@ -64,4 +66,6 @@ def sample_from_distribution(self) -> tuple[Tensor, Tensor]: sampled_bias = self._bias_dist.mu if self._bias_dist else None else: raise ValueError("Only training with probabilistic mode is allowed") - return sampled_weight, sampled_bias + self._sampled_weight = sampled_weight + self._sampled_bias = sampled_bias + return sampled_weight, sampled_bias \ No newline at end of file diff --git a/core/objective/IWAEObjective.py b/core/objective/IWAEObjective.py new file mode 100644 index 0000000..3f478d4 --- /dev/null +++ b/core/objective/IWAEObjective.py @@ -0,0 +1,115 @@ +import logging, math +from typing import Dict, Optional + +import torch, torch.distributions as dists, wandb +from torch import nn, Tensor + +from core.model import bounded_call +from core.layer.utils import get_torch_layers +from core.objective import AbstractObjective + + +class IWAEObjective(AbstractObjective): + def __init__(self, kl_penalty: float, n: int, temperature: float = 1.0) -> None: + self.kl_penalty = kl_penalty # usually 1 / |D| + self.temperature = temperature + self.k=n + print(self.temperature) + print(self.k) + + # -------- helpers to compute log p(w) and log q(w) ------------------- + @staticmethod + def _log_prior(model: nn.Module, eps: float = 1e-6) -> Tensor: + device = next(model.parameters()).device + dtype = next(model.parameters()).dtype + s = torch.zeros(1, device=device, dtype=dtype) + + for _, l in get_torch_layers(model): + s += dists.Normal(l._prior_weight_dist.mu, + l._prior_weight_dist.sigma + eps + ).log_prob(l._sampled_weight).sum() + s += dists.Normal(l._prior_bias_dist.mu, + l._prior_bias_dist.sigma + eps + ).log_prob(l._sampled_bias).sum() + return s + + @staticmethod + def _log_post(model: nn.Module, eps: float = 1e-6) -> Tensor: + device = next(model.parameters()).device + dtype = next(model.parameters()).dtype + s = torch.zeros(1, device=device, dtype=dtype) + + for _, l in get_torch_layers(model): + s += dists.Normal(l._weight_dist.mu, + l._weight_dist.sigma + eps + ).log_prob(l._sampled_weight).sum() + s += dists.Normal(l._bias_dist.mu, + l._bias_dist.sigma + eps + ).log_prob(l._sampled_bias).sum() + return s + + # -------------------------------------------------------------------- + def calculate( + self, + model: nn.Module, + data: Tensor, + target: Tensor, + epoch: int, + batch_idx: int, + dataset_size: int, + pmin: Optional[float] = None, + wandb_params: Optional[Dict] = None, + ) -> Tensor: + + + batch_size = data.size(0) + scale = dataset_size / batch_size # N / |B| + log_ws = [] # list[k] of scalars + + temp = self.temperature + beta = min(1.0, (epoch / 70) ** 2) + + for l in range(self.k): + # sample w and compute log p(x|w) + logits = bounded_call(model, data, pmin) if pmin is not None else model(data) + + if torch.isnan(logits).any() or torch.isinf(logits).any(): + logging.warning(f"NaN/Inf in logits at epoch {epoch}, batch {batch_idx}") + logits = torch.where(torch.isfinite(logits), logits, torch.zeros_like(logits)) + + log_px = dists.Categorical(logits=logits).log_prob(target) # (batch,) + log_lik = scale * log_px.sum() # scalar + + # global KL part + kl = beta * (self._log_prior(model) - self._log_post(model)) + log_w = log_lik + temp * kl # scalar + log_ws.append(log_w) + + # -------------------- per-sample logging -------------------- + if wandb_params and wandb_params.get("log_wandb", False): + tag = wandb_params["name_wandb"] + wandb.log({ + f"{tag}/epoch": epoch, + f"{tag}/batch": batch_idx, + f"{tag}/sample": l, + f"{tag}/log_likelihood": log_lik.detach(), + f"{tag}/kl": kl.detach(), + f"{tag}/log_weight": log_w.detach(), + }) + + # ----------- PB-IWAE loss (one scalar) --------------------------- + log_ws_tensor = torch.stack(log_ws) # (k,) + loss = -(torch.logsumexp(log_ws_tensor, dim=0) - math.log(self.k)) + + # ----------- final logging -------------------------------------- + if wandb_params and wandb_params.get("log_wandb", False): + wandb.log({f"{wandb_params['name_wandb']}/iwae_loss": loss}) + + if batch_idx == 0: + logging.info( + f"[Epoch {epoch:03d} | Batch {batch_idx:04d}] " + f"IWAE-loss {loss.item():.4f} " + f"| mean log_px {(log_px.mean()).item():.4f} " + f"| KL {kl.item():.2f}" + ) + return loss diff --git a/core/objective/__init__.py b/core/objective/__init__.py index 6ff21a5..d909d87 100644 --- a/core/objective/__init__.py +++ b/core/objective/__init__.py @@ -17,3 +17,4 @@ from core.objective.FQuadObjective import FQuadObjective from core.objective.McAllesterObjective import McAllesterObjective from core.objective.TolstikhinObjective import TolstikhinObjective +from core.objective.IWAEObjective import IWAEObjective diff --git a/core/training.py b/core/training.py index e5ed2cb..507f2a6 100644 --- a/core/training.py +++ b/core/training.py @@ -8,7 +8,7 @@ import wandb from core.distribution.utils import DistributionT, compute_kl from core.model import bounded_call -from core.objective import AbstractObjective +from core.objective import AbstractObjective, IWAEObjective def __raise_exception_on_invalid_value(value: torch.Tensor): @@ -75,9 +75,13 @@ def train( None: The model (and its posterior) are updated in-place over the specified epochs. """ criterion = torch.nn.NLLLoss() - optimizer = torch.optim.SGD( - model.parameters(), lr=parameters["lr"], momentum=parameters["momentum"] - ) + #optimizer = torch.optim.SGD( + # model.parameters(), lr=parameters["lr"], momentum=parameters["momentum"] + #) + + optimizer = torch.optim.Adam(model.parameters(), + lr=parameters['lr']) + dataset_size = len(train_loader.dataset) if "seed" in parameters: torch.manual_seed(parameters["seed"]) @@ -89,11 +93,25 @@ def train( output = bounded_call(model, data, parameters["pmin"]) else: output = model(data) - kl = compute_kl(posterior, prior) - loss = criterion(output, target) - objective_value = objective.calculate(loss, kl, parameters["num_samples"]) + if isinstance(objective, IWAEObjective): + objective_value = objective.calculate(model, + data, + target, + epoch=epoch, + batch_idx=_i, + dataset_size=dataset_size, + pmin=parameters.get('pmin', None), + wandb_params=wandb_params) + with torch.no_grad(): + loss = criterion(model(data), target) + kl = compute_kl(posterior, prior) + else: + kl = compute_kl(posterior, prior) + loss = criterion(output, target) + objective_value = objective.calculate(loss, kl, parameters["num_samples"]) __raise_exception_on_invalid_value(objective_value) objective_value.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0) optimizer.step() logging.info( f"Epoch: {epoch}, Objective: {objective_value}, Loss: {loss}, KL/n: {kl / parameters['num_samples']}" diff --git a/scripts/utils/factory/ObjectiveFactory.py b/scripts/utils/factory/ObjectiveFactory.py index 59e6aad..7c187a8 100644 --- a/scripts/utils/factory/ObjectiveFactory.py +++ b/scripts/utils/factory/ObjectiveFactory.py @@ -5,6 +5,7 @@ FQuadObjective, McAllesterObjective, TolstikhinObjective, + IWAEObjective ) from scripts.utils.factory import AbstractFactory @@ -17,3 +18,4 @@ def __init__(self) -> None: self.register_creator("fquad", FQuadObjective) self.register_creator("mcallester", McAllesterObjective) self.register_creator("tolstikhin", TolstikhinObjective) + self.register_creator("iwae", IWAEObjective)