Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
86 commits
Select commit Hold shift + click to select a range
6f42ce7
save code
Oct 17, 2025
36d2aa1
remove warmup
Oct 17, 2025
87c8e69
compile normalize grad + figsize factor on plot
Nov 4, 2025
d86808f
make remove duplicates optional
Nov 17, 2025
11b952b
fix prev_size dataset in split_data script
Dec 1, 2025
e0c9be4
format code
Dec 15, 2025
455c2d5
add map model as option
Dec 16, 2025
fa9baa4
Merge branch 'develop' of github.com:DsysDML/rbms into papier_ptt_train
Dec 19, 2025
1ee0b5c
Merge branch 'develop' of github.com:DsysDML/rbms into papier_ptt_train
Dec 19, 2025
4605d21
Merge branch 'develop' of github.com:DsysDML/rbms into papier_ptt_train
Dec 19, 2025
2cbd4e4
fix merge
Dec 19, 2025
9d9a915
remove 3.14 as torch compile is not supported yet
Dec 19, 2025
716da6f
fix merge
Dec 19, 2025
b6dcefd
add missing keys to args dict in tests
Dec 19, 2025
0bcc1cc
batch function dataset
Jan 9, 2026
1e06197
Merge branch 'develop' of github.com:DsysDML/rbms into papier_ptt_train
Jan 9, 2026
efe31de
use batch method
Jan 9, 2026
4695ab1
add visible_type to EBM class
Jan 13, 2026
0cae04c
change variable_type after conversion
Jan 13, 2026
10af4f2
change variable_type from binary to bernoulli
Jan 13, 2026
7607287
add visible_type
Jan 13, 2026
152803e
add categorical_to_bernoulli implementation
Jan 13, 2026
878f14e
fix variable_type
Jan 13, 2026
3520807
match dataset variable type with model visible type
Jan 13, 2026
1c75c62
sample bernoulli when variable_type is bernoulli
Jan 14, 2026
fc19aef
add log_scale option to PCA plot
Jan 14, 2026
2ca6df6
removed unused variable in non centered gradient
Jan 27, 2026
1f4d543
add conversion print + astype to dataset class
Jan 27, 2026
60c42f6
add __eq__ to class for easier comparison
Jan 27, 2026
852725d
add IIRBM and BGRBM to map_model
Jan 27, 2026
1cfd16e
add model_type and normalize_grad option to parser
Jan 27, 2026
8c69c83
add dataset weights arg
Jan 27, 2026
da1035c
fix binary to bernoulli and add ising to model match
Jan 27, 2026
c654f8f
make normalize_grad optional
Jan 27, 2026
8365a70
save result from get_eigenvalues_history in file to avoid repeating c…
Jan 27, 2026
1a3682b
change version number
Jan 27, 2026
ecbfb08
simplify imports
Jan 27, 2026
9da903c
fix: add __init__ to bernoulli_gaussian
Jan 29, 2026
bdf553e
clip grad
Feb 4, 2026
dc54a16
rework the main loop and add rbms restore script allowing to change m…
Feb 4, 2026
be952e4
new parser, keep the old fucntions for compatibility
Feb 4, 2026
8559d4e
save learning rate during training and remove the hyperparameters loa…
Feb 4, 2026
feb4e0a
util to handle optimizer declaration
Feb 4, 2026
9544f24
remove test for removed function
Feb 4, 2026
34ee17a
remove weights from init_parameters
Feb 4, 2026
b363fff
add learning_rate
Feb 4, 2026
7ed3c2b
margaret update
Feb 11, 2026
6703879
remove prefactor variance initialization w
Feb 18, 2026
7ecb673
formatting
Feb 18, 2026
cb86ee3
fix gzip
Feb 18, 2026
11541af
add Sampler class for more training modularity
Feb 18, 2026
040036f
add cossim optimizer
Feb 18, 2026
33d1588
put gradient modifications in one pipeline
Feb 18, 2026
ca25934
cd, pcd and rdm sampler
Feb 18, 2026
56fc7a3
use new training function and script
Feb 18, 2026
9b533e9
fix some compatibility for restore
Feb 18, 2026
0f53b0f
Merge branch 'develop' of github.com:DsysDML/rbms into papier_ptt_train
Feb 19, 2026
4db095b
put all saving methods with np array
Feb 19, 2026
475b7d9
change sampler call during training to get_grad_conf and add post_gra…
Feb 20, 2026
f81979d
add check keys dict + save sampler name
Feb 20, 2026
63a19ed
cleaning code
nbereux Feb 23, 2026
3e2231e
cleaning code
nbereux Feb 23, 2026
4dcf606
fix ordering abstractmethod property
nbereux Feb 23, 2026
34b5833
remove obsolete code
nbereux Feb 23, 2026
2e0441a
remove old parser
nbereux Feb 23, 2026
cce1977
pre_grad_update sampler and model
nbereux Feb 23, 2026
b04f8a1
various changes in ising_gauss
AurelienDecelle Feb 24, 2026
1c94ad3
rename var IGRBM
AurelienDecelle Feb 24, 2026
66d819c
corrected minor bugs for free-energy computation of IGRBM
AurelienDecelle Feb 24, 2026
02145de
add str option device
nbereux Feb 24, 2026
97e220a
Merge branch 'papier_ptt_train' of github.com:DsysDML/rbms into papie…
nbereux Feb 24, 2026
4770d9f
remove save method
nbereux Feb 24, 2026
997db0a
add kwargs metrics
nbereux Feb 24, 2026
eb0b20e
type hinting
Feb 25, 2026
737bed8
metrics save and display for sampler
Feb 25, 2026
0135092
remove unused import
Feb 25, 2026
b3e5b57
type hint
Feb 25, 2026
5329755
add some safeguard against None
Feb 25, 2026
f3ebeb9
correct type hint for map
Feb 25, 2026
476671a
regularization is now handled by pre_grad_update
Feb 25, 2026
35881d5
remove compilation
Feb 26, 2026
d6068e9
remove unused code + type hint
Feb 26, 2026
424976b
fix format of default names when splitting data
Mar 2, 2026
3f7e213
save parallel chains as bool/int
Mar 4, 2026
de94167
fix read training_type
Mar 4, 2026
4a7f0e0
update h5py version
Mar 4, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/codecov.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: 3.14
python-version: 3.13

