Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
bound:
delta: 0.025
delta_test: 0.01
dist_init:
seed: 110
factory:
bounds:
- kl
- mcallester
data_loader:
name: cifar10
params:
dataset_path: ./data/cifar10
losses:
- nll_loss
- scaled_nll_loss
- 01_loss
metrics:
- accuracy_micro_metric
- accuracy_macro_metric
- f1_micro_metric
- f1_macro_metric
model:
name: conv15
params:
dataset: cifar10
in_channels: 3
posterior_objective:
name: bbb
params:
kl_penalty: 1.0
prior_objective:
name: iwae
params:
n: 10
kl_penalty: 0.01
log_wandb: true
mcsamples: 3000
pmin: 5.0e-05
posterior:
training:
epochs: 1
lr: 0.0001
momentum: 0.9
seed: 1135
prior:
training:
epochs: 100
lr: 0.0001
momentum: 0.95
seed: 1135
sigma: 0.005
split_config:
batch_size: 250
dataset_loader_seed: 112
seed: 111
split_strategy:
prior_percent: 0.7
prior_type: learnt
self_certified: true
train_percent: 1.0
val_percent: 0.0
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
bound:
delta: 0.025
delta_test: 0.01
dist_init:
seed: 110
factory:
bounds:
- kl
- mcallester
data_loader:
name: cifar10
params:
dataset_path: ./data/cifar10
losses:
- nll_loss
- scaled_nll_loss
- 01_loss
metrics:
- accuracy_micro_metric
- accuracy_macro_metric
- f1_micro_metric
- f1_macro_metric
model:
name: conv15
params:
dataset: cifar10
in_channels: 3
posterior_objective:
name: bbb
params:
kl_penalty: 1.0
prior_objective:
name: iwae
params:
n: 20
kl_penalty: 0.01
log_wandb: true
mcsamples: 3000
pmin: 5.0e-05
posterior:
training:
epochs: 1
lr: 0.0001
momentum: 0.9
seed: 1135
prior:
training:
epochs: 100
lr: 0.0001
momentum: 0.95
seed: 1135
sigma: 0.005
split_config:
batch_size: 250
dataset_loader_seed: 112
seed: 111
split_strategy:
prior_percent: 0.7
prior_type: learnt
self_certified: true
train_percent: 1.0
val_percent: 0.0
6 changes: 5 additions & 1 deletion core/layer/AbstractProbLayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ class AbstractProbLayer(nn.Module, ABC):
_bias_dist: AbstractVariable
_prior_weight_dist: AbstractVariable
_prior_bias_dist: AbstractVariable
_sampled_weight: Tensor
_sampled_bias: Tensor

