Skip to content

Commit 8163d21

Browse files
The generic framework for solving approximate inference problem with BP (#17)
* Add PyTorch BP docs, examples, and tests * Ignore Python cache files * Add extra BP tests and update README * Apply BP fixes and update tests * Add tests for BP edge cases and UAI errors
1 parent 7acf58e commit 8163d21

20 files changed

Lines changed: 1164 additions & 1 deletion

README.md

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,42 @@ BPDecoderPlus/
186186
└── belief_propagation_qec_plan.tex
187187
```
188188

189+
## PyTorch BP Module (UAI)
190+
191+
This repository also includes a PyTorch implementation of belief propagation for
192+
UAI factor graphs under `src/bpdecoderplus/pytorch_bp`.
193+
194+
### Python Setup
195+
196+
```bash
197+
pip install -e .
198+
```
199+
200+
### Quick Example
201+
202+
```python
203+
from bpdecoderplus.pytorch_bp import (
204+
read_model_file,
205+
BeliefPropagation,
206+
belief_propagate,
207+
compute_marginals,
208+
)
209+
210+
model = read_model_file("examples/simple_model.uai")
211+
bp = BeliefPropagation(model)
212+
state, info = belief_propagate(bp)
213+
print(info)
214+
print(compute_marginals(state, bp))
215+
```
216+
217+
### Examples and Tests
218+
219+
```bash
220+
python examples/simple_example.py
221+
python examples/evidence_example.py
222+
pytest tests/test_bp_basic.py tests/test_uai_parser.py tests/test_integration.py tests/testcase.py
223+
```
224+
189225
## Available Decoders
190226

191227
| Decoder | Symbol | Description |

docs/api_reference.md

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
## PyTorch BP API Reference
2+
3+
This reference documents the public API exported from `bpdecoderplus.pytorch_bp`.
4+
5+
### UAI Parsing
6+
7+
- `read_model_file(path, factor_eltype=torch.float64) -> UAIModel`
8+
Parse a UAI `.uai` model file.
9+
10+
- `read_model_from_string(content, factor_eltype=torch.float64) -> UAIModel`
11+
Parse a UAI model from an in-memory string.
12+
13+
- `read_evidence_file(path) -> Dict[int, int]`
14+
Parse a UAI `.evid` file and return evidence as 1-based indices.
15+
16+
### Data Structures
17+
18+
- `Factor(vars: List[int], values: torch.Tensor)`
19+
Container for a factor scope and its tensor.
20+
21+
- `UAIModel(nvars: int, cards: List[int], factors: List[Factor])`
22+
Holds all model metadata for BP.
23+
24+
### Belief Propagation
25+
26+
- `BeliefPropagation(uai_model: UAIModel)`
27+
Builds factor graph adjacency for BP.
28+
29+
- `initial_state(bp: BeliefPropagation) -> BPState`
30+
Initialize messages to uniform vectors.
31+
32+
- `collect_message(bp, state, normalize=True)`
33+
Update factor-to-variable messages in place.
34+
35+
- `process_message(bp, state, normalize=True, damping=0.2)`
36+
Update variable-to-factor messages in place.
37+
38+
- `belief_propagate(bp, max_iter=100, tol=1e-6, damping=0.2, normalize=True)`
39+
Run the full BP loop and return `(BPState, BPInfo)`.
40+
41+
- `compute_marginals(state, bp) -> Dict[int, torch.Tensor]`
42+
Compute marginal distributions after convergence.
43+
44+
- `apply_evidence(bp, evidence: Dict[int, int]) -> BeliefPropagation`
45+
Return a new BP object with evidence applied to factor tensors.

docs/mathematical_description.md

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
## Belief Propagation (BP) Overview
2+
3+
This document summarizes the BP message-passing rules implemented in
4+
`src/bpdecoderplus/pytorch_bp/belief_propagation.py` for discrete factor graphs. The approach
5+
mirrors the tensor-contraction perspective used in TensorInference.jl.
6+
See https://github.com/TensorBFS/TensorInference.jl for the Julia reference.
7+
8+
### Factor Graph Notation
9+
10+
- Variables are indexed by x_i with domain size d_i.
11+
- Factors are indexed by f and connect a subset of variables.
12+
- Each factor has a tensor (potential) phi_f defined over its variables.
13+
14+
### Messages
15+
16+
Factor to variable message:
17+
18+
mu_{f->x}(x) = sum_{all y in ne(f), y != x} phi_f(x, y, ...) * product_{y != x} mu_{y->f}(y)
19+
20+
Variable to factor message:
21+
22+
mu_{x->f}(x) = product_{g in ne(x), g != f} mu_{g->x}(x)
23+
24+
### Damping
25+
26+
To improve stability on loopy graphs, a damping update is applied:
27+
28+
mu_new = damping * mu_old + (1 - damping) * mu_candidate
29+
30+
### Convergence
31+
32+
We use an L1 difference threshold between consecutive factor->variable
33+
messages to determine convergence.
34+
35+
### Marginals
36+
37+
After convergence, variable marginals are computed as:
38+
39+
b(x) = (1 / Z) * product_{f in ne(x)} mu_{f->x}(x)
40+
41+
The normalization constant Z is obtained by summing the unnormalized vector.

docs/usage_guide.md

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
## PyTorch Belief Propagation Usage
2+
3+
This guide shows how to parse a UAI file, run BP, and apply evidence.
4+
The implementation follows the tensor-contraction viewpoint in
5+
TensorInference.jl: https://github.com/TensorBFS/TensorInference.jl
6+
7+
### Quick Start
8+
9+
```python
10+
from bpdecoderplus.pytorch_bp import (
11+
read_model_file,
12+
BeliefPropagation,
13+
belief_propagate,
14+
compute_marginals,
15+
)
16+
17+
model = read_model_file("examples/simple_model.uai")
18+
bp = BeliefPropagation(model)
19+
state, info = belief_propagate(bp, max_iter=50, tol=1e-8, damping=0.1)
20+
print(info)
21+
22+
marginals = compute_marginals(state, bp)
23+
print(marginals[1])
24+
```
25+
26+
### Evidence
27+
28+
```python
29+
from bpdecoderplus.pytorch_bp import read_model_file, read_evidence_file, apply_evidence
30+
from bpdecoderplus.pytorch_bp import BeliefPropagation, belief_propagate, compute_marginals
31+
32+
model = read_model_file("examples/simple_model.uai")
33+
evidence = read_evidence_file("examples/simple_model.evid")
34+
bp = apply_evidence(BeliefPropagation(model), evidence)
35+
state, info = belief_propagate(bp)
36+
marginals = compute_marginals(state, bp)
37+
```
38+
39+
### Tips
40+
41+
- For loopy graphs, use damping between 0.1 and 0.5.
42+
- Normalize messages to avoid numerical underflow.
43+
- Use float64 for consistent comparisons in tests.

examples/evidence_example.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from bpdecoderplus.pytorch_bp import (
2+
read_model_file,
3+
read_evidence_file,
4+
BeliefPropagation,
5+
belief_propagate,
6+
compute_marginals,
7+
apply_evidence,
8+
)
9+
10+
11+
def main():
12+
model = read_model_file("examples/simple_model.uai")
13+
evidence = read_evidence_file("examples/simple_model.evid")
14+
bp = apply_evidence(BeliefPropagation(model), evidence)
15+
state, info = belief_propagate(bp, max_iter=50, tol=1e-8, damping=0.1)
16+
print(info)
17+
18+
marginals = compute_marginals(state, bp)
19+
for var_idx, marginal in marginals.items():
20+
print(f"Variable {var_idx} marginal: {marginal}")
21+
22+
23+
if __name__ == "__main__":
24+
main()

examples/simple_example.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from bpdecoderplus.pytorch_bp import (
2+
read_model_file,
3+
BeliefPropagation,
4+
belief_propagate,
5+
compute_marginals,
6+
)
7+
8+
9+
def main():
10+
model = read_model_file("examples/simple_model.uai")
11+
bp = BeliefPropagation(model)
12+
state, info = belief_propagate(bp, max_iter=50, tol=1e-8, damping=0.1)
13+
print(info)
14+
15+
marginals = compute_marginals(state, bp)
16+
for var_idx, marginal in marginals.items():
17+
print(f"Variable {var_idx} marginal: {marginal}")
18+
19+
20+
if __name__ == "__main__":
21+
main()

examples/simple_model.evid

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
1 0 1

examples/simple_model.uai

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
MARKOV
2+
2
3+
2 2
4+
2
5+
1 0
6+
2 0 1
7+
2
8+
0.6 0.4
9+
4
10+
0.9 0.1 0.2 0.8

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ classifiers = [
2323
dependencies = [
2424
"stim>=1.12.0",
2525
"numpy>=1.24.0",
26+
"torch>=2.0.0",
2627
]
2728

2829
[project.optional-dependencies]

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
torch

0 commit comments

Comments
 (0)