Skip to content

Commit d3de0a0

Browse files
Use omeco ordering in tropical MPE.
Update docs/imports and improve validation coverage.
1 parent ba30f97 commit d3de0a0

File tree

15 files changed

+185
-94
lines changed

15 files changed

+185
-94
lines changed

tropical_in_new/README.md

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
# Tropical Tensor Network for MPE
22

33
This folder contains an independent implementation of tropical tensor network
4-
contraction for Most Probable Explanation (MPE). It does not depend on the
5-
`bpdecoderplus` package; all code lives under `tropical_in_new/src`.
4+
contraction for Most Probable Explanation (MPE). It uses `omeco` for contraction
5+
order optimization and does not depend on the `bpdecoderplus` package; all code
6+
lives under `tropical_in_new/src`.
7+
8+
`omeco` provides high-quality contraction order heuristics (greedy and
9+
simulated annealing). Install it alongside Torch to run the examples and tests.
610

711
## Structure
812

@@ -20,7 +24,8 @@ tropical_in_new/
2024
├── tests/
2125
│ ├── test_primitives.py
2226
│ ├── test_contraction.py
23-
│ └── test_mpe.py
27+
│ ├── test_mpe.py
28+
│ └── test_utils.py
2429
├── examples/
2530
│ └── asia_network/
2631
│ ├── main.py
@@ -34,5 +39,13 @@ tropical_in_new/
3439
## Quick Start
3540

3641
```bash
42+
pip install -r tropical_in_new/requirements.txt
3743
python tropical_in_new/examples/asia_network/main.py
3844
```
45+
46+
## Notes on omeco
47+
48+
`omeco` is a Rust-backed Python package. If a prebuilt wheel is not available
49+
for your Python version, you will need a Rust toolchain with `cargo` on PATH to
50+
build it from source. See the omeco repository for details:
51+
https://github.com/GiggleLiu/omeco

tropical_in_new/docs/api_reference.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ Public APIs exported from `tropical_in_new/src`.
3434
- `build_network(factors: Iterable[Factor]) -> list[TensorNode]`
3535
Convert factors into log-domain tensors with scopes.
3636

37-
- `choose_order(nodes, heuristic="min_fill") -> list[int]`
38-
Select variable elimination order (min-fill / min-degree).
37+
- `choose_order(nodes, heuristic="omeco") -> list[int]`
38+
Select variable elimination order using `omeco`.
3939

4040
- `build_contraction_tree(order, nodes) -> ContractionTree`
4141
Construct a contraction plan from order and nodes.

tropical_in_new/docs/usage_guide.md

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,19 @@
33
This guide shows how to parse a UAI model and compute MPE using the
44
tropical tensor network implementation.
55

6+
### Install Dependencies
7+
8+
```bash
9+
pip install -r tropical_in_new/requirements.txt
10+
```
11+
12+
`omeco` provides contraction order optimization. If a prebuilt wheel is not
13+
available for your Python version, you may need a Rust toolchain installed.
14+
615
### Quick Start
716

817
```python
9-
from src import mpe_tropical, read_model_file
18+
from tropical_in_new.src import mpe_tropical, read_model_file
1019

1120
model = read_model_file("tropical_in_new/examples/asia_network/model.uai")
1221
assignment, score, info = mpe_tropical(model)
@@ -16,7 +25,7 @@ print(assignment, score, info)
1625
### Evidence
1726

1827
```python
19-
from src import mpe_tropical, read_model_file
28+
from tropical_in_new.src import mpe_tropical, read_model_file
2029

2130
model = read_model_file("tropical_in_new/examples/asia_network/model.uai")
2231
evidence = {1: 0} # variable index is 1-based

tropical_in_new/examples/asia_network/main.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,6 @@
11
"""Run tropical MPE on a small UAI model."""
22