- name: Install test dependencies
run: pip install pytest pytest-cov
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.12, 3.13, 3.14]
python-version: [3.12, 3.13]
steps:
- name: Checkout
uses: actions/checkout@v4
Expand Down
10 changes: 5 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "rbms"
version = "0.5.0"
version = "0.6.0"
authors = [
{name="Nicolas Béreux", email="nicolas.bereux@gmail.com"},
{name="Aurélien Decelle"},
Expand All @@ -19,12 +19,12 @@ maintainers = [
]
description = "Training and analyzing Restricted Boltzmann Machines in PyTorch"
readme = "README.md"
requires-python = ">=3.12"
requires-python = ">=3.12, <3.14"
dependencies = [
"h5py>=3.12.0",
"h5py>=3.14.0",
"numpy>=2.0.0",
"matplotlib>=3.8.0",
"torch>=2.5.0",
"torch>=2.10.0",
"tqdm>=4.65.0",
]

Expand Down Expand Up @@ -85,4 +85,4 @@ docstring-code-format = false
[dependency-groups]
dev = [
"pytest>=8.4.1",
]
]
42 changes: 42 additions & 0 deletions rbms/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from rbms.bernoulli_bernoulli.classes import BBRBM
from rbms.bernoulli_gaussian.classes import BGRBM
from rbms.dataset import load_dataset
from rbms.dataset.utils import convert_data
from rbms.io import load_model, load_params
from rbms.ising_ising.classes import IIRBM
from rbms.map_model import map_model
from rbms.plot import plot_image, plot_mult_PCA
from rbms.potts_bernoulli.classes import PBRBM
from rbms.utils import (
bernoulli_to_ising,
compute_log_likelihood,
get_categorical_configurations,
get_eigenvalues_history,
get_flagged_updates,
get_saved_updates,
ising_to_bernoulli,
)

__all__ = [
BBRBM,
BGRBM,
IIRBM,
PBRBM,
map_model,
bernoulli_to_ising,
ising_to_bernoulli,
compute_log_likelihood,
get_eigenvalues_history,
get_saved_updates,
get_flagged_updates,
get_categorical_configurations,
plot_mult_PCA,
plot_image,
load_params,
load_model,
load_dataset,
convert_data,
]


__version__ = "0.5.1"
11 changes: 10 additions & 1 deletion rbms/bernoulli_bernoulli/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
# ruff: noqa
from rbms.bernoulli_bernoulli.classes import BBRBM
from rbms.bernoulli_bernoulli.functional import *
from rbms.bernoulli_bernoulli.functional import (
compute_energy,
compute_energy_hiddens,
compute_energy_visibles,
compute_gradient,
init_chains,
init_parameters,
sample_hiddens,
sample_visibles,
)
66 changes: 45 additions & 21 deletions rbms/bernoulli_bernoulli/classes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Self
from __future__ import annotations

