diff --git a/.github/workflows/codecov.yaml b/.github/workflows/codecov.yaml index 445267a..4477207 100644 --- a/.github/workflows/codecov.yaml +++ b/.github/workflows/codecov.yaml @@ -15,7 +15,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v5 with: - python-version: 3.14 + python-version: 3.13 - name: Install test dependencies run: pip install pytest pytest-cov diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 2128877..756ea63 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -8,7 +8,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.12, 3.13, 3.14] + python-version: [3.12, 3.13] steps: - name: Checkout uses: actions/checkout@v4 diff --git a/pyproject.toml b/pyproject.toml index 391cb99..16a6890 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "rbms" -version = "0.5.0" +version = "0.6.0" authors = [ {name="Nicolas Béreux", email="nicolas.bereux@gmail.com"}, {name="Aurélien Decelle"}, @@ -19,12 +19,12 @@ maintainers = [ ] description = "Training and analyzing Restricted Boltzmann Machines in PyTorch" readme = "README.md" -requires-python = ">=3.12" +requires-python = ">=3.12, <3.14" dependencies = [ - "h5py>=3.12.0", + "h5py>=3.14.0", "numpy>=2.0.0", "matplotlib>=3.8.0", - "torch>=2.5.0", + "torch>=2.10.0", "tqdm>=4.65.0", ] @@ -85,4 +85,4 @@ docstring-code-format = false [dependency-groups] dev = [ "pytest>=8.4.1", -] \ No newline at end of file +] diff --git a/rbms/__init__.py b/rbms/__init__.py index e69de29..0169025 100644 --- a/rbms/__init__.py +++ b/rbms/__init__.py @@ -0,0 +1,42 @@ +from rbms.bernoulli_bernoulli.classes import BBRBM +from rbms.bernoulli_gaussian.classes import BGRBM +from rbms.dataset import load_dataset +from rbms.dataset.utils import convert_data +from rbms.io import load_model, load_params +from rbms.ising_ising.classes import IIRBM +from rbms.map_model import map_model +from rbms.plot import plot_image, plot_mult_PCA +from rbms.potts_bernoulli.classes import PBRBM +from rbms.utils import ( + bernoulli_to_ising, + compute_log_likelihood, + get_categorical_configurations, + get_eigenvalues_history, + get_flagged_updates, + get_saved_updates, + ising_to_bernoulli, +) + +__all__ = [ + BBRBM, + BGRBM, + IIRBM, + PBRBM, + map_model, + bernoulli_to_ising, + ising_to_bernoulli, + compute_log_likelihood, + get_eigenvalues_history, + get_saved_updates, + get_flagged_updates, + get_categorical_configurations, + plot_mult_PCA, + plot_image, + load_params, + load_model, + load_dataset, + convert_data, +] + + +__version__ = "0.5.1" diff --git a/rbms/bernoulli_bernoulli/__init__.py b/rbms/bernoulli_bernoulli/__init__.py index 25d1da7..b2248e0 100644 --- a/rbms/bernoulli_bernoulli/__init__.py +++ b/rbms/bernoulli_bernoulli/__init__.py @@ -1,3 +1,12 @@ # ruff: noqa from rbms.bernoulli_bernoulli.classes import BBRBM -from rbms.bernoulli_bernoulli.functional import * +from rbms.bernoulli_bernoulli.functional import ( + compute_energy, + compute_energy_hiddens, + compute_energy_visibles, + compute_gradient, + init_chains, + init_parameters, + sample_hiddens, + sample_visibles, +) diff --git a/rbms/bernoulli_bernoulli/classes.py b/rbms/bernoulli_bernoulli/classes.py index 5ce9c42..e9e7656 100644 --- a/rbms/bernoulli_bernoulli/classes.py +++ b/rbms/bernoulli_bernoulli/classes.py @@ -1,4 +1,4 @@ -from typing import Self +from __future__ import annotations import numpy as np import torch @@ -15,17 +15,20 @@ _sample_visibles, ) from rbms.classes import RBM +from rbms.custom_fn import check_keys_dict class BBRBM(RBM): """Parameters of the Bernoulli-Bernoulli RBM""" + visible_type: str = "bernoulli" + def __init__( self, weight_matrix: Tensor, vbias: Tensor, hbias: Tensor, - device: torch.device | None = None, + device: torch.device | str | None = None, dtype: torch.dtype | None = None, ): """Initialize the parameters of the Bernoulli-Bernoulli RBM. @@ -49,6 +52,7 @@ def __init__( self.vbias = vbias.to(device=self.device, dtype=self.dtype) self.hbias = hbias.to(device=self.device, dtype=self.dtype) self.name = "BBRBM" + self.flags = [] def __add__(self, other): return BBRBM( @@ -64,7 +68,9 @@ def __mul__(self, other): hbias=self.hbias * other, ) - def clone(self, device: torch.device | None = None, dtype: torch.dtype | None = None): + def clone( + self, device: torch.device | str | None = None, dtype: torch.dtype | None = None + ): if device is None: device = self.device if dtype is None: @@ -102,7 +108,7 @@ def compute_energy_visibles(self, v: Tensor) -> Tensor: weight_matrix=self.weight_matrix, ) - def compute_gradient(self, data, chains, centered=True, lambda_l1=0.0, lambda_l2=0.0): + def compute_gradient(self, data, chains, centered=True): _compute_gradient( v_data=data["visible"], mh_data=data["hidden_mag"], @@ -114,8 +120,6 @@ def compute_gradient(self, data, chains, centered=True, lambda_l1=0.0, lambda_l2 hbias=self.hbias, weight_matrix=self.weight_matrix, centered=centered, - lambda_l1=lambda_l1, - lambda_l2=lambda_l2, ) def independent_model(self): @@ -159,25 +163,28 @@ def init_parameters(num_hiddens, dataset, device, dtype, var_init=0.0001): ) return BBRBM(weight_matrix=weight_matrix, vbias=vbias, hbias=hbias) - def named_parameters(self): + def named_parameters(self) -> dict[str, np.ndarray]: return { - "weight_matrix": self.weight_matrix, - "vbias": self.vbias, - "hbias": self.hbias, + "weight_matrix": self.weight_matrix.cpu().numpy(), + "vbias": self.vbias.cpu().numpy(), + "hbias": self.hbias.cpu().numpy(), } + @property def num_hiddens(self): return self.hbias.shape[0] + @property def num_visibles(self): return self.vbias.shape[0] def parameters(self) -> list[Tensor]: return [self.weight_matrix, self.vbias, self.hbias] + @property def ref_log_z(self): return ( - torch.log1p(torch.exp(self.vbias)).sum() + self.num_hiddens() * np.log(2) + torch.log1p(torch.exp(self.vbias)).sum() + self.num_hiddens * np.log(2) ).item() def sample_hiddens(self, chains: dict[str, Tensor], beta=1) -> dict[str, Tensor]: @@ -199,17 +206,23 @@ def sample_visibles(self, chains: dict[str, Tensor], beta=1) -> dict[str, Tensor return chains @staticmethod - def set_named_parameters(named_params: dict[str, Tensor]) -> Self: + def set_named_parameters( + named_params: dict[str, np.ndarray], + device: torch.device | str, + dtype: torch.dtype, + ) -> BBRBM: names = ["vbias", "hbias", "weight_matrix"] - for k in names: - if k not in named_params.keys(): - raise ValueError( - f"""Dictionary params missing key '{k}'\n Provided keys : {named_params.keys()}\n Expected keys: {names}""" - ) + check_keys_dict(d=named_params, names=names) params = BBRBM( - weight_matrix=named_params.pop("weight_matrix"), - vbias=named_params.pop("vbias"), - hbias=named_params.pop("hbias"), + weight_matrix=torch.from_numpy(named_params.pop("weight_matrix")).to( + device=device, dtype=dtype + ), + vbias=torch.from_numpy(named_params.pop("vbias")).to( + device=device, dtype=dtype + ), + hbias=torch.from_numpy(named_params.pop("hbias")).to( + device=device, dtype=dtype + ), ) if len(named_params.keys()) > 0: raise ValueError( @@ -217,7 +230,9 @@ def set_named_parameters(named_params: dict[str, Tensor]) -> Self: ) return params - def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None): + def to( + self, device: torch.device | str | None = None, dtype: torch.dtype | None = None + ): if device is not None: self.device = device if dtype is not None: @@ -226,3 +241,12 @@ def to(self, device: torch.device | None = None, dtype: torch.dtype | None = Non self.vbias = self.vbias.to(device=self.device, dtype=self.dtype) self.hbias = self.hbias.to(device=self.device, dtype=self.dtype) return self + + def get_metrics(self, metrics): + return metrics + + def post_grad_update(self): + pass + + def pre_grad_update(self): + pass diff --git a/rbms/bernoulli_bernoulli/functional.py b/rbms/bernoulli_bernoulli/functional.py index 9f4f348..0b2ff7e 100644 --- a/rbms/bernoulli_bernoulli/functional.py +++ b/rbms/bernoulli_bernoulli/functional.py @@ -116,8 +116,6 @@ def compute_gradient( chains: dict[str, Tensor], params: BBRBM, centered: bool = True, - lambda_l1: float = 0.0, - lambda_l2: float = 0.0, ) -> None: """Compute the gradient for each of the parameters and attach it. @@ -140,8 +138,6 @@ def compute_gradient( hbias=params.hbias, weight_matrix=params.weight_matrix, centered=centered, - lambda_l1=lambda_l1, - lambda_l2=lambda_l2, ) diff --git a/rbms/bernoulli_bernoulli/implement.py b/rbms/bernoulli_bernoulli/implement.py index a2f4d2a..7a289a0 100644 --- a/rbms/bernoulli_bernoulli/implement.py +++ b/rbms/bernoulli_bernoulli/implement.py @@ -1,6 +1,5 @@ import torch from torch import Tensor -from torch.nn.functional import softmax @torch.jit.script @@ -59,7 +58,7 @@ def _compute_energy_hiddens( return -field - log_term.sum(1) -@torch.jit.script +# @torch.jit.script def _compute_gradient( v_data: Tensor, mh_data: Tensor, @@ -71,13 +70,11 @@ def _compute_gradient( hbias: Tensor, weight_matrix: Tensor, centered: bool = True, - lambda_l1: float = 0.0, - lambda_l2: float = 0.0, ) -> None: w_data = w_data.view(-1, 1) w_chain = w_chain.view(-1, 1) # Turn the weights of the chains into normalized weights - chain_weights = softmax(-w_chain, dim=0) + chain_weights = w_chain / w_chain.sum() w_data_norm = w_data.sum() # Averages over data and generated samples @@ -102,11 +99,6 @@ def _compute_gradient( grad_vbias = v_data_mean - v_gen_mean - (grad_weight_matrix @ h_data_mean) grad_hbias = h_data_mean - h_gen_mean - (v_data_mean @ grad_weight_matrix) else: - v_data_centered = v_data - h_data_centered = mh_data - v_gen_centered = v_chain - h_gen_centered = h_chain - # Gradient grad_weight_matrix = ((v_data * w_data).T @ mh_data) / w_data_norm - ( (v_chain * chain_weights).T @ h_chain @@ -114,21 +106,11 @@ def _compute_gradient( grad_vbias = v_data_mean - v_gen_mean grad_hbias = h_data_mean - h_gen_mean - if lambda_l1 > 0: - grad_weight_matrix -= lambda_l1 * torch.sign(weight_matrix) - grad_vbias -= lambda_l1 * torch.sign(vbias) - grad_hbias -= lambda_l1 * torch.sign(hbias) - - if lambda_l2 > 0: - grad_weight_matrix -= 2 * lambda_l2 * weight_matrix - grad_vbias -= 2 * lambda_l2 * vbias - grad_hbias -= 2 * lambda_l2 * hbias - # Attach to the parameters - weight_matrix.grad.set_(grad_weight_matrix) - vbias.grad.set_(grad_vbias) - hbias.grad.set_(grad_hbias) + weight_matrix.grad = grad_weight_matrix + vbias.grad = grad_vbias + hbias.grad = grad_hbias @torch.jit.script diff --git a/rbms/bernoulli_gaussian/__init__.py b/rbms/bernoulli_gaussian/__init__.py new file mode 100644 index 0000000..3743484 --- /dev/null +++ b/rbms/bernoulli_gaussian/__init__.py @@ -0,0 +1,12 @@ +# ruff: noqa +from rbms.bernoulli_gaussian.classes import BGRBM +from rbms.bernoulli_gaussian.functional import ( + compute_energy, + compute_energy_hiddens, + compute_energy_visibles, + compute_gradient, + init_chains, + init_parameters, + sample_hiddens, + sample_visibles, +) diff --git a/rbms/bernoulli_gaussian/classes.py b/rbms/bernoulli_gaussian/classes.py index 1c4a0b3..4c375dc 100644 --- a/rbms/bernoulli_gaussian/classes.py +++ b/rbms/bernoulli_gaussian/classes.py @@ -1,3 +1,6 @@ +from __future__ import annotations +from botocore.vendored.six import u + import numpy as np import torch from torch import Tensor @@ -19,12 +22,14 @@ class BGRBM(RBM): """Bernoulli-Gaussian RBM with fixed hidden variance = 1/Nv, 0-1 visibles, hidden and visible biases""" + visible_type: str = "bernoulli" + def __init__( self, weight_matrix: Tensor, vbias: Tensor, hbias: Tensor, - device: torch.device | None = None, + device: torch.device | str | None = None, dtype: torch.dtype | None = None, ): if device is None: @@ -36,7 +41,8 @@ def __init__( self.weight_matrix = weight_matrix.to(device=self.device, dtype=self.dtype) self.vbias = vbias.to(device=self.device, dtype=self.dtype) self.hbias = hbias.to(device=self.device, dtype=self.dtype) - log_two_pi = torch.log(2.0 * torch.pi, dtype=vbias.dtype, device=vbias.device) + log_two_pi = torch.log(torch.tensor(2.0 * torch.pi, dtype=dtype, device=device)) + self.const = ( 0.5 * float(weight_matrix.shape[1]) @@ -53,6 +59,7 @@ def __init__( ) self.name = "BGRBM" + self.flags = [] def __add__(self, other): # keep fixed variance policy; recompute eta from resulting vbias size @@ -76,7 +83,9 @@ def __mul__(self, other): ) return out - def clone(self, device: torch.device | None = None, dtype: torch.dtype | None = None): + def clone( + self, device: torch.device | str | None = None, dtype: torch.dtype | None = None + ): if device is None: device = self.device if dtype is None: @@ -113,7 +122,7 @@ def compute_energy_visibles(self, v: Tensor) -> Tensor: const=self.const, ) - def compute_gradient(self, data, chains, centered=True, lambda_l1=0.0, lambda_l2=0.0): + def compute_gradient(self, data, chains, centered=True): # backend should ignore grads on eta or treat it as const; we pass it for conditionals _compute_gradient( v_data=data["visible"], @@ -126,8 +135,6 @@ def compute_gradient(self, data, chains, centered=True, lambda_l1=0.0, lambda_l2 hbias=self.hbias, weight_matrix=self.weight_matrix, centered=centered, - lambda_l1=lambda_l1, - lambda_l2=lambda_l2, ) def independent_model(self): @@ -180,24 +187,27 @@ def init_parameters(num_hiddens, dataset, device, dtype, var_init=0.0001): def named_parameters(self): return { - "weight_matrix": self.weight_matrix, - "vbias": self.vbias, - "hbias": self.hbias, + "weight_matrix": self.weight_matrix.cpu().numpy(), + "vbias": self.vbias.cpu().numpy(), + "hbias": self.hbias.cpu().numpy(), } - def num_hiddens(self): + @property + def num_hiddens(self) -> int: return self.hbias.shape[0] - def num_visibles(self): + @property + def num_visibles(self) -> int: return self.vbias.shape[0] def parameters(self) -> list[Tensor]: # keep trainables only return [self.weight_matrix, self.vbias, self.hbias] - def ref_log_z(self): - K = self.num_hiddens() - Nv = self.num_visibles() + @property + def ref_log_z(self) -> float: + K = self.num_hiddens + Nv = self.num_visibles logZ_v = torch.log1p(torch.exp(self.vbias)).sum() inv_gamma = 1.0 / float(Nv) quad = 0.5 * inv_gamma * torch.dot(self.hbias, self.hbias) @@ -223,7 +233,11 @@ def sample_visibles(self, chains: dict[str, Tensor], beta=1) -> dict[str, Tensor return chains @staticmethod - def set_named_parameters(named_params: dict[str, Tensor]) -> "BGRBM": + def set_named_parameters( + named_params: dict[str, np.ndarray], + device: torch.device | str, + dtype: torch.dtype, + ) -> BGRBM: names = ["vbias", "hbias", "weight_matrix"] for k in names: if k not in named_params: @@ -231,9 +245,15 @@ def set_named_parameters(named_params: dict[str, Tensor]) -> "BGRBM": f"""Dictionary params missing key '{k}'\n Provided keys : {named_params.keys()}\n Expected keys: {names}""" ) params = BGRBM( - weight_matrix=named_params.pop("weight_matrix"), - vbias=named_params.pop("vbias"), - hbias=named_params.pop("hbias"), + weight_matrix=torch.from_numpy(named_params.pop("weight_matrix")).to( + device=device, dtype=dtype + ), + vbias=torch.from_numpy(named_params.pop("vbias")).to( + device=device, dtype=dtype + ), + hbias=torch.from_numpy(named_params.pop("hbias")).to( + device=device, dtype=dtype + ), ) if len(named_params) > 0: raise ValueError( @@ -242,7 +262,7 @@ def set_named_parameters(named_params: dict[str, Tensor]) -> "BGRBM": return params def to( - self, device: torch.device | None = None, dtype: torch.dtype | None = None + self, device: torch.device | str | None = None, dtype: torch.dtype | None = None ) -> "BGRBM": if device is not None: self.device = device @@ -252,3 +272,12 @@ def to( self.vbias = self.vbias.to(device=self.device, dtype=self.dtype) self.hbias = self.hbias.to(device=self.device, dtype=self.dtype) return self + + def get_metrics(self, metrics): + return metrics + + def post_grad_update(self): + pass + + def pre_grad_update(self): + pass diff --git a/rbms/bernoulli_gaussian/functional.py b/rbms/bernoulli_gaussian/functional.py index 5434db6..94bf023 100644 --- a/rbms/bernoulli_gaussian/functional.py +++ b/rbms/bernoulli_gaussian/functional.py @@ -79,12 +79,10 @@ def compute_gradient( chains: dict[str, Tensor], params: BGRBM, centered: bool, - lambda_l1: float = 0.0, - lambda_l2: float = 0.0, ) -> None: _compute_gradient( v_data=data["visible"], - mh_data=data["hidden_mag"], # use conditional mean for positive phase + h_data=data["hidden_mag"], # use conditional mean for positive phase w_data=data["weights"], v_chain=chains["visible"], h_chain=chains["hidden_mag"], # negative phase from chain samples @@ -93,8 +91,6 @@ def compute_gradient( hbias=params.hbias, weight_matrix=params.weight_matrix, centered=centered, - lambda_l1=lambda_l1, - lambda_l2=lambda_l2, ) diff --git a/rbms/bernoulli_gaussian/implement.py b/rbms/bernoulli_gaussian/implement.py index a6c81d3..0de1780 100644 --- a/rbms/bernoulli_gaussian/implement.py +++ b/rbms/bernoulli_gaussian/implement.py @@ -1,9 +1,7 @@ import torch from torch import Tensor -from torch.nn.functional import softmax -@torch.jit.script def _sample_hiddens( v: Tensor, weight_matrix: Tensor, hbias: Tensor, beta: float = 1.0 ) -> tuple[Tensor, Tensor]: @@ -12,7 +10,6 @@ def _sample_hiddens( return h, mh -@torch.jit.script def _sample_visibles( h: Tensor, weight_matrix: Tensor, vbias: Tensor, beta: float = 1.0 ) -> tuple[Tensor, Tensor]: @@ -21,7 +18,6 @@ def _sample_visibles( return v, mv -@torch.jit.script def _compute_energy( v: Tensor, h: Tensor, @@ -40,7 +36,6 @@ def _compute_energy( return -fields - interaction + quad -@torch.jit.script def _compute_energy_visibles( v: Tensor, vbias: Tensor, @@ -55,7 +50,6 @@ def _compute_energy_visibles( return -field - quad_term + const -@torch.jit.script def _compute_energy_hiddens( h: Tensor, vbias: Tensor, hbias: Tensor, weight_matrix: Tensor ) -> Tensor: @@ -67,7 +61,6 @@ def _compute_energy_hiddens( return -field - log_term.sum(1) + quad -@torch.jit.script def _compute_gradient( v_data: Tensor, h_data: Tensor, @@ -79,12 +72,10 @@ def _compute_gradient( hbias: Tensor, weight_matrix: Tensor, centered: bool, - lambda_l1: float = 0.0, - lambda_l2: float = 0.0, ) -> None: w_data = w_data.view(-1, 1) w_chain = w_chain.view(-1, 1) - chain_weights = softmax(-w_chain, dim=0) + chain_weights = w_chain / w_chain.sum() w_data_norm = w_data.sum() v_data_mean = (v_data * w_data).sum(0) / w_data_norm @@ -108,11 +99,6 @@ def _compute_gradient( grad_vbias = v_data_mean - v_gen_mean - (grad_weight_matrix @ h_data_mean) grad_hbias = h_data_mean - h_gen_mean - (v_data_mean @ grad_weight_matrix) else: - v_data_centered = v_data - h_data_centered = h_data - v_gen_centered = v_chain - h_gen_centered = h_chain - # Gradient: h_data instead of mh_data grad_weight_matrix = ((v_data * w_data).T @ h_data) / w_data_norm - ( (v_chain * chain_weights).T @ h_chain @@ -120,23 +106,12 @@ def _compute_gradient( grad_vbias = v_data_mean - v_gen_mean grad_hbias = h_data_mean - h_gen_mean - if lambda_l1 > 0: - grad_weight_matrix -= lambda_l1 * torch.sign(weight_matrix) - grad_vbias -= lambda_l1 * torch.sign(vbias) - grad_hbias -= lambda_l1 * torch.sign(hbias) - - if lambda_l2 > 0: - grad_weight_matrix -= 2 * lambda_l2 * weight_matrix - grad_vbias -= 2 * lambda_l2 * vbias - grad_hbias -= 2 * lambda_l2 * hbias - # Attach to the parameters - weight_matrix.grad.set_(grad_weight_matrix) - vbias.grad.set_(grad_vbias) - hbias.grad.set_(grad_hbias) + weight_matrix.grad = grad_weight_matrix + vbias.grad = grad_vbias + hbias.grad = grad_hbias -@torch.jit.script def _init_chains( num_samples: int, weight_matrix: Tensor, diff --git a/rbms/classes.py b/rbms/classes.py index 88a21e3..3e5d9ba 100644 --- a/rbms/classes.py +++ b/rbms/classes.py @@ -3,6 +3,7 @@ from abc import ABC, abstractmethod from typing import Self +import numpy as np import torch from torch import Tensor @@ -13,21 +14,32 @@ class EBM(ABC): """An abstract class representing the parameters of an Energy-Based Model.""" name: str - device: torch.device + device: torch.device | str | None + visible_type: str + flags: list[str] @abstractmethod def __init__(self): ... @abstractmethod def __add__(self, other: EBM) -> EBM: - """Add the parameters of two RBMs. Useful for interpolation""" + """Add the parameters of two EBMs. Useful for interpolation""" ... @abstractmethod def __mul__(self, other: float) -> EBM: - """Multiplies the parameters of the RBM by a float.""" + """Multiplies the parameters of the EBM by a float.""" ... + def __eq__(self, other: object) -> bool: + if not isinstance(other, EBM): + return False + other_params = other.named_parameters() + for k, v in self.named_parameters().items(): + if not np.equal(other_params[k], v): + return False + return True + @abstractmethod def sample_visibles( self, chains: dict[str, Tensor], beta: float = 1.0 @@ -62,7 +74,7 @@ def init_chains( weights: Tensor | None = None, start_v: Tensor | None = None, ) -> dict[str, Tensor]: - """Initialize a Markov chain for the RBM by sampling a uniform distribution on the visible layer + """Initialize a Markov chain for the EBM by sampling a uniform distribution on the visible layer and sampling the hidden layer according to the visible one. Args: @@ -83,8 +95,6 @@ def compute_gradient( data: dict[str, Tensor], chains: dict[str, Tensor], centered: bool = True, - lambda_l1: float = 0.0, - lambda_l2: float = 0.0, ) -> None: """Compute the gradient for each of the parameters and attach it. @@ -107,16 +117,20 @@ def parameters(self) -> list[Tensor]: ... @abstractmethod - def named_parameters(self) -> dict[str, Tensor]: ... + def named_parameters(self) -> dict[str, np.ndarray]: ... @staticmethod @abstractmethod - def set_named_parameters(named_params: dict[str, Tensor]) -> EBM: ... + def set_named_parameters( + named_params: dict[str, np.ndarray], + device: torch.device | str, + dtype: torch.dtype, + ) -> EBM: ... @abstractmethod def to( self, - device: torch.device | None = None, + device: torch.device | str | None = None, dtype: torch.dtype | None = None, ) -> Self: """Move the parameters to the specified device and/or convert them to the specified data type. @@ -154,7 +168,7 @@ def clone( def init_parameters( num_hiddens: int, dataset: RBMDataset, - device: torch.device, + device: torch.device | str, dtype: torch.dtype, var_init: float = 1e-4, ) -> EBM: @@ -175,11 +189,13 @@ def init_parameters( """ ... + @property @abstractmethod def num_visibles(self) -> int: """Number of visible units""" ... + @property @abstractmethod def ref_log_z(self) -> float: """Reference log partition function with weights set to 0 (except for the visible bias).""" @@ -209,12 +225,43 @@ def init_grad(self) -> None: for p in self.parameters(): p.grad = torch.zeros_like(p) + @torch.compile def normalize_grad(self) -> None: norm_grad = torch.sqrt( torch.sum(torch.tensor([p.grad.square().sum() for p in self.parameters()])) ) for p in self.parameters(): p.grad /= norm_grad + # for p in self.parameters(): + # p.grad /= p.grad.norm() + + def clip_grad(self, max_norm=5): + for p in self.parameters(): + if p.grad is not None: + grad_norm = p.grad.norm() + if grad_norm > max_norm: + p.grad /= grad_norm + p.grad *= max_norm + + def save_flags(self, flags: list[str]) -> list[str]: + if len(self.flags) > 0: + for elt in self.flags: + flags.append(elt) + self.flags = [] + return flags + + @abstractmethod + def get_metrics(self, metrics: dict[str, float]) -> dict[str, float]: ... + + @abstractmethod + def pre_grad_update(self) -> None: ... + + @abstractmethod + def post_grad_update(self) -> None: ... + + @property + @abstractmethod + def effective_number_variables(self) -> float: ... class RBM(EBM): @@ -257,6 +304,7 @@ def compute_energy_hiddens(self, h: Tensor) -> Tensor: """ ... + @property @abstractmethod def num_hiddens(self) -> int: """Number of hidden units""" @@ -272,3 +320,54 @@ def sample_state(self, chains, n_steps, beta=1.0): new_chains = self.sample_visibles(chains=new_chains, beta=beta) new_chains = self.sample_hiddens(chains=new_chains, beta=beta) return new_chains + + @property + def effective_number_variables(self) -> float: + return np.sqrt(self.num_visibles * self.num_hiddens) + + +class Sampler(ABC): + name: str + flags: list[str] + + @abstractmethod + def __init__(self): ... + + @abstractmethod + def get_conf_grad(self, batch: Tensor) -> dict[str, Tensor]: ... + + @abstractmethod + def sample(self, num_steps: int | None, **kwargs) -> None: ... + + def save_flags(self, flags: list[str]) -> list[str]: + if len(self.flags) > 0: + for elt in self.flags: + flags.append(elt) + self.flags = [] + return flags + + @abstractmethod + def named_parameters(self) -> dict[str, np.ndarray]: ... + + @staticmethod + @abstractmethod + def set_named_parameters( + named_params: dict[str, np.ndarray], + map_model: dict[str, type[EBM]], + device: torch.device | str, + dtype: torch.dtype, + ) -> Sampler: ... + + @abstractmethod + def pre_grad_update(self) -> None: ... + + @abstractmethod + def post_grad_update(self, params: EBM) -> None: ... + + @abstractmethod + def get_metrics_display( + self, metrics: dict[str, float], **kwargs + ) -> dict[str, float]: ... + + @abstractmethod + def get_metrics_save(self) -> dict[str, np.ndarray] | None: ... diff --git a/rbms/correlations.py b/rbms/correlations.py index 7bf65a3..8bdc7ee 100644 --- a/rbms/correlations.py +++ b/rbms/correlations.py @@ -48,10 +48,8 @@ def compute_2b_correlations( ) if full_mat: res = torch.triu(res, 1) + torch.tril(res).T - return res / torch.sqrt( - torch.diag(res).unsqueeze(1) @ torch.diag(res).unsqueeze(0) - ) - return torch.corrcoef(data) + return res # / torch.sqrt(torch.diag(res).unsqueeze(1) @ torch.diag(res).unsqueeze(0)) + return torch.corrcoef(data.T) @torch.jit.script @@ -102,7 +100,7 @@ def compute_3b_correlations( res = _3b_batched( centered_data=centered_data, weights=weights.unsqueeze(1), - batcu_size=batch_size, + batch_size=batch_size, ) if full_mat: res = _3b_full_mat(res) diff --git a/rbms/custom_fn.py b/rbms/custom_fn.py index d721ce6..eabb39e 100644 --- a/rbms/custom_fn.py +++ b/rbms/custom_fn.py @@ -39,3 +39,11 @@ def log2cosh(x: Tensor) -> Tensor: Tensor: Output tensor. """ return torch.abs(x) + torch.log1p(torch.exp(-2 * torch.abs(x))) + + +def check_keys_dict(d: dict, names: list[str]): + for k in names: + if k not in d.keys(): + raise ValueError( + f"""Dictionary params missing key '{k}'\n Provided keys : {d.keys()}\n Expected keys: {names}""" + ) diff --git a/rbms/dataset/__init__.py b/rbms/dataset/__init__.py index 86324d4..5191420 100644 --- a/rbms/dataset/__init__.py +++ b/rbms/dataset/__init__.py @@ -15,7 +15,8 @@ def load_dataset( subset_labels: list[int] | None = None, use_weights: bool = False, alphabet="protein", - device: str = "cpu", + remove_duplicates: bool = False, + device: torch.device | str = "cpu", dtype: torch.dtype = torch.float32, ) -> tuple[RBMDataset, RBMDataset | None]: return_datasets = [] @@ -45,7 +46,7 @@ def load_dataset( variable_type = "categorical" # Select subset of dataset w.r.t. labels if subset_labels is not None and labels is not None: - data, labels = get_subset_labels(data, labels, subset_labels) + data, labels = get_subset_labels(data, labels, np.asarray(subset_labels)) if weights is None: weights = np.ones(data.shape[0]) @@ -54,10 +55,15 @@ def load_dataset( if labels is None: labels = -np.ones(data.shape[0]) - # Remove duplicates and internally shuffle the dataset - unique_ind = get_unique_indices(torch.from_numpy(data)).cpu().numpy() + if remove_duplicates: + # Remove duplicates and internally shuffle the dataset + unique_ind = get_unique_indices(torch.from_numpy(data)).cpu().numpy() + else: + unique_ind = np.arange(data.shape[0]) idx = torch.randperm(unique_ind.shape[0]) + if unique_ind.shape[0] < data.shape[0]: + print(f"N_samples: {data.shape[0]} -> {unique_ind.shape[0]}") data = data[unique_ind[idx]] labels = labels[unique_ind[idx]] weights = weights[unique_ind[idx]] diff --git a/rbms/dataset/dataset_class.py b/rbms/dataset/dataset_class.py index f0c678c..d73aa69 100644 --- a/rbms/dataset/dataset_class.py +++ b/rbms/dataset/dataset_class.py @@ -1,12 +1,15 @@ +from __future__ import annotations import gzip import textwrap -from typing import Self, Union -from rbms.dataset.utils import convert_data +from typing import Union import numpy as np import torch +from torch import Tensor from torch.utils.data import Dataset -from tqdm import tqdm +from tqdm.autonotebook import tqdm + +from rbms.dataset.utils import convert_data class RBMDataset(Dataset): @@ -20,7 +23,7 @@ def __init__( names: np.ndarray, dataset_name: str, variable_type: str, - device: str = "cuda", + device: torch.device | str = "cuda", dtype: torch.dtype = torch.float32, ) -> None: # names should stay as a np array as its dtype is object @@ -117,22 +120,30 @@ def get_gzip_entropy(self, mean_size: int = 50, num_samples: int = 100): for i in pbar: en[i] = len( gzip.compress( - (self.data[torch.randperm(self.data.shape[0])[:num_samples]]).astype( - int - ) + (self.data[torch.randperm(self.data.shape[0])[:num_samples]]) + .cpu() + .numpy() + .astype(int) ) ) return np.mean(en) def match_model_variable_type(self, visible_type: str): self.data = convert_data[self.variable_type][visible_type](self.data) + if self.variable_type != visible_type: + print(f"Converting from '{self.variable_type}' to '{visible_type}'") + print(self.data) + self.variable_type = visible_type + + def astype(self, target_variable_type: str): + return convert_data[self.variable_type][target_variable_type](self.data) def split_train_test( self, rng: np.random.Generator, train_size: float, test_size: float | None = None, - ) -> tuple[Self, Self | None]: + ) -> tuple[RBMDataset, RBMDataset]: num_samples = self.data.shape[0] if test_size is None: test_size = 1.0 - train_size @@ -172,4 +183,21 @@ def split_train_test( device=self.device, dtype=self.dtype, ) + else: + raise ValueError("Could not split in train test") return train_dataset, test_dataset + + def batch(self, batch_size: int) -> dict[str, Tensor]: + rand_idx = torch.randperm(len(self))[:batch_size] + sampled_batch: dict[str, Tensor] = { + "data": self.data[rand_idx], + "weights": self.weights[rand_idx], + "labels": self.labels[rand_idx], + } + # sampled_batch = self[rand_idx[:batch_size]] + match self.variable_type: + case "bernoulli": + sampled_batch["data"] = torch.bernoulli(sampled_batch["data"]) + case _: + pass + return sampled_batch diff --git a/rbms/dataset/fasta_utils.py b/rbms/dataset/fasta_utils.py index 1123e2f..2f17295 100644 --- a/rbms/dataset/fasta_utils.py +++ b/rbms/dataset/fasta_utils.py @@ -4,14 +4,14 @@ import numpy as np import torch -ArrayLike = tuple[np.ndarray, list] +# ArrayLike = tuple[np.ndarray, list] TOKENS_PROTEIN = "-ACDEFGHIKLMNPQRSTVWY" TOKENS_RNA = "-ACGU" TOKENS_DNA = "-ACGT" -def get_tokens(alphabet: str): +def get_tokens(alphabet: str) -> str: """Load the vocabulary associated to the alphabet type. Args: alphabet (str): alphabet type (one of 'protein', 'rna', 'dna'). @@ -44,7 +44,7 @@ def encode_sequence(sequence: str, tokens: str) -> np.ndarray: return np.array([letter_map[letter] for letter in sequence]) -def decode_sequence(sequence: ArrayLike, tokens: str) -> str: +def decode_sequence(sequence: np.ndarray, tokens: str) -> str: """Takes a numeric sequence in input an returns the string encoding. Args: @@ -98,8 +98,8 @@ def import_from_fasta(fasta_name: Union[str, Path]) -> tuple[np.ndarray, np.ndar def write_fasta( fname: str, - headers: ArrayLike, - sequences: ArrayLike, + headers: np.ndarray, + sequences: np.ndarray, numeric_input: bool = False, remove_gaps: bool = False, alphabet: str = "protein", @@ -137,7 +137,7 @@ def write_fasta( def compute_weights( - data: ArrayLike, th: float = 0.8, device: torch.device = "cpu" + data: np.ndarray, th: float = 0.8, device: torch.device | str = "cpu" ) -> np.ndarray: """Computes the weight to be assigned to each sequence 's' in 'data' as 1 / n_clust, where 'n_clust' is the number of sequences that have a sequence identity with 's' >= th. @@ -151,20 +151,20 @@ def compute_weights( np.ndarray: Array with the weights of the sequences. """ device = torch.device(device) - data = torch.tensor(data, device=device) - assert len(data.shape) == 2, "'data' must be a 2-dimensional array" - _, L = data.shape + data_tensor = torch.from_numpy(data).to(device=device) + assert len(data_tensor) == 2, "'data' must be a 2-dimensional array" + _, L = data_tensor.shape def get_sequence_weight(s: torch.Tensor, data: torch.Tensor, L: int, th: float): seq_id = torch.sum(s == data, dim=1) / L n_clust = torch.sum(seq_id >= th) return 1.0 / n_clust - weights = torch.vstack([get_sequence_weight(s, data, L, th) for s in data]) + weights = torch.vstack([get_sequence_weight(s, data_tensor, L, th) for s in data]) return weights.cpu().numpy() -def validate_alphabet(sequences: ArrayLike, tokens: str): +def validate_alphabet(sequences: np.ndarray, tokens: str): all_char = "".join(sequences) tokens_data = "".join(sorted(set(all_char))) for c in tokens_data: diff --git a/rbms/dataset/load_fasta.py b/rbms/dataset/load_fasta.py index d5794bd..7bfab0b 100644 --- a/rbms/dataset/load_fasta.py +++ b/rbms/dataset/load_fasta.py @@ -1,7 +1,7 @@ from pathlib import Path import numpy as np -# import torch +import torch # from rbms.custom_fn import one_hot from rbms.dataset.fasta_utils import ( @@ -17,7 +17,7 @@ def load_FASTA( filename: str | Path, use_weights: bool = False, alphabet: str = "protein", - device="cuda", + device: torch.device | str = "cuda", ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: """Load a dataset from a FASTA file. diff --git a/rbms/dataset/load_h5.py b/rbms/dataset/load_h5.py index 954edd6..9d26444 100644 --- a/rbms/dataset/load_h5.py +++ b/rbms/dataset/load_h5.py @@ -2,14 +2,16 @@ import h5py import numpy as np +import torch + from rbms.dataset.fasta_utils import compute_weights def load_HDF5( filename: str | Path, use_weights: bool = False, - device: str = "cuda", -) -> tuple[np.ndarray, np.ndarray | None, str]: + device: torch.device | str = "cuda", +) -> tuple[np.ndarray, np.ndarray | None, str, np.ndarray]: """Load a dataset from an HDF5 file. Args: @@ -19,7 +21,7 @@ def load_HDF5( Tuple[np.ndarray, np.ndarray]: The dataset and labels. """ labels = None - variable_type = "binary" + variable_type = "bernoulli" with h5py.File(filename, "r") as f: if "samples" not in f.keys(): raise ValueError( @@ -28,10 +30,10 @@ def load_HDF5( dataset = np.array(f["samples"][()]) if "variable_type" not in f.keys(): print( - f"No variable_type found in the hdf5 file keys: {f.keys()}. Assuming 'binary'." + f"No variable_type found in the hdf5 file keys: {f.keys()}. Assuming 'bernoulli'." ) print( - "Set a 'variable_type' with value 'binary', 'categorical' or 'continuous' in the hdf5 archive to remove this message" + "Set a 'variable_type' with value 'bernoulli', 'ising', 'categorical' or 'continuous' in the hdf5 archive to remove this message" ) else: variable_type = f["variable_type"][()].decode() diff --git a/rbms/dataset/parser.py b/rbms/dataset/parser.py index 1344392..0537576 100644 --- a/rbms/dataset/parser.py +++ b/rbms/dataset/parser.py @@ -51,6 +51,12 @@ def add_args_dataset(parser: argparse.ArgumentParser) -> argparse.ArgumentParser default="protein", help="(Defaults to protein). Type of encoding for the sequences. Choose among ['protein', 'rna', 'dna'] or a user-defined string of tokens.", ) + dataset_args.add_argument( + "--remove_duplicates", + default=False, + action="store_true", + help="Remove duplicates from the dataset before splitting.", + ) dataset_args.add_argument( "--seed", default=None, diff --git a/rbms/dataset/utils.py b/rbms/dataset/utils.py index e8d7f64..98f4faf 100644 --- a/rbms/dataset/utils.py +++ b/rbms/dataset/utils.py @@ -1,7 +1,10 @@ +from collections.abc import Callable import numpy as np import torch from torch import Tensor +from rbms.custom_fn import one_hot + def get_subset_labels( data: np.ndarray, labels: np.ndarray, subset_labels: np.ndarray @@ -47,29 +50,73 @@ def ising_to_bernoulli(x): return (x + 1) / 2 -def bernoulli_to_categorical(x): - pass - - def categorical_to_bernoulli(x): - pass + return one_hot(x.long()).reshape(x.shape[0], -1) -convert_data = { +convert_data: dict[str, dict[str, Callable[[Tensor], Tensor]]] = { "bernoulli": { "bernoulli": (lambda x: x), "ising": (lambda x: bernoulli_to_ising(x)), - "categorical": (lambda x: bernoulli_to_categorical(x)), - # "continuous": lambda x: raise ValueError("Cannot convert from 'bernoulli' to 'continuous' data.") + "categorical": (lambda x: x), + "continuous": (lambda x: x), }, "ising": { "bernoulli": (lambda x: ising_to_bernoulli(x)), "ising": (lambda x: x), - "categorical": (lambda x: bernoulli_to_categorical(ising_to_bernoulli(x))), + "categorical": (lambda x: ising_to_bernoulli(x)), }, "categorical": { "bernoulli": (lambda x: categorical_to_bernoulli(x)), "ising": (lambda x: bernoulli_to_ising(categorical_to_bernoulli(x))), "categorical": (lambda x: x), }, + "continuous": {"bernoulli": (lambda x: x), "continuous": (lambda x: x)}, } + + +def get_covariance_matrix( + data: Tensor, + weights: Tensor | None = None, + num_extract: int | None = None, + center: bool = True, + device: torch.device = torch.device("cpu"), +) -> Tensor: + """Returns the covariance matrix of the data. If weights is specified, the weighted covariance matrix is computed. + + Args: + data (Tensor): Data. + weights (Tensor, optional): Weights of the data. Defaults to None. + num_extract (int, optional): Number of data to extract to compute the covariance matrix. Defaults to None. + center (bool): Center the data. Defaults to True. + device (torch.device): Device. Defaults to 'cpu'. + dtype (torch.dtype): DType. Defaults to torch.float32. + + Returns: + Tensor: Covariance matrix of the dataset. + """ + num_data = len(data) + num_classes = int(data.max().item() + 1) + + if weights is None: + weights = torch.ones(num_data) + weights = weights.to(device=device, dtype=torch.float32) + + if num_extract is not None: + idxs = np.random.choice(a=np.arange(num_data), size=(num_extract,), replace=False) + data = data[idxs] + weights = weights[idxs] + num_data = num_extract + + if num_classes != 2: + data = data.to(device=device, dtype=torch.int32) + data_oh = one_hot(data, num_classes=num_classes).reshape(num_data, -1) + else: + data_oh = data.to(device=device, dtype=torch.float32) + + norm_weights = weights.reshape(-1, 1) / weights.sum() + data_mean = (data_oh * norm_weights).sum(0, keepdim=True) + cov_matrix = ((data_oh * norm_weights).mT @ data_oh) - int(center) * ( + data_mean.mT @ data_mean + ) + return cov_matrix diff --git a/rbms/io.py b/rbms/io.py index b92957f..24a4b1c 100644 --- a/rbms/io.py +++ b/rbms/io.py @@ -3,17 +3,19 @@ import torch from torch import Tensor -from rbms.classes import EBM +from rbms.classes import EBM, Sampler from rbms.map_model import map_model from rbms.utils import restore_rng_state +@torch.compiler.disable def save_model( filename: str, params: EBM, chains: dict[str, Tensor], num_updates: int, time: float, + learning_rate: Tensor, flags: list[str] = [], ) -> None: """Save the current state of the model. @@ -35,7 +37,7 @@ def save_model( # Save the parameters of the model params_ckpt = checkpoint.create_group("params") for n, p in named_params.items(): - params_ckpt[n] = p.detach().cpu().numpy() + params_ckpt[n] = p # This is for retrocompatibility purpose checkpoint[n] = params_ckpt[n] # Save current random state @@ -46,7 +48,7 @@ def save_model( checkpoint["numpy_rng_arg3"] = np.random.get_state()[3] checkpoint["numpy_rng_arg4"] = np.random.get_state()[4] checkpoint["time"] = time - + checkpoint["learning_rate"] = learning_rate.cpu().numpy() # Update the parallel chains to resume training if "parallel_chains" in f.keys(): f["parallel_chains"][...] = chains["visible"].cpu().numpy() @@ -65,9 +67,9 @@ def save_model( def load_params( filename: str, index: int, - device: torch.device, + device: torch.device | str, dtype: torch.dtype, - map_model: dict[str, EBM] = map_model, + map_model: dict[str, type[EBM]] = map_model, ) -> EBM: """Load the parameters of the RBM from the specified archive at the given update index. @@ -84,21 +86,19 @@ def load_params( params = {} with h5py.File(filename, "r") as f: for k in f[last_file_key]["params"].keys(): - params[k] = torch.from_numpy(f[last_file_key]["params"][k][()]).to( - device=device, dtype=dtype - ) + params[k] = f[last_file_key]["params"][k][()] model_type = f["model_type"][()].decode() - return map_model[model_type].set_named_parameters(params) + return map_model[model_type].set_named_parameters(params, device=device, dtype=dtype) def load_model( filename: str, index: int, - device: torch.device, + device: torch.device | str, dtype: torch.dtype, restore: bool = False, - map_model: dict[str, EBM] = map_model, -) -> tuple[EBM, dict[str, Tensor], float, dict]: + map_model: dict[str, type[EBM]] = map_model, +) -> tuple[EBM, dict[str, Tensor], float]: """Load a RBM from a h5 archive. Args: @@ -111,10 +111,9 @@ def load_model( Returns: Tuple[EBM, dict[str, Tensor], float, dict]: A tuple containing the loaded RBM parameters, - the parallel chains, the time taken, and the model's hyperparameters. + the parallel chains and the time taken """ last_file_key = f"update_{index}" - hyperparameters = dict() with h5py.File(filename, "r") as f: visible = torch.from_numpy(f["parallel_chains"][()]).to( device=device, dtype=dtype @@ -122,21 +121,6 @@ def load_model( # Elapsed time start = np.array(f[last_file_key]["time"]).item() - # Hyperparameters - if "hyperparameters" in f.keys(): - hyperparameters["batch_size"] = int(f["hyperparameters"]["batch_size"][()]) - hyperparameters["gibbs_steps"] = int(f["hyperparameters"]["gibbs_steps"][()]) - hyperparameters["learning_rate"] = float( - f["hyperparameters"]["learning_rate"][()] - ) - hyperparameters["L1"] = float(f["hyperparameters"]["L1"][()]) - hyperparameters["L2"] = float(f["hyperparameters"]["L2"][()]) - if "seed" in f["hyperparameters"].keys(): - hyperparameters["seed"] = int(f["hyperparameters"]["seed"][()]) - if "train_size" in f["hyperparameters"].keys(): - hyperparameters["train_size"] = float( - f["hyperparameters"]["train_size"][()] - ) params = load_params( filename=filename, index=index, device=device, dtype=dtype, map_model=map_model ) @@ -144,4 +128,25 @@ def load_model( if restore: restore_rng_state(filename=filename, index=index) - return (params, perm_chains, start, hyperparameters) + return (params, perm_chains, start) + + +def save_sampler(filename: str, sampler: Sampler, update: int): + named_params = sampler.named_parameters() + metrics = sampler.get_metrics_save() + name = sampler.name + with h5py.File(filename, "a") as f: + if "sampler" not in f.keys(): + f.create_group("sampler") + f["sampler"]["name"] = name + + # Save the parameters of the model + for n, p in named_params.items(): + if n in f["sampler"].keys(): + f["sampler"][n][...] = p + else: + f["sampler"][n] = p + + if metrics is not None: + for n, p in metrics.items(): + f[f"update_{update}"][n] = p diff --git a/rbms/ising_gaussian/classes.py b/rbms/ising_gaussian/classes.py index 3ba5042..93c3a2c 100644 --- a/rbms/ising_gaussian/classes.py +++ b/rbms/ising_gaussian/classes.py @@ -1,9 +1,13 @@ -from typing import List, Optional +from __future__ import annotations + +from typing import List import numpy as np import torch from torch import Tensor +from rbms.classes import RBM +from rbms.custom_fn import check_keys_dict, log2cosh from rbms.ising_gaussian.implement import ( _compute_energy, _compute_energy_hiddens, @@ -14,26 +18,27 @@ _sample_hiddens, _sample_visibles, ) -from rbms.classes import RBM class IGRBM(RBM): - """Ising-Gaussian RBM with fixed hidden variance = 1/Nv, \pm 1 visibles, without any bias""" + """Ising-Gaussian RBM with fixed hidden variance = 1/Nv, +- 1 visibles, without any bias""" + + visible_type: str = "ising" def __init__( self, weight_matrix: Tensor, vbias: Tensor, hbias: Tensor, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, + device: torch.device | str | None = None, + dtype: torch.dtype | None = None, ): if device is None: device = weight_matrix.device if dtype is None: dtype = weight_matrix.dtype - self.device, self.dtype = device, dtype + self.device, self.dtype = device, dtype self.weight_matrix = weight_matrix.to(device=self.device, dtype=self.dtype) self.vbias = vbias.to(device=self.device, dtype=self.dtype) self.hbias = hbias.to(device=self.device, dtype=self.dtype) @@ -41,39 +46,42 @@ def __init__( log_two_pi = torch.log(torch.tensor(2.0 * torch.pi, dtype=dtype, device=device)) const = ( 0.5 - * float(self.weight_matrix[1]) + * float(self.weight_matrix.shape[1]) * ( - torch.log( - torch.tensor(float(self.weight_matrix[0]), dtype=dtype, device=device) + -torch.log( + torch.tensor( + float(self.weight_matrix.shape[0]), dtype=dtype, device=device + ) ) - - log_two_pi + + log_two_pi ) ) self.const = const self.name = "IGRBM" + self.flags = [] def __add__(self, other): - out = IGRBM( + return IGRBM( weight_matrix=self.weight_matrix + other.weight_matrix, vbias=self.vbias + other.vbias, hbias=self.hbias + other.hbias, device=self.device, dtype=self.dtype, ) - return out def __mul__(self, other): - out = IGRBM( + return IGRBM( weight_matrix=self.weight_matrix * other, vbias=self.vbias * other, hbias=self.hbias * other, device=self.device, dtype=self.dtype, ) - return out def clone( - self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None + self, + device: torch.device | str | None = None, + dtype: torch.dtype | None = None, ): if device is None: device = self.device @@ -99,13 +107,17 @@ def compute_energy_hiddens(self, h: Tensor) -> Tensor: def compute_energy_visibles(self, v: Tensor) -> Tensor: return _compute_energy_visibles( - v=v, vbias=self.vbias, hbias=self.hbias, weight_matrix=self.weight_matrix + v=v, + vbias=self.vbias, + hbias=self.hbias, + weight_matrix=self.weight_matrix, + const=self.const, ) - def compute_gradient(self, data, chains, centered=True, lambda_l1=0.0, lambda_l2=0.0): + def compute_gradient(self, data, chains, centered=True): _compute_gradient( v_data=data["visible"], - h_data=data["hidden_mag"], + mh_data=data["hidden_mag"], w_data=data["weights"], v_chain=chains["visible"], h_chain=chains["hidden_mag"], @@ -114,15 +126,13 @@ def compute_gradient(self, data, chains, centered=True, lambda_l1=0.0, lambda_l2 hbias=self.hbias, weight_matrix=self.weight_matrix, centered=centered, - lambda_l1=lambda_l1, - lambda_l2=lambda_l2, ) def independent_model(self): return IGRBM( weight_matrix=torch.zeros_like(self.weight_matrix), vbias=self.vbias, - hbias=torch.zeros_like(self.hbias), + hbias=self.hbias, # torch.zeros_like(self.hbias), device=self.device, dtype=self.dtype, ) @@ -168,25 +178,31 @@ def init_parameters(num_hiddens, dataset, device, dtype, var_init=0.0001): def named_parameters(self): return { - "weight_matrix": self.weight_matrix, - "vbias": self.vbias, - "hbias": self.hbias, + "weight_matrix": self.weight_matrix.cpu().numpy(), + "vbias": self.vbias.cpu().numpy(), + "hbias": self.hbias.cpu().numpy(), } + @property def num_hiddens(self): return self.hbias.shape[0] + @property def num_visibles(self): return self.vbias.shape[0] def parameters(self) -> List[Tensor]: return [self.weight_matrix, self.vbias, self.hbias] + @property def ref_log_z(self): - K = self.num_hiddens() - logZ_v = torch.log1p(torch.exp(self.vbias)).sum() - quad = 0.5 * torch.dot(self.hbias, self.hbias) / float(self.num_visibles()) - log_norm = 0.5 * K * np.log(2.0 * np.pi) - 0.5 * K * np.log(float(self.num_visibles())) + K = self.num_hiddens + # logZ_v = torch.log1p(torch.exp(self.vbias)).sum() + logZ_v = log2cosh(self.vbias).sum() + quad = 0.5 * torch.dot(self.hbias, self.hbias) / float(self.num_visibles) + log_norm = 0.5 * K * np.log(2.0 * np.pi) - 0.5 * K * np.log( + float(self.num_visibles) + ) return (logZ_v + quad + log_norm).item() def sample_hiddens(self, chains: dict[str, Tensor], beta=1) -> dict[str, Tensor]: @@ -208,7 +224,11 @@ def sample_visibles(self, chains: dict[str, Tensor], beta=1) -> dict[str, Tensor return chains @staticmethod - def set_named_parameters(named_params: dict[str, Tensor]) -> "IGRBM": + def set_named_parameters( + named_params: dict[str, np.ndarray], + device: torch.device | str, + dtype: torch.dtype, + ) -> IGRBM: names = ["vbias", "hbias", "weight_matrix"] for k in names: if k not in named_params: @@ -216,9 +236,15 @@ def set_named_parameters(named_params: dict[str, Tensor]) -> "IGRBM": f"""Dictionary params missing key '{k}'\n Provided keys : {named_params.keys()}\n Expected keys: {names}""" ) params = IGRBM( - weight_matrix=named_params.pop("weight_matrix"), - vbias=named_params.pop("vbias"), - hbias=named_params.pop("hbias"), + weight_matrix=torch.from_numpy(named_params.pop("weight_matrix")).to( + device=device, dtype=dtype + ), + vbias=torch.from_numpy(named_params.pop("vbias")).to( + device=device, dtype=dtype + ), + hbias=torch.from_numpy(named_params.pop("hbias")).to( + device=device, dtype=dtype + ), ) if len(named_params) > 0: raise ValueError( @@ -227,7 +253,9 @@ def set_named_parameters(named_params: dict[str, Tensor]) -> "IGRBM": return params def to( - self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None + self, + device: torch.device | str | None = None, + dtype: torch.dtype | None = None, ) -> "IGRBM": if device is not None: self.device = device @@ -237,3 +265,12 @@ def to( self.vbias = self.vbias.to(device=self.device, dtype=self.dtype) self.hbias = self.hbias.to(device=self.device, dtype=self.dtype) return self + + def get_metrics(self, metrics): + return metrics + + def post_grad_update(self): + pass + + def pre_grad_update(self): + pass diff --git a/rbms/ising_gaussian/functional.py b/rbms/ising_gaussian/functional.py index d6be08b..8a7b3d6 100644 --- a/rbms/ising_gaussian/functional.py +++ b/rbms/ising_gaussian/functional.py @@ -58,6 +58,7 @@ def compute_energy_visibles(v: Tensor, params: IGRBM) -> Tensor: vbias=params.vbias, hbias=params.hbias, weight_matrix=params.weight_matrix, + const=params.const, ) @@ -75,8 +76,6 @@ def compute_gradient( chains: dict[str, Tensor], params: IGRBM, centered: bool, - lambda_l1: float = 0.0, - lambda_l2: float = 0.0, ) -> None: _compute_gradient( v_data=data["visible"], @@ -89,8 +88,6 @@ def compute_gradient( hbias=params.hbias, weight_matrix=params.weight_matrix, centered=centered, - lambda_l1=lambda_l1, - lambda_l2=lambda_l2, ) diff --git a/rbms/ising_gaussian/implement.py b/rbms/ising_gaussian/implement.py index 0d440b1..2f208aa 100644 --- a/rbms/ising_gaussian/implement.py +++ b/rbms/ising_gaussian/implement.py @@ -3,18 +3,20 @@ import torch from torch import Tensor from torch.nn.functional import softmax +from rbms.custom_fn import log2cosh -@torch.jit.script def _sample_hiddens( v: Tensor, weight_matrix: Tensor, hbias: Tensor, beta: float = 1.0 ) -> Tuple[Tensor, Tensor]: mh = hbias + (v @ weight_matrix) - h = torch.randn_like(mh) + mh + h = ( + torch.randn_like(mh) / torch.sqrt(torch.ones_like(mh) * weight_matrix.shape[0]) + + mh + ) return h, mh -@torch.jit.script def _sample_visibles( h: Tensor, weight_matrix: Tensor, vbias: Tensor, beta: float = 1.0 ) -> Tuple[Tensor, Tensor]: @@ -24,7 +26,6 @@ def _sample_visibles( return v, mv -@torch.jit.script def _compute_energy( v: Tensor, h: Tensor, @@ -42,7 +43,6 @@ def _compute_energy( return -fields - interaction + quad -@torch.jit.script def _compute_energy_visibles( v: Tensor, vbias: Tensor, hbias: Tensor, weight_matrix: Tensor, const: Tensor ) -> Tensor: @@ -52,21 +52,20 @@ def _compute_energy_visibles( return -field - quad_term + const -@torch.jit.script def _compute_energy_hiddens( h: Tensor, vbias: Tensor, hbias: Tensor, weight_matrix: Tensor ) -> Tensor: field = h @ hbias exponent = vbias + (h @ weight_matrix.T) - log_term = torch.where(exponent < 10, torch.log1p(torch.exp(exponent)), exponent) + # log_term = torch.where(exponent < 10, torch.log1p(torch.exp(exponent)), exponent) + log_term = log2cosh(exponent) quad = 0.5 * float(weight_matrix.shape[0]) * (h * h).sum(1) return -field - log_term.sum(1) + quad -@torch.jit.script def _compute_gradient( v_data: Tensor, - h_data: Tensor, + mh_data: Tensor, w_data: Tensor, v_chain: Tensor, h_chain: Tensor, @@ -75,8 +74,6 @@ def _compute_gradient( hbias: Tensor, weight_matrix: Tensor, centered: bool, - lambda_l1: float = 0.0, - lambda_l2: float = 0.0, ) -> None: w_data = w_data.view(-1, 1) w_chain = w_chain.view(-1, 1) @@ -85,13 +82,13 @@ def _compute_gradient( v_data_mean = (v_data * w_data).sum(0) / w_data_norm torch.clamp_(v_data_mean, min=1e-4, max=(1.0 - 1e-4)) - h_data_mean = (h_data * w_data).sum(0) / w_data_norm + h_data_mean = (mh_data * w_data).sum(0) / w_data_norm v_gen_mean = v_chain.mean(0) torch.clamp_(v_gen_mean, min=1e-4, max=(1.0 - 1e-4)) if centered: v_data_centered = v_data - v_data_mean - h_data_centered = h_data - h_data_mean + h_data_centered = mh_data - h_data_mean v_gen_centered = v_chain - v_data_mean h_gen_centered = h_chain - h_data_mean @@ -106,11 +103,11 @@ def _compute_gradient( ) # No training on biases else: v_data_centered = v_data - h_data_centered = h_data + h_data_centered = mh_data v_gen_centered = v_chain h_gen_centered = h_chain - grad_weight_matrix = ((v_data * w_data).T @ h_data) / w_data_norm - ( + grad_weight_matrix = ((v_data * w_data).T @ mh_data) / w_data_norm - ( (v_chain * chain_weights).T @ h_chain ) @@ -121,22 +118,11 @@ def _compute_gradient( hbias.shape[0], device=hbias.device, dtype=hbias.dtype ) # No training on biases - if lambda_l1 > 0: - grad_weight_matrix -= lambda_l1 * torch.sign(weight_matrix) - grad_vbias -= lambda_l1 * torch.sign(vbias) - grad_hbias -= lambda_l1 * torch.sign(hbias) - - if lambda_l2 > 0: - grad_weight_matrix -= 2 * lambda_l2 * weight_matrix - grad_vbias -= 2 * lambda_l2 * vbias - grad_hbias -= 2 * lambda_l2 * hbias + weight_matrix.grad = grad_weight_matrix + vbias.grad = grad_vbias + hbias.grad = grad_hbias - weight_matrix.grad.set_(grad_weight_matrix) - vbias.grad.set_(grad_vbias) - hbias.grad.set_(grad_hbias) - -@torch.jit.script def _init_chains( num_samples: int, weight_matrix: Tensor, @@ -152,7 +138,12 @@ def _init_chains( raise ValueError(f"Got negative num_samples arg: {num_samples}") if start_v is None: - mv = torch.ones(size=(num_samples, weight_matrix.shape[0]), device=device, dtype=dtype) / 2 + mv = ( + torch.ones( + size=(num_samples, weight_matrix.shape[0]), device=device, dtype=dtype + ) + / 2 + ) v = torch.bernoulli(mv) * 2 - 1 else: mv = torch.zeros_like(start_v, device=device, dtype=dtype) diff --git a/rbms/ising_ising/__init__.py b/rbms/ising_ising/__init__.py index d467f56..39fa13e 100644 --- a/rbms/ising_ising/__init__.py +++ b/rbms/ising_ising/__init__.py @@ -1,3 +1,12 @@ # ruff: noqa from rbms.ising_ising.classes import IIRBM -from rbms.ising_ising.functional import * +from rbms.ising_ising.functional import ( + compute_energy, + compute_energy_hiddens, + compute_energy_visibles, + compute_gradient, + init_chains, + init_parameters, + sample_hiddens, + sample_visibles, +) diff --git a/rbms/ising_ising/classes.py b/rbms/ising_ising/classes.py index 068ceb0..dce007a 100644 --- a/rbms/ising_ising/classes.py +++ b/rbms/ising_ising/classes.py @@ -1,11 +1,11 @@ -from typing import Self +from __future__ import annotations import numpy as np import torch from torch import Tensor from rbms.classes import RBM -from rbms.custom_fn import log2cosh +from rbms.custom_fn import check_keys_dict, log2cosh from rbms.ising_ising.implement import ( _compute_energy, _compute_energy_hiddens, @@ -21,12 +21,14 @@ class IIRBM(RBM): """Parameters of the Ising-Ising RBM""" + visible_type: str = "ising" + def __init__( self, weight_matrix: Tensor, vbias: Tensor, hbias: Tensor, - device: torch.device | None = None, + device: torch.device | str | None = None, dtype: torch.dtype | None = None, ): """Initialize the parameters of the Ising-Ising RBM. @@ -50,6 +52,7 @@ def __init__( self.vbias = vbias.to(device=self.device, dtype=self.dtype) self.hbias = hbias.to(device=self.device, dtype=self.dtype) self.name = "IIRBM" + self.flags = [] def __add__(self, other): return IIRBM( @@ -65,7 +68,9 @@ def __mul__(self, other): hbias=self.hbias * other, ) - def clone(self, device: torch.device | None = None, dtype: torch.dtype | None = None): + def clone( + self, device: torch.device | str | None = None, dtype: torch.dtype | None = None + ): if device is None: device = self.device if dtype is None: @@ -103,7 +108,7 @@ def compute_energy_visibles(self, v: Tensor) -> Tensor: weight_matrix=self.weight_matrix, ) - def compute_gradient(self, data, chains, centered=True, lambda_l1=0.0, lambda_l2=0.0): + def compute_gradient(self, data, chains, centered=True): _compute_gradient( v_data=data["visible"], mh_data=data["hidden_mag"], @@ -115,8 +120,6 @@ def compute_gradient(self, data, chains, centered=True, lambda_l1=0.0, lambda_l2 hbias=self.hbias, weight_matrix=self.weight_matrix, centered=centered, - lambda_l1=lambda_l1, - lambda_l2=lambda_l2, ) def independent_model(self): @@ -162,20 +165,23 @@ def init_parameters(num_hiddens, dataset, device, dtype, var_init=0.0001): def named_parameters(self): return { - "weight_matrix": self.weight_matrix, - "vbias": self.vbias, - "hbias": self.hbias, + "weight_matrix": self.weight_matrix.cpu().numpy(), + "vbias": self.vbias.cpu().numpy(), + "hbias": self.hbias.cpu().numpy(), } + @property def num_hiddens(self): return self.hbias.shape[0] + @property def num_visibles(self): return self.vbias.shape[0] def parameters(self) -> list[Tensor]: return [self.weight_matrix, self.vbias, self.hbias] + @property def ref_log_z(self): return (log2cosh(self.vbias).sum() + log2cosh(self.hbias).sum()).item() @@ -198,17 +204,23 @@ def sample_visibles(self, chains: dict[str, Tensor], beta=1) -> dict[str, Tensor return chains @staticmethod - def set_named_parameters(named_params: dict[str, Tensor]) -> Self: + def set_named_parameters( + named_params: dict[str, np.ndarray], + device: torch.device | str, + dtype: torch.dtype, + ) -> IIRBM: names = ["vbias", "hbias", "weight_matrix"] - for k in names: - if k not in named_params.keys(): - raise ValueError( - f"""Dictionary params missing key '{k}'\n Provided keys : {named_params.keys()}\n Expected keys: {names}""" - ) + check_keys_dict(d=named_params, names=names) params = IIRBM( - weight_matrix=named_params.pop("weight_matrix"), - vbias=named_params.pop("vbias"), - hbias=named_params.pop("hbias"), + weight_matrix=torch.from_numpy(named_params.pop("weight_matrix")).to( + device=device, dtype=dtype + ), + vbias=torch.from_numpy(named_params.pop("vbias")).to( + device=device, dtype=dtype + ), + hbias=torch.from_numpy(named_params.pop("hbias")).to( + device=device, dtype=dtype + ), ) if len(named_params.keys()) > 0: raise ValueError( @@ -216,7 +228,9 @@ def set_named_parameters(named_params: dict[str, Tensor]) -> Self: ) return params - def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None): + def to( + self, device: torch.device | str | None = None, dtype: torch.dtype | None = None + ): if device is not None: self.device = device if dtype is not None: @@ -225,3 +239,12 @@ def to(self, device: torch.device | None = None, dtype: torch.dtype | None = Non self.vbias = self.vbias.to(device=self.device, dtype=self.dtype) self.hbias = self.hbias.to(device=self.device, dtype=self.dtype) return self + + def get_metrics(self, metrics): + return metrics + + def post_grad_update(self): + pass + + def pre_grad_update(self): + pass diff --git a/rbms/ising_ising/functional.py b/rbms/ising_ising/functional.py index 68a4794..886952c 100644 --- a/rbms/ising_ising/functional.py +++ b/rbms/ising_ising/functional.py @@ -116,8 +116,6 @@ def compute_gradient( chains: dict[str, Tensor], params: IIRBM, centered: bool = True, - lambda_l1: float = 0.0, - lambda_l2: float = 0.0, ) -> None: """Compute the gradient for each of the parameters and attach it. @@ -140,8 +138,6 @@ def compute_gradient( hbias=params.hbias, weight_matrix=params.weight_matrix, centered=centered, - lambda_l1=lambda_l1, - lambda_l2=lambda_l2, ) diff --git a/rbms/ising_ising/implement.py b/rbms/ising_ising/implement.py index 9e0c5e4..72bfcc4 100644 --- a/rbms/ising_ising/implement.py +++ b/rbms/ising_ising/implement.py @@ -1,11 +1,9 @@ import torch from torch import Tensor -from torch.nn.functional import softmax from rbms.custom_fn import log2cosh -@torch.jit.script def _sample_hiddens( v: Tensor, weight_matrix: Tensor, hbias: Tensor, beta: float = 1.0 ) -> tuple[Tensor, Tensor]: @@ -15,7 +13,6 @@ def _sample_hiddens( return h, mh -@torch.jit.script def _sample_visibles( h: Tensor, weight_matrix: Tensor, vbias: Tensor, beta: float = 1.0 ) -> tuple[Tensor, Tensor]: @@ -25,7 +22,6 @@ def _sample_visibles( return v, mv -@torch.jit.script def _compute_energy( v: Tensor, h: Tensor, @@ -43,7 +39,6 @@ def _compute_energy( return -fields - interaction -@torch.jit.script def _compute_energy_visibles( v: Tensor, vbias: Tensor, hbias: Tensor, weight_matrix: Tensor ) -> Tensor: @@ -53,7 +48,6 @@ def _compute_energy_visibles( return -field - log_term.sum(1) -@torch.jit.script def _compute_energy_hiddens( h: Tensor, vbias: Tensor, hbias: Tensor, weight_matrix: Tensor ) -> Tensor: @@ -63,7 +57,6 @@ def _compute_energy_hiddens( return -field - log_term.sum(1) -@torch.jit.script def _compute_gradient( v_data: Tensor, mh_data: Tensor, @@ -75,13 +68,11 @@ def _compute_gradient( hbias: Tensor, weight_matrix: Tensor, centered: bool = True, - lambda_l1: float = 0.0, - lambda_l2: float = 0.0, ) -> None: w_data = w_data.view(-1, 1) w_chain = w_chain.view(-1, 1) # Turn the weights of the chains into normalized weights - chain_weights = softmax(-w_chain, dim=0) + chain_weights = w_chain / w_chain.sum() w_data_norm = w_data.sum() # Averages over data and generated samples @@ -106,11 +97,6 @@ def _compute_gradient( grad_vbias = v_data_mean - v_gen_mean - (grad_weight_matrix @ h_data_mean) grad_hbias = h_data_mean - h_gen_mean - (v_data_mean @ grad_weight_matrix) else: - v_data_centered = v_data - h_data_centered = mh_data - v_gen_centered = v_chain - h_gen_centered = h_chain - # Gradient grad_weight_matrix = ((v_data * w_data).T @ mh_data) / w_data_norm - ( (v_chain * chain_weights).T @ h_chain @@ -118,24 +104,12 @@ def _compute_gradient( grad_vbias = v_data_mean - v_gen_mean grad_hbias = h_data_mean - h_gen_mean - if lambda_l1 > 0: - grad_weight_matrix -= lambda_l1 * torch.sign(weight_matrix) - grad_vbias -= lambda_l1 * torch.sign(vbias) - grad_hbias -= lambda_l1 * torch.sign(hbias) - - if lambda_l2 > 0: - grad_weight_matrix -= 2 * lambda_l2 * weight_matrix - grad_vbias -= 2 * lambda_l2 * vbias - grad_hbias -= 2 * lambda_l2 * hbias - # Attach to the parameters - - weight_matrix.grad.set_(grad_weight_matrix) - vbias.grad.set_(grad_vbias) - hbias.grad.set_(grad_hbias) + weight_matrix.grad = grad_weight_matrix + vbias.grad = grad_vbias + hbias.grad = grad_hbias -@torch.jit.script def _init_chains( num_samples: int, weight_matrix: Tensor, diff --git a/rbms/map_model.py b/rbms/map_model.py index c2a236b..f339dfd 100644 --- a/rbms/map_model.py +++ b/rbms/map_model.py @@ -1,7 +1,14 @@ from rbms.bernoulli_bernoulli.classes import BBRBM +from rbms.bernoulli_gaussian.classes import BGRBM from rbms.classes import EBM -from rbms.potts_bernoulli.classes import PBRBM from rbms.ising_gaussian.classes import IGRBM -from rbms.bernoulli_gaussian.classes import BGRBM +from rbms.ising_ising.classes import IIRBM +from rbms.potts_bernoulli.classes import PBRBM -map_model: dict[str, EBM] = {"BBRBM": BBRBM, "PBRBM": PBRBM, "BGRBM": BGRBM, "IGRBM": IGRBM} +map_model: dict[str, type[EBM]] = { + "BBRBM": BBRBM, + "PBRBM": PBRBM, + "BGRBM": BGRBM, + "IGRBM": IGRBM, + "IIRBM": IIRBM, +} diff --git a/rbms/optim.py b/rbms/optim.py new file mode 100644 index 0000000..fc39f34 --- /dev/null +++ b/rbms/optim.py @@ -0,0 +1,109 @@ +import numpy as np +import torch + +# from ptt.optim.cossim import SGD_cossim +from torch import Tensor +from torch.optim import SGD, Optimizer + +from rbms.classes import EBM + + +class SGD_cossim(SGD): + def __init__( + self, + params, + lr=0.001, + max_lr=0.001, + momentum=0, + dampening=0, + weight_decay=0, + nesterov=False, + *, + maximize=True, + foreach=None, + differentiable=False, + fused=None, + ): + super().__init__( + params, + lr, + momentum, + dampening, + weight_decay, + nesterov, + maximize=maximize, + foreach=foreach, + differentiable=differentiable, + fused=fused, + ) + self.prev_grad = torch.concatenate([p.grad.flatten() for p in params]).flatten() + self.max_lr = max_lr + + def step(self, closure=None): + for group in self.param_groups: + params = group["params"] + learning_rate = group["lr"] + curr_grad = torch.concatenate([p.grad.flatten() for p in params]).flatten() + cosine_similarity = curr_grad @ self.prev_grad + if cosine_similarity > 1e-6: + learning_rate *= 1.002 + elif cosine_similarity < -1e-6: + learning_rate *= 0.998 + group["lr"] = min(self.max_lr, learning_rate) + self.prev_grad = curr_grad.clone() + return super().step(closure) + + +def setup_optim(optim: str, args: dict, params: EBM) -> list[Optimizer]: + match args["optim"]: + case "sgd": + optim_class = SGD + case "cossim": + optim_class = SGD_cossim + case _: + print(f"Unrecognized optimizer {args['optim']}, falling back to SGD.") + optim_class = SGD + learning_rate = args["learning_rate"] + max_lr = args["max_lr"] + if args["scale_lr"]: + learning_rate /= np.sqrt(params.effective_number_variables) + max_lr /= np.sqrt(params.effective_number_variables) + + if args["mult_optim"]: + if not isinstance(learning_rate, Tensor): + learning_rate = torch.tensor([learning_rate] * len(params.parameters())) + optimizer = [ + optim_class( + [p], + lr=learning_rate[i], + maximize=True, + ) + for i, p in enumerate(params.parameters()) + ] + else: + if not isinstance(learning_rate, Tensor): + learning_rate = torch.tensor([learning_rate]) + optimizer = [ + optim_class( + params.parameters(), + lr=learning_rate[0], + maximize=True, + ) + ] + for opt in optimizer: + if isinstance(opt, SGD_cossim): + opt.max_lr = max_lr + + if args["optim"] == "nag": + optimizer = [ + SGD( + opt.param_groups[0]["params"], + lr=opt.param_groups[0]["lr"], + maximize=True, + momentum=0.9, + nesterov=True, + ) + for opt in optimizer + ] + + return optimizer diff --git a/rbms/parser.py b/rbms/parser.py index 320df5d..7354ef9 100644 --- a/rbms/parser.py +++ b/rbms/parser.py @@ -48,18 +48,7 @@ def add_args_saves(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: default=50, help="(Defaults to 50). Number of models to save during the training.", ) - save_args.add_argument( - "--acc_ptt", - type=float, - default=None, - help="(Defaults to 0.25). Minimum PTT acceptance to save configurations for ptt file.", - ) - save_args.add_argument( - "--acc_ll", - type=float, - default=None, - help="(Defaults to 0.7). Minimum PTT acceptance to save configurations for ll file.", - ) + save_args.add_argument( "--spacing", type=str, @@ -79,12 +68,7 @@ def add_args_saves(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: return parser -def add_args_rbm(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: - """Add an argument group to the parser for the general hyperparameters of a RBM - - Args: - parser (argparse.ArgumentParser): argparse.ArgumentParser: - """ +def add_args_init_rbm(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: rbm_args = parser.add_argument_group("RBM") rbm_args.add_argument( "--num_hiddens", @@ -93,60 +77,131 @@ def add_args_rbm(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: help="(Defaults to 100). Number of hidden units.", ) rbm_args.add_argument( - "--batch_size", + "--num_chains", type=int, default=None, - help="(Defaults to 2000). Minibatch size.", + help="(Defaults to 2000). Number of parallel chains.", ) rbm_args.add_argument( + "--model_type", + type=str, + default=None, + help="(Defaults to None). Model to use. If None is provided, will be a RBM with the same visible type as the dataset and binary hiddens. If restore, this argument is ignored.", + ) + return parser + + +def add_sampling_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + sampling_args = parser.add_argument_group("Sampling") + sampling_args.add_argument( "--gibbs_steps", type=int, default=None, help="(Defaults to 100). Number of gibbs steps to perform for each gradient update.", ) - rbm_args.add_argument( - "--learning_rate", + sampling_args.add_argument( + "--beta", + default=None, type=float, + help="(Defaults to 1.0). The inverse temperature of the RBM", + ) + return parser + + +def add_grad_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + grad_args = parser.add_argument_group("Gradient") + grad_args.add_argument( + "--L1", default=None, - help="(Defaults to 0.01). Learning rate.", + type=float, + help="(Defaults to 0.0). Lambda parameter for the L1 regularization.", ) - rbm_args.add_argument( - "--num_chains", + grad_args.add_argument( + "--L2", + default=None, + type=float, + help="(Defaults to 0.0). Lambda parameter for the L2 regularization.", + ) + grad_args.add_argument( + "--no_center", + default=False, + action="store_true", + help="(Defaults to False). Use the non-centered gradient.", + ) + grad_args.add_argument( + "--max_norm_grad", + default=None, + type=float, + help="(Defaults to None). Maximum norm of the gradient before update.", + ) + grad_args.add_argument( + "--normalize_grad", + default=False, + action="store_true", + help="(Defaults to False). Normalize the gradient before update.", + ) + return parser + + +def add_args_train(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + train_args = parser.add_argument_group("Train") + train_args.add_argument( + "--batch_size", type=int, default=None, - help="(Defaults to 2000). Number of parallel chains.", + help="(Defaults to 2000). Minibatch size.", ) - rbm_args.add_argument( + train_args.add_argument( + "--learning_rate", + type=float, + default=None, + help="(Defaults to 0.01). Learning rate.", + ) + train_args.add_argument( "--num_updates", default=None, type=int, help="(Defaults to 10 000). Number of gradient updates to perform.", ) - rbm_args.add_argument( - "--beta", + train_args.add_argument( + "--optim", default=None, type=str, help="(Defaults to sgd). Optimizer to use." + ) + train_args.add_argument( + "--mult_optim", + action="store_true", + default=False, + help="(Defaults to False). Use a different optimizer for each param group.", + ) + train_args.add_argument( + "--training_type", + type=str, default=None, + help="(Defaults to 'pcd'). Type of the training, should be one of {'pcd', 'cd', 'rdm'}.", + ) + train_args.add_argument( + "--max_lr", type=float, - help="(Defaults to 1.0). The inverse temperature of the RBM", + default=None, + help="(Defaults to 10). Maximum learning rate when adaptative learning rate is used.", ) - rbm_args.add_argument( - "--restore", - default=False, + train_args.add_argument( + "--scale_lr", action="store_true", - help="(Defaults to False). Restore the training", + default=False, + help="Set it to scale learning rate with the number of variables of the system", ) - rbm_args.add_argument( - "--no_center", + train_args.add_argument( + "--restore", default=False, action="store_true", - help="(Defaults to False). Use the non-centered gradient.", + help="(Defaults to False). Restore the training", ) - rbm_args.add_argument( - "--training_type", - type=str, - default = "pcd", - help="(Defaults to 'pcd'). Type of the training, should be one of {'pcd', 'cd', 'rdm'}." + train_args.add_argument( + "--update", + default=None, + type=int, + help="(Defaults to None). Update to restore from, if None the last is selected.", ) - rbm_args.add_argument("--model_type", type=str, default=None, help="(Defaults to None). Model to use. If None is provided, will be a RBM with the same visible type as the dataset and binary hiddens. If restore, this argument is ignored.") return parser @@ -219,6 +274,10 @@ def match_args_dtype(args: dict[str, Any]) -> dict[str, Any]: "no_center": False, "L1": 0.0, "L2": 0.0, + "max_norm_grad": -1, + "optim": "sgd", + "max_lr": 10, + "training_type": "pcd", } diff --git a/rbms/partition_function/ais.py b/rbms/partition_function/ais.py index 89cf494..d75df8e 100644 --- a/rbms/partition_function/ais.py +++ b/rbms/partition_function/ais.py @@ -59,7 +59,7 @@ def compute_partition_function_ais(num_chains: int, num_beta: int, params: EBM) # Compute the reference log partition function ## Here the case where all the weights are 0 - log_z_init = params.ref_log_z() + log_z_init = params.ref_log_z params_ref = params.independent_model() chains = params_ref.init_chains(num_samples=num_chains) diff --git a/rbms/partition_function/exact.py b/rbms/partition_function/exact.py index af383cf..3329276 100644 --- a/rbms/partition_function/exact.py +++ b/rbms/partition_function/exact.py @@ -15,7 +15,7 @@ def compute_partition_function_rbm(params: RBM, all_config: Tensor) -> float: float: Exact log partition function. """ n_dim_config = all_config.shape[1] - n_visible, n_hidden = params.num_visibles(), params.num_hiddens() + n_visible, n_hidden = params.num_visibles, params.num_hiddens if n_dim_config == n_hidden: energy = params.compute_energy_hiddens(h=all_config) elif n_dim_config == n_visible: @@ -30,7 +30,7 @@ def compute_partition_function_rbm(params: RBM, all_config: Tensor) -> float: def compute_partition_function(params: EBM, all_config: Tensor) -> float: if isinstance(params, RBM): return compute_partition_function_rbm(params=params, all_config=all_config) - n_visible = params.num_visibles() + n_visible = params.num_visibles n_dim_config = all_config.shape[1] if n_dim_config == n_visible: energy = params.compute_energy_visibles(v=all_config) diff --git a/rbms/plot.py b/rbms/plot.py index e7514c3..9dfc49b 100644 --- a/rbms/plot.py +++ b/rbms/plot.py @@ -151,6 +151,7 @@ def plot_one_PCA( labels: list[str] | None = None, dir1: int = 0, dir2: int = 1, + log_scale: bool = False, ): label_1 = None label_2 = None @@ -181,6 +182,7 @@ def plot_one_PCA( s=size_scat, zorder=0, alpha=0.3, + rasterized=True, ) _, bins_x, _ = ax_hist_x.hist( data1[:, dir1], @@ -214,6 +216,7 @@ def plot_one_PCA( marker="o", alpha=1, linewidth=0.4, + rasterized=True, ) ax_hist_x.hist( data2[:, dir1], @@ -236,6 +239,9 @@ def plot_one_PCA( orientation="horizontal", lw=1, ) + if log_scale: + ax_hist_x.semilogy() + ax_hist_y.semilogx() if labels is not None: ax_hist_x.legend(fontsize=12, bbox_to_anchor=(1, 1)) @@ -245,6 +251,8 @@ def plot_mult_PCA( data2: np.ndarray | None = None, labels: list[str] | None = None, n_dir: int = 2, + figsize_factor=4, + log_scale: bool = False, ): if data2 is not None: if data2.shape[1] < data1.shape[1]: @@ -267,7 +275,9 @@ def plot_mult_PCA( else ((data1.shape[1] // 2) // max_cols) + 1 ) - fig, ax = plt.subplots(n_rows, n_cols, figsize=(4 * n_cols, 4 * n_rows)) + fig, ax = plt.subplots( + n_rows, n_cols, figsize=(figsize_factor * n_cols, figsize_factor * n_rows) + ) for i in range(n_rows): for j in range(n_cols): @@ -281,6 +291,7 @@ def plot_mult_PCA( labels=labels if curr_plot_idx == 0 else None, dir1=curr_plot_idx * 2, dir2=curr_plot_idx * 2 + 1, + log_scale=log_scale, ) else: ax[*indexes].set_axis_off() diff --git a/rbms/potts_bernoulli/__init__.py b/rbms/potts_bernoulli/__init__.py index e69de29..8aa8542 100644 --- a/rbms/potts_bernoulli/__init__.py +++ b/rbms/potts_bernoulli/__init__.py @@ -0,0 +1,12 @@ +# ruff: noqa +from rbms.potts_bernoulli.classes import PBRBM +from rbms.potts_bernoulli.functional import ( + compute_energy, + compute_energy_hiddens, + compute_energy_visibles, + compute_gradient, + init_chains, + init_parameters, + sample_hiddens, + sample_visibles, +) diff --git a/rbms/potts_bernoulli/classes.py b/rbms/potts_bernoulli/classes.py index 2691786..0e506f3 100644 --- a/rbms/potts_bernoulli/classes.py +++ b/rbms/potts_bernoulli/classes.py @@ -1,8 +1,13 @@ +from __future__ import annotations + +from typing import override + import numpy as np import torch from torch import Tensor from rbms.classes import RBM +from rbms.custom_fn import check_keys_dict from rbms.potts_bernoulli.implement import ( _compute_energy, _compute_energy_hiddens, @@ -12,18 +17,21 @@ _init_parameters, _sample_hiddens, _sample_visibles, + _zero_sum_gauge, ) class PBRBM(RBM): """Parameters of the Potts-Bernoulli RBM""" + visible_type: str = "categorical" + def __init__( self, weight_matrix: Tensor, vbias: Tensor, hbias: Tensor, - device: torch.device | None = None, + device: torch.device | str | None = None, dtype: torch.dtype | None = None, ): """Initialize the parameters of the Potts-Bernoulli RBM. @@ -32,7 +40,7 @@ def __init__( weight_matrix (Tensor): The weight matrix of the RBM. vbias (Tensor): The visible bias of the RBM. hbias (Tensor): The hidden bias of the RBM. - device (Optional[torch.device], optional): The device for the parameters. + device (torch.device | str | None, optional): The device for the parameters. Defaults to the device of `weight_matrix`. dtype (Optional[torch.dtype], optional): The data type for the parameters. Defaults to the data type of `weight_matrix`. @@ -47,6 +55,7 @@ def __init__( self.vbias = vbias.to(device=self.device, dtype=self.dtype) self.hbias = hbias.to(device=self.device, dtype=self.dtype) self.name = "PBRBM" + self.flags = [] def __add__(self, other): return PBRBM( @@ -63,7 +72,9 @@ def __mul__(self, other): ) @torch.jit.export - def clone(self, device: torch.device | None = None, dtype: torch.dtype | None = None): + def clone( + self, device: torch.device | str | None = None, dtype: torch.dtype | None = None + ): if device is None: device = self.device if dtype is None: @@ -101,7 +112,7 @@ def compute_energy_visibles(self, v): weight_matrix=self.weight_matrix, ) - def compute_gradient(self, data, chains, centered=True, lambda_l1=0.0, lambda_l2=0.0): + def compute_gradient(self, data, chains, centered=True): _compute_gradient( v_data=data["visible"], mh_data=data["hidden_mag"], @@ -113,8 +124,6 @@ def compute_gradient(self, data, chains, centered=True, lambda_l1=0.0, lambda_l2 hbias=self.hbias, weight_matrix=self.weight_matrix, centered=centered, - lambda_l1=lambda_l1, - lambda_l2=lambda_l2, ) def independent_model(self): @@ -157,32 +166,36 @@ def init_parameters(num_hiddens, dataset, device, dtype, var_init=0.0001): var_init=var_init, ) params = PBRBM(weight_matrix=weight_matrix, vbias=vbias, hbias=hbias) + params.set_zero_sum_gauge() return params def named_parameters(self): return { - "weight_matrix": self.weight_matrix, - "vbias": self.vbias, - "hbias": self.hbias, + "weight_matrix": self.weight_matrix.cpu().numpy(), + "vbias": self.vbias.cpu().numpy(), + "hbias": self.hbias.cpu().numpy(), } - def num_hiddens(self): + @property + def num_hiddens(self) -> int: return self.hbias.shape[0] + @property def num_states(self) -> int: """Number of colors for the Potts variables""" return self.weight_matrix.shape[1] - def num_visibles(self): + @property + def num_visibles(self) -> int: return self.vbias.shape[0] def parameters(self) -> list[Tensor]: return [self.weight_matrix, self.vbias, self.hbias] + @property def ref_log_z(self): return ( - self.num_hiddens() * np.log(2) - + self.num_visibles() * np.log(self.num_states()) + self.num_hiddens * np.log(2) + self.num_visibles * np.log(self.num_states) ).item() def sample_hiddens(self, chains, beta=1): @@ -198,17 +211,23 @@ def sample_visibles(self, chains, beta=1): return chains @staticmethod - def set_named_parameters(named_params): + def set_named_parameters( + named_params: dict[str, np.ndarray], + device: torch.device | str, + dtype: torch.dtype, + ) -> PBRBM: names = ["vbias", "hbias", "weight_matrix"] - for k in names: - if k not in named_params.keys(): - raise ValueError( - f"""Dictionary params missing key '{k}'\n Provided keys : {named_params.keys()}\n Expected keys: {names}""" - ) + check_keys_dict(d=named_params, names=names) params = PBRBM( - weight_matrix=named_params.pop("weight_matrix"), - vbias=named_params.pop("vbias"), - hbias=named_params.pop("hbias"), + weight_matrix=torch.from_numpy(named_params.pop("weight_matrix")).to( + device=device, dtype=dtype + ), + vbias=torch.from_numpy(named_params.pop("vbias")).to( + device=device, dtype=dtype + ), + hbias=torch.from_numpy(named_params.pop("hbias")).to( + device=device, dtype=dtype + ), ) if len(named_params.keys()) > 0: raise ValueError( @@ -216,7 +235,9 @@ def set_named_parameters(named_params): ) return params - def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None): + def to( + self, device: torch.device | str | None = None, dtype: torch.dtype | None = None + ): if device is not None: self.device = device if dtype is not None: @@ -225,3 +246,29 @@ def to(self, device: torch.device | None = None, dtype: torch.dtype | None = Non self.vbias = self.vbias.to(device=self.device, dtype=self.dtype) self.hbias = self.hbias.to(device=self.device, dtype=self.dtype) return self + + @override + @torch.compile + def normalize_grad(self) -> None: + norm_factor = torch.sqrt( + self.weight_matrix.square().sum() + + self.vbias.square().sum() + + self.hbias.square().sum() + ) + self.weight_matrix.grad /= norm_factor + self.vbias.grad /= norm_factor + self.hbias.grad /= norm_factor + + def set_zero_sum_gauge(self): + self.vbias, self.hbias, self.weight_matrix = _zero_sum_gauge( + vbias=self.vbias, hbias=self.hbias, weight_matrix=self.weight_matrix + ) + + def get_metrics(self, metrics): + return metrics + + def post_grad_update(self): + self.set_zero_sum_gauge() + + def pre_grad_update(self): + pass diff --git a/rbms/potts_bernoulli/functional.py b/rbms/potts_bernoulli/functional.py index 341bea3..4f31f20 100644 --- a/rbms/potts_bernoulli/functional.py +++ b/rbms/potts_bernoulli/functional.py @@ -13,6 +13,7 @@ _init_parameters, _sample_hiddens, _sample_visibles, + _zero_sum_gauge, ) @@ -116,8 +117,6 @@ def compute_gradient( chains: dict[str, Tensor], params: PBRBM, centered: bool = True, - lambda_l1: float = 0.0, - lambda_l2: float = 0.0, ) -> None: """Compute the gradient for each of the parameters and attach it. @@ -140,8 +139,6 @@ def compute_gradient( hbias=params.hbias, weight_matrix=params.weight_matrix, centered=centered, - lambda_l1=lambda_l1, - lambda_l2=lambda_l2, ) @@ -217,3 +214,14 @@ def init_parameters( var_init=var_init, ) return PBRBM(weight_matrix=weight_matrix, vbias=vbias, hbias=hbias) + + +def ensure_zero_sum_gauge(params: PBRBM) -> None: + """Ensure the weight matrix has a zero-sum gauge. + + Args: + params (PBRBM): The parameters of the RBM. + """ + params.vbias, params.hbias, params.weight_matrix = _zero_sum_gauge( + vbias=params.vbias, hbias=params.hbias, weight_matrix=params.weight_matrix + ) diff --git a/rbms/potts_bernoulli/implement.py b/rbms/potts_bernoulli/implement.py index 6b641aa..b4be369 100644 --- a/rbms/potts_bernoulli/implement.py +++ b/rbms/potts_bernoulli/implement.py @@ -5,7 +5,6 @@ from rbms.custom_fn import one_hot -@torch.jit.script def _sample_hiddens( v: Tensor, weight_matrix: Tensor, hbias: Tensor, beta: float = 1.0 ) -> tuple[Tensor, Tensor]: @@ -20,7 +19,6 @@ def _sample_hiddens( return h, mh -@torch.jit.script def _sample_visibles( h: Tensor, weight_matrix: Tensor, vbias: Tensor, beta: float = 1.0 ) -> tuple[Tensor, Tensor]: @@ -37,7 +35,6 @@ def _sample_visibles( return v, mv -@torch.jit.script def _compute_energy( v: Tensor, h: Tensor, vbias: Tensor, hbias: Tensor, weight_matrix: Tensor ): @@ -53,7 +50,6 @@ def _compute_energy( return -fields - interaction -@torch.jit.script def _compute_energy_visibles( v: Tensor, vbias: Tensor, hbias: Tensor, weight_matrix: Tensor ): @@ -71,7 +67,6 @@ def _compute_energy_visibles( return -field - log_term.sum(1) -@torch.jit.script def _compute_energy_hiddens( h: Tensor, vbias: Tensor, hbias: Tensor, weight_matrix: Tensor ): @@ -81,7 +76,6 @@ def _compute_energy_hiddens( return -field - lse -@torch.jit.script def _compute_gradient( v_data: Tensor, mh_data: Tensor, @@ -93,8 +87,6 @@ def _compute_gradient( hbias: Tensor, weight_matrix: Tensor, centered: bool = True, - lambda_l1: float = 0.0, - lambda_l2: float = 0.0, ): w_data = w_data.view(-1, 1, 1) w_chain = w_chain.view(-1, 1, 1) @@ -150,22 +142,17 @@ def _compute_gradient( - torch.tensordot(v_data_mean, grad_weight_matrix, dims=[[0, 1], [0, 1]]) ) else: - v_data_centered = v_data_one_hot - h_data_centered = mh_data - v_gen_centered = v_gen_one_hot - h_gen_centered = h_chain - # Gradient grad_weight_matrix = ( torch.tensordot( - v_data_centered, - h_data_centered, + v_data_one_hot, + mh_data, dims=[[0], [0]], ) / v_data.shape[0] - torch.tensordot( - v_gen_centered, - h_gen_centered, + v_gen_one_hot, + h_chain, dims=[[0], [0]], ) / v_chain.shape[0] @@ -174,18 +161,9 @@ def _compute_gradient( grad_vbias = v_data_mean - v_gen_mean grad_hbias = h_data_mean - h_gen_mean - if lambda_l1 > 0: - grad_weight_matrix -= lambda_l1 * torch.sign(weight_matrix) - grad_vbias -= lambda_l1 * torch.sign(vbias) - grad_hbias -= lambda_l1 * torch.sign(hbias) - - if lambda_l2 > 0: - grad_weight_matrix -= 2 * lambda_l2 * weight_matrix - grad_vbias -= 2 * lambda_l2 * vbias - grad_hbias -= 2 * lambda_l2 * hbias - weight_matrix.grad.set_(grad_weight_matrix) - vbias.grad.set_(grad_vbias) - hbias.grad.set_(grad_hbias) + weight_matrix.grad = grad_weight_matrix + vbias.grad = grad_vbias + hbias.grad = grad_hbias def _init_chains( @@ -240,5 +218,12 @@ def _init_parameters( ) * var_init ) - # print(torch.svd(weight_matrix.reshape(-1, weight_matrix.shape[-1])).S) + return vbias, hbias, weight_matrix + + +def _zero_sum_gauge(vbias: Tensor, hbias: Tensor, weight_matrix: Tensor): + mean_W = weight_matrix.mean(1, keepdim=True) + weight_matrix -= mean_W + hbias += mean_W.squeeze().sum(0) + vbias -= vbias.mean(1, keepdim=True) return vbias, hbias, weight_matrix diff --git a/rbms/potts_bernoulli/tools.py b/rbms/potts_bernoulli/tools.py deleted file mode 100644 index 482da43..0000000 --- a/rbms/potts_bernoulli/tools.py +++ /dev/null @@ -1,53 +0,0 @@ -import numpy as np -import torch -from torch import Tensor - -from rbms.custom_fn import one_hot - - -def get_covariance_matrix( - data: Tensor, - weights: Tensor | None = None, - num_extract: int | None = None, - center: bool = True, - device: torch.device = torch.device("cpu"), - dtype: torch.dtype = torch.float32, -) -> Tensor: - """Returns the covariance matrix of the data. If weights is specified, the weighted covariance matrix is computed. - - Args: - data (Tensor): Data. - weights (Tensor, optional): Weights of the data. Defaults to None. - num_extract (int, optional): Number of data to extract to compute the covariance matrix. Defaults to None. - center (bool): Center the data. Defaults to True. - device (torch.device): Device. Defaults to 'cpu'. - dtype (torch.dtype): DType. Defaults to torch.float32. - - Returns: - Tensor: Covariance matrix of the dataset. - """ - num_data = len(data) - num_classes = int(data.max().item() + 1) - - if weights is None: - weights = torch.ones(num_data) - weights = weights.to(device=device, dtype=torch.float32) - - if num_extract is not None: - idxs = np.random.choice(a=np.arange(num_data), size=(num_extract,), replace=False) - data = data[idxs] - weights = weights[idxs] - num_data = num_extract - - if num_classes != 2: - data = data.to(device=device, dtype=torch.int32) - data_oh = one_hot(data, num_classes=num_classes).reshape(num_data, -1) - else: - data_oh = data.to(device=device, dtype=torch.float32) - - norm_weights = weights.reshape(-1, 1) / weights.sum() - data_mean = (data_oh * norm_weights).sum(0, keepdim=True) - cov_matrix = ((data_oh * norm_weights).mT @ data_oh) - int(center) * ( - data_mean.mT @ data_mean - ) - return cov_matrix diff --git a/rbms/potts_bernoulli/utils.py b/rbms/potts_bernoulli/utils.py deleted file mode 100644 index 73d286c..0000000 --- a/rbms/potts_bernoulli/utils.py +++ /dev/null @@ -1,13 +0,0 @@ -from rbms.potts_bernoulli.classes import PBRBM - - -def ensure_zero_sum_gauge(params: PBRBM) -> None: - """Ensure the weight matrix has a zero-sum gauge. - - Args: - params (PBRBM): The parameters of the RBM. - """ - mean_W = params.weight_matrix.mean(1, keepdim=True) - params.weight_matrix -= mean_W - params.hbias += mean_W.squeeze().sum(0) - params.vbias -= params.vbias.mean(1, keepdim=True) diff --git a/rbms/pre_grad.py b/rbms/pre_grad.py new file mode 100644 index 0000000..ac8c89c --- /dev/null +++ b/rbms/pre_grad.py @@ -0,0 +1,74 @@ +import torch +from torch.optim import Optimizer + + +class L1Regularization(torch.nn.Module): + def __init__(self, optimizer: list[Optimizer], lambda_l1: float, *args, **kwargs): + super().__init__(*args, **kwargs) + self.optimizer = optimizer + self.lambda_l1 = lambda_l1 + + def forward(self, input): + for opt in self.optimizer: + for p in opt.param_groups[0]["params"]: + p.grad -= self.lambda_l1 * torch.sign(p) + + +class L2Regularization(torch.nn.Module): + def __init__(self, optimizer: list[Optimizer], lambda_l2: float, *args, **kwargs): + super().__init__(*args, **kwargs) + self.optimizer = optimizer + self.lambda_l2 = lambda_l2 + + def forward(self, input): + for opt in self.optimizer: + for p in opt.param_groups[0]["params"]: + p.grad -= self.lambda_l2 * p + + +class ClipGradNorm(torch.nn.Module): + def __init__(self, optimizer: list[Optimizer], max_grad_norm, *args, **kwargs): + super().__init__(*args, **kwargs) + self.optimizer = optimizer + self.max_grad_norm = max_grad_norm + + def forward(self, input): + for opt in self.optimizer: + torch.nn.utils.clip_grad_norm_( + opt.param_groups[0]["params"], max_norm=self.max_grad_norm + ) + + +class NormalizeGrad(torch.nn.Module): + def __init__(self, optimizer: list[Optimizer], *args, **kwargs): + super().__init__(*args, **kwargs) + self.optimizer = optimizer + + def forward(self, input): + for opt in self.optimizer: + norm_grad = torch.nn.utils.get_total_norm( + [p.grad for p in opt.param_groups[0]["params"] if p.grad is not None] + ) + for p in opt.param_groups[0]["params"]: + p.grad /= norm_grad + + +def build_pre_grad_update( + optimizer: list[Optimizer], + lambda_l1: float, + lambda_l2: float, + normalize_grad: bool, + max_grad_norm: float, + **kwargs, +): + return torch.compile( + torch.nn.Sequential( + *[L1Regularization(optimizer=optimizer, lambda_l1=lambda_l1)] + * (lambda_l1 > 0), + *[L2Regularization(optimizer=optimizer, lambda_l2=lambda_l2)] + * (lambda_l2 > 0), + *[NormalizeGrad(optimizer=optimizer)] * normalize_grad, + *[ClipGradNorm(optimizer=optimizer, max_grad_norm=max_grad_norm)] + * (max_grad_norm > 0), + ) + ) diff --git a/rbms/sampler/__init__.py b/rbms/sampler/__init__.py new file mode 100644 index 0000000..76f8917 --- /dev/null +++ b/rbms/sampler/__init__.py @@ -0,0 +1,5 @@ +from rbms.sampler.cd import CD +from rbms.sampler.pcd import PCD +from rbms.sampler.rdm import RDM + +__all__ = [CD, PCD, RDM] diff --git a/rbms/sampler/cd.py b/rbms/sampler/cd.py new file mode 100644 index 0000000..5945b01 --- /dev/null +++ b/rbms/sampler/cd.py @@ -0,0 +1,83 @@ +import numpy as np +import torch +from torch import Tensor + +from rbms.classes import EBM, Sampler + + +class CD(Sampler): + def __init__(self, params: EBM, num_steps: int, beta: float = 1, **kwargs): + self.name = "CD" + self.params = params + self.beta = beta + self.num_steps = num_steps + self.chains = self.params.init_chains(2) + self.flags = [] + + def get_conf_grad(self, batch: Tensor) -> dict[str, Tensor]: + self.sample(num_steps=None, batch=batch) + return self.chains + + def sample(self, num_steps: int | None, **kwargs) -> None: + batch = kwargs["batch"] + self.chains = self.params.init_chains(num_samples=batch.shape[0], start_v=batch) + self.chains = self.params.sample_state( + chains=self.chains, n_steps=self.num_steps, beta=self.beta + ) + + @torch.compiler.disable + def named_parameters(self): + params_dict = self.params.named_parameters() + params_dict["model_type"] = np.asarray(self.params.name, dtype="T") + params_dict["sampler_type"] = np.asarray(self.name, dtype="T") + params_dict["beta"] = np.asarray(self.beta) + params_dict["num_steps"] = np.asarray(self.num_steps) + match self.params.visible_type: + case "bernoulli": + chains_save = self.chains["visible"].bool().cpu().numpy() + case "ising" | "categorical": + chains_save = self.chains["visible"].to(torch.int16).cpu().numpy() + case _: + chains_save = self.chains["visible"].cpu().numpy() + params_dict["parallel_chains"] = chains_save + return params_dict + + @staticmethod + def set_named_parameters( + named_params: dict[str, np.ndarray], + map_model: dict[str, type[EBM]], + device: torch.device | str, + dtype: torch.dtype, + ): + names = ["model_type", "beta", "num_steps"] + for k in names: + if k not in named_params.keys(): + raise ValueError( + f"""Dictionary params missing key '{k}'\n Provided keys : {named_params.keys()}\n Expected keys: {names}""" + ) + model_type = str(named_params.pop("model_type")) + beta = float(named_params.pop("beta")) + num_steps = int(named_params.pop("num_steps")) + chains_visible = torch.from_numpy(named_params.pop("parallel_chains")).to( + device=device, dtype=dtype + ) + # There should only remain the keys for the model loading + params = map_model[model_type].set_named_parameters( + named_params=named_params, device=device, dtype=dtype + ) + chains = params.init_chains(chains_visible.shape[0], start_v=chains_visible) + sampler = CD(params=params, num_steps=num_steps, beta=beta) + sampler.chains = chains + return sampler + + def post_grad_update(self, params: EBM): + self.params = params + + def get_metrics_display(self, metrics, **kwargs): + return metrics + + def get_metrics_save(self): + return None + + def pre_grad_update(self): + pass diff --git a/rbms/sampling/gibbs.py b/rbms/sampler/gibbs.py similarity index 100% rename from rbms/sampling/gibbs.py rename to rbms/sampler/gibbs.py diff --git a/rbms/sampler/pcd.py b/rbms/sampler/pcd.py new file mode 100644 index 0000000..7a52eb4 --- /dev/null +++ b/rbms/sampler/pcd.py @@ -0,0 +1,89 @@ +from __future__ import annotations + +import numpy as np +import torch +from torch import Tensor + +from rbms.classes import EBM, Sampler + + +class PCD(Sampler): + def __init__( + self, + params: EBM, + chains: dict[str, Tensor], + num_steps: int, + beta: float = 1, + **kwargs, + ): + self.name = "PCD" + self.chains = chains + self.params = params + self.beta = beta + self.num_steps = num_steps + self.flags = [] + + def get_conf_grad(self, batch: Tensor): + self.sample(num_steps=None) + return self.chains + + def sample(self, num_steps: int | None, **kwargs): + self.chains = self.params.sample_state( + chains=self.chains, n_steps=self.num_steps, beta=self.beta + ) + + @torch.compiler.disable + def named_parameters(self): + params_dict = self.params.named_parameters() + params_dict["model_type"] = np.asarray(self.params.name, dtype="T") + params_dict["sampler_type"] = np.asarray(self.name, dtype="T") + match self.params.visible_type: + case "bernoulli": + chains_save = self.chains["visible"].bool().cpu().numpy() + case "ising" | "categorical": + chains_save = self.chains["visible"].to(torch.int16).cpu().numpy() + case _: + chains_save = self.chains["visible"].cpu().numpy() + params_dict["parallel_chains"] = chains_save + params_dict["beta"] = np.asarray(self.beta) + params_dict["num_steps"] = np.asarray(self.num_steps) + return params_dict + + @staticmethod + def set_named_parameters( + named_params: dict[str, np.ndarray], + map_model: dict[str, type[EBM]], + device: torch.device | str, + dtype: torch.dtype, + ) -> PCD: + names = ["model_type", "chains", "beta", "num_steps"] + for k in names: + if k not in named_params.keys(): + raise ValueError( + f"""Dictionary params missing key '{k}'\n Provided keys : {named_params.keys()}\n Expected keys: {names}""" + ) + model_type = str(named_params.pop("model_type")) + chains_visible = torch.from_numpy(named_params.pop("parallel_chains")).to( + device=device, dtype=dtype + ) + beta = float(named_params.pop("beta")) + num_steps = int(named_params.pop("num_steps")) + # There should only remain the keys for the model loading + params = map_model[model_type].set_named_parameters( + named_params=named_params, device=device, dtype=dtype + ) + chains = params.init_chains(chains_visible.shape[0], start_v=chains_visible) + + return PCD(params=params, chains=chains, num_steps=num_steps, beta=beta) + + def post_grad_update(self, params: EBM): + self.params = params + + def get_metrics_display(self, metrics, **kwargs): + return metrics + + def get_metrics_save(self): + return None + + def pre_grad_update(self): + pass diff --git a/rbms/sampling/pt.py b/rbms/sampler/pt.py similarity index 100% rename from rbms/sampling/pt.py rename to rbms/sampler/pt.py diff --git a/rbms/sampler/rdm.py b/rbms/sampler/rdm.py new file mode 100644 index 0000000..f8dd845 --- /dev/null +++ b/rbms/sampler/rdm.py @@ -0,0 +1,90 @@ +import numpy as np +import torch +from torch import Tensor + +from rbms.classes import EBM, Sampler + + +class RDM(Sampler): + def __init__( + self, params: EBM, num_chains: int, num_steps: int, beta: float = 1, **kwargs + ): + self.name = "RDM" + self.params = params + self.beta = beta + self.num_chains = num_chains + self.num_steps = num_steps + self.chains = self.params.init_chains(num_chains) + self.flags = [] + + def sample(self, num_steps: int | None, **kwargs): + chains = self.params.init_chains(num_samples=self.num_chains) + chains = self.params.sample_state( + chains=chains, n_steps=self.num_steps, beta=self.beta + ) + return chains + + def get_conf_grad(self, batch: Tensor): + self.sample(num_steps=None) + return self.chains + + @torch.compiler.disable + def named_parameters(self): + params_dict = self.params.named_parameters() + params_dict["model_type"] = np.asarray(self.params.name, dtype="T") + params_dict["sampler_type"] = np.asarray(self.name, dtype="T") + params_dict["num_chains"] = np.asarray(self.num_chains) + params_dict["beta"] = np.asarray(self.beta) + params_dict["num_steps"] = np.asarray(self.num_steps) + match self.params.visible_type: + case "bernoulli": + chains_save = self.chains["visible"].bool().cpu().numpy() + case "ising" | "categorical": + chains_save = self.chains["visible"].to(torch.int16).cpu().numpy() + case _: + chains_save = self.chains["visible"].cpu().numpy() + params_dict["parallel_chains"] = chains_save + return params_dict + + @staticmethod + def set_named_parameters( + named_params: dict[str, np.ndarray], + map_model: dict[str, type[EBM]], + device: torch.device | str, + dtype: torch.dtype, + ): + names = ["model_type", "num_chains", "beta", "num_steps"] + for k in names: + if k not in named_params.keys(): + raise ValueError( + f"""Dictionary params missing key '{k}'\n Provided keys : {named_params.keys()}\n Expected keys: {names}""" + ) + model_type = str(named_params.pop("model_type")) + num_chains = int(named_params.pop("num_chains")) + beta = float(named_params.pop("beta")) + num_steps = int(named_params.pop("num_steps")) + chains_visible = torch.from_numpy(named_params.pop("parallel_chains")).to( + device=device, dtype=dtype + ) + # There should only remain the keys for the model loading + params = map_model[model_type].set_named_parameters( + named_params=named_params, device=device, dtype=dtype + ) + chains = params.init_chains(chains_visible.shape[0], start_v=chains_visible) + sampler = RDM( + params=params, num_chains=num_chains, num_steps=num_steps, beta=beta + ) + sampler.chains = chains + return sampler + + def post_grad_update(self, params: EBM): + self.params = params + + def get_metrics_display(self, metrics, **kwargs): + return metrics + + def get_metrics_save(self): + return None + + def pre_grad_update(self): + pass diff --git a/rbms/scripts/entrypoint.py b/rbms/scripts/entrypoint.py index e3fea49..98ec07c 100644 --- a/rbms/scripts/entrypoint.py +++ b/rbms/scripts/entrypoint.py @@ -4,13 +4,14 @@ def main(): - # Get the directory of the current script SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) # Check if the first positional argument is provided if len(sys.argv) < 2: - print("Error: No command provided. Use 'train' or 'pt_sampling'.") + print( + "Error: No command provided. Use 'train', 'restore', 'split' or 'pt_sampling'." + ) sys.exit(1) # Assign the first positional argument to a variable @@ -23,9 +24,13 @@ def main(): case "pt_sampling": SCRIPT = "pt_sampling.py" case "split": - SCRIPT = "split_data.py" + SCRIPT = "split_data.py" + case "restore": + SCRIPT = "restore.py" case _: - print(f"Error: Invalid command '{COMMAND}'. Use 'train', 'split' or 'pt_sampling'.") + print( + f"Error: Invalid command '{COMMAND}'. Use 'train', 'restore', 'split' or 'pt_sampling'." + ) sys.exit(1) # Run the corresponding Python script with the remaining optional arguments diff --git a/rbms/scripts/pt_sampling.py b/rbms/scripts/pt_sampling.py index 180613d..143d802 100644 --- a/rbms/scripts/pt_sampling.py +++ b/rbms/scripts/pt_sampling.py @@ -2,7 +2,7 @@ import h5py -from rbms.classes import RBM +from rbms.classes import RBM, EBM from rbms.io import load_params from rbms.map_model import map_model from rbms.parser import add_args_pytorch, match_args_dtype @@ -11,9 +11,7 @@ def create_parser(): - parser = argparse.ArgumentParser( - "Parallel Tempering sampling on the provided model" - ) + parser = argparse.ArgumentParser("Parallel Tempering sampling on the provided model") parser.add_argument("-i", "--filename", type=str, help="Model to use for sampling") parser.add_argument( "-o", "--out_file", type=str, help="Path to save the samples after generation" @@ -63,7 +61,7 @@ def run_pt( save_index: bool, device, dtype, - map_model: dict[str, RBM] = map_model, + map_model: dict[str, type[EBM]] = map_model, ): check_file_existence(out_file) diff --git a/rbms/scripts/restore.py b/rbms/scripts/restore.py new file mode 100644 index 0000000..bd976d5 --- /dev/null +++ b/rbms/scripts/restore.py @@ -0,0 +1,215 @@ +# import argparse + +# import h5py +# import torch + +# from rbms import get_saved_updates +# from rbms.dataset import load_dataset +# from rbms.map_model import map_model +# from rbms.optim import setup_optim +# from rbms.parser import ( +# add_args_pytorch, +# add_args_saves, +# add_args_train, +# add_grad_args, +# add_sampling_args, +# match_args_dtype, +# remove_argument, +# ) +# from rbms.training.pcd import train +# from rbms.training.utils import get_checkpoints, restore_training + + +# def create_parser_restore(): +# parser = argparse.ArgumentParser( +# description="Restore the training of a Restricted Boltzmann Machine" +# ) +# dataset_args = parser.add_argument_group("Dataset") +# dataset_args.add_argument( +# "-d", +# "--dataset", +# type=str, +# required=True, +# help="Path to a data file (type should be .h5 or .fasta)", +# ) +# dataset_args.add_argument( +# "--test_dataset", +# type=str, +# required=False, +# default=None, +# help="Path to test dataset file (type should be .h5 or .fasta)", +# ) +# parser = add_args_train(parser) +# parser = add_grad_args(parser) +# parser.add_argument( +# "--update", +# default=None, +# type=int, +# help="(Defaults to None). Which update to restore from. If None, the last update is used.", +# ) +# remove_argument(parser, "no_center") +# remove_argument(parser, "normalize_grad") + +# parser = add_sampling_args(parser) +# parser = add_args_saves(parser) +# parser = add_args_pytorch(parser) +# remove_argument(parser, "use_torch") +# return parser + + +# def recover_args( +# args: dict, +# ) -> tuple[ +# dict[str, str], +# dict[str, int | float], +# dict[str, bool | float], +# dict[str, int | float], +# dict[str, str | torch.dtype], +# ]: +# with h5py.File(args["filename"], "r") as f: +# # dataset +# args_dataset = { +# "dataset_name": args["dataset"], +# "test_dataset_name": args["test_dataset"], +# } +# dataset = f["dataset_args"] +# if "subset_labels" in dataset.keys(): +# args_dataset["subset_labels"] = dataset["subset_labels"][()] +# else: +# args_dataset["subset_labels"] = None +# args_dataset["train_size"] = dataset["train_size"][()].item() +# args_dataset["test_size"] = dataset["test_size"][()].item() + +# args_dataset["use_weights"] = dataset["use_weights"][()].item() +# args_dataset["alphabet"] = dataset["alphabet"][()].decode() +# args_dataset["remove_duplicates"] = dataset["remove_duplicates"][()].item() +# args_dataset["seed"] = dataset["seed"][()].item() + +# # grad +# args_grad = {} +# grad = f["grad_args"] +# ## Default args +# args_grad["no_center"] = grad["no_center"][()].item() +# args_grad["normalize_grad"] = grad["normalize_grad"][()].item() +# ## Can be overriden +# args_grad["max_norm_grad"] = args["max_norm_grad"] +# if args_grad["max_norm_grad"] is None: +# args_grad["max_norm_grad"] = grad["max_norm_grad"][()].item() +# args_grad["L1"] = args["L1"] +# if args_grad["L1"] is None: +# args_grad["L1"] = grad["L1"][()].item() +# args_grad["L2"] = args["L2"] +# if args_grad["L2"] is None: +# args_grad["L2"] = grad["L2"][()].item() + +# # sampling +# args_sampling = {} +# sampling = f["sampling_args"] +# args_sampling["gibbs_steps"] = args["gibbs_steps"] +# if args_sampling["gibbs_steps"] is None: +# args_sampling["gibbs_steps"] = sampling["gibbs_steps"][()].item() +# args_sampling["beta"] = args["beta"] +# if args_sampling["beta"] is None: +# args_sampling["beta"] = sampling["beta"][()].item() + +# # train +# args_train = {} +# train_args = f["train_args"] +# args_train["optim"] = args["optim"] +# args_train["num_updates"] = args["num_updates"] +# if args_train["optim"] is None: +# args_train["optim"] = train_args["optim"][()].decode() +# args_train["learning_rate"] = args["learning_rate"] +# if args_train["learning_rate"] is None: +# args_train["learning_rate"] = train_args["learning_rate"][()] +# args_train["batch_size"] = args["batch_size"] +# if args_train["batch_size"] is None: +# args_train["batch_size"] = train_args["batch_size"][()].item() +# args_train["update"] = args["update"] +# if args_train["update"] is None: +# args_train["update"] = get_saved_updates(args["filename"])[-1] +# args_train["mult_optim"] = args["mult_optim"] +# args_train["training_type"] = args["training_type"] +# if args_train["training_type"] is None: +# args_train["training_type"] = train_args["training_type"][()].decode() +# if args_train["max_lr"] is None: +# args_train["max_lr"] = train_args["max_lr"][()].item() +# args_train["scale_lr"] = args["scale_lr"] + +# # Torch +# args_torch = {} +# args_torch["device"] = args["device"] +# args_torch["dtype"] = args["dtype"] + +# # save +# args_save = {} +# args_save["filename"] = args["filename"] +# save = f["save_args"] +# args_save["n_save"] = args["n_save"] +# if args_save["n_save"] is None: +# args_save["n_save"] = save["n_save"][()].item() +# args_save["spacing"] = args["spacing"] +# if args_save["spacing"] is None: +# args_save["spacing"] = save["spacing"][()] +# return (args_dataset, args_save, args_train, args_grad, args_sampling, args_torch) + + +# def main(): +# torch.set_float32_matmul_precision("high") +# torch.backends.cudnn.benchmark = True +# parser = create_parser_restore() +# args = parser.parse_args() +# args = vars(args) +# args = match_args_dtype(args) +# args_dataset, args_save, args_train, args_grad, args_sampling, args_torch = ( +# recover_args(args) +# ) +# checkpoints = get_checkpoints( +# num_updates=args_train["num_updates"], +# n_save=args_save["n_save"], +# spacing=args_save["spacing"], +# ) +# train_dataset, test_dataset = load_dataset( +# dataset_name=args_dataset["dataset_name"], +# test_dataset_name=args_dataset["test_dataset_name"], +# subset_labels=args_dataset["subset_labels"], +# use_weights=args_dataset["use_weights"], +# alphabet=args_dataset["alphabet"], +# remove_duplicates=args_dataset["remove_duplicates"], +# **args_torch, +# ) +# ( +# params, +# parallel_chains, +# target_update, +# elapsed_time, +# train_dataset, +# test_dataset, +# ) = restore_training( +# train_dataset=train_dataset, +# test_dataset=test_dataset, +# args_save=args_save, +# args_train=args_train, +# args_dataset=args_dataset, +# args_torch=args_torch, +# map_model=map_model, +# ) +# optimizer = setup_optim(args_train["optim"], args_train, params) +# train( +# train_dataset=train_dataset, +# test_dataset=test_dataset, +# params=params, +# parallel_chains=parallel_chains, +# optimizer=optimizer, +# curr_update=target_update, +# elapsed_time=elapsed_time, +# checkpoints=checkpoints, +# args_save=args_save, +# args_train=args_train, +# args_grad=args_grad, +# args_sampling=args_sampling, +# ) + + +# if __name__ == "__main__": +# main() diff --git a/rbms/scripts/split_data.py b/rbms/scripts/split_data.py index 589f60e..b49dbda 100644 --- a/rbms/scripts/split_data.py +++ b/rbms/scripts/split_data.py @@ -42,6 +42,12 @@ def create_parser(): default="protein", help="(Defaults to protein). Type of encoding for the sequences. Choose among ['protein', 'rna', 'dna'] or a user-defined string of tokens.", ) + parser.add_argument( + "--remove_duplicates", + action="store_true", + default=False, + help="Remove duplicates from the dataset before splitting.", + ) return parser @@ -50,7 +56,8 @@ def split_data_train_test( output_train_file: str | None = None, output_test_file: str | None = None, train_size=0.6, - seed: int = None, + remove_duplicates: bool = False, + seed: int | None = None, alphabet: str = "protein", ): dset_name = Path(input_file) @@ -58,12 +65,17 @@ def split_data_train_test( dataset, _ = load_dataset(input_file, None, alphabet=alphabet) - print("Removing duplicates...") prev_size = dataset.data.shape[0] - unique_ind = get_unique_indices(dataset.data) - data = dataset.data[unique_ind] - names = dataset.names[unique_ind] - labels = dataset.labels[unique_ind] + if remove_duplicates: + print("Removing duplicates...") + unique_ind = get_unique_indices(dataset.data) + data = dataset.data[unique_ind] + names = dataset.names[unique_ind] + labels = dataset.labels[unique_ind] + else: + data = dataset.data + names = dataset.names + labels = dataset.labels curr_size = data.shape[0] print(f" Dataset size: {prev_size} -> {curr_size} samples") @@ -99,12 +111,12 @@ def split_data_train_test( if output_train_file is None: output_train_file = ( ".".join(str(dset_name).split(".")[:-1]) - + f"_train={train_size}.{file_format}" + + f"_train={train_size:.1f}.{file_format}" ) if output_test_file is None: output_test_file = ( ".".join(str(dset_name).split(".")[:-1]) - + f"_test={1 - train_size}.{file_format}" + + f"_test={1 - train_size:.1f}.{file_format}" ) match file_format: @@ -141,6 +153,7 @@ def main(): output_train_file=args["out_train"], output_test_file=args["out_test"], train_size=args["train_size"], + remove_duplicates=args["remove_duplicates"], seed=args["seed"], alphabet=args["alphabet"], ) diff --git a/rbms/scripts/train_rbm.py b/rbms/scripts/train_rbm.py index 976264c..1c65c36 100644 --- a/rbms/scripts/train_rbm.py +++ b/rbms/scripts/train_rbm.py @@ -3,38 +3,47 @@ import h5py import torch +from rbms import get_saved_updates from rbms.dataset import load_dataset from rbms.dataset.parser import add_args_dataset from rbms.map_model import map_model +from rbms.optim import setup_optim from rbms.parser import ( + add_args_init_rbm, add_args_pytorch, - add_args_rbm, - add_args_regularization, add_args_saves, + add_args_train, + add_grad_args, + add_sampling_args, default_args, match_args_dtype, remove_argument, + set_args_default, ) +from rbms.sampler import CD, PCD, RDM +from rbms.training.implement import _init_training, _restore_training from rbms.training.pcd import train from rbms.training.utils import get_checkpoints -def create_parser(): +def create_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser(description="Train a Restricted Boltzmann Machine") parser = add_args_dataset(parser) - parser = add_args_rbm(parser) - parser = add_args_regularization(parser) + parser = add_args_init_rbm(parser) + parser = add_args_train(parser) + parser = add_sampling_args(parser) + parser = add_grad_args(parser) parser = add_args_saves(parser) parser = add_args_pytorch(parser) remove_argument(parser, "use_torch") return parser -def train_rbm(args: dict): - if args["num_updates"] is None: - args["num_updates"] = default_args["num_updates"] +def main(args, map_model=map_model): checkpoints = get_checkpoints( - num_updates=args["num_updates"], n_save=args["n_save"], spacing=args["spacing"] + num_updates=args["num_updates"], + n_save=args["n_save"], + spacing=args["spacing"], ) train_dataset, test_dataset = load_dataset( dataset_name=args["dataset"], @@ -42,44 +51,157 @@ def train_rbm(args: dict): subset_labels=args["subset_labels"], use_weights=args["use_weights"], alphabet=args["alphabet"], + remove_duplicates=args["remove_duplicates"], device=args["device"], dtype=args["dtype"], ) - print(train_dataset) - if args["restore"]: - with h5py.File(args["filename"], "r") as f: - model_type = f["model_type"][()].decode() - else: - model_type = args["model_type"] - if model_type is None: - match train_dataset.visible_type: - case "binary": - model_type = "BBRBM" - case "categorical": - model_type = "PBRBM" - case _: - raise NotImplementedError() - print(model_type) - train( + flags = ["checkpoint"] + if not args["restore"]: + args = set_args_default(args, default_args=default_args) + _init_training( + train_dataset=train_dataset, + seed=args["seed"], + train_size=args["train_size"], + test_size=1 - args["train_size"], + num_hiddens=args["num_hiddens"], + num_chains=args["num_chains"], + model_type=args["model_type"], + filename=args["filename"], + n_save=args["n_save"], + spacing=args["spacing"], + batch_size=args["batch_size"], + optim=args["optim"], + mult_optim=args["mult_optim"], + training_type=args["training_type"], + learning_rate=args["learning_rate"], + max_lr=args["max_lr"], + gibbs_steps=args["gibbs_steps"], + beta=args["beta"], + centered=not (args["no_center"]), + L1=args["L1"], + L2=args["L2"], + normalize_grad=args["normalize_grad"], + max_norm_grad=args["max_norm_grad"], + subset_labels=args["subset_labels"], + use_weights=args["use_weights"], + alphabet=args["alphabet"], + remove_duplicates=args["remove_duplicates"], + dtype=args["dtype"], + device=args["device"], + flags=flags, + map_model=map_model, + ) + args["update"] = 1 + + args = load_args_from_filename(args) + print(args) + args = set_args_default(args, default_args) + if args["update"] is None: + args["update"] = get_saved_updates(args["filename"])[-1] + ( + params, + parallel_chains, + target_update, + elapsed_time, + train_dataset, + test_dataset, + ) = _restore_training( + filename=args["filename"], train_dataset=train_dataset, test_dataset=test_dataset, - model_type=model_type, - args=args, + num_updates=args["num_updates"], + target_update=args["update"], + seed=args["seed"], + train_size=args["train_size"], + test_size=args["test_size"], + device=args["device"], dtype=args["dtype"], - checkpoints=checkpoints, map_model=map_model, - default_args=default_args, ) + optimizer = setup_optim(args["optim"], args, params) + from rbms.pre_grad import build_pre_grad_update + + pre_grad_update = build_pre_grad_update( + optimizer=optimizer, + lambda_l1=args["L1"], + lambda_l2=args["L2"], + normalize_grad=args["normalize_grad"], + max_grad_norm=args["max_norm_grad"], + ) + + match args["training_type"]: + case "pcd": + sampler = PCD( + params=params, + chains=parallel_chains, + num_steps=args["gibbs_steps"], + beta=args["beta"], + ) + case "cd": + sampler = CD(params=params, num_steps=args["gibbs_steps"], beta=args["beta"]) + case "rdm": + sampler = RDM( + params=params, + num_chains=parallel_chains["visible"].shape[0], + num_steps=args["gibbs_steps"], + beta=args["beta"], + ) + + case _: + raise ValueError(f"No training type {args['training_type']} supported.") + + train( + train_dataset=train_dataset, + test_dataset=test_dataset, + params=params, + sampler=sampler, + optimizer=optimizer, + batch_size=args["batch_size"], + centered=not (args["no_center"]), + curr_update=args["update"], + pre_grad_update=pre_grad_update, + elapsed_time=elapsed_time, + checkpoints=checkpoints, + num_updates=args["num_updates"], + filename=args["filename"], + ) + + +def load_args_from_filename(args: dict): + with h5py.File(args["filename"], "r") as f: + if args["gibbs_steps"] is None: + args["gibbs_steps"] = f["sampling_args"]["gibbs_steps"][()].item() + if args["beta"] is None: + args["beta"] = f["sampling_args"]["beta"][()].item() + if args["optim"] is None: + args["optim"] = str(f["train_args"]["optim"][()]) + if args["batch_size"] is None: + args["batch_size"] = f["train_args"]["batch_size"][()].item() + if args["training_type"] is None: + args["training_type"] = str(f["train_args"]["training_type"][()].decode()) + args["no_center"] = f["grad_args"]["no_center"][()].item() + args["seed"] = f["dataset_args"]["seed"][()].item() + args["train_size"] = f["dataset_args"]["train_size"][()].item() + args["test_size"] = f["dataset_args"]["test_size"][()].item() + if args["L1"] is None: + args["L1"] = f["grad_args"]["L1"][()].item() + if args["L2"] is None: + args["L2"] = f["grad_args"]["L2"][()].item() + if args["normalize_grad"] is None: + args["normalize_grad"] = f["grad_args"]["normalize_grad"][()].item() + if args["max_norm_grad"] is None: + args["max_norm_grad"] = f["grad_args"]["max_norm_grad"][()].item() + + return args -def main(): + +if __name__ == "__main__": + torch.set_float32_matmul_precision("high") torch.backends.cudnn.benchmark = True parser = create_parser() args = parser.parse_args() args = vars(args) + # args = set_args_default(args, default_args=default_args) args = match_args_dtype(args) - train_rbm(args=args) - - -if __name__ == "__main__": - main() + main(args=args) diff --git a/rbms/training/implement.py b/rbms/training/implement.py new file mode 100644 index 0000000..9e798dd --- /dev/null +++ b/rbms/training/implement.py @@ -0,0 +1,194 @@ +import h5py +import numpy as np +import torch + +from rbms.classes import EBM +from rbms.dataset.dataset_class import RBMDataset +from rbms.io import load_model, save_model +from rbms.map_model import map_model +from rbms.utils import get_saved_updates +from torch import Tensor + + +def _init_training( + train_dataset: RBMDataset, + seed: int, + train_size: float, + test_size: float, + num_hiddens: int, + num_chains: int, + model_type: str, + filename: str, + n_save: int, + spacing: str, + batch_size: int, + optim: str, + mult_optim: bool, + training_type: str, + learning_rate: float, + max_lr: float, + gibbs_steps: int, + beta: float, + centered: bool, + L1: float, + L2: float, + normalize_grad: bool, + max_norm_grad: float, + subset_labels: list, + use_weights: bool, + alphabet: str, + remove_duplicates: bool, + dtype: torch.dtype, + device: torch.device | str, + flags: list[str], + map_model: dict[str, type[EBM]] = map_model, +): + if model_type is None: + match train_dataset.variable_type: + case "bernoulli": + model_type = "BBRBM" + case "categorical": + model_type = "PBRBM" + case "ising": + model_type = "IIRBM" + case _: + raise NotImplementedError() + + train_dataset.match_model_variable_type(map_model[model_type].visible_type) + # Setup dataset + num_visibles = train_dataset.get_num_visibles() + + # Setup RBM + params = map_model[model_type].init_parameters( + num_hiddens=num_hiddens, + dataset=train_dataset, + device=device, + dtype=dtype, + ) + + # Permanent chains + parallel_chains = params.init_chains(num_samples=num_chains) + parallel_chains = params.sample_state(chains=parallel_chains, n_steps=gibbs_steps) + + # Save hyperparameters + if mult_optim: + lr = torch.tensor([learning_rate] * len(params.parameters())) + else: + lr = torch.tensor([learning_rate]) + + with h5py.File(filename, "w") as file_model: + hyperparameters = file_model.create_group("hyperparameters") + hyperparameters["num_visibles"] = num_visibles + hyperparameters["num_hiddens"] = num_hiddens + hyperparameters["num_chains"] = num_chains + hyperparameters["filename"] = str(filename) + + save_model( + filename=filename, + params=params, + chains=parallel_chains, + num_updates=1, + time=0.0, + flags=flags, + learning_rate=lr, + ) + + with h5py.File(filename, "a") as f: + dataset = f.create_group("dataset_args") + if subset_labels is not None: + dataset["subset_labels"] = subset_labels + dataset["use_weights"] = use_weights + dataset["train_size"] = train_size + dataset["test_size"] = test_size + dataset["alphabet"] = np.asarray(alphabet, dtype="T") + dataset["remove_duplicates"] = remove_duplicates + dataset["seed"] = seed + + grad = f.create_group("grad_args") + grad["no_center"] = not (centered) + grad["normalize_grad"] = normalize_grad + grad["max_norm_grad"] = max_norm_grad + grad["L1"] = L1 + grad["L2"] = L2 + + sampling = f.create_group("sampling_args") + sampling["gibbs_steps"] = gibbs_steps + sampling["beta"] = beta + + train_args = f.create_group("train_args") + train_args["optim"] = np.asarray(optim, dtype="T") + train_args["batch_size"] = batch_size + train_args["learning_rate"] = lr + train_args["training_type"] = np.asarray(training_type, dtype="T") + train_args["max_lr"] = max_lr + + save_args = f.create_group("save_args") + save_args["n_save"] = n_save + save_args["spacing"] = np.asarray(spacing, dtype="T") + + +def _restore_training( + filename: str, + train_dataset: RBMDataset, + test_dataset: RBMDataset | None, + num_updates: int, + target_update: int, + seed: int, + train_size: float, + test_size: float, + device: str, + dtype: torch.dtype, + map_model: dict[str, type[EBM]] = map_model, +) -> tuple[EBM, dict[str, Tensor], int, float, RBMDataset, RBMDataset]: + # Retrieve the the number of training updates already performed on the model + print(f"Restoring training from update {target_update}") + + if num_updates <= target_update: + raise RuntimeError( + f"The parameter /'num_updates/' ({num_updates}) must be greater than the previous number of updates ({target_update})." + ) + + params, parallel_chains, elapsed_time = load_model( + filename, + target_update, + device=device, + dtype=dtype, + restore=True, + map_model=map_model, + ) + + # Delete all updates after the current one + saved_updates = get_saved_updates(filename) + if saved_updates[-1] > target_update: + to_delete = saved_updates[saved_updates > target_update] + with h5py.File(filename, "a") as f: + print("Deleting:") + for upd in to_delete: + print(f" - {upd}") + del f[f"update_{upd}"] + + if test_dataset is None: + print("Splitting dataset") + train_dataset, test_dataset = train_dataset.split_train_test( + rng=np.random.default_rng(seed), + train_size=train_size, + test_size=test_size, + ) + print("Train dataset:") + print(train_dataset) + print("Test dataset:") + print(test_dataset) + + # Initialize gradients for the parameters + params.init_grad() + + train_dataset.match_model_variable_type(params.visible_type) + test_dataset.match_model_variable_type(params.visible_type) + return ( + params, + parallel_chains, + target_update, + elapsed_time, + train_dataset, + test_dataset, + ) diff --git a/rbms/training/pcd.py b/rbms/training/pcd.py index db938e5..2cae466 100644 --- a/rbms/training/pcd.py +++ b/rbms/training/pcd.py @@ -1,179 +1,113 @@ import time -import h5py import numpy as np import torch -from torch import Tensor -from torch.optim import SGD, Optimizer +from torch.optim import Optimizer +from tqdm.autonotebook import tqdm -from rbms.classes import EBM +from rbms.classes import EBM, Sampler from rbms.dataset.dataset_class import RBMDataset -from rbms.io import save_model -from rbms.map_model import map_model -from rbms.parser import default_args, set_args_default -from rbms.potts_bernoulli.classes import PBRBM -from rbms.potts_bernoulli.utils import ensure_zero_sum_gauge -from rbms.training.utils import initialize_model_archive, setup_training -from rbms.utils import check_file_existence, log_to_csv - - -def fit_batch_pcd( - batch: tuple[Tensor, Tensor], - parallel_chains: dict[str, Tensor], - params: EBM, - gibbs_steps: int, - beta: float, - centered: bool = True, - lambda_l1: float = 0.0, - lambda_l2: float = 0.0, -) -> tuple[dict[str, Tensor], dict]: - """Sample the EBM and compute the gradient. - - Args: - batch (Tuple[Tensor, Tensor]): Dataset samples and associated weights. - parallel_chains (dict[str, Tensor]): Parallel chains used for gradient computation. - params (EBM): Parameters of the EBM. - gibbs_steps (int): Number of Gibbs steps to perform. - beta (float): Inverse temperature. - - Returns: - Tuple[dict[str, Tensor], dict]: A tuple containing the updated chains and the logs. - """ - v_data, w_data = batch - # Initialize batch - curr_batch = params.init_chains( - num_samples=v_data.shape[0], - weights=w_data, - start_v=v_data, - ) - # sample permanent chains - parallel_chains = params.sample_state( - chains=parallel_chains, n_steps=gibbs_steps, beta=beta - ) - params.compute_gradient( - data=curr_batch, - chains=parallel_chains, - centered=centered, - lambda_l1=lambda_l1, - lambda_l2=lambda_l2, - ) - params.normalize_grad() - logs = {} - return parallel_chains, logs +from rbms.io import save_model, save_sampler +from rbms.training.utils import EarlyStopper +@torch.compile(dynamic=True, disable=True) +@torch.no_grad def train( train_dataset: RBMDataset, - test_dataset: RBMDataset | None, - model_type: str, - args: dict, - dtype: torch.dtype, + test_dataset: RBMDataset, + params: EBM, + sampler: Sampler, + optimizer: list[Optimizer], + # early_stopper: EarlyStopper | None, + batch_size: int, + centered: bool, + curr_update: int, + pre_grad_update: torch.nn.Sequential, + elapsed_time: float, checkpoints: np.ndarray, - optim: Optimizer = SGD, - map_model: dict[str, EBM] = map_model, - default_args: dict = default_args, -) -> None: - """Train an EBM. - - Args: - dataset (RBMDataset): The training dataset. - test_dataset (RBMDataset): The test dataset (not used). - model_type (str): Type of RBM used (BBRBM or PBRBM) - args (dict): A dictionary of training arguments. - dtype (torch.dtype): The data type for the parameters. - checkpoints (np.ndarray): An array of checkpoints for saving model states. - """ - - if not (args["overwrite"]): - check_file_existence(args["filename"]) - - # Create a first archive with the initialized model - if not (args["restore"]): - initialize_model_archive( - args=args, - model_type=model_type, - train_dataset=train_dataset, - test_dataset=test_dataset, - dtype=dtype, - ) - ( - params, - parallel_chains, - args, - num_updates, - start, - elapsed_time, - log_filename, - pbar, - train_dataset, - test_dataset, - ) = setup_training( - args, - map_model=map_model, - train_dataset=train_dataset, - test_dataset=test_dataset, + num_updates: int, + filename: str, +): + pbar = tqdm( + initial=curr_update, + total=num_updates, + colour="red", + dynamic_ncols=True, + ascii="-#", ) - args = set_args_default(args=args, default_args=default_args) - - optimizer = optim(params.parameters(), lr=args["learning_rate"], maximize=True) - - # Continue the training - with torch.no_grad(): - for idx in range(num_updates + 1, args["num_updates"] + 1): - rand_idx = torch.randperm(len(train_dataset))[: args["batch_size"]] - batch = (train_dataset.data[rand_idx], train_dataset.weights[rand_idx]) - if args["training_type"] == "rdm": - parallel_chains = params.init_chains(parallel_chains["visible"].shape[0]) - elif args["training_type"] == "cd": - parallel_chains = params.init_chains( - batch[0].shape[0], weights=batch[1], start_v=batch[0] - ) - optimizer.zero_grad(set_to_none=False) - - parallel_chains, logs = fit_batch_pcd( - batch=batch, - parallel_chains=parallel_chains, + pbar.set_description(f"Training {params.name}") + + start = time.perf_counter() + + for idx in range(curr_update + 1, num_updates + 1): + batch = train_dataset.batch(batch_size) + data, weights = batch["data"], batch["weights"] + + for opt in optimizer: + opt.zero_grad(set_to_none=False) + + # Initialize batch + curr_batch = params.init_chains( + num_samples=data.shape[0], + weights=weights, + start_v=data, + ) + parallel_chains = sampler.get_conf_grad(batch=data) + + params.compute_gradient( + data=curr_batch, + chains=parallel_chains, + centered=centered, + ) + # Do a bunch of modification on the gradient + + pre_grad_update(input=None) + params.pre_grad_update() + sampler.pre_grad_update() + + for opt in optimizer: + opt.step() + + params.post_grad_update() + sampler.post_grad_update(params=params) + + # Get flags for save + flags = [] + flags = params.save_flags(flags) + flags = sampler.save_flags(flags) + if idx in checkpoints or idx == num_updates: + flags.append("checkpoint") + + if len(flags) > 0: + names_params = ( + list(params.named_parameters().keys()) if len(optimizer) > 1 else ["all"] + ) + learning_rates = np.asarray([opt.param_groups[0]["lr"] for opt in optimizer]) + + metrics = {} + metrics = sampler.get_metrics_display( + metrics, train_dataset=train_dataset, test_dataset=test_dataset + ) + pbar.write(f"=========== Update {idx} ===========") + for k, v in metrics.items(): + pbar.write(f"{k}: {v}") + pbar.write("learning rate :") + for i in range(len(optimizer)): + pbar.write(f" - {names_params[i]} : {learning_rates[i]:.6f}") + + # pbar.write(metrics) + curr_time = time.perf_counter() - start + learning_rate = torch.tensor([opt.param_groups[0]["lr"] for opt in optimizer]) + save_model( + filename=filename, params=params, - gibbs_steps=args["gibbs_steps"], - beta=args["beta"], - centered=not (args["no_center"]), - lambda_l1=args["L1"], - lambda_l2=args["L2"], + chains=parallel_chains, + num_updates=idx, + time=curr_time + elapsed_time, + learning_rate=learning_rate, + flags=flags, ) - optimizer.step() - if isinstance(params, PBRBM): - ensure_zero_sum_gauge(params) - - # Save current model if necessary - if idx in checkpoints: - curr_time = time.time() - start - save_model( - filename=args["filename"], - params=params, - chains=parallel_chains, - num_updates=idx, - time=curr_time + elapsed_time, - flags=["checkpoint"], - ) - - # Save some logs - learning_rates = np.array([optimizer.param_groups[0]["lr"]]) - with h5py.File(args["filename"], "a") as f: - if "learning_rate" in f.keys(): - learning_rates = np.append(f["learning_rate"][()], learning_rates) - del f["learning_rate"] - f["learning_rate"] = learning_rates - if hasattr(optimizer, "cosine_similarity"): - if "cosine_similarities" in f.keys(): - cosine_similarities = np.append( - f["cosine_similarities"][()], - optimizer.cosine_similarity, - ) - del f["cosine_similarities"] - f["cosine_similarities"] = cosine_similarities - - if args["log"]: - log_to_csv(logs, log_file=log_filename) - pbar.set_postfix_str(f"lr: {optimizer.param_groups[0]['lr']:.6f}") - # Update progress bar - pbar.update(1) + + save_sampler(filename, sampler, idx) + pbar.update(1) diff --git a/rbms/training/utils.py b/rbms/training/utils.py index a54e5bc..fe6bec8 100644 --- a/rbms/training/utils.py +++ b/rbms/training/utils.py @@ -1,167 +1,4 @@ -import pathlib -import time -from typing import Any - -import h5py import numpy as np -import torch -from torch import Tensor -from tqdm import tqdm - -from rbms.classes import EBM -from rbms.const import LOG_FILE_HEADER -from rbms.dataset.dataset_class import RBMDataset -from rbms.io import load_model, save_model -from rbms.map_model import map_model -from rbms.parser import default_args, set_args_default -from rbms.potts_bernoulli.classes import PBRBM -from rbms.potts_bernoulli.utils import ensure_zero_sum_gauge -from rbms.utils import get_saved_updates - - -def setup_training( - args: dict, - train_dataset: RBMDataset, - test_dataset: RBMDataset | None = None, - map_model: dict[str, EBM] = map_model, -) -> tuple[ - EBM, - dict[str, Tensor], - dict[str, Any], - int, - float, - float, - pathlib.Path, - tqdm, - RBMDataset, - RBMDataset, -]: - # Retrieve the the number of training updates already performed on the model - updates = get_saved_updates(filename=args["filename"]) - num_updates = updates[-1] - if args["num_updates"] <= num_updates: - raise RuntimeError( - f"The parameter /'num_updates/' ({args['num_updates']}) must be greater than the previous number of updates ({num_updates})." - ) - - params, parallel_chains, elapsed_time, hyperparameters = load_model( - args["filename"], - num_updates, - device=args["device"], - dtype=args["dtype"], - restore=True, - map_model=map_model, - ) - - # Hyperparameters - for k, v in hyperparameters.items(): - if args[k] is None: - args[k] = v - - if test_dataset is None: - train_dataset, test_dataset = train_dataset.split_train_test( - rng=np.random.default_rng(args["seed"]), - train_size=args["train_size"], - test_size=args["test_size"], - ) - - # Open the log file if it exists - log_filename = pathlib.Path(args["filename"]).parent / pathlib.Path( - f"log-{pathlib.Path(args['filename']).stem}.csv" - ) - args["log"] = log_filename.exists() - - # Progress bar - pbar = tqdm( - initial=num_updates, - total=args["num_updates"], - colour="red", - dynamic_ncols=True, - ascii="-#", - ) - pbar.set_description(f"Training {params.name}") - - # Initialize gradients for the parameters - params.init_grad() - - # Start recording training time - start = time.time() - - return ( - params, - parallel_chains, - args, - num_updates, - start, - elapsed_time, - log_filename, - pbar, - train_dataset, - test_dataset, - ) - - -def create_machine( - filename: str, - params: EBM, - num_visibles: int, - num_hiddens: int, - num_chains: int, - batch_size: int, - gibbs_steps: int, - learning_rate: float, - train_size: float, - log: bool, - flags: list[str], - seed: int, - L1: float, - L2: float, -) -> None: - """Create a RBM and save it to a new file. - - Args: - filename (str): The name of the file to save the RBM. - params (RBM): Initialized parameters. - num_visibles (int): Number of visible units. - num_hiddens (int): Number of hidden units. - num_chains (int): Number of parallel chains for gradient computation. - batch_size (int): Size of the data batch. - gibbs_steps (int): Number of Gibbs steps to perform. - learning_rate (float): Learning rate for training. - log (bool): Whether to enable logging. - L1 (float): Lambda parameter for L1 regularization. - L2 (float): Lambda parameter for L2 regularization. - """ - # Permanent chains - parallel_chains = params.init_chains(num_samples=num_chains) - parallel_chains = params.sample_state(chains=parallel_chains, n_steps=gibbs_steps) - with h5py.File(filename, "w") as file_model: - hyperparameters = file_model.create_group("hyperparameters") - hyperparameters["num_hiddens"] = num_hiddens - hyperparameters["num_visibles"] = num_visibles - hyperparameters["num_chains"] = num_chains - hyperparameters["batch_size"] = batch_size - hyperparameters["gibbs_steps"] = gibbs_steps - hyperparameters["filename"] = str(filename) - hyperparameters["learning_rate"] = learning_rate - hyperparameters["train_size"] = train_size - hyperparameters["seed"] = seed - hyperparameters["L1"] = L1 - hyperparameters["L2"] = L2 - - save_model( - filename=filename, - params=params, - chains=parallel_chains, - num_updates=1, - time=0.0, - flags=flags, - ) - if log: - filename = pathlib.Path(filename) - log_filename = filename.parent / pathlib.Path(f"log-{filename.stem}.csv") - with open(log_filename, "w", encoding="utf-8") as log_file: - log_file.write(",".join(LOG_FILE_HEADER) + "\n") def get_checkpoints(num_updates: int, n_save: int, spacing: str = "exp") -> np.ndarray: @@ -191,43 +28,19 @@ def get_checkpoints(num_updates: int, n_save: int, spacing: str = "exp") -> np.n return checkpoints -def initialize_model_archive( - args: dict, - model_type: str, - train_dataset: RBMDataset, - test_dataset: RBMDataset | None, - dtype: torch.dtype, - flags: list[str] = ["checkpoint"], -): - num_visibles = train_dataset.get_num_visibles() - args = set_args_default(args=args, default_args=default_args) - rng = np.random.default_rng(args["seed"]) - if test_dataset is None: - train_dataset, _ = train_dataset.split_train_test( - rng, args["train_size"], args["test_size"] - ) - params = map_model[model_type].init_parameters( - num_hiddens=args["num_hiddens"], - dataset=train_dataset, - device=args["device"], - dtype=dtype, - ) - - if isinstance(params, PBRBM): - ensure_zero_sum_gauge(params) - create_machine( - filename=args["filename"], - params=params, - num_visibles=num_visibles, - num_hiddens=args["num_hiddens"], - num_chains=args["num_chains"], - batch_size=args["batch_size"], - gibbs_steps=args["gibbs_steps"], - learning_rate=args["learning_rate"], - train_size=args["train_size"], - log=args["log"], - flags=flags, - seed=args["seed"], - L1=args["L1"], - L2=args["L2"], - ) +class EarlyStopper: + def __init__(self, patience=1, min_delta=0): + self.patience = patience + self.min_delta = min_delta + self.counter = 0 + self.min_validation_loss = float("inf") + + def early_stop(self, validation_loss): + if validation_loss < self.min_validation_loss: + self.min_validation_loss = validation_loss + self.counter = 0 + elif validation_loss > (self.min_validation_loss + self.min_delta): + self.counter += 1 + if self.counter >= self.patience: + return True + return False diff --git a/rbms/utils.py b/rbms/utils.py index 5cc647d..d9ba483 100644 --- a/rbms/utils.py +++ b/rbms/utils.py @@ -13,7 +13,7 @@ from rbms.ising_ising.classes import IIRBM -def get_eigenvalues_history(filename: str): +def get_eigenvalues_history(filename: str, backend="cpu"): """ Extracts the history of eigenvalues of the RBM's weight matrix. @@ -25,23 +25,36 @@ def get_eigenvalues_history(filename: str): - gradient_updates (np.ndarray): Array of gradient update steps. - eigenvalues (np.ndarray): Eigenvalues along training. """ - with h5py.File(filename, "r") as f: - gradient_updates = [] - eigenvalues = [] - for key in f.keys(): - if "update_" in key: - weight_matrix = f[key]["params"]["weight_matrix"][()] - weight_matrix = weight_matrix.reshape(-1, weight_matrix.shape[-1]) + saved_updates = get_saved_updates(filename) + eigenvalues = [] + for upd in saved_updates: + compute = False + with h5py.File(filename, "a") as f: + if "singular_values" not in f[f"update_{upd}"]: + compute = True + weight_matrix = f[f"update_{upd}"]["params"]["weight_matrix"][()] + + if compute: + weight_matrix = weight_matrix.reshape(-1, weight_matrix.shape[-1]) + if backend == "gpu": + eig = ( + torch.svd( + torch.from_numpy(weight_matrix).to(device="cuda"), + compute_uv=False, + ) + .S.cpu() + .numpy() + ) + else: eig = np.linalg.svd(weight_matrix, compute_uv=False) - eigenvalues.append(eig.reshape(*eig.shape, 1)) - gradient_updates.append(int(key.split("_")[1])) - - # Sort the results - sorting = np.argsort(gradient_updates) - gradient_updates = np.array(gradient_updates)[sorting] - eigenvalues = np.array(np.hstack(eigenvalues).T)[sorting] + with h5py.File(filename, "a") as f: + f[f"update_{upd}"]["singular_values"] = eig - return gradient_updates, eigenvalues + with h5py.File(filename, "a") as f: + eig = f[f"update_{upd}"]["singular_values"][()] + eigenvalues.append(eig.reshape(*eig.shape, 1)) + eigenvalues = np.array(np.hstack(eigenvalues).T) + return saved_updates, eigenvalues def get_saved_updates(filename: str) -> np.ndarray: @@ -297,19 +310,19 @@ def get_flagged_updates(filename: str, flag: str) -> np.ndarray: if flag in f[key]["flags"]: if f[key]["flags"][flag][()]: flagged_updates.append(update) - flagged_updates = np.sort(np.array(flagged_updates)) + flagged_updates = np.sort(np.array(flagged_updates, dtype=int)) return flagged_updates def bernoulli_to_ising(params: BBRBM) -> IIRBM: weight_matrix = 0.25 * params.weight_matrix - vbias = 0.5 * params.vbias + weight_matrix.sum(axis=1) - hbias = 0.5 * params.hbias + weight_matrix.sum(axis=0) + vbias = 0.5 * params.vbias + weight_matrix.sum(dim=1) + hbias = 0.5 * params.hbias + weight_matrix.sum(dim=0) return IIRBM(vbias=vbias, hbias=hbias, weight_matrix=weight_matrix) def ising_to_bernoulli(params: IIRBM) -> BBRBM: weight_matrix = 4.0 * params.weight_matrix - vbias = 2.0 * params.vbias - 2.0 * params.weight_matrix.sum(axis=1) - hbias = 2.0 * params.hbias - 2.0 * params.weight_matrix.sum(axis=0) + vbias = 2.0 * params.vbias - 2.0 * params.weight_matrix.sum(dim=1) + hbias = 2.0 * params.hbias - 2.0 * params.weight_matrix.sum(dim=0) return BBRBM(vbias=vbias, hbias=hbias, weight_matrix=weight_matrix) diff --git a/tests/conftest.py b/tests/conftest.py index c203b5a..6ad857e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -156,6 +156,8 @@ def sample_args(tmp_path): "L1": 0.0, "L2": 1.0, "training_type": "pcd", + "optim": "sgd", + "remove_duplicates": False, "model_type": None, } diff --git a/tests/unit_test/bernoulli_bernoulli/test_bernoulli_utils_bbrbm.py b/tests/unit_test/bernoulli_bernoulli/test_bernoulli_utils_bbrbm.py index dd580f6..e69de29 100644 --- a/tests/unit_test/bernoulli_bernoulli/test_bernoulli_utils_bbrbm.py +++ b/tests/unit_test/bernoulli_bernoulli/test_bernoulli_utils_bbrbm.py @@ -1,124 +0,0 @@ -from pathlib import Path - -import pytest -import torch - -from rbms.bernoulli_bernoulli.classes import BBRBM -from rbms.const import LOG_FILE_HEADER -from rbms.io import load_model -from rbms.training.utils import create_machine - - -# Helper function to create a temporary HDF5 file for testing -def create_temp_hdf5_file(tmp_path, sample_params_class_bbrbm, sample_chains_bbrbm): - filename = tmp_path / "test_model.h5" - create_machine( - filename, - params=sample_params_class_bbrbm, - chains=sample_chains_bbrbm, - num_updates=1, - time=0.0, - ) - return filename - - -def test_create_load_machine(tmp_path, sample_params_class_bbrbm): - filename = tmp_path / "test_model.h5" - device = torch.device("cpu") - dtype = torch.float32 - create_machine( - filename=str(filename), - params=sample_params_class_bbrbm, - num_visibles=pytest.NUM_VISIBLES, - num_hiddens=pytest.NUM_HIDDENS, - num_chains=pytest.NUM_CHAINS, - batch_size=pytest.BATCH_SIZE, - gibbs_steps=pytest.GIBBS_STEPS, - learning_rate=pytest.LEARNING_RATE, - train_size=pytest.TRAIN_SIZE, - log=True, - flags=["test"], - seed=pytest.SEED, - L1=1.0, - L2=0.0, - ) - - # Check if the file was created - assert filename.exists() - - # Check if the log file was created - log_filename = filename.parent / Path(f"log-{filename.stem}.csv") - assert log_filename.exists() - - # Check the contents of the log file - with open(log_filename, "r", encoding="utf-8") as log_file: - header = log_file.readline().strip() - assert header == ",".join(LOG_FILE_HEADER) - - params, chains, start, hyperparameters = load_model( - filename=str(filename), - index=1, - device=device, - dtype=dtype, - restore=False, - ) - assert isinstance(params, BBRBM) - assert isinstance(chains, dict) - assert isinstance(start, float) - assert isinstance(hyperparameters, dict) - assert hyperparameters["batch_size"] == pytest.BATCH_SIZE - assert hyperparameters["gibbs_steps"] == pytest.GIBBS_STEPS - assert hyperparameters["learning_rate"] == pytest.LEARNING_RATE - assert hyperparameters["seed"] == pytest.SEED - - -def test_create_load_machine_dtype(tmp_path, sample_params_class_bbrbm): - filename = tmp_path / "test_model.h5" - device = torch.device("cpu") - dtype = torch.float64 - create_machine( - filename=str(filename), - params=sample_params_class_bbrbm, - num_visibles=pytest.NUM_VISIBLES, - num_hiddens=pytest.NUM_HIDDENS, - num_chains=pytest.NUM_CHAINS, - batch_size=pytest.BATCH_SIZE, - gibbs_steps=pytest.GIBBS_STEPS, - learning_rate=pytest.LEARNING_RATE, - train_size=pytest.TRAIN_SIZE, - log=True, - flags=["test"], - seed=pytest.SEED, - L1=0.0, - L2=1.0, - ) - - # Check if the file was created - assert filename.exists() - - # Check if the log file was created - log_filename = filename.parent / Path(f"log-{filename.stem}.csv") - assert log_filename.exists() - - # Check the contents of the log file - with open(log_filename, "r", encoding="utf-8") as log_file: - header = log_file.readline().strip() - assert header == ",".join(LOG_FILE_HEADER) - - params, chains, start, hyperparameters = load_model( - filename=str(filename), - index=1, - device=device, - dtype=dtype, - restore=False, - ) - assert isinstance(params, BBRBM) - assert isinstance(chains, dict) - assert isinstance(start, float) - assert isinstance(hyperparameters, dict) - assert hyperparameters["batch_size"] == pytest.BATCH_SIZE - assert hyperparameters["gibbs_steps"] == pytest.GIBBS_STEPS - assert hyperparameters["learning_rate"] == pytest.LEARNING_RATE - assert hyperparameters["seed"] == pytest.SEED - assert chains["weights"].shape == (pytest.NUM_CHAINS,) - assert chains["visible"].shape == (pytest.NUM_CHAINS, pytest.NUM_VISIBLES) diff --git a/tests/unit_test/potts_bernoulli/test_utils_pbrbm.py b/tests/unit_test/potts_bernoulli/test_utils_pbrbm.py index 77ae445..e69de29 100644 --- a/tests/unit_test/potts_bernoulli/test_utils_pbrbm.py +++ b/tests/unit_test/potts_bernoulli/test_utils_pbrbm.py @@ -1,118 +0,0 @@ -from pathlib import Path - -import pytest -import torch - -from rbms.const import LOG_FILE_HEADER -from rbms.io import load_model -from rbms.potts_bernoulli.classes import PBRBM -from rbms.training.utils import create_machine - - -# Helper function to create a temporary HDF5 file for testing -def create_temp_hdf5_file(tmp_path, sample_params, sample_chains): - filename = tmp_path / "test_model.h5" - create_machine( - filename, params=sample_params, chains=sample_chains, num_updates=1, time=0 - ) - return filename - - -def test_create_load_machine(tmp_path, sample_params_class_pbrbm): - filename = tmp_path / "test_model.h5" - device = torch.device("cpu") - dtype = torch.float32 - create_machine( - filename=str(filename), - params=sample_params_class_pbrbm, - num_visibles=pytest.NUM_VISIBLES, - num_hiddens=pytest.NUM_HIDDENS, - num_chains=pytest.NUM_CHAINS, - batch_size=pytest.BATCH_SIZE, - gibbs_steps=pytest.GIBBS_STEPS, - learning_rate=pytest.LEARNING_RATE, - train_size=pytest.TRAIN_SIZE, - log=True, - flags=["test"], - seed=pytest.SEED, - L1=1.0, - L2=0.0, - ) - - # Check if the file was created - assert filename.exists() - - # Check if the log file was created - log_filename = filename.parent / Path(f"log-{filename.stem}.csv") - assert log_filename.exists() - - # Check the contents of the log file - with open(log_filename, "r", encoding="utf-8") as log_file: - header = log_file.readline().strip() - assert header == ",".join(LOG_FILE_HEADER) - - params, chains, start, hyperparameters = load_model( - filename=str(filename), - index=1, - device=device, - dtype=dtype, - restore=False, - ) - assert isinstance(params, PBRBM) - assert isinstance(chains, dict) - assert isinstance(start, float) - assert isinstance(hyperparameters, dict) - assert hyperparameters["batch_size"] == pytest.BATCH_SIZE - assert hyperparameters["gibbs_steps"] == pytest.GIBBS_STEPS - assert hyperparameters["learning_rate"] == pytest.LEARNING_RATE - assert hyperparameters["seed"] == pytest.SEED - - -def test_create_load_machine_dtype(tmp_path, sample_params_class_pbrbm): - filename = tmp_path / "test_model.h5" - device = torch.device("cpu") - dtype = torch.float64 - create_machine( - filename=str(filename), - params=sample_params_class_pbrbm, - num_visibles=pytest.NUM_VISIBLES, - num_hiddens=pytest.NUM_HIDDENS, - num_chains=pytest.NUM_CHAINS, - batch_size=pytest.BATCH_SIZE, - gibbs_steps=pytest.GIBBS_STEPS, - learning_rate=pytest.LEARNING_RATE, - train_size=pytest.TRAIN_SIZE, - log=True, - flags=["test"], - seed=pytest.SEED, - L1=0.0, - L2=1.0, - ) - - # Check if the file was created - assert filename.exists() - - # Check if the log file was created - log_filename = filename.parent / Path(f"log-{filename.stem}.csv") - assert log_filename.exists() - - # Check the contents of the log file - with open(log_filename, "r", encoding="utf-8") as log_file: - header = log_file.readline().strip() - assert header == ",".join(LOG_FILE_HEADER) - - params, chains, start, hyperparameters = load_model( - filename=str(filename), - index=1, - device=device, - dtype=dtype, - restore=False, - ) - assert isinstance(params, PBRBM) - assert isinstance(chains, dict) - assert isinstance(start, float) - assert isinstance(hyperparameters, dict) - assert hyperparameters["batch_size"] == pytest.BATCH_SIZE - assert hyperparameters["gibbs_steps"] == pytest.GIBBS_STEPS - assert hyperparameters["learning_rate"] == pytest.LEARNING_RATE - assert hyperparameters["seed"] == pytest.SEED diff --git a/tests/unit_test/test_utils.py b/tests/unit_test/test_utils.py index d4fc94c..f63b017 100644 --- a/tests/unit_test/test_utils.py +++ b/tests/unit_test/test_utils.py @@ -229,7 +229,15 @@ def test_save_model(tmp_path, sample_params_class_bbrbm, sample_chains_bbrbm): num_updates = 1 time = 0.0 - save_model(str(filename), params, chains, num_updates, time, ["flag_1", "flag_2"]) + save_model( + str(filename), + params, + chains, + num_updates, + time, + torch.tensor([0.01, 0.01, 0.01]), + ["flag_1", "flag_2"], + ) with h5py.File(filename, "r") as f: assert "update_1" in f.keys() diff --git a/tests/use_cases/test_bbrbm.py b/tests/use_cases/test_bbrbm.py index 96bc062..88b7eb5 100644 --- a/tests/use_cases/test_bbrbm.py +++ b/tests/use_cases/test_bbrbm.py @@ -61,6 +61,8 @@ def test_use_case_train_bbrbm(): "L1": 0.0, "L2": 1.0, "training_type": "pcd", + "optim": "sgd", + "remove_duplicates": False, "model_type": "BBRBM", } train_rbm(args) diff --git a/tests/use_cases/test_pbrbm.py b/tests/use_cases/test_pbrbm.py index b15696e..650476a 100644 --- a/tests/use_cases/test_pbrbm.py +++ b/tests/use_cases/test_pbrbm.py @@ -61,6 +61,8 @@ def test_use_case_train_pbrbm_no_weights(): "L1": 0.0, "L2": 1.0, "training_type": "pcd", + "optim": "sgd", + "remove_duplicates": False, "model_type": "PBRBM", } train_rbm(args) @@ -138,6 +140,8 @@ def test_use_case_train_pbrbm_weights(): "L1": 1.0, "L2": 0.0, "training_type": "pcd", + "optim": "sgd", + "remove_duplicates": False, "model_type": "PBRBM", } train_rbm(args)