Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
a8f2751
rename ParameterNode to DataNode
mscroggs Jun 5, 2026
df6bf52
add evaluate method
mscroggs Jun 5, 2026
1d0738d
mypy?
mscroggs Jun 5, 2026
cbb10ea
ruff
mscroggs Jun 5, 2026
a956b83
typeerror
mscroggs Jun 5, 2026
b7b3332
rufff
mscroggs Jun 5, 2026
abbff6b
simplify some graph code
mscroggs Jun 24, 2026
43149c5
Update src/causalprog/graph/node/base.py
mscroggs Jun 25, 2026
6e68264
Update src/causalprog/graph/node/base.py
mscroggs Jun 25, 2026
8f64aa4
Update tests/test_algorithms/test_evaluate.py
mscroggs Jun 25, 2026
752005a
Apply suggestion from @mscroggs
mscroggs Jun 25, 2026
7f1db29
fix tests
mscroggs Jun 25, 2026
5ed79b5
ruff
mscroggs Jun 25, 2026
bb1b0f3
hasattr(..., "shape") instead of type check
mscroggs Jun 25, 2026
2c23e01
let KeyError happen rather than asserting types
mscroggs Jun 25, 2026
2815ff8
mypy?
mscroggs Jun 25, 2026
42a2844
currect type
mscroggs Jun 25, 2026
2e1eb1b
type: ignore
mscroggs Jun 25, 2026
dae7e6b
PGH003
mscroggs Jun 25, 2026
16c04df
Merge branch 'mscroggs/data-node' into mscroggs/rvnodes
mscroggs Jun 25, 2026
5717c16
typing
mscroggs Jun 25, 2026
209b86e
corrections
mscroggs Jun 25, 2026
06d0f69
add check that labels are valid variable names
mscroggs Jun 25, 2026
7700ea5
ruff
mscroggs Jun 25, 2026
29f7679
update tests and labels
mscroggs Jun 25, 2026
ac099fe
Add random variable nodes
mscroggs Jun 25, 2026
30c0216
ruff
mscroggs Jun 25, 2026
13ad547
ruff
mscroggs Jun 25, 2026
76f7fb1
add evaluate algorithm
mscroggs Jun 26, 2026
ad4ccb7
start adding example graph
mscroggs Jun 26, 2026
dc7b643
add ricardo's graph (!)
mscroggs Jun 26, 2026
ccb5f0d
add test
mscroggs Jun 26, 2026
01017bc
ruff
mscroggs Jun 26, 2026
f19ae94
Merge branch 'main' into mscroggs/evaluate
mscroggs Jun 26, 2026
9fc5383
values not keys for parents, update some tests
mscroggs Jun 26, 2026
c7e5e88
sort toml
mscroggs Jun 26, 2026
63d0b7c
ruff
mscroggs Jun 26, 2026
83a1b51
Update src/causalprog/algorithms/evaluate.py
mscroggs Jun 29, 2026
6b4ec2e
correct
mscroggs Jun 29, 2026
9d65fcc
move graph to graph.special and rearrange tests
mscroggs Jun 29, 2026
b639584
add new file
mscroggs Jun 29, 2026
b06b6f5
add evaluate_down_to variant
mscroggs Jun 29, 2026
89fd2b2
Update src/causalprog/algorithms/__init__.py
mscroggs Jun 29, 2026
18432f7
RUFF
mscroggs Jun 29, 2026
6b7ba9a
Merge branch 'mscroggs/evaluate' of github.com:UCL/causalprog into ms…
mscroggs Jun 29, 2026
a512b2d
Parametrize evaluate tests
mscroggs Jun 29, 2026
1b92a35
don't copy big arrays
mscroggs Jun 29, 2026
20da685
mypy
mscroggs Jun 29, 2026
c1ed858
get doesn't work here as it evaluates arguments BEFORE checking if ke…
mscroggs Jun 29, 2026
c0c54f8
Add test of evaluate_down_to
mscroggs Jun 29, 2026
5cf59ec
RUFF
mscroggs Jun 29, 2026
5f4a080
type
mscroggs Jun 29, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,10 @@ lint.per-file-ignores = {"__init__.py" = [
"ANN",
"D",
"INP001", # File is part of an implicit namespace package.
"N803", # Argument name should be lowercase
"PLR0913", # Too many arguments in function definition
"PLR2004", # Magic value used in comparison
"PT028", # Test function parameters have default arguments
"S101", # Use of `assert` detected
]}
lint.select = ["ALL"]
Expand Down
10 changes: 10 additions & 0 deletions src/causalprog/algorithms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,14 @@
"""Algorithms."""