import numpy as np
import torch
Expand All @@ -15,17 +15,20 @@
_sample_visibles,
)
from rbms.classes import RBM
from rbms.custom_fn import check_keys_dict


class BBRBM(RBM):
"""Parameters of the Bernoulli-Bernoulli RBM"""

visible_type: str = "bernoulli"

def __init__(
self,
weight_matrix: Tensor,
vbias: Tensor,
hbias: Tensor,
device: torch.device | None = None,
device: torch.device | str | None = None,
dtype: torch.dtype | None = None,
):
"""Initialize the parameters of the Bernoulli-Bernoulli RBM.
Expand All @@ -49,6 +52,7 @@ def __init__(
self.vbias = vbias.to(device=self.device, dtype=self.dtype)
self.hbias = hbias.to(device=self.device, dtype=self.dtype)
self.name = "BBRBM"
self.flags = []

def __add__(self, other):
return BBRBM(
Expand All @@ -64,7 +68,9 @@ def __mul__(self, other):
hbias=self.hbias * other,
)

def clone(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
def clone(
self, device: torch.device | str | None = None, dtype: torch.dtype | None = None
):
if device is None:
device = self.device
if dtype is None:
Expand Down Expand Up @@ -102,7 +108,7 @@ def compute_energy_visibles(self, v: Tensor) -> Tensor:
weight_matrix=self.weight_matrix,
)

def compute_gradient(self, data, chains, centered=True, lambda_l1=0.0, lambda_l2=0.0):
def compute_gradient(self, data, chains, centered=True):
_compute_gradient(
v_data=data["visible"],
mh_data=data["hidden_mag"],
Expand All @@ -114,8 +120,6 @@ def compute_gradient(self, data, chains, centered=True, lambda_l1=0.0, lambda_l2
hbias=self.hbias,
weight_matrix=self.weight_matrix,
centered=centered,
lambda_l1=lambda_l1,
lambda_l2=lambda_l2,
)

def independent_model(self):
Expand Down Expand Up @@ -159,25 +163,28 @@ def init_parameters(num_hiddens, dataset, device, dtype, var_init=0.0001):
)
return BBRBM(weight_matrix=weight_matrix, vbias=vbias, hbias=hbias)

def named_parameters(self):
def named_parameters(self) -> dict[str, np.ndarray]:
return {
"weight_matrix": self.weight_matrix,
"vbias": self.vbias,
"hbias": self.hbias,
"weight_matrix": self.weight_matrix.cpu().numpy(),
"vbias": self.vbias.cpu().numpy(),
"hbias": self.hbias.cpu().numpy(),
}

@property
def num_hiddens(self):
return self.hbias.shape[0]

@property
def num_visibles(self):
return self.vbias.shape[0]

def parameters(self) -> list[Tensor]:
return [self.weight_matrix, self.vbias, self.hbias]

@property
def ref_log_z(self):
return (
torch.log1p(torch.exp(self.vbias)).sum() + self.num_hiddens() * np.log(2)
torch.log1p(torch.exp(self.vbias)).sum() + self.num_hiddens * np.log(2)
).item()

def sample_hiddens(self, chains: dict[str, Tensor], beta=1) -> dict[str, Tensor]:
Expand All @@ -199,25 +206,33 @@ def sample_visibles(self, chains: dict[str, Tensor], beta=1) -> dict[str, Tensor
return chains

@staticmethod
def set_named_parameters(named_params: dict[str, Tensor]) -> Self:
def set_named_parameters(
named_params: dict[str, np.ndarray],
device: torch.device | str,
dtype: torch.dtype,
) -> BBRBM:
names = ["vbias", "hbias", "weight_matrix"]
for k in names:
if k not in named_params.keys():
raise ValueError(
f"""Dictionary params missing key '{k}'\n Provided keys : {named_params.keys()}\n Expected keys: {names}"""
)
check_keys_dict(d=named_params, names=names)
params = BBRBM(
weight_matrix=named_params.pop("weight_matrix"),
vbias=named_params.pop("vbias"),
hbias=named_params.pop("hbias"),
weight_matrix=torch.from_numpy(named_params.pop("weight_matrix")).to(
device=device, dtype=dtype
),
vbias=torch.from_numpy(named_params.pop("vbias")).to(
device=device, dtype=dtype
),
hbias=torch.from_numpy(named_params.pop("hbias")).to(
device=device, dtype=dtype
),
)
if len(named_params.keys()) > 0:
raise ValueError(
f"Too many keys in params dictionary. Remaining keys: {named_params.keys()}"
)
return params

def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
def to(
self, device: torch.device | str | None = None, dtype: torch.dtype | None = None
):
if device is not None:
self.device = device
if dtype is not None:
Expand All @@ -226,3 +241,12 @@ def to(self, device: torch.device | None = None, dtype: torch.dtype | None = Non
self.vbias = self.vbias.to(device=self.device, dtype=self.dtype)
self.hbias = self.hbias.to(device=self.device, dtype=self.dtype)
return self

def get_metrics(self, metrics):
return metrics

def post_grad_update(self):
pass

def pre_grad_update(self):
pass
4 changes: 0 additions & 4 deletions rbms/bernoulli_bernoulli/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,6 @@ def compute_gradient(
chains: dict[str, Tensor],
params: BBRBM,
centered: bool = True,
lambda_l1: float = 0.0,
lambda_l2: float = 0.0,
) -> None:
"""Compute the gradient for each of the parameters and attach it.