def probabilistic(self, mode: bool = True):
"""
Expand Down Expand Up @@ -64,4 +66,6 @@ def sample_from_distribution(self) -> tuple[Tensor, Tensor]:
sampled_bias = self._bias_dist.mu if self._bias_dist else None
else:
raise ValueError("Only training with probabilistic mode is allowed")
return sampled_weight, sampled_bias
self._sampled_weight = sampled_weight
self._sampled_bias = sampled_bias
return sampled_weight, sampled_bias
115 changes: 115 additions & 0 deletions core/objective/IWAEObjective.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import logging, math
from typing import Dict, Optional

import torch, torch.distributions as dists, wandb
from torch import nn, Tensor

from core.model import bounded_call
from core.layer.utils import get_torch_layers
from core.objective import AbstractObjective


class IWAEObjective(AbstractObjective):
def __init__(self, kl_penalty: float, n: int, temperature: float = 1.0) -> None:
self.kl_penalty = kl_penalty # usually 1 / |D|
self.temperature = temperature
self.k=n
print(self.temperature)
print(self.k)

# -------- helpers to compute log p(w) and log q(w) -------------------
@staticmethod
def _log_prior(model: nn.Module, eps: float = 1e-6) -> Tensor:
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
s = torch.zeros(1, device=device, dtype=dtype)

for _, l in get_torch_layers(model):
s += dists.Normal(l._prior_weight_dist.mu,
l._prior_weight_dist.sigma + eps
).log_prob(l._sampled_weight).sum()
s += dists.Normal(l._prior_bias_dist.mu,
l._prior_bias_dist.sigma + eps
).log_prob(l._sampled_bias).sum()
return s

@staticmethod
def _log_post(model: nn.Module, eps: float = 1e-6) -> Tensor:
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
s = torch.zeros(1, device=device, dtype=dtype)

for _, l in get_torch_layers(model):
s += dists.Normal(l._weight_dist.mu,
l._weight_dist.sigma + eps
).log_prob(l._sampled_weight).sum()
s += dists.Normal(l._bias_dist.mu,
l._bias_dist.sigma + eps
).log_prob(l._sampled_bias).sum()
return s

# --------------------------------------------------------------------
def calculate(
self,
model: nn.Module,
data: Tensor,
target: Tensor,
epoch: int,
batch_idx: int,
dataset_size: int,
pmin: Optional[float] = None,
wandb_params: Optional[Dict] = None,
) -> Tensor:


batch_size = data.size(0)
scale = dataset_size / batch_size # N / |B|
log_ws = [] # list[k] of scalars

temp = self.temperature
beta = min(1.0, (epoch / 70) ** 2)

for l in range(self.k):
# sample w and compute log p(x|w)
logits = bounded_call(model, data, pmin) if pmin is not None else model(data)

if torch.isnan(logits).any() or torch.isinf(logits).any():
logging.warning(f"NaN/Inf in logits at epoch {epoch}, batch {batch_idx}")
logits = torch.where(torch.isfinite(logits), logits, torch.zeros_like(logits))

log_px = dists.Categorical(logits=logits).log_prob(target) # (batch,)
log_lik = scale * log_px.sum() # scalar

# global KL part
kl = beta * (self._log_prior(model) - self._log_post(model))
log_w = log_lik + temp * kl # scalar
log_ws.append(log_w)

# -------------------- per-sample logging --------------------
if wandb_params and wandb_params.get("log_wandb", False):
tag = wandb_params["name_wandb"]
wandb.log({
f"{tag}/epoch": epoch,
f"{tag}/batch": batch_idx,
f"{tag}/sample": l,
f"{tag}/log_likelihood": log_lik.detach(),
f"{tag}/kl": kl.detach(),
f"{tag}/log_weight": log_w.detach(),
})

# ----------- PB-IWAE loss (one scalar) ---------------------------
log_ws_tensor = torch.stack(log_ws) # (k,)
loss = -(torch.logsumexp(log_ws_tensor, dim=0) - math.log(self.k))

# ----------- final logging --------------------------------------
if wandb_params and wandb_params.get("log_wandb", False):
wandb.log({f"{wandb_params['name_wandb']}/iwae_loss": loss})

if batch_idx == 0:
logging.info(
f"[Epoch {epoch:03d} | Batch {batch_idx:04d}] "
f"IWAE-loss {loss.item():.4f} "
f"| mean log_px {(log_px.mean()).item():.4f} "
f"| KL {kl.item():.2f}"
)
return loss
1 change: 1 addition & 0 deletions core/objective/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@
from core.objective.FQuadObjective import FQuadObjective
from core.objective.McAllesterObjective import McAllesterObjective
from core.objective.TolstikhinObjective import TolstikhinObjective
from core.objective.IWAEObjective import IWAEObjective
32 changes: 25 additions & 7 deletions core/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import wandb
from core.distribution.utils import DistributionT, compute_kl
from core.model import bounded_call
from core.objective import AbstractObjective
from core.objective import AbstractObjective, IWAEObjective


def __raise_exception_on_invalid_value(value: torch.Tensor):
Expand Down Expand Up @@ -75,9 +75,13 @@ def train(
None: The model (and its posterior) are updated in-place over the specified epochs.
"""
criterion = torch.nn.NLLLoss()
optimizer = torch.optim.SGD(
model.parameters(), lr=parameters["lr"], momentum=parameters["momentum"]
)
#optimizer = torch.optim.SGD(
# model.parameters(), lr=parameters["lr"], momentum=parameters["momentum"]
#)

optimizer = torch.optim.Adam(model.parameters(),
lr=parameters['lr'])
dataset_size = len(train_loader.dataset)

if "seed" in parameters:
torch.manual_seed(parameters["seed"])
Expand All @@ -89,11 +93,25 @@ def train(
output = bounded_call(model, data, parameters["pmin"])
else:
output = model(data)
kl = compute_kl(posterior, prior)
loss = criterion(output, target)
objective_value = objective.calculate(loss, kl, parameters["num_samples"])
if isinstance(objective, IWAEObjective):
objective_value = objective.calculate(model,
data,
target,
epoch=epoch,
batch_idx=_i,
dataset_size=dataset_size,
pmin=parameters.get('pmin', None),
wandb_params=wandb_params)
with torch.no_grad():
loss = criterion(model(data), target)
kl = compute_kl(posterior, prior)
else:
kl = compute_kl(posterior, prior)
loss = criterion(output, target)
objective_value = objective.calculate(loss, kl, parameters["num_samples"])
__raise_exception_on_invalid_value(objective_value)
objective_value.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
optimizer.step()
logging.info(
f"Epoch: {epoch}, Objective: {objective_value}, Loss: {loss}, KL/n: {kl / parameters['num_samples']}"
Expand Down
2 changes: 2 additions & 0 deletions scripts/utils/factory/ObjectiveFactory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
FQuadObjective,
McAllesterObjective,
TolstikhinObjective,
IWAEObjective
)
from scripts.utils.factory import AbstractFactory

Expand All @@ -17,3 +18,4 @@ def __init__(self) -> None:
self.register_creator("fquad", FQuadObjective)
self.register_creator("mcallester", McAllesterObjective)
self.register_creator("tolstikhin", TolstikhinObjective)
self.register_creator("iwae", IWAEObjective)