3-
from pathlib import Path
4-
import sys
5-
6-
ROOT = Path(__file__).resolve().parents[2]
7-
sys.path.insert(0, str(ROOT))
8-
9-
from src import mpe_tropical, read_model_file # noqa: E402
3+
from tropical_in_new.src import mpe_tropical, read_model_file
104

115

126
def main() -> None:

tropical_in_new/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
torch>=2.0.0
2-
numpy>=1.24.0
2+
omeco

tropical_in_new/src/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,14 @@
44
from .mpe import mpe_tropical, recover_mpe_assignment
55
from .network import TensorNode, build_network
66
from .primitives import argmax_trace, safe_log, tropical_einsum
7-
from .utils import Factor, UAIModel, build_tropical_factors, read_model_file, read_model_from_string
7+
from .utils import (
8+
Factor,
9+
UAIModel,
10+
build_tropical_factors,
11+
read_evidence_file,
12+
read_model_file,
13+
read_model_from_string,
14+
)
815

916
__all__ = [
1017
"Factor",
@@ -17,6 +24,7 @@
1724
"choose_order",
1825
"contract_tree",
1926
"mpe_tropical",
27+
"read_evidence_file",
2028
"read_model_file",
2129
"read_model_from_string",
2230
"recover_mpe_assignment",

tropical_in_new/src/contraction.py

Lines changed: 79 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,7 @@
77

88
import torch
99

10-
try: # Optional heuristic provider
11-
import omeco # type: ignore
12-
except Exception: # pragma: no cover - optional dependency
13-
omeco = None
10+
import omeco
1411

1512
from .network import TensorNode
1613
from .primitives import Backpointer, tropical_reduce_max
@@ -45,59 +42,86 @@ class ContractionTree:
4542
nodes: Tuple[TensorNode, ...]
4643

4744

48-
def _build_var_graph(nodes: Iterable[TensorNode]) -> dict[int, set[int]]:
49-
graph: dict[int, set[int]] = {}
45+
def _infer_var_sizes(nodes: Iterable[TensorNode]) -> dict[int, int]:
46+
sizes: dict[int, int] = {}
5047
for node in nodes:
51-
vars = list(node.vars)
48+
for var, dim in zip(node.vars, node.values.shape):
49+
if var in sizes and sizes[var] != dim:
50+
raise ValueError(
51+
f"Variable {var} has inconsistent sizes: {sizes[var]} vs {dim}."
52+
)
53+
sizes[var] = int(dim)
54+
return sizes
55+
56+
57+
def _extract_leaf_index(node_dict: dict) -> int | None:
58+
for key in ("leaf", "leaf_index", "index", "tensor"):
59+
if key in node_dict:
60+
value = node_dict[key]
61+
if isinstance(value, int):
62+
return value
63+
return None
64+
65+
66+
def _elim_order_from_tree_dict(tree_dict: dict, ixs: list[list[int]]) -> list[int]:
67+
total_counts: dict[int, int] = {}
68+
for vars in ixs:
5269
for var in vars:
53-
graph.setdefault(var, set()).update(v for v in vars if v != var)
54-
return graph
55-
56-
57-
def _min_fill_order(graph: dict[int, set[int]]) -> list[int]:
58-
order: list[int] = []
59-
graph = {k: set(v) for k, v in graph.items()}
60-
while graph:
61-
best_var = None
62-
best_fill = None
63-
best_degree = None
64-
for var, neighbors in graph.items():
65-
fill = 0
66-
neighbor_list = list(neighbors)
67-
for i in range(len(neighbor_list)):
68-
for j in range(i + 1, len(neighbor_list)):
69-
if neighbor_list[j] not in graph[neighbor_list[i]]:
70-
fill += 1
71-
degree = len(neighbors)
72-
if best_fill is None or (fill, degree) < (best_fill, best_degree):
73-
best_var = var
74-
best_fill = fill
75-
best_degree = degree
76-
if best_var is None:
77-
break
78-
neighbors = list(graph[best_var])
79-
for i in range(len(neighbors)):
80-
for j in range(i + 1, len(neighbors)):
81-
graph[neighbors[i]].add(neighbors[j])
82-
graph[neighbors[j]].add(neighbors[i])
83-
for neighbor in neighbors:
84-
graph[neighbor].discard(best_var)
85-
graph.pop(best_var, None)
86-
order.append(best_var)
87-
return order
88-
89-
90-
def choose_order(nodes: list[TensorNode], heuristic: str = "min_fill") -> list[int]:
91-
"""Select elimination order over variable indices."""
92-
if heuristic == "omeco" and omeco is not None:
93-
if hasattr(omeco, "min_fill_order"):
94-
return list(omeco.min_fill_order([node.vars for node in nodes]))
95-
graph = _build_var_graph(nodes)
96-
if heuristic in ("min_fill", "omeco"):
97-
return _min_fill_order(graph)
98-
if heuristic == "min_degree":
99-
return sorted(graph, key=lambda v: len(graph[v]))
100-
raise ValueError(f"Unknown heuristic: {heuristic!r}")
70+
total_counts[var] = total_counts.get(var, 0) + 1
71+
72+
eliminated: set[int] = set()
73+
74+
def visit(node: dict) -> tuple[dict[int, int], list[int]]:
75+
leaf_index = _extract_leaf_index(node)
76+
if leaf_index is not None:
77+
counts: dict[int, int] = {}
78+
for var in ixs[leaf_index]:
79+
counts[var] = counts.get(var, 0) + 1
80+
return counts, []
81+
82+
children = node.get("children", [])
83+
if not isinstance(children, list) or not children:
84+
return {}, []
85+
86+
counts: dict[int, int] = {}
87+
order: list[int] = []
88+
for child in children:
89+
child_counts, child_order = visit(child)
90+
order.extend(child_order)
91+
for var, count in child_counts.items():
92+
counts[var] = counts.get(var, 0) + count
93+
94+
newly_eliminated = [
95+
var
96+
for var, count in counts.items()
97+
if count == total_counts.get(var, 0) and var not in eliminated
98+
]
99+
for var in sorted(newly_eliminated):
100+
eliminated.add(var)
101+
order.append(var)
102+
return counts, order
103+
104+
_, order = visit(tree_dict)
105+
remaining = sorted([var for var in total_counts if var not in eliminated])
106+
return order + remaining
107+
108+
109+
def choose_order(nodes: list[TensorNode], heuristic: str = "omeco") -> list[int]:
110+
"""Select elimination order over variable indices using omeco."""
111+
if heuristic != "omeco":
112+
raise ValueError("Only the 'omeco' heuristic is supported.")
113+
ixs = [list(node.vars) for node in nodes]
114+
sizes = _infer_var_sizes(nodes)
115+
method = omeco.GreedyMethod() if hasattr(omeco, "GreedyMethod") else None
116+
tree = (
117+
omeco.optimize_code(ixs, [], sizes, method)
118+
if method is not None
119+
else omeco.optimize_code(ixs, [], sizes)
120+
)
121+
tree_dict = tree.to_dict() if hasattr(tree, "to_dict") else tree
122+
if not isinstance(tree_dict, dict):
123+
raise ValueError("omeco.optimize_code did not return a usable tree.")
124+
return _elim_order_from_tree_dict(tree_dict, ixs)
101125

102126

103127
def build_contraction_tree(order: Iterable[int], nodes: list[TensorNode]) -> ContractionTree:

tropical_in_new/src/mpe.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,14 @@ def recover_mpe_assignment(root) -> Dict[int, int]:
2929
"""Recover MPE assignment from a contraction tree with backpointers."""
3030
assignment: Dict[int, int] = {}
3131

32+
def require_vars(required: Iterable[int], available: Dict[int, int]) -> None:
33+
missing = [v for v in required if v not in available]
34+
if missing:
35+
raise KeyError(
36+
"Missing assignment values for variables: "
37+
f"{missing}. Provided assignment keys: {sorted(available.keys())}"
38+
)
39+
3240
def traverse(node, out_assignment: Dict[int, int]) -> None:
3341
assignment.update(out_assignment)
3442
if isinstance(node, TensorNode):
@@ -38,6 +46,7 @@ def traverse(node, out_assignment: Dict[int, int]) -> None:
3846
argmax_trace(node.backpointer, out_assignment) if node.backpointer else {}
3947
)
4048
combined = {**out_assignment, **elim_assignment}
49+
require_vars(node.child.vars, combined)
4150
child_assignment = {v: combined[v] for v in node.child.vars}
4251
traverse(node.child, child_assignment)
4352
return
@@ -46,7 +55,9 @@ def traverse(node, out_assignment: Dict[int, int]) -> None:
4655
argmax_trace(node.backpointer, out_assignment) if node.backpointer else {}
4756
)
4857
combined = {**out_assignment, **elim_assignment}
58+
require_vars(node.left.vars, combined)
4959
left_assignment = {v: combined[v] for v in node.left.vars}
60+
require_vars(node.right.vars, combined)
5061
right_assignment = {v: combined[v] for v in node.right.vars}
5162
traverse(node.left, left_assignment)
5263
traverse(node.right, right_assignment)
@@ -66,7 +77,7 @@ def mpe_tropical(
6677
factors = build_tropical_factors(model, evidence)
6778
nodes = build_network(factors)
6879
if order is None:
69-
order = choose_order(nodes, heuristic="min_fill")
80+
order = choose_order(nodes, heuristic="omeco")
7081
tree = build_contraction_tree(order, nodes)
7182
root = _contract_tree(tree, einsum_fn=tropical_einsum)
7283
if root.vars:

tropical_in_new/src/primitives.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -33,19 +33,6 @@ class Backpointer:
3333
argmax_flat: torch.Tensor
3434

3535

36-
@dataclass(frozen=True)
37-
class TropicalTensor:
38-
"""Lightweight wrapper for tropical (max-plus) tensors."""
39-
40-
vars: Tuple[int, ...]
41-
values: torch.Tensor
42-
43-
def __add__(self, other: "TropicalTensor") -> "TropicalTensor":
44-
if self.vars != other.vars:
45-
raise ValueError("TropicalTensor.__add__ requires identical variable order.")
46-
return TropicalTensor(self.vars, torch.maximum(self.values, other.values))
47-
48-
4936
def safe_log(tensor: torch.Tensor) -> torch.Tensor:
5037
"""Convert potentials to log domain; zeros map to -inf."""
5138
neg_inf = torch.tensor(float("-inf"), dtype=tensor.dtype, device=tensor.device)
@@ -83,6 +70,12 @@ def tropical_reduce_max(
8370
if not elim_vars:
8471
return tensor, None
8572
target_vars = tuple(vars)
73+
missing_elim_vars = [v for v in elim_vars if v not in target_vars]
74+
if missing_elim_vars:
75+
raise ValueError(
76+
"tropical_reduce_max: elim_vars "
77+
f"{missing_elim_vars} are not present in vars {target_vars}."
78+
)
8679
elim_axes = [target_vars.index(v) for v in elim_vars]
8780
keep_axes = [i for i in range(len(target_vars)) if i not in elim_axes]
8881
perm = keep_axes + elim_axes
@@ -138,6 +131,12 @@ def argmax_trace(backpointer: Backpointer, assignment: Dict[int, int]) -> Dict[i
138131
if not backpointer.elim_vars:
139132
return {}
140133
if backpointer.out_vars:
134+
missing = [v for v in backpointer.out_vars if v not in assignment]
135+
if missing:
136+
raise KeyError(
137+
"Missing assignment values for output variables: "
138+
f"{missing}. Provided assignment keys: {sorted(assignment.keys())}"
139+
)
141140
idx = tuple(assignment[v] for v in backpointer.out_vars)
142141
flat = int(backpointer.argmax_flat[idx].item())
143142
else:

0 commit comments

Comments
 (0)