Expand All @@ -140,8 +138,6 @@ def compute_gradient(
hbias=params.hbias,
weight_matrix=params.weight_matrix,
centered=centered,
lambda_l1=lambda_l1,
lambda_l2=lambda_l2,
)


Expand Down
28 changes: 5 additions & 23 deletions rbms/bernoulli_bernoulli/implement.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import torch
from torch import Tensor
from torch.nn.functional import softmax


@torch.jit.script
Expand Down Expand Up @@ -59,7 +58,7 @@ def _compute_energy_hiddens(
return -field - log_term.sum(1)


@torch.jit.script
# @torch.jit.script
def _compute_gradient(
v_data: Tensor,
mh_data: Tensor,
Expand All @@ -71,13 +70,11 @@ def _compute_gradient(
hbias: Tensor,
weight_matrix: Tensor,
centered: bool = True,
lambda_l1: float = 0.0,
lambda_l2: float = 0.0,
) -> None:
w_data = w_data.view(-1, 1)
w_chain = w_chain.view(-1, 1)
# Turn the weights of the chains into normalized weights
chain_weights = softmax(-w_chain, dim=0)
chain_weights = w_chain / w_chain.sum()
w_data_norm = w_data.sum()

# Averages over data and generated samples
Expand All @@ -102,33 +99,18 @@ def _compute_gradient(
grad_vbias = v_data_mean - v_gen_mean - (grad_weight_matrix @ h_data_mean)
grad_hbias = h_data_mean - h_gen_mean - (v_data_mean @ grad_weight_matrix)
else:
v_data_centered = v_data
h_data_centered = mh_data
v_gen_centered = v_chain
h_gen_centered = h_chain

# Gradient
grad_weight_matrix = ((v_data * w_data).T @ mh_data) / w_data_norm - (
(v_chain * chain_weights).T @ h_chain
)
grad_vbias = v_data_mean - v_gen_mean
grad_hbias = h_data_mean - h_gen_mean

if lambda_l1 > 0:
grad_weight_matrix -= lambda_l1 * torch.sign(weight_matrix)
grad_vbias -= lambda_l1 * torch.sign(vbias)
grad_hbias -= lambda_l1 * torch.sign(hbias)

if lambda_l2 > 0:
grad_weight_matrix -= 2 * lambda_l2 * weight_matrix
grad_vbias -= 2 * lambda_l2 * vbias
grad_hbias -= 2 * lambda_l2 * hbias

# Attach to the parameters

weight_matrix.grad.set_(grad_weight_matrix)
vbias.grad.set_(grad_vbias)
hbias.grad.set_(grad_hbias)
weight_matrix.grad = grad_weight_matrix
vbias.grad = grad_vbias
hbias.grad = grad_hbias


@torch.jit.script
Expand Down
12 changes: 12 additions & 0 deletions rbms/bernoulli_gaussian/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# ruff: noqa
from rbms.bernoulli_gaussian.classes import BGRBM
from rbms.bernoulli_gaussian.functional import (
compute_energy,
compute_energy_hiddens,
compute_energy_visibles,
compute_gradient,
init_chains,
init_parameters,
sample_hiddens,
sample_visibles,
)
Loading
Loading