From 267bec31f3ef96665e534220334a97c42ae64b18 Mon Sep 17 00:00:00 2001 From: Misipuk Date: Sun, 25 May 2025 10:57:24 +0200 Subject: [PATCH 1/7] adapting iwae to the new code --- .../conv15_sigma_0.01_kl_0.01_lr_0.1.yaml | 63 ++++++++++ core/objective/IWAEObjective.py | 113 ++++++++++++++++++ core/training.py | 32 +++-- scripts/utils/factory/ObjectiveFactory.py | 2 + 4 files changed, 203 insertions(+), 7 deletions(-) create mode 100644 config/iwae_test_configs/conv15_sigma_0.01_kl_0.01_lr_0.1.yaml create mode 100644 core/objective/IWAEObjective.py diff --git a/config/iwae_test_configs/conv15_sigma_0.01_kl_0.01_lr_0.1.yaml b/config/iwae_test_configs/conv15_sigma_0.01_kl_0.01_lr_0.1.yaml new file mode 100644 index 0000000..3bb3103 --- /dev/null +++ b/config/iwae_test_configs/conv15_sigma_0.01_kl_0.01_lr_0.1.yaml @@ -0,0 +1,63 @@ +bound: + delta: 0.025 + delta_test: 0.01 +dist_init: + seed: 110 +factory: + bounds: + - kl + - mcallester + data_loader: + name: cifar10 + params: + dataset_path: ./data/cifar10 + losses: + - nll_loss + - scaled_nll_loss + - 01_loss + metrics: + - accuracy_micro_metric + - accuracy_macro_metric + - f1_micro_metric + - f1_macro_metric + model: + name: conv15 + params: + dataset: cifar10 + in_channels: 3 + posterior_objective: + name: iwae # adapted + params: + delta: 0.025 + kl_penalty: 1.0 + prior_objective: + name: iwae # adapted + params: + delta: 0.025 + kl_penalty: 0.01 +log_wandb: true +mcsamples: 3000 # adapted +pmin: 5.0e-05 +posterior: + training: + epochs: 1 + lr: 0.0001 # adapted + momentum: 0.9 + seed: 1135 +prior: + training: + epochs: 100 + lr: 0.0005 # adapted + momentum: 0.95 + seed: 1135 +sigma: 0.01 +split_config: + batch_size: 250 + dataset_loader_seed: 112 + seed: 111 +split_strategy: + prior_percent: 0.7 + prior_type: learnt + self_certified: true + train_percent: 1.0 + val_percent: 0.0 diff --git a/core/objective/IWAEObjective.py b/core/objective/IWAEObjective.py new file mode 100644 index 0000000..e7ce0ee --- /dev/null +++ b/core/objective/IWAEObjective.py @@ -0,0 +1,113 @@ +import logging, math +from typing import Dict, Optional + +import torch, torch.distributions as dists, wandb +from torch import nn, Tensor + +from core.model import bounded_call +from core.layer.utils import get_torch_layers +from core.objective import AbstractObjective + + +class IWAEObjective(AbstractObjective): + def __init__(self, kl_penalty: float, n: int, temperature: float = 1.0) -> None: + self.kl_penalty = kl_penalty # usually 1 / |D| + self.temperature = temperature + + # -------- helpers to compute log p(w) and log q(w) ------------------- + @staticmethod + def _log_prior(model: nn.Module, eps: float = 1e-6) -> Tensor: + device = next(model.parameters()).device + dtype = next(model.parameters()).dtype + s = torch.zeros(1, device=device, dtype=dtype) + + for _, l in get_torch_layers(model): + s += dists.Normal(l._prior_weight_dist.mu, + l._prior_weight_dist.sigma + eps + ).log_prob(l._sampled_weight).sum() + s += dists.Normal(l._prior_bias_dist.mu, + l._prior_bias_dist.sigma + eps + ).log_prob(l._sampled_bias).sum() + return s + + @staticmethod + def _log_post(model: nn.Module, eps: float = 1e-6) -> Tensor: + device = next(model.parameters()).device + dtype = next(model.parameters()).dtype + s = torch.zeros(1, device=device, dtype=dtype) + + for _, l in get_torch_layers(model): + s += dists.Normal(l._weight_dist.mu, + l._weight_dist.sigma + eps + ).log_prob(l._sampled_weight).sum() + s += dists.Normal(l._bias_dist.mu, + l._bias_dist.sigma + eps + ).log_prob(l._sampled_bias).sum() + return s + + # -------------------------------------------------------------------- + def calculate( + self, + model: nn.Module, + data: Tensor, + target: Tensor, + epoch: int, + batch_idx: int, + dataset_size: int, + pmin: Optional[float] = None, + wandb_params: Optional[Dict] = None, + ) -> Tensor: + + + batch_size = data.size(0) + scale = dataset_size / batch_size # N / |B| + log_ws = [] # list[k] of scalars + + kl_pen = 1 / dataset_size + temp = 1.0 + self.k = 20 + + for l in range(self.k): + # sample w and compute log p(x|w) + logits = bounded_call(model, data, pmin) if pmin is not None else model(data) + + if torch.isnan(logits).any() or torch.isinf(logits).any(): + logging.warning(f"NaN/Inf in logits at epoch {epoch}, batch {batch_idx}") + logits = torch.where(torch.isfinite(logits), logits, torch.zeros_like(logits)) + + log_px = dists.Categorical(logits=logits).log_prob(target) # (batch,) + log_lik = scale * log_px.sum() # scalar + + # global KL part + kl = (self._log_prior(model) - self._log_post(model)) * kl_pen + log_w = log_lik + temp * kl # scalar + log_ws.append(log_w) + + # -------------------- per-sample logging -------------------- + if wandb_params and wandb_params.get("log_wandb", False): + tag = wandb_params["name_wandb"] + wandb.log({ + f"{tag}/epoch": epoch, + f"{tag}/batch": batch_idx, + f"{tag}/sample": l, + f"{tag}/log_likelihood": log_lik.detach(), + f"{tag}/kl": kl.detach(), + f"{tag}/log_weight": log_w.detach(), + }) + + # ----------- PB-IWAE loss (one scalar) --------------------------- + log_ws_tensor = torch.stack(log_ws) # (k,) + loss = -(torch.logsumexp(log_ws_tensor, dim=0) - math.log(self.k)) + + # ----------- final logging -------------------------------------- + if wandb_params and wandb_params.get("log_wandb", False): + wandb.log({f"{wandb_params['name_wandb']}/iwae_loss": loss}) + + if batch_idx == 0: + logging.info( + f"[Epoch {epoch:03d} | Batch {batch_idx:04d}] " + f"IWAE-loss {loss.item():.4f} " + f"| mean log_px {(log_px.mean()).item():.4f} " + f"| KL {kl.item():.2f}" + ) + return loss diff --git a/core/training.py b/core/training.py index e5ed2cb..507f2a6 100644 --- a/core/training.py +++ b/core/training.py @@ -8,7 +8,7 @@ import wandb from core.distribution.utils import DistributionT, compute_kl from core.model import bounded_call -from core.objective import AbstractObjective +from core.objective import AbstractObjective, IWAEObjective def __raise_exception_on_invalid_value(value: torch.Tensor): @@ -75,9 +75,13 @@ def train( None: The model (and its posterior) are updated in-place over the specified epochs. """ criterion = torch.nn.NLLLoss() - optimizer = torch.optim.SGD( - model.parameters(), lr=parameters["lr"], momentum=parameters["momentum"] - ) + #optimizer = torch.optim.SGD( + # model.parameters(), lr=parameters["lr"], momentum=parameters["momentum"] + #) + + optimizer = torch.optim.Adam(model.parameters(), + lr=parameters['lr']) + dataset_size = len(train_loader.dataset) if "seed" in parameters: torch.manual_seed(parameters["seed"]) @@ -89,11 +93,25 @@ def train( output = bounded_call(model, data, parameters["pmin"]) else: output = model(data) - kl = compute_kl(posterior, prior) - loss = criterion(output, target) - objective_value = objective.calculate(loss, kl, parameters["num_samples"]) + if isinstance(objective, IWAEObjective): + objective_value = objective.calculate(model, + data, + target, + epoch=epoch, + batch_idx=_i, + dataset_size=dataset_size, + pmin=parameters.get('pmin', None), + wandb_params=wandb_params) + with torch.no_grad(): + loss = criterion(model(data), target) + kl = compute_kl(posterior, prior) + else: + kl = compute_kl(posterior, prior) + loss = criterion(output, target) + objective_value = objective.calculate(loss, kl, parameters["num_samples"]) __raise_exception_on_invalid_value(objective_value) objective_value.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0) optimizer.step() logging.info( f"Epoch: {epoch}, Objective: {objective_value}, Loss: {loss}, KL/n: {kl / parameters['num_samples']}" diff --git a/scripts/utils/factory/ObjectiveFactory.py b/scripts/utils/factory/ObjectiveFactory.py index 59e6aad..7c187a8 100644 --- a/scripts/utils/factory/ObjectiveFactory.py +++ b/scripts/utils/factory/ObjectiveFactory.py @@ -5,6 +5,7 @@ FQuadObjective, McAllesterObjective, TolstikhinObjective, + IWAEObjective ) from scripts.utils.factory import AbstractFactory @@ -17,3 +18,4 @@ def __init__(self) -> None: self.register_creator("fquad", FQuadObjective) self.register_creator("mcallester", McAllesterObjective) self.register_creator("tolstikhin", TolstikhinObjective) + self.register_creator("iwae", IWAEObjective) From 53ba96434a2383011c30a1e5a00d6388fbdd0dcf Mon Sep 17 00:00:00 2001 From: Misipuk Date: Sun, 25 May 2025 17:49:43 +0200 Subject: [PATCH 2/7] added new configs --- ...onv15_sigma_0.01_kl_0.01_lr_0.1_n_10.yaml} | 15 +++-- ...conv15_sigma_0.01_kl_0.01_lr_0.1_n_20.yaml | 62 +++++++++++++++++++ core/objective/IWAEObjective.py | 10 +-- 3 files changed, 75 insertions(+), 12 deletions(-) rename config/iwae_test_configs/{conv15_sigma_0.01_kl_0.01_lr_0.1.yaml => conv15_sigma_0.01_kl_0.01_lr_0.1_n_10.yaml} (82%) create mode 100644 config/iwae_test_configs/conv15_sigma_0.01_kl_0.01_lr_0.1_n_20.yaml diff --git a/config/iwae_test_configs/conv15_sigma_0.01_kl_0.01_lr_0.1.yaml b/config/iwae_test_configs/conv15_sigma_0.01_kl_0.01_lr_0.1_n_10.yaml similarity index 82% rename from config/iwae_test_configs/conv15_sigma_0.01_kl_0.01_lr_0.1.yaml rename to config/iwae_test_configs/conv15_sigma_0.01_kl_0.01_lr_0.1_n_10.yaml index 3bb3103..0f617c5 100644 --- a/config/iwae_test_configs/conv15_sigma_0.01_kl_0.01_lr_0.1.yaml +++ b/config/iwae_test_configs/conv15_sigma_0.01_kl_0.01_lr_0.1_n_10.yaml @@ -26,28 +26,27 @@ factory: dataset: cifar10 in_channels: 3 posterior_objective: - name: iwae # adapted + name: bbb params: - delta: 0.025 kl_penalty: 1.0 prior_objective: - name: iwae # adapted + name: iwae params: - delta: 0.025 + n: 10 kl_penalty: 0.01 log_wandb: true -mcsamples: 3000 # adapted +mcsamples: 3000 pmin: 5.0e-05 posterior: training: epochs: 1 - lr: 0.0001 # adapted + lr: 0.0001 momentum: 0.9 seed: 1135 prior: training: epochs: 100 - lr: 0.0005 # adapted + lr: 0.0005 momentum: 0.95 seed: 1135 sigma: 0.01 @@ -60,4 +59,4 @@ split_strategy: prior_type: learnt self_certified: true train_percent: 1.0 - val_percent: 0.0 + val_percent: 0.0 \ No newline at end of file diff --git a/config/iwae_test_configs/conv15_sigma_0.01_kl_0.01_lr_0.1_n_20.yaml b/config/iwae_test_configs/conv15_sigma_0.01_kl_0.01_lr_0.1_n_20.yaml new file mode 100644 index 0000000..50eb546 --- /dev/null +++ b/config/iwae_test_configs/conv15_sigma_0.01_kl_0.01_lr_0.1_n_20.yaml @@ -0,0 +1,62 @@ +bound: + delta: 0.025 + delta_test: 0.01 +dist_init: + seed: 110 +factory: + bounds: + - kl + - mcallester + data_loader: + name: cifar10 + params: + dataset_path: ./data/cifar10 + losses: + - nll_loss + - scaled_nll_loss + - 01_loss + metrics: + - accuracy_micro_metric + - accuracy_macro_metric + - f1_micro_metric + - f1_macro_metric + model: + name: conv15 + params: + dataset: cifar10 + in_channels: 3 + posterior_objective: + name: bbb + params: + kl_penalty: 1.0 + prior_objective: + name: iwae + params: + n: 20 + kl_penalty: 0.01 +log_wandb: true +mcsamples: 3000 +pmin: 5.0e-05 +posterior: + training: + epochs: 1 + lr: 0.0001 + momentum: 0.9 + seed: 1135 +prior: + training: + epochs: 100 + lr: 0.0005 + momentum: 0.95 + seed: 1135 +sigma: 0.01 +split_config: + batch_size: 250 + dataset_loader_seed: 112 + seed: 111 +split_strategy: + prior_percent: 0.7 + prior_type: learnt + self_certified: true + train_percent: 1.0 + val_percent: 0.0 \ No newline at end of file diff --git a/core/objective/IWAEObjective.py b/core/objective/IWAEObjective.py index e7ce0ee..5e972bb 100644 --- a/core/objective/IWAEObjective.py +++ b/core/objective/IWAEObjective.py @@ -13,6 +13,9 @@ class IWAEObjective(AbstractObjective): def __init__(self, kl_penalty: float, n: int, temperature: float = 1.0) -> None: self.kl_penalty = kl_penalty # usually 1 / |D| self.temperature = temperature + self.k=n + print(self.temperature) + print(self.k) # -------- helpers to compute log p(w) and log q(w) ------------------- @staticmethod @@ -63,9 +66,8 @@ def calculate( scale = dataset_size / batch_size # N / |B| log_ws = [] # list[k] of scalars - kl_pen = 1 / dataset_size - temp = 1.0 - self.k = 20 + temp = self.temperature + # self.k = 20 for l in range(self.k): # sample w and compute log p(x|w) @@ -79,7 +81,7 @@ def calculate( log_lik = scale * log_px.sum() # scalar # global KL part - kl = (self._log_prior(model) - self._log_post(model)) * kl_pen + kl = self._log_prior(model) - self._log_post(model) log_w = log_lik + temp * kl # scalar log_ws.append(log_w) From 7a94253b0a20e4a39e7450c6c19af5dbe266610b Mon Sep 17 00:00:00 2001 From: Misipuk Date: Sun, 25 May 2025 18:12:42 +0200 Subject: [PATCH 3/7] fixed init --- core/objective/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/core/objective/__init__.py b/core/objective/__init__.py index 6ff21a5..d909d87 100644 --- a/core/objective/__init__.py +++ b/core/objective/__init__.py @@ -17,3 +17,4 @@ from core.objective.FQuadObjective import FQuadObjective from core.objective.McAllesterObjective import McAllesterObjective from core.objective.TolstikhinObjective import TolstikhinObjective +from core.objective.IWAEObjective import IWAEObjective From cd38adc1584d5e022349f10a2e5e41b111d566f7 Mon Sep 17 00:00:00 2001 From: Misipuk Date: Sun, 25 May 2025 18:28:45 +0200 Subject: [PATCH 4/7] added inner vars for prob layer --- core/layer/AbstractProbLayer.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/core/layer/AbstractProbLayer.py b/core/layer/AbstractProbLayer.py index fbc9bd0..407c6f2 100644 --- a/core/layer/AbstractProbLayer.py +++ b/core/layer/AbstractProbLayer.py @@ -19,6 +19,8 @@ class AbstractProbLayer(nn.Module, ABC): _bias_dist: AbstractVariable _prior_weight_dist: AbstractVariable _prior_bias_dist: AbstractVariable + _sampled_weight: Tensor + _sampled_bias: Tensor def probabilistic(self, mode: bool = True): """ @@ -64,4 +66,6 @@ def sample_from_distribution(self) -> tuple[Tensor, Tensor]: sampled_bias = self._bias_dist.mu if self._bias_dist else None else: raise ValueError("Only training with probabilistic mode is allowed") - return sampled_weight, sampled_bias + self._sampled_weight = sampled_weight + self._sampled_bias = sampled_bias + return sampled_weight, sampled_bias \ No newline at end of file From 97b6e4d34d4e70f713753a930ec20b201db31f96 Mon Sep 17 00:00:00 2001 From: Misipuk Date: Sun, 25 May 2025 18:59:53 +0200 Subject: [PATCH 5/7] adding beta --- .../conv15_sigma_0.01_kl_0.01_lr_0.1_n_10.yaml | 2 +- .../conv15_sigma_0.01_kl_0.01_lr_0.1_n_20.yaml | 2 +- core/objective/IWAEObjective.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/config/iwae_test_configs/conv15_sigma_0.01_kl_0.01_lr_0.1_n_10.yaml b/config/iwae_test_configs/conv15_sigma_0.01_kl_0.01_lr_0.1_n_10.yaml index 0f617c5..4708058 100644 --- a/config/iwae_test_configs/conv15_sigma_0.01_kl_0.01_lr_0.1_n_10.yaml +++ b/config/iwae_test_configs/conv15_sigma_0.01_kl_0.01_lr_0.1_n_10.yaml @@ -49,7 +49,7 @@ prior: lr: 0.0005 momentum: 0.95 seed: 1135 -sigma: 0.01 +sigma: 0.005 split_config: batch_size: 250 dataset_loader_seed: 112 diff --git a/config/iwae_test_configs/conv15_sigma_0.01_kl_0.01_lr_0.1_n_20.yaml b/config/iwae_test_configs/conv15_sigma_0.01_kl_0.01_lr_0.1_n_20.yaml index 50eb546..042b1a7 100644 --- a/config/iwae_test_configs/conv15_sigma_0.01_kl_0.01_lr_0.1_n_20.yaml +++ b/config/iwae_test_configs/conv15_sigma_0.01_kl_0.01_lr_0.1_n_20.yaml @@ -49,7 +49,7 @@ prior: lr: 0.0005 momentum: 0.95 seed: 1135 -sigma: 0.01 +sigma: 0.005 split_config: batch_size: 250 dataset_loader_seed: 112 diff --git a/core/objective/IWAEObjective.py b/core/objective/IWAEObjective.py index 5e972bb..27f5dc9 100644 --- a/core/objective/IWAEObjective.py +++ b/core/objective/IWAEObjective.py @@ -67,7 +67,7 @@ def calculate( log_ws = [] # list[k] of scalars temp = self.temperature - # self.k = 20 + beta = min(1.0, epoch / 20) for l in range(self.k): # sample w and compute log p(x|w) @@ -81,7 +81,7 @@ def calculate( log_lik = scale * log_px.sum() # scalar # global KL part - kl = self._log_prior(model) - self._log_post(model) + kl = beta * (self._log_prior(model) - self._log_post(model)) log_w = log_lik + temp * kl # scalar log_ws.append(log_w) From 534e9d212f5258a5be1f93b267df29dfd9048919 Mon Sep 17 00:00:00 2001 From: Misipuk Date: Sun, 25 May 2025 19:20:13 +0200 Subject: [PATCH 6/7] fix --- core/objective/IWAEObjective.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/objective/IWAEObjective.py b/core/objective/IWAEObjective.py index 27f5dc9..5664bad 100644 --- a/core/objective/IWAEObjective.py +++ b/core/objective/IWAEObjective.py @@ -67,7 +67,7 @@ def calculate( log_ws = [] # list[k] of scalars temp = self.temperature - beta = min(1.0, epoch / 20) + beta = min(1.0, (epoch / 40) ** 2) for l in range(self.k): # sample w and compute log p(x|w) From cdb89cbee6a4b02e98184a46b4688ea8512a4984 Mon Sep 17 00:00:00 2001 From: Misipuk Date: Sun, 25 May 2025 20:48:48 +0200 Subject: [PATCH 7/7] hyp fixes --- .../conv15_sigma_0.01_kl_0.01_lr_0.1_n_10.yaml | 2 +- .../conv15_sigma_0.01_kl_0.01_lr_0.1_n_20.yaml | 2 +- core/objective/IWAEObjective.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/config/iwae_test_configs/conv15_sigma_0.01_kl_0.01_lr_0.1_n_10.yaml b/config/iwae_test_configs/conv15_sigma_0.01_kl_0.01_lr_0.1_n_10.yaml index 4708058..89ad595 100644 --- a/config/iwae_test_configs/conv15_sigma_0.01_kl_0.01_lr_0.1_n_10.yaml +++ b/config/iwae_test_configs/conv15_sigma_0.01_kl_0.01_lr_0.1_n_10.yaml @@ -46,7 +46,7 @@ posterior: prior: training: epochs: 100 - lr: 0.0005 + lr: 0.0001 momentum: 0.95 seed: 1135 sigma: 0.005 diff --git a/config/iwae_test_configs/conv15_sigma_0.01_kl_0.01_lr_0.1_n_20.yaml b/config/iwae_test_configs/conv15_sigma_0.01_kl_0.01_lr_0.1_n_20.yaml index 042b1a7..5fc4dc6 100644 --- a/config/iwae_test_configs/conv15_sigma_0.01_kl_0.01_lr_0.1_n_20.yaml +++ b/config/iwae_test_configs/conv15_sigma_0.01_kl_0.01_lr_0.1_n_20.yaml @@ -46,7 +46,7 @@ posterior: prior: training: epochs: 100 - lr: 0.0005 + lr: 0.0001 momentum: 0.95 seed: 1135 sigma: 0.005 diff --git a/core/objective/IWAEObjective.py b/core/objective/IWAEObjective.py index 5664bad..3f478d4 100644 --- a/core/objective/IWAEObjective.py +++ b/core/objective/IWAEObjective.py @@ -67,7 +67,7 @@ def calculate( log_ws = [] # list[k] of scalars temp = self.temperature - beta = min(1.0, (epoch / 40) ** 2) + beta = min(1.0, (epoch / 70) ** 2) for l in range(self.k): # sample w and compute log p(x|w)