from .do import do
from .evaluate import evaluate, evaluate_down_to
from .moments import expectation, moment, standard_deviation
Comment thread
mscroggs marked this conversation as resolved.

__all__ = (
"do",
"evaluate",
"evaluate_down_to",
"expectation",
"moment",
"standard_deviation",
)
3 changes: 1 addition & 2 deletions src/causalprog/algorithms/do.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,8 @@ def do(graph: Graph, node: str, value: float, label: str | None = None) -> Graph
)
raise ValueError(msg)

nodes[node] = ConstantNode(label=node, value=value)

g = Graph(label=f"{label}_do_{node}__" + f"{value}".replace(".", "_"))
g.add_node(ConstantNode(label=node, value=value))
for n in nodes.values():
g.add_node(n)

Expand Down
51 changes: 51 additions & 0 deletions src/causalprog/algorithms/evaluate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""Algorithms for evaluating a graph node."""

import numpy.typing as npt

from causalprog.graph import Graph


def evaluate_down_to(
graph: Graph, outcome_node_label: str, **values: float | npt.NDArray[float]
) -> dict[str, float | npt.NDArray[float]]:
"""
Evaluate all nodes down to a particular node.

Args:
graph: The graph that the node is contained in.
outcome_node_label: The label of the node to evaluate down to.
values: Values taken by nodes whose value is given

Returns:
A dictionary of the values of all the nodes that are ancestors of the input node

"""
computed_values: dict[str, float | npt.NDArray[float]] = {}
nodes_to_evaluate = [
n
for n in graph.roots_down_to_outcome(outcome_node_label)
if n.label not in values
]
for node in nodes_to_evaluate:
computed_values[node.label] = node.evaluate(**values, **computed_values)
return computed_values


def evaluate(
graph: Graph, outcome_node_label: str, **values: float | npt.NDArray[float]
Comment thread
mscroggs marked this conversation as resolved.
) -> float | npt.NDArray[float]:
"""
Evaluate a node.

Args:
graph: The graph that the node is contained in.
outcome_node_label: The label of the node to evaluate.
values: Values taken by nodes whose value is given

Returns:
The evaluation of the node

"""
if outcome_node_label in values:
return values[outcome_node_label]
return evaluate_down_to(graph, outcome_node_label, **values)[outcome_node_label]
4 changes: 4 additions & 0 deletions src/causalprog/causal_problem/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,7 @@ def __eq__(self, other: object) -> bool:
and self.handler is other.handler
and self.options == other.options
)

def __hash__(self) -> int:
"""Hash."""
return hash((self.handler, self.options))
13 changes: 5 additions & 8 deletions src/causalprog/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy.typing as npt

from causalprog._abc.labelled import Labelled
from causalprog.graph.node import ComponentNode, DistributionNode, Node
from causalprog.graph.node import DataNode, DistributionNode, Node


class Graph(Labelled):
Expand Down Expand Up @@ -60,11 +60,8 @@ def add_node(self, node: Node) -> None:
raise ValueError(msg)
self._nodes_by_label[node.label] = node
self._graph.add_node(node)
if isinstance(node, ComponentNode):
if len(node.parents) != 1:
msg = "ComponentNode should have exactly one parent."
raise ValueError(msg)
self.add_edge(node.parents[0], node.label)
for p in node.parents:
self.add_edge(p, node.label)

def add_edge(self, start_node: Node | str, end_node: Node | str) -> None:
"""
Expand Down Expand Up @@ -241,8 +238,8 @@ def model(self, **parameter_values: npt.ArrayLike) -> dict[str, npt.ArrayLike]:

"""
# Confirm that all `DataNode`s have been assigned a value.
for node in self.root_nodes:
if node.label not in parameter_values:
for node in self.nodes:
if isinstance(node, DataNode) and node.label not in parameter_values:
msg = f"DataNode '{node.label}' not assigned"
raise KeyError(msg)

Expand Down
4 changes: 2 additions & 2 deletions src/causalprog/graph/node/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ def __init__(

def __getitem__(self, indices: int | slice | tuple[int | slice, ...]) -> Node:
"""Get a component of this node."""
from causalprog.graph import ComponentNode

if isinstance(indices, int | slice):
indices = (indices,)
if not isinstance(indices, tuple):
Expand All @@ -79,8 +81,6 @@ def __getitem__(self, indices: int | slice | tuple[int | slice, ...]) -> Node:
e = "list index out of range"
raise IndexError(e)

from causalprog.graph import ComponentNode

shape: tuple[int, ...] = ()
for i, s in zip(indices, self._shape, strict=False):
if isinstance(i, slice):
Expand Down
2 changes: 1 addition & 1 deletion src/causalprog/graph/node/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def __repr__(self) -> str:
@override
@property
def parents(self) -> list[str]:
return [*self._parameters.keys(), *self._constant_parameters.keys()]
return list(self._parameters.values())

def create_model_site(self, **dependent_nodes: jax.Array) -> npt.ArrayLike:
"""
Expand Down
80 changes: 80 additions & 0 deletions src/causalprog/graph/special.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
"""Functions to create example graphs."""

import inspect
from collections.abc import Callable

from causalprog.graph import (
ContinuousRandomVariableNode,
DataNode,
DiscreteRandomVariableNode,
Graph,
)


def example_model(
Comment thread
mscroggs marked this conversation as resolved.
*,
label: str = "G",
l_len: int = 1,
z_len: int = 1,
k: int = 10,
compute_u_x: Callable,
compute_u_y: Callable,
compute_phi_x: Callable,
compute_x: Callable,
compute_y: Callable,
Comment thread
willGraham01 marked this conversation as resolved.
) -> Graph:
"""
Create a graph representing the example model.

Args:
label: The label of the graph.
l_len: The number of entries in the vector data node L.
z_len: The number of entries in the vector data node Z.
k: The maximum value that could be taken by the mixture indicator C.
compute_u_x: Compute UX given the value of C.
compute_u_y: Compute UY given the value of C.
compute_phi_x: Compute PhiX given the value of L.
compute_x: Compute X given the values of Z, PhiX and UX.
compute_y: Compute Y given the values of X and UY.

Returns:
A graph

"""
for p in inspect.signature(compute_u_x).parameters:
if p != "C":
msg = f"Invalid input to UX: {p}"
raise ValueError(msg)
for p in inspect.signature(compute_u_y).parameters:
if p != "C":
msg = f"Invalid input to UX: {p}"
raise ValueError(msg)
for p in inspect.signature(compute_phi_x).parameters:
if p != "L":
msg = f"Invalid input to PhiX: {p}"
raise ValueError(msg)
for p in inspect.signature(compute_x).parameters:
if p not in ["Z", "PhiX", "UX"]:
msg = f"Invalid input to X: {p}"
raise ValueError(msg)
for p in inspect.signature(compute_y).parameters:
if p not in ["X", "UY"]:
msg = f"Invalid input to X: {p}"
raise ValueError(msg)
Comment thread
mscroggs marked this conversation as resolved.

graph = Graph(label=label)

graph.add_node(DataNode(label="L", shape=(l_len,)))
graph.add_node(DataNode(label="Z", shape=(z_len,)))
graph.add_node(
DiscreteRandomVariableNode(
label="C", values=[float(i) for i in range(1, k + 1)]
)
)
graph.add_node(ContinuousRandomVariableNode(label="UX", compute=compute_u_x))
graph.add_node(ContinuousRandomVariableNode(label="UY", compute=compute_u_y))
graph.add_node(ContinuousRandomVariableNode(label="PhiX", compute=compute_phi_x))
graph.add_node(ContinuousRandomVariableNode(label="X", compute=compute_x))
graph.add_node(ContinuousRandomVariableNode(label="Y", compute=compute_y))

return graph
6 changes: 3 additions & 3 deletions tests/test_algorithms/test_do.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ def test_do(two_normal_graph):
graph = two_normal_graph(5.0, 1.2, 0.8)
graph2 = algorithms.do(graph, "UX", 4.0)

assert "loc" in graph.get_node("X").parents
assert "loc" in graph2.get_node("X").parents
assert "UX" in graph.get_node("X").parents
assert "UX" in graph2.get_node("X").parents

graph.get_node("UX")
assert not isinstance(graph.get_node("UX"), ConstantNode)
assert isinstance(graph2.get_node("UX"), ConstantNode)


Expand Down
117 changes: 89 additions & 28 deletions tests/test_algorithms/test_evaluate.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,109 @@
"""Tests for evaluate algorithms."""

import numpy as np
import jax.numpy as jnp
import pytest
from jax import Array

from causalprog.graph import ComponentNode, DataNode, DistributionNode
from causalprog.algorithms import evaluate, evaluate_down_to
from causalprog.graph import Graph
from causalprog.graph.special import example_model


@pytest.fixture
def evaluate_test_graph() -> Graph:
return example_model(
z_len=2,
compute_u_x=lambda C: C,
compute_u_y=lambda C: C + 1,
compute_phi_x=lambda L: L[0],
compute_x=lambda Z, PhiX, UX: Z[0] + UX - PhiX,
compute_y=lambda X, UY: X * UY,
)


@pytest.mark.parametrize(
("node", "kwargs_to_evaluate", "expected_result"),
("outcome_node_label", "initial_values", "expected_result"),
[
pytest.param(
DataNode(label="A"), {"A": 2.0}, 2.0, id="Evaluate DataNode itself"
"L",
{"L": jnp.array([5.5]), "Z": jnp.array([2.0, 0.0]), "C": 4.0},
{},
id="DataNode evaluation w/ excess information provided",
),
pytest.param(
ComponentNode("Parent", 1, label="Child"),
{"Parent": np.arange(4)},
1.0,
id="Evaluate ComponentNode, given parent",
"Z",
{"Z": jnp.array([2.0, 0.0])},
{},
id="DataNode evaluation",
),
],
)
def test_evaluate_node(node, kwargs_to_evaluate, expected_result):
assert np.allclose(node.evaluate(**kwargs_to_evaluate), expected_result)


@pytest.mark.parametrize(
("node", "kwargs_for_evaluate", "expected_error"),
[
pytest.param(
DataNode(label="A"),
"C",
{"C": 4.0},
{},
ValueError("Missing input for node: A"),
id="DataNode missing input value",
id="DiscreteRVNode evaluation",
),
pytest.param(
"UX",
{"C": 4.0},
{"UX": 4.0},
id="CtsRVNode evaluation",
),
pytest.param(
DistributionNode(distribution=None, label="A"),
"UX",
{"C": 4.0, "UX": 1.0},
{},
RuntimeError("Cannot evaluate a DistributionNode"),
id="Attempt to evaluate DistributionNode",
id="CtsRVNode evaluation, 'given that' overrides computed value",
),
pytest.param(
"UY",
{"C": 4.0},
{"UY": 5.0},
id="CtsRVNode evaluation, with parents that need evaluating",
),
pytest.param(
"X",
{"L": jnp.array([5.5]), "Z": jnp.array([2.0, 0.0]), "C": 4.0},
{"UX": 4.0, "PhiX": 5.5, "X": 0.5},
id="Multiple paths from different root nodes",
),
pytest.param(
"X",
{"L": jnp.array([5.5]), "Z": jnp.array([2.0, 0.0]), "C": 4.0, "PhiX": 0.0},
{"UX": 4.0, "X": 6.0},
id="Multiple paths from different root nodes, with some given values",
),
pytest.param(
"Y",
{"L": jnp.array([5.5]), "Z": jnp.array([2.0, 0.0]), "C": 4.0},
{"UX": 4.0, "UY": 5.0, "PhiX": 5.5, "X": 0.5, "Y": 2.5},
id="Evaluating the 'outcome' node.",
),
],
)
def test_evaluate_node_fail_on_missing_data(
node, kwargs_for_evaluate, expected_error, raises_context
):
with raises_context(expected_error):
node.evaluate(**kwargs_for_evaluate)
def test_evaluate(
evaluate_test_graph: Graph,
outcome_node_label: str,
initial_values: dict[str, Array],
expected_result: dict[str, Array],
) -> None:
computed_result = evaluate_down_to(
evaluate_test_graph, outcome_node_label, **initial_values
)

# Same number of entries
assert len(expected_result) == len(computed_result)
# Keys are correct
assert set(expected_result.keys()) == set(computed_result.keys())

# All entries match to acceptable precision for floats
for node_label, computed_value in computed_result.items():
assert jnp.allclose(computed_value, expected_result[node_label])

# Just asking for one value did indeed extract the correct node value
computed_result_single = evaluate(
evaluate_test_graph, outcome_node_label, **initial_values
)
if outcome_node_label in initial_values:
assert jnp.allclose(computed_result_single, initial_values[outcome_node_label])
else:
assert jnp.allclose(computed_result_single, computed_result[outcome_node_label])
2 changes: 1 addition & 1 deletion tests/test_graph/test_ordering.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def test_roots_down_to_outcome() -> None:
graph.get_node("W"),
)
nodes = graph.roots_down_to_outcome("Z")
assert len(nodes) == 5 # noqa: PLR2004
assert len(nodes) == 5
for e in edges:
if "W" not in e:
assert nodes.index(graph.get_node(e[0])) < nodes.index(graph.get_node(e[1]))
Expand Down
Loading
Loading