From a8f2751fd471f96a662413580870debda747929a Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Fri, 5 Jun 2026 11:06:21 +0100 Subject: [PATCH 01/49] rename ParameterNode to DataNode --- examples/two_normal.ipynb | 12 +++++------ src/causalprog/graph/__init__.py | 2 +- src/causalprog/graph/graph.py | 10 ++++----- src/causalprog/graph/node/__init__.py | 2 +- .../graph/node/{parameter.py => data.py} | 21 +++++++------------ tests/fixtures/graph.py | 20 ++++++++---------- tests/test_algorithms/test_do.py | 10 ++++----- .../test_associate_models_to_components.py | 4 ++-- tests/test_graph/test_model.py | 10 ++++----- tests/test_graph/test_ordering.py | 14 ++++++------- tests/test_graph/test_parameters.py | 8 +++---- 11 files changed, 52 insertions(+), 61 deletions(-) rename src/causalprog/graph/node/{parameter.py => data.py} (62%) diff --git a/examples/two_normal.ipynb b/examples/two_normal.ipynb index 73fe43d..0123037 100644 --- a/examples/two_normal.ipynb +++ b/examples/two_normal.ipynb @@ -102,9 +102,9 @@ "\n", "For each parameter and random variable (RV) in our causal problem, we need a node to represent it.\n", "\n", - "1. **Model Parameters** are represented with `ParameterNode`s. The model parameters are the set of variables that fully parametrise the (RVs that appear in the) DAG / causal problem.\n", + "1. **Model Parameters** are represented with `DataNode`s. The model parameters are the set of variables that fully parametrise the (RVs that appear in the) DAG / causal problem.\n", " - In our example, these are the values $\\mu_X$ and $\\nu_Y$.\n", - " - For each of these, we add a `ParameterNode`. \n", + " - For each of these, we add a `DataNode`. \n", " - Model parameters that are referenced in the `parameters` dictionary of a `DistributionNode` will be used when constructing the `DistributionNode`'s RV.\n", "\n", "2. **Derived (RV) Parameters** are parameters of RVs that are the result of sampling from a previous distribution, or take the value of a model parameter.\n", @@ -126,9 +126,9 @@ "source": [ "from numpyro.distributions import Normal\n", "\n", - "from causalprog.graph import DistributionNode, ParameterNode\n", + "from causalprog.graph import DistributionNode, DataNode\n", "\n", - "graph.add_node(ParameterNode(label=\"mu_X\"))\n", + "graph.add_node(DataNode(label=\"mu_X\"))\n", "graph.add_node(\n", " DistributionNode(\n", " distribution=Normal,\n", @@ -138,7 +138,7 @@ " )\n", ")\n", "\n", - "graph.add_node(ParameterNode(label=\"nu_Y\"))\n", + "graph.add_node(DataNode(label=\"nu_Y\"))\n", "graph.add_node(\n", " DistributionNode(\n", " distribution=Normal,\n", @@ -162,7 +162,7 @@ "\n", "Note that edges should be directed **into** the dependent RV.\n", "That is, `DistributionNode`s should have edges directed into them from the nodes referenced by their derived parameters.\n", - "These nodes may be either `ParameterNode`s; for example our RV $X$ will need an edge from the `ParameterNode` for $\\mu_X$ into the `DistributionNode` representing $X$.\n", + "These nodes may be either `DataNode`s; for example our RV $X$ will need an edge from the `DataNode` for $\\mu_X$ into the `DistributionNode` representing $X$.\n", "They may also be other `DistributionNode`s; for example, $Y$ has $X$ as a derived parameter since we know that $Y \\vert X \\sim \\mathcal{N}(X, \\nu_Y)$." ] }, diff --git a/src/causalprog/graph/__init__.py b/src/causalprog/graph/__init__.py index f1358ab..1fa4396 100644 --- a/src/causalprog/graph/__init__.py +++ b/src/causalprog/graph/__init__.py @@ -1,4 +1,4 @@ """Creation and storage of graphs.""" from .graph import Graph -from .node import ComponentNode, DistributionNode, Node, ParameterNode +from .node import ComponentNode, DistributionNode, Node, DataNode diff --git a/src/causalprog/graph/graph.py b/src/causalprog/graph/graph.py index de565fd..e252bd0 100644 --- a/src/causalprog/graph/graph.py +++ b/src/causalprog/graph/graph.py @@ -209,7 +209,7 @@ def model(self, **parameter_values: npt.ArrayLike) -> dict[str, npt.ArrayLike]: The model created takes values of the nodes that are parameter as keyword arguments. Names of the keyword arguments should match the labels of the - `ParameterNode`s, and their values should be the values of those parameters. + `DataNode`s, and their values should be the values of those parameters. The method returns a dictionary recording the mode sites that are created. This means that the model can be 'extended' further by defining additional @@ -217,18 +217,18 @@ def model(self, **parameter_values: npt.ArrayLike) -> dict[str, npt.ArrayLike]: Args: parameter_values: Names of the keyword arguments should match the labels - of the `ParameterNode`s, and their values should be the values of those + of the `DataNode`s, and their values should be the values of those parameters. Returns: - Mapping of non-`ParameterNode` `Node` labels to the site objects created + Mapping of non-`DataNode` `Node` labels to the site objects created for these nodes. """ - # Confirm that all `ParameterNode`s have been assigned a value. + # Confirm that all `DataNode`s have been assigned a value. for node in self.parameter_nodes: if node.label not in parameter_values: - msg = f"ParameterNode '{node.label}' not assigned" + msg = f"DataNode '{node.label}' not assigned" raise KeyError(msg) # Build model sequentially, using the node_order to inform the diff --git a/src/causalprog/graph/node/__init__.py b/src/causalprog/graph/node/__init__.py index 13e66b5..f2a1866 100644 --- a/src/causalprog/graph/node/__init__.py +++ b/src/causalprog/graph/node/__init__.py @@ -3,4 +3,4 @@ from .base import Node from .component import ComponentNode from .distribution import DistributionNode -from .parameter import ParameterNode +from .data import DataNode diff --git a/src/causalprog/graph/node/parameter.py b/src/causalprog/graph/node/data.py similarity index 62% rename from src/causalprog/graph/node/parameter.py rename to src/causalprog/graph/node/data.py index 1bcceb7..6b701d3 100644 --- a/src/causalprog/graph/node/parameter.py +++ b/src/causalprog/graph/node/data.py @@ -1,4 +1,4 @@ -"""Graph nodes representing parameters.""" +"""Graph nodes representing known of unknown data.""" import jax import jax.numpy as jnp @@ -8,18 +8,11 @@ from .base import Node -class ParameterNode(Node): +class DataNode(Node): """ - A node containing a parameter. + A node containing non-stochastic data. - `ParameterNode`s differ from `DistributionNode`s in that they do not have an - attached distribution, but rather represent a parameter that contributes - to the shape of one (or more) `DistributionNode`s. - - The collection of parameters described by `ParameterNode`s forms the set of - variables that will be optimised over in the corresponding `CausalProblem`. - - `ParameterNode`s should not be used to encode constant values used by + `DataNode`s should not be used to encode constant values used by `DistributionNode`s. Such constant values should be given to the necessary `DistributionNode`s directly as `constant_parameters`. """ @@ -44,17 +37,17 @@ def sample( rng_key: jax.Array, ) -> npt.ArrayLike: if self.label not in parameter_values: - msg = f"Missing input for parameter node: {self.label}." + msg = f"Missing input for node: {self.label}." raise ValueError(msg) return jnp.full(samples, parameter_values[self.label]) @override def copy(self) -> Node: - return ParameterNode(label=self.label) + return DataNode(label=self.label) @override def __repr__(self) -> str: - return f'ParameterNode(label="{self.label}")' + return f'DataNode(label="{self.label}")' @override @property diff --git a/tests/fixtures/graph.py b/tests/fixtures/graph.py index 62143d6..f7e93ae 100644 --- a/tests/fixtures/graph.py +++ b/tests/fixtures/graph.py @@ -8,19 +8,17 @@ import pytest from numpyro.distributions import Normal -from causalprog.graph import DistributionNode, Graph, ParameterNode +from causalprog.graph import DistributionNode, Graph, DataNode NormalGraphNodeNames: TypeAlias = Literal["mean", "cov", "outcome"] -NormalGraphNodes: TypeAlias = dict[ - NormalGraphNodeNames, DistributionNode | ParameterNode -] +NormalGraphNodes: TypeAlias = dict[NormalGraphNodeNames, DistributionNode | DataNode] @pytest.fixture def normal_graph() -> Callable[[float, float], Graph]: """Creates a graph with one normal distribution X. - Parameter nodes are included if no values are given for the mean and covariance. + Data nodes are included if no values are given for the mean and covariance. """ def _inner(mean: float | None = None, cov: float | None = None): @@ -28,12 +26,12 @@ def _inner(mean: float | None = None, cov: float | None = None): parameters = {} constant_parameters = {} if mean is None: - graph.add_node(ParameterNode(label="mean")) + graph.add_node(DataNode(label="mean")) parameters["loc"] = "mean" else: constant_parameters["loc"] = mean if cov is None: - graph.add_node(ParameterNode(label="cov")) + graph.add_node(DataNode(label="cov")) parameters["scale"] = "cov" else: constant_parameters["scale"] = cov @@ -61,7 +59,7 @@ def two_normal_graph() -> Callable[[float, float, float], Graph]: where UX is a normal distribution with mean `mean` and covariance `cov`, and X is a normal distrubution with mean UX and covariance `cov2`. - Parameter nodes are included if no values are given for the mean and covariances. + Data nodes are included if no values are given for the mean and covariances. """ @@ -75,17 +73,17 @@ def _inner( ux_parameters = {} ux_constant_parameters = {} if mean is None: - graph.add_node(ParameterNode(label="mean")) + graph.add_node(DataNode(label="mean")) ux_parameters["loc"] = "mean" else: ux_constant_parameters["loc"] = mean if cov is None: - graph.add_node(ParameterNode(label="cov")) + graph.add_node(DataNode(label="cov")) ux_parameters["scale"] = "cov" else: ux_constant_parameters["scale"] = cov if cov2 is None: - graph.add_node(ParameterNode(label="cov2")) + graph.add_node(DataNode(label="cov2")) x_parameters["scale"] = "cov2" else: x_constant_parameters["scale"] = cov2 diff --git a/tests/test_algorithms/test_do.py b/tests/test_algorithms/test_do.py index a13f496..1599ffe 100644 --- a/tests/test_algorithms/test_do.py +++ b/tests/test_algorithms/test_do.py @@ -1,7 +1,7 @@ """Tests for the do algorithm.""" from causalprog import algorithms -from causalprog.graph import Graph, ParameterNode +from causalprog.graph import Graph, DataNode max_samples = 10**5 @@ -56,10 +56,10 @@ def test_do_edges(two_normal_graph): def test_do_error(raises_context): graph = Graph(label="ABC") - graph.add_node(ParameterNode(label="A")) - graph.add_node(ParameterNode(label="B1")) - graph.add_node(ParameterNode(label="B2")) - graph.add_node(ParameterNode(label="C")) + graph.add_node(DataNode(label="A")) + graph.add_node(DataNode(label="B1")) + graph.add_node(DataNode(label="B2")) + graph.add_node(DataNode(label="C")) graph.add_edge("A", "B1") graph.add_edge("A", "B2") graph.add_edge("B1", "C") diff --git a/tests/test_causal_problem/test_associate_models_to_components.py b/tests/test_causal_problem/test_associate_models_to_components.py index 631f44b..caac467 100644 --- a/tests/test_causal_problem/test_associate_models_to_components.py +++ b/tests/test_causal_problem/test_associate_models_to_components.py @@ -21,7 +21,7 @@ Constraint, HandlerToApply, ) -from causalprog.graph import Graph, ParameterNode +from causalprog.graph import Graph, DataNode @pytest.fixture @@ -30,7 +30,7 @@ def underlying_graph() -> Graph: so we just return a single node graph. """ g = Graph(label="Placeholder") - g.add_node(ParameterNode(label="p")) + g.add_node(DataNode(label="p")) return g diff --git a/tests/test_graph/test_model.py b/tests/test_graph/test_model.py index 26e0b04..db44ec0 100644 --- a/tests/test_graph/test_model.py +++ b/tests/test_graph/test_model.py @@ -2,7 +2,7 @@ import numpyro import pytest -from causalprog.graph import DistributionNode, Graph, ParameterNode +from causalprog.graph import DistributionNode, Graph, DataNode @pytest.mark.parametrize( @@ -21,7 +21,7 @@ def test_model( ) -> None: """Test the `Graph.model` method. - `Graph.model` takes values for the `ParameterNode`s (parameters of the model) + `Graph.model` takes values for the `DataNode`s (parameters of the model) as its arguments. It is designed to be able to be used just like any other function defining a model, namely that `Graph.model(**parameter_values)` is a function that creates the appropriate model sites, given values for the @@ -52,14 +52,14 @@ def test_model_missing_parameter( seed: int, ) -> None: """`Graph.model` will raise a `KeyError` when a value is not passed for - a `ParameterNode`. + a `DataNode`. """ graph = two_normal_graph(cov=1.0) # Deliberately leave out the "cov2" variable. parameter_values = {"mean": 0.0} # Which should result in the error below. - expected_exception = KeyError("ParameterNode 'cov2' not assigned") + expected_exception = KeyError("DataNode 'cov2' not assigned") # Not passing enough parameters should be picked up by the model. with raises_context(expected_exception), numpyro.handlers.seed(rng_seed=seed): @@ -77,7 +77,7 @@ def test_model_extension( parameter_values = {"mean": 0.0, "cov2": 1.0} # Build the graph, but without the X-node. - mean = ParameterNode(label="mean") + mean = DataNode(label="mean") x = DistributionNode( numpyro.distributions.Normal, label="UX", diff --git a/tests/test_graph/test_ordering.py b/tests/test_graph/test_ordering.py index 57ba0a1..2ef804d 100644 --- a/tests/test_graph/test_ordering.py +++ b/tests/test_graph/test_ordering.py @@ -1,17 +1,17 @@ """Tests for ordering of nodes in a graph.""" -from causalprog.graph import Graph, ParameterNode +from causalprog.graph import Graph, DataNode def test_roots_down_to_outcome() -> None: graph = Graph(label="G0") - graph.add_node(ParameterNode(label="U")) - graph.add_node(ParameterNode(label="V")) - graph.add_node(ParameterNode(label="W")) - graph.add_node(ParameterNode(label="X")) - graph.add_node(ParameterNode(label="Y")) - graph.add_node(ParameterNode(label="Z")) + graph.add_node(DataNode(label="U")) + graph.add_node(DataNode(label="V")) + graph.add_node(DataNode(label="W")) + graph.add_node(DataNode(label="X")) + graph.add_node(DataNode(label="Y")) + graph.add_node(DataNode(label="Z")) edges = [ ["V", "W"], diff --git a/tests/test_graph/test_parameters.py b/tests/test_graph/test_parameters.py index 589ddbe..8712ee9 100644 --- a/tests/test_graph/test_parameters.py +++ b/tests/test_graph/test_parameters.py @@ -2,13 +2,13 @@ import jax.numpy as jnp -from causalprog.graph import ParameterNode +from causalprog.graph import DataNode -def test_parameter_node(rng_key, raises_context): - node = ParameterNode(label="mu") +def test_data_node(rng_key, raises_context): + node = DataNode(label="mu") - with raises_context(ValueError("Missing input for parameter")): + with raises_context(ValueError("Missing input for node")): node.sample({}, {}, 1, rng_key=rng_key) assert jnp.allclose( From df6bf524073e348093ced4d825041693863955d0 Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Fri, 5 Jun 2026 11:31:47 +0100 Subject: [PATCH 02/49] add evaluate method --- src/causalprog/graph/node/base.py | 16 ++++++++++++++++ src/causalprog/graph/node/component.py | 7 +++++++ src/causalprog/graph/node/data.py | 18 ++++++++++++++++++ src/causalprog/graph/node/distribution.py | 8 ++++++++ tests/test_algorithms/test_evaluate.py | 14 ++++++++++++++ 5 files changed, 63 insertions(+) create mode 100644 tests/test_algorithms/test_evaluate.py diff --git a/src/causalprog/graph/node/base.py b/src/causalprog/graph/node/base.py index f0d74b8..ce5318d 100644 --- a/src/causalprog/graph/node/base.py +++ b/src/causalprog/graph/node/base.py @@ -127,6 +127,22 @@ def sample( """ + @abstractmethod + def evaluate( + self, + **given_values: dict[str, float | npt.NDArray[float]], + ) -> float | npt.NDArray[float]: + """ + Evaluate the node. + + Args: + given_values: Values for data nodes and values of parents + + Returns: + Value of this node given the given values + + """ + @abstractmethod def copy(self) -> Node: """ diff --git a/src/causalprog/graph/node/component.py b/src/causalprog/graph/node/component.py index f5d7ca1..ea8aa07 100644 --- a/src/causalprog/graph/node/component.py +++ b/src/causalprog/graph/node/component.py @@ -48,6 +48,13 @@ def sample( ) -> npt.NDArray[float]: return sampled_dependencies[self._parent_node_label][:, *self._component] + @override + def evaluate( + self, + **given_values: dict[str, float | npt.NDArray[float]], + ) -> float | npt.NDArray[float]: + return given_values[self._parent_node_label][*self.component] + @override def copy(self) -> Node: return ComponentNode( diff --git a/src/causalprog/graph/node/data.py b/src/causalprog/graph/node/data.py index 6b701d3..ac92887 100644 --- a/src/causalprog/graph/node/data.py +++ b/src/causalprog/graph/node/data.py @@ -41,6 +41,24 @@ def sample( raise ValueError(msg) return jnp.full(samples, parameter_values[self.label]) + @override + def evaluate( + self, + **given_values: dict[str, float | npt.NDArray[float]], + ) -> float | npt.NDArray[float]: + if self.label not in given_values: + msg = f"Missing input for node: {self.label}." + raise ValueError(msg) + value = given_values[self.label] + if self.shape == (): + if not isinstance(value, float): + msg = f"Invalid value for note: {self.label}" + raise ValueError(msg) + elif isinstance(value, float) or self.shape != value.shape: + msg = f"Invalid value for nose: {self.label}" + raise ValueError(msg) + return value + @override def copy(self) -> Node: return DataNode(label=self.label) diff --git a/src/causalprog/graph/node/distribution.py b/src/causalprog/graph/node/distribution.py index 555672f..18f5abf 100644 --- a/src/causalprog/graph/node/distribution.py +++ b/src/causalprog/graph/node/distribution.py @@ -70,6 +70,14 @@ def sample( else self.shape, ) + @override + def evaluate( + self, + **given_values: dict[str, float | npt.NDArray[float]], + ) -> float | npt.NDArray[float]: + msg = "Cannot evaluate a DistributionNode" + raise RuntimeError(msg) + @override def copy(self) -> Node: return DistributionNode( diff --git a/tests/test_algorithms/test_evaluate.py b/tests/test_algorithms/test_evaluate.py new file mode 100644 index 0000000..105eb3b --- /dev/null +++ b/tests/test_algorithms/test_evaluate.py @@ -0,0 +1,14 @@ +"""Tests for evaluate algorithms.""" + +import numpy as np +import jax.numpy as jnp +import pytest +from causalprog.graph import DataNode + + +def test_evaluate_node(raises_context): + node = DataNode(label="A") + assert np.isclose(node.evaluate(A=2.0), 2.0) + + with raises_context(ValueError("Missing input for node: A")): + node.evaluate() From 1d0738df1b518507183ea23c7dc513846ad38f49 Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Fri, 5 Jun 2026 11:45:12 +0100 Subject: [PATCH 03/49] mypy? --- src/causalprog/graph/node/component.py | 5 ++++- src/causalprog/graph/node/data.py | 3 ++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/causalprog/graph/node/component.py b/src/causalprog/graph/node/component.py index ea8aa07..71120ef 100644 --- a/src/causalprog/graph/node/component.py +++ b/src/causalprog/graph/node/component.py @@ -2,6 +2,7 @@ from __future__ import annotations +import numpy as np import typing from typing_extensions import override @@ -53,7 +54,9 @@ def evaluate( self, **given_values: dict[str, float | npt.NDArray[float]], ) -> float | npt.NDArray[float]: - return given_values[self._parent_node_label][*self.component] + parent_value = given_values[self._parent_node_label] + assert isinstance(parent_value, np.ndarray) + return parent_value[*self.component] @override def copy(self) -> Node: diff --git a/src/causalprog/graph/node/data.py b/src/causalprog/graph/node/data.py index ac92887..dbe0d43 100644 --- a/src/causalprog/graph/node/data.py +++ b/src/causalprog/graph/node/data.py @@ -2,6 +2,7 @@ import jax import jax.numpy as jnp +import numpy as np import numpy.typing as npt from typing_extensions import override @@ -54,7 +55,7 @@ def evaluate( if not isinstance(value, float): msg = f"Invalid value for note: {self.label}" raise ValueError(msg) - elif isinstance(value, float) or self.shape != value.shape: + elif not isinstance(value, np.ndarray) or self.shape != value.shape: msg = f"Invalid value for nose: {self.label}" raise ValueError(msg) return value From cbb10ea64ddc081e5bb825bd46f7fd0e2a247d40 Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Fri, 5 Jun 2026 14:45:21 +0100 Subject: [PATCH 04/49] ruff --- src/causalprog/graph/node/component.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/causalprog/graph/node/component.py b/src/causalprog/graph/node/component.py index 71120ef..e433bdc 100644 --- a/src/causalprog/graph/node/component.py +++ b/src/causalprog/graph/node/component.py @@ -55,7 +55,9 @@ def evaluate( **given_values: dict[str, float | npt.NDArray[float]], ) -> float | npt.NDArray[float]: parent_value = given_values[self._parent_node_label] - assert isinstance(parent_value, np.ndarray) + if not isinstance(parent_value, np.ndarray): + msg = f"Invalid data in node: {self._parent_node_label}" + raise ValueError(msg) return parent_value[*self.component] @override From a956b83bcabeaf1715eb192fe58ef56983f3706f Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Fri, 5 Jun 2026 14:47:39 +0100 Subject: [PATCH 05/49] typeerror --- src/causalprog/graph/node/component.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/causalprog/graph/node/component.py b/src/causalprog/graph/node/component.py index e433bdc..c32dc56 100644 --- a/src/causalprog/graph/node/component.py +++ b/src/causalprog/graph/node/component.py @@ -57,7 +57,7 @@ def evaluate( parent_value = given_values[self._parent_node_label] if not isinstance(parent_value, np.ndarray): msg = f"Invalid data in node: {self._parent_node_label}" - raise ValueError(msg) + raise TypeError(msg) return parent_value[*self.component] @override From b7b3332c6011a284a09ffbd659778b59c7582596 Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Fri, 5 Jun 2026 14:49:20 +0100 Subject: [PATCH 06/49] rufff --- examples/two_normal.ipynb | 2 +- src/causalprog/graph/__init__.py | 2 +- src/causalprog/graph/node/__init__.py | 2 +- src/causalprog/graph/node/component.py | 2 +- tests/fixtures/graph.py | 2 +- tests/test_algorithms/test_do.py | 2 +- tests/test_algorithms/test_evaluate.py | 3 +-- .../test_causal_problem/test_associate_models_to_components.py | 2 +- tests/test_graph/test_model.py | 2 +- tests/test_graph/test_ordering.py | 2 +- 10 files changed, 10 insertions(+), 11 deletions(-) diff --git a/examples/two_normal.ipynb b/examples/two_normal.ipynb index 0123037..8697c8d 100644 --- a/examples/two_normal.ipynb +++ b/examples/two_normal.ipynb @@ -126,7 +126,7 @@ "source": [ "from numpyro.distributions import Normal\n", "\n", - "from causalprog.graph import DistributionNode, DataNode\n", + "from causalprog.graph import DataNode, DistributionNode\n", "\n", "graph.add_node(DataNode(label=\"mu_X\"))\n", "graph.add_node(\n", diff --git a/src/causalprog/graph/__init__.py b/src/causalprog/graph/__init__.py index 1fa4396..20832aa 100644 --- a/src/causalprog/graph/__init__.py +++ b/src/causalprog/graph/__init__.py @@ -1,4 +1,4 @@ """Creation and storage of graphs.""" from .graph import Graph -from .node import ComponentNode, DistributionNode, Node, DataNode +from .node import ComponentNode, DataNode, DistributionNode, Node diff --git a/src/causalprog/graph/node/__init__.py b/src/causalprog/graph/node/__init__.py index f2a1866..a58bfc4 100644 --- a/src/causalprog/graph/node/__init__.py +++ b/src/causalprog/graph/node/__init__.py @@ -2,5 +2,5 @@ from .base import Node from .component import ComponentNode -from .distribution import DistributionNode from .data import DataNode +from .distribution import DistributionNode diff --git a/src/causalprog/graph/node/component.py b/src/causalprog/graph/node/component.py index c32dc56..39aaccd 100644 --- a/src/causalprog/graph/node/component.py +++ b/src/causalprog/graph/node/component.py @@ -2,9 +2,9 @@ from __future__ import annotations -import numpy as np import typing +import numpy as np from typing_extensions import override from .base import Node diff --git a/tests/fixtures/graph.py b/tests/fixtures/graph.py index f7e93ae..d707f5b 100644 --- a/tests/fixtures/graph.py +++ b/tests/fixtures/graph.py @@ -8,7 +8,7 @@ import pytest from numpyro.distributions import Normal -from causalprog.graph import DistributionNode, Graph, DataNode +from causalprog.graph import DataNode, DistributionNode, Graph NormalGraphNodeNames: TypeAlias = Literal["mean", "cov", "outcome"] NormalGraphNodes: TypeAlias = dict[NormalGraphNodeNames, DistributionNode | DataNode] diff --git a/tests/test_algorithms/test_do.py b/tests/test_algorithms/test_do.py index 1599ffe..6c89338 100644 --- a/tests/test_algorithms/test_do.py +++ b/tests/test_algorithms/test_do.py @@ -1,7 +1,7 @@ """Tests for the do algorithm.""" from causalprog import algorithms -from causalprog.graph import Graph, DataNode +from causalprog.graph import DataNode, Graph max_samples = 10**5 diff --git a/tests/test_algorithms/test_evaluate.py b/tests/test_algorithms/test_evaluate.py index 105eb3b..a25fe24 100644 --- a/tests/test_algorithms/test_evaluate.py +++ b/tests/test_algorithms/test_evaluate.py @@ -1,8 +1,7 @@ """Tests for evaluate algorithms.""" import numpy as np -import jax.numpy as jnp -import pytest + from causalprog.graph import DataNode diff --git a/tests/test_causal_problem/test_associate_models_to_components.py b/tests/test_causal_problem/test_associate_models_to_components.py index caac467..d444671 100644 --- a/tests/test_causal_problem/test_associate_models_to_components.py +++ b/tests/test_causal_problem/test_associate_models_to_components.py @@ -21,7 +21,7 @@ Constraint, HandlerToApply, ) -from causalprog.graph import Graph, DataNode +from causalprog.graph import DataNode, Graph @pytest.fixture diff --git a/tests/test_graph/test_model.py b/tests/test_graph/test_model.py index db44ec0..7a18b0d 100644 --- a/tests/test_graph/test_model.py +++ b/tests/test_graph/test_model.py @@ -2,7 +2,7 @@ import numpyro import pytest -from causalprog.graph import DistributionNode, Graph, DataNode +from causalprog.graph import DataNode, DistributionNode, Graph @pytest.mark.parametrize( diff --git a/tests/test_graph/test_ordering.py b/tests/test_graph/test_ordering.py index 2ef804d..04407b1 100644 --- a/tests/test_graph/test_ordering.py +++ b/tests/test_graph/test_ordering.py @@ -1,6 +1,6 @@ """Tests for ordering of nodes in a graph.""" -from causalprog.graph import Graph, DataNode +from causalprog.graph import DataNode, Graph def test_roots_down_to_outcome() -> None: From abbff6bdcf0076ea5fd79a65aceab2717f89b803 Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Wed, 24 Jun 2026 14:58:20 +0100 Subject: [PATCH 07/49] simplify some graph code --- src/causalprog/algorithms/do.py | 17 ++---- src/causalprog/graph/__init__.py | 2 +- src/causalprog/graph/graph.py | 47 +++++++---------- src/causalprog/graph/node/__init__.py | 1 + src/causalprog/graph/node/base.py | 51 ++---------------- src/causalprog/graph/node/component.py | 16 ++---- src/causalprog/graph/node/constant.py | 63 +++++++++++++++++++++++ src/causalprog/graph/node/data.py | 11 ++-- src/causalprog/graph/node/distribution.py | 19 +++---- tests/test_algorithms/test_do.py | 19 ++++--- 10 files changed, 113 insertions(+), 133 deletions(-) create mode 100644 src/causalprog/graph/node/constant.py diff --git a/src/causalprog/algorithms/do.py b/src/causalprog/algorithms/do.py index 6e0688c..709ea31 100644 --- a/src/causalprog/algorithms/do.py +++ b/src/causalprog/algorithms/do.py @@ -2,7 +2,7 @@ from copy import deepcopy -from causalprog.graph import Graph, Node +from causalprog.graph import Graph, Node, ConstantNode def get_included_excluded_successors( @@ -72,19 +72,6 @@ def do(graph: Graph, node: str, value: float, label: str | None = None) -> Graph nodes = {n.label: deepcopy(n) for n in graph.nodes if n.label != node} - # Search through the old graph, identifying nodes that had parameters which were - # defined by the node being fixed in the DO operation. - # We recreate these nodes, but replace each such parameter we encounter with - # a constant parameter equal that takes the fixed value given as an input. - for n in nodes.values(): - params = tuple(n.parameters.keys()) - for parameter_name in params: - if n.parameters[parameter_name] == node: - # Swap the parameter to a constant parameter, giving it the fixed value - n.constant_parameters[parameter_name] = value - # Remove the parameter from the node's record of non-constant parameters - n.parameters.pop(parameter_name) - # Recursively remove nodes that are predecessors of removed nodes nodes_to_remove: tuple[str, ...] = (node,) while len(nodes_to_remove) > 0: @@ -103,6 +90,8 @@ def do(graph: Graph, node: str, value: float, label: str | None = None) -> Graph ) raise ValueError(msg) + nodes[node] = ConstantNode(label=node, value=value) + g = Graph(label=f"{label}|do[{node}={value}]") for n in nodes.values(): g.add_node(n) diff --git a/src/causalprog/graph/__init__.py b/src/causalprog/graph/__init__.py index 20832aa..d4166bb 100644 --- a/src/causalprog/graph/__init__.py +++ b/src/causalprog/graph/__init__.py @@ -1,4 +1,4 @@ """Creation and storage of graphs.""" from .graph import Graph -from .node import ComponentNode, DataNode, DistributionNode, Node +from .node import ComponentNode, DataNode, DistributionNode, Node, ConstantNode diff --git a/src/causalprog/graph/graph.py b/src/causalprog/graph/graph.py index e252bd0..41ecaa7 100644 --- a/src/causalprog/graph/graph.py +++ b/src/causalprog/graph/graph.py @@ -4,7 +4,7 @@ import numpy.typing as npt from causalprog._abc.labelled import Labelled -from causalprog.graph.node import ComponentNode, Node +from causalprog.graph.node import ComponentNode, Node, DistributionNode class Graph(Labelled): @@ -61,7 +61,10 @@ def add_node(self, node: Node) -> None: self._nodes_by_label[node.label] = node self._graph.add_node(node) if isinstance(node, ComponentNode): - self.add_edge(node.parent_node, node.label) + if len(node.parents) != 1: + msg = "ComponentNode should have exactly one parent." + raise ValueError(msg) + self.add_edge(node.parents[0], node.label) def add_edge(self, start_node: Node | str, end_node: Node | str) -> None: """ @@ -91,22 +94,22 @@ def add_edge(self, start_node: Node | str, end_node: Node | str) -> None: self._graph.add_edge(start_node, end_node) @property - def parameter_nodes(self) -> tuple[Node, ...]: + def leaf_nodes(self) -> tuple[Node, ...]: """ - Returns all parameter nodes in the graph. + Returns all leaf nodes in the graph. - The returned tuple uses the `ordered_nodes` property to obtain the parameter - nodes so that a natural "fixed order" is given to the parameters. When parameter + The returned tuple uses the `ordered_nodes` property to obtain the leaf + nodes so that a natural "fixed order" is given to the leaves. When leaf values are given as inputs to the causal estimand and / or constraint functions, - they will ideally be given as a single vector of parameter values, in which case - a fixed ordering for the parameters is necessary to make an association to the + they will ideally be given as a single vector of leaf values, in which case + a fixed ordering for the leaves is necessary to make an association to the components of the given input vector. Returns: - Parameter nodes + Leaf nodes """ - return tuple(node for node in self.ordered_nodes if node.is_parameter) + return tuple(node for node in self.ordered_nodes if len(node.parents) == 0) @property def predecessors(self) -> dict[Node, tuple[Node, ...]]: @@ -170,17 +173,6 @@ def ordered_nodes(self) -> tuple[Node, ...]: raise RuntimeError(msg) return tuple(nx.topological_sort(self._graph)) - @property - def ordered_dist_nodes(self) -> tuple[Node, ...]: - """ - `DistributionNode`s in dependency order. - - Each `DistributionNode` in the returned list appears after all its - dependencies. Order is derived from `self.ordered_nodes`, selecting - only those nodes where `is_distribution` is `True`. - """ - return tuple(node for node in self.ordered_nodes if node.is_distribution) - def roots_down_to_outcome( self, outcome_node_label: str, @@ -226,7 +218,7 @@ def model(self, **parameter_values: npt.ArrayLike) -> dict[str, npt.ArrayLike]: """ # Confirm that all `DataNode`s have been assigned a value. - for node in self.parameter_nodes: + for node in self.leaf_nodes: if node.label not in parameter_values: msg = f"DataNode '{node.label}' not assigned" raise KeyError(msg) @@ -234,10 +226,11 @@ def model(self, **parameter_values: npt.ArrayLike) -> dict[str, npt.ArrayLike]: # Build model sequentially, using the node_order to inform the # construction process. node_record: dict[str, npt.ArrayLike] = {} - for node in self.ordered_dist_nodes: - node_record[node.label] = node.create_model_site( - **parameter_values, # All nodes require knowledge of the parameters - **node_record, # and any dependent nodes we have already visited - ) + for node in self.ordered_nodes: + if isinstance(node, DistributionNode): + node_record[node.label] = node.create_model_site( + **parameter_values, # All nodes require knowledge of the parameters + **node_record, # and any dependent nodes we have already visited + ) return node_record diff --git a/src/causalprog/graph/node/__init__.py b/src/causalprog/graph/node/__init__.py index a58bfc4..1698224 100644 --- a/src/causalprog/graph/node/__init__.py +++ b/src/causalprog/graph/node/__init__.py @@ -2,5 +2,6 @@ from .base import Node from .component import ComponentNode +from .constant import ConstantNode from .data import DataNode from .distribution import DistributionNode diff --git a/src/causalprog/graph/node/base.py b/src/causalprog/graph/node/base.py index ce5318d..e3e6d1a 100644 --- a/src/causalprog/graph/node/base.py +++ b/src/causalprog/graph/node/base.py @@ -40,8 +40,6 @@ def __init__( *, label: str, shape: tuple[int, ...] = (), - is_parameter: bool = False, - is_distribution: bool = False, ) -> None: """ Initialise. @@ -55,23 +53,15 @@ def __init__( constraints), and as such the value that a "parameter node" passes to its dependent nodes will vary as the optimiser runs and explores the solution space. - Note that a "constant parameter" is distinct from a "parameter" in the sense - that a constant parameter is _not_ added to the collection of parameters over - which we will want to optimise (it is a hard-coded, fixed value). - Distributions (equivalently `DistributionNode`s) are Nodes that represent random variables described by probability distributions. Args: label: A unique label to identify the node shape: The shape of the node's value for each sample - is_parameter: Is the node a parameter? - is_distribution: Is the node a distribution? """ super().__init__(label=label) - self._is_parameter = is_parameter - self._is_distribution = is_distribution self._shape = shape def __getitem__(self, indices: int | slice | tuple[int | slice, ...]) -> Node: @@ -168,48 +158,13 @@ def shape(self) -> tuple[int, ...]: """ return self._shape - @property - def is_parameter(self) -> bool: - """ - Identify if the node is an parameter. - - Returns: - True if the node is an parameter - - """ - return self._is_parameter - - @property - def is_distribution(self) -> bool: - """ - Identify if the node is an distribution. - - Returns: - True if the node is an distribution - - """ - return self._is_distribution - - @property - @abstractmethod - def constant_parameters(self) -> dict[str, float]: - """ - Named constants that this node depends on. - - Returns: - A dictionary of the constant parameter names (keys) and their corresponding - values - - """ - @property @abstractmethod - def parameters(self) -> dict[str, str]: + def parents(self) -> list[str]: """ - Mapping of distribution parameter names to the nodes they are represented by. + Nodes that this node depends on the value of. Returns: - Mapping of distribution parameters (keys) to the corresponding label of the - node that represents this parameter (value). + List of labels of parent nodes """ diff --git a/src/causalprog/graph/node/component.py b/src/causalprog/graph/node/component.py index 39aaccd..8ec1c7a 100644 --- a/src/causalprog/graph/node/component.py +++ b/src/causalprog/graph/node/component.py @@ -36,7 +36,7 @@ def __init__( """ self._component = component self._parent_node_label = parent_node_label - super().__init__(shape=shape, label=label, is_distribution=True) + super().__init__(shape=shape, label=label) @override def sample( @@ -82,15 +82,5 @@ def __repr__(self) -> str: @override @property - def constant_parameters(self) -> dict[str, float]: - return {} - - @override - @property - def parameters(self) -> dict[str, str]: - return {} - - @property - def parent_node(self) -> str: - """The label of the parent node.""" - return self._parent_node_label + def parents(self) -> list[str]: + return [self._parent_node_label] diff --git a/src/causalprog/graph/node/constant.py b/src/causalprog/graph/node/constant.py new file mode 100644 index 0000000..df8dc79 --- /dev/null +++ b/src/causalprog/graph/node/constant.py @@ -0,0 +1,63 @@ +"""Graph nodes representing distributions.""" + +from __future__ import annotations + +import typing + +import numpy as np +from typing_extensions import override + +from .base import Node + +if typing.TYPE_CHECKING: + import jax + import numpy.typing as npt + + +class ConstantNode(Node): + """A node representing a constant.""" + + def __init__(self, *, label: str, value: flat | npt.NDArray[float]) -> None: + """ + Initialise. + + Args: + label: A unique label to identify the node + value: The value of this constant + + """ + self._value = value + super().__init__( + shape=() if isinstance(value, float) else value.shape, label=label + ) + + @override + def sample( + self, + parameter_values: dict[str, float], + sampled_dependencies: dict[str, npt.NDArray[float]], + samples: int, + *, + rng_key: jax.Array, + ) -> npt.NDArray[float]: + return jnp.full(samples, self._value) + + @override + def evaluate( + self, + **given_values: dict[str, float | npt.NDArray[float]], + ) -> float | npt.NDArray[float]: + return self._value + + @override + def copy(self) -> Node: + return ConstantNode(label=self.label, value=self._value) + + @override + def __repr__(self) -> str: + r = f"ConstantNode({self._value})" + + @override + @property + def parents(self) -> list[str]: + return [] diff --git a/src/causalprog/graph/node/data.py b/src/causalprog/graph/node/data.py index dbe0d43..c9eb729 100644 --- a/src/causalprog/graph/node/data.py +++ b/src/causalprog/graph/node/data.py @@ -26,7 +26,7 @@ def __init__(self, *, shape: tuple[int, ...] = (), label: str) -> None: label: A unique label to identify the node """ - super().__init__(label=label, shape=shape, is_parameter=True) + super().__init__(label=label, shape=shape) @override def sample( @@ -70,10 +70,5 @@ def __repr__(self) -> str: @override @property - def constant_parameters(self) -> dict[str, float]: - return {} - - @override - @property - def parameters(self) -> dict[str, str]: - return {} + def parents(self) -> list[str]: + return [] diff --git a/src/causalprog/graph/node/distribution.py b/src/causalprog/graph/node/distribution.py index 18f5abf..30495f9 100644 --- a/src/causalprog/graph/node/distribution.py +++ b/src/causalprog/graph/node/distribution.py @@ -40,7 +40,7 @@ def __init__( self._dist = distribution self._constant_parameters = constant_parameters if constant_parameters else {} self._parameters = parameters if parameters else {} - super().__init__(label=label, shape=shape, is_distribution=True) + super().__init__(label=label, shape=shape) @override def sample( @@ -55,10 +55,10 @@ def sample( # Pass in node values derived from construction so far **{ native_name: sampled_dependencies[node_name] - for native_name, node_name in self.parameters.items() + for native_name, node_name in self._parameters.items() }, # Pass in any constant parameters this node sets - **self.constant_parameters, + **self._constant_parameters, ) return numpyro.sample( @@ -101,13 +101,8 @@ def __repr__(self) -> str: @override @property - def constant_parameters(self) -> dict[str, float]: - return self._constant_parameters - - @override - @property - def parameters(self) -> dict[str, str]: - return self._parameters + def parents(self) -> list[str]: + return list(self._parameters.keys()) + list(self._constant_parameters.keys()) def create_model_site(self, **dependent_nodes: jax.Array) -> npt.ArrayLike: """ @@ -125,9 +120,9 @@ def create_model_site(self, **dependent_nodes: jax.Array) -> npt.ArrayLike: # Pass in node values derived from construction so far **{ native_name: dependent_nodes[node_name] - for native_name, node_name in self.parameters.items() + for native_name, node_name in self._parameters.items() }, # Pass in any constant parameters this node sets - **self.constant_parameters, + **self._constant_parameters, ), ) diff --git a/tests/test_algorithms/test_do.py b/tests/test_algorithms/test_do.py index 6c89338..4977c70 100644 --- a/tests/test_algorithms/test_do.py +++ b/tests/test_algorithms/test_do.py @@ -1,30 +1,29 @@ """Tests for the do algorithm.""" from causalprog import algorithms -from causalprog.graph import DataNode, Graph +from causalprog.graph import DataNode, Graph, ConstantNode max_samples = 10**5 -def test_do(two_normal_graph, raises_context): +def test_do(two_normal_graph): graph = two_normal_graph(5.0, 1.2, 0.8) graph2 = algorithms.do(graph, "UX", 4.0) - assert "loc" in graph.get_node("X").parameters - assert "loc" not in graph.get_node("X").constant_parameters - assert "loc" not in graph2.get_node("X").parameters - assert "loc" in graph2.get_node("X").constant_parameters + assert "loc" in graph.get_node("X").parents + assert "loc" in graph2.get_node("X").parents graph.get_node("UX") - with raises_context(KeyError('Node not found with label "UX"')): - graph2.get_node("UX") + assert isinstance(graph2.get_node("UX"), ConstantNode) def test_do_removes_dependencies(two_normal_graph, raises_context): graph = two_normal_graph() graph2 = algorithms.do(graph, "UX", 4.0) - for node in ["UX", "mean", "cov"]: + graph.get_node("UX") + assert isinstance(graph2.get_node("UX"), ConstantNode) + for node in ["mean", "cov"]: graph.get_node(node) with raises_context(KeyError(f'Node not found with label "{node}"')): graph2.get_node(node) @@ -39,7 +38,6 @@ def test_do_edges(two_normal_graph): # Check that correct edges are removed for e in [ - ("UX", "X"), ("mean", "UX"), ("cov", "UX"), ]: @@ -48,6 +46,7 @@ def test_do_edges(two_normal_graph): # Check that correct edges remain for e in [ + ("UX", "X"), ("cov2", "X"), ]: assert e in edges From 43149c55455eec62fe0e342858fce686c7a78640 Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Thu, 25 Jun 2026 08:59:03 +0100 Subject: [PATCH 08/49] Update src/causalprog/graph/node/base.py Co-authored-by: Will Graham <32364977+willGraham01@users.noreply.github.com> --- src/causalprog/graph/node/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/causalprog/graph/node/base.py b/src/causalprog/graph/node/base.py index ce5318d..0f81ffe 100644 --- a/src/causalprog/graph/node/base.py +++ b/src/causalprog/graph/node/base.py @@ -139,7 +139,7 @@ def evaluate( given_values: Values for data nodes and values of parents Returns: - Value of this node given the given values + Value of this node given `given_values`. """ From 6e6826423baeefa469f4a02fc829765ba66a2e92 Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Thu, 25 Jun 2026 08:59:21 +0100 Subject: [PATCH 09/49] Update src/causalprog/graph/node/base.py Co-authored-by: Will Graham <32364977+willGraham01@users.noreply.github.com> --- src/causalprog/graph/node/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/causalprog/graph/node/base.py b/src/causalprog/graph/node/base.py index 0f81ffe..204fe6a 100644 --- a/src/causalprog/graph/node/base.py +++ b/src/causalprog/graph/node/base.py @@ -133,7 +133,7 @@ def evaluate( **given_values: dict[str, float | npt.NDArray[float]], ) -> float | npt.NDArray[float]: """ - Evaluate the node. + Evaluate the node, given evaluations of its precursor nodes. Args: given_values: Values for data nodes and values of parents From 8f64aa42487234569aed2304213dc5cf49798cd5 Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Thu, 25 Jun 2026 09:00:20 +0100 Subject: [PATCH 10/49] Update tests/test_algorithms/test_evaluate.py Co-authored-by: Will Graham <32364977+willGraham01@users.noreply.github.com> --- tests/test_algorithms/test_evaluate.py | 25 ++++++++++++++++++++----- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/tests/test_algorithms/test_evaluate.py b/tests/test_algorithms/test_evaluate.py index a25fe24..2047578 100644 --- a/tests/test_algorithms/test_evaluate.py +++ b/tests/test_algorithms/test_evaluate.py @@ -5,9 +5,24 @@ from causalprog.graph import DataNode -def test_evaluate_node(raises_context): - node = DataNode(label="A") - assert np.isclose(node.evaluate(A=2.0), 2.0) +@pytest.mark.parametrize( + ("node", "kwargs_to_evaluate", "expected_result"), + [ + pytest.param(DataNode(label="A"), {"A": 2.0}, 2.0, id="Evaluate DataNode itself"), + # Likely want a test here for a ComponentNode? + pytest.param(ComponentNode("Parent", 1, label="Child"), {"Parent": np.arange(4)}, 1.0, id="Evaluate ComponentNode, given parent"), + ] +) +def test_evaluate_node(node, kwargs_to_evaluate, expected_result): + assert np.allclose(node.evaluate(**kwargs_to_evaluate), expected_result) - with raises_context(ValueError("Missing input for node: A")): - node.evaluate() +@pytest.mark.parametrize( + ("node", "kwargs_for_evaluate", "expected_error"), + [ + pytest.param(DataNode(label="A"), {}, ValueError("Missing input for node: A"), id="DataNode missing input value"), + pytest.param(DistributionNode(dist=None, label="A"), {}, ValueError("Cannot evaluate a DistributionNode"), id="Attempt to evaluate DistributionNode"), + ] +) +def test_evaluate_node_fail_on_missing_data(node, kwargs_for_evaluate, expected_error, raises_context): + with raises_context(expected_error): + node.evaluate(**kwargs_for_evaluate) From 752005ace0d1133d6500b55aebea81feab8caf62 Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Thu, 25 Jun 2026 09:00:50 +0100 Subject: [PATCH 11/49] Apply suggestion from @mscroggs --- tests/test_algorithms/test_evaluate.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_algorithms/test_evaluate.py b/tests/test_algorithms/test_evaluate.py index 2047578..b6f5b64 100644 --- a/tests/test_algorithms/test_evaluate.py +++ b/tests/test_algorithms/test_evaluate.py @@ -9,7 +9,6 @@ ("node", "kwargs_to_evaluate", "expected_result"), [ pytest.param(DataNode(label="A"), {"A": 2.0}, 2.0, id="Evaluate DataNode itself"), - # Likely want a test here for a ComponentNode? pytest.param(ComponentNode("Parent", 1, label="Child"), {"Parent": np.arange(4)}, 1.0, id="Evaluate ComponentNode, given parent"), ] ) From 7f1db29691ffaa87bd5fd3f443b7264370fe19da Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Thu, 25 Jun 2026 09:12:32 +0100 Subject: [PATCH 12/49] fix tests --- src/causalprog/graph/node/component.py | 8 +++-- tests/test_algorithms/test_evaluate.py | 37 +++++++++++++++++----- tests/test_causal_problem/test_handlers.py | 2 +- 3 files changed, 35 insertions(+), 12 deletions(-) diff --git a/src/causalprog/graph/node/component.py b/src/causalprog/graph/node/component.py index 39aaccd..d9b8c67 100644 --- a/src/causalprog/graph/node/component.py +++ b/src/causalprog/graph/node/component.py @@ -20,7 +20,7 @@ class ComponentNode(Node): def __init__( self, parent_node_label: str, - component: tuple[int, ...], + component: int | tuple[int, ...], *, shape: tuple[int, ...] = (), label: str, @@ -34,7 +34,9 @@ def __init__( label: A unique label to identify the node """ - self._component = component + self._component = ( + (component,) if isinstance(component, int) else tuple(component) + ) self._parent_node_label = parent_node_label super().__init__(shape=shape, label=label, is_distribution=True) @@ -58,7 +60,7 @@ def evaluate( if not isinstance(parent_value, np.ndarray): msg = f"Invalid data in node: {self._parent_node_label}" raise TypeError(msg) - return parent_value[*self.component] + return parent_value[*self._component] @override def copy(self) -> Node: diff --git a/tests/test_algorithms/test_evaluate.py b/tests/test_algorithms/test_evaluate.py index b6f5b64..0a1dcbc 100644 --- a/tests/test_algorithms/test_evaluate.py +++ b/tests/test_algorithms/test_evaluate.py @@ -1,27 +1,48 @@ """Tests for evaluate algorithms.""" import numpy as np +import pytest -from causalprog.graph import DataNode +from causalprog.graph import DataNode, ComponentNode, DistributionNode @pytest.mark.parametrize( ("node", "kwargs_to_evaluate", "expected_result"), [ - pytest.param(DataNode(label="A"), {"A": 2.0}, 2.0, id="Evaluate DataNode itself"), - pytest.param(ComponentNode("Parent", 1, label="Child"), {"Parent": np.arange(4)}, 1.0, id="Evaluate ComponentNode, given parent"), - ] + pytest.param( + DataNode(label="A"), {"A": 2.0}, 2.0, id="Evaluate DataNode itself" + ), + pytest.param( + ComponentNode("Parent", 1, label="Child"), + {"Parent": np.arange(4)}, + 1.0, + id="Evaluate ComponentNode, given parent", + ), + ], ) def test_evaluate_node(node, kwargs_to_evaluate, expected_result): assert np.allclose(node.evaluate(**kwargs_to_evaluate), expected_result) + @pytest.mark.parametrize( ("node", "kwargs_for_evaluate", "expected_error"), [ - pytest.param(DataNode(label="A"), {}, ValueError("Missing input for node: A"), id="DataNode missing input value"), - pytest.param(DistributionNode(dist=None, label="A"), {}, ValueError("Cannot evaluate a DistributionNode"), id="Attempt to evaluate DistributionNode"), - ] + pytest.param( + DataNode(label="A"), + {}, + ValueError("Missing input for node: A"), + id="DataNode missing input value", + ), + pytest.param( + DistributionNode(distribution=None, label="A"), + {}, + RuntimeError("Cannot evaluate a DistributionNode"), + id="Attempt to evaluate DistributionNode", + ), + ], ) -def test_evaluate_node_fail_on_missing_data(node, kwargs_for_evaluate, expected_error, raises_context): +def test_evaluate_node_fail_on_missing_data( + node, kwargs_for_evaluate, expected_error, raises_context +): with raises_context(expected_error): node.evaluate(**kwargs_for_evaluate) diff --git a/tests/test_causal_problem/test_handlers.py b/tests/test_causal_problem/test_handlers.py index 3801201..494e234 100644 --- a/tests/test_causal_problem/test_handlers.py +++ b/tests/test_causal_problem/test_handlers.py @@ -6,7 +6,7 @@ def placeholder_callable() -> EffectHandler: """Stand-in for an effect handler.""" - return lambda model, **kwargs: (lambda **pv: model(**kwargs)) + return lambda model, **kwargs: lambda **pv: model(**kwargs) @pytest.mark.parametrize( From 5ed79b5911ea68d7322b495c0ad11567210d88da Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Thu, 25 Jun 2026 09:19:12 +0100 Subject: [PATCH 13/49] ruff --- src/causalprog/algorithms/moments.py | 2 +- src/causalprog/graph/node/distribution.py | 4 ++-- src/causalprog/utils/translator.py | 4 +--- tests/test_algorithms/test_evaluate.py | 2 +- 4 files changed, 5 insertions(+), 7 deletions(-) diff --git a/src/causalprog/algorithms/moments.py b/src/causalprog/algorithms/moments.py index d6f9801..6ce14ef 100644 --- a/src/causalprog/algorithms/moments.py +++ b/src/causalprog/algorithms/moments.py @@ -22,7 +22,7 @@ def sample( for node, key in zip(nodes, keys, strict=False): values[node.label] = node.sample( - parameter_values if parameter_values else {}, + parameter_values or {}, values, samples, rng_key=key, diff --git a/src/causalprog/graph/node/distribution.py b/src/causalprog/graph/node/distribution.py index 18f5abf..9e9f2a3 100644 --- a/src/causalprog/graph/node/distribution.py +++ b/src/causalprog/graph/node/distribution.py @@ -38,8 +38,8 @@ def __init__( """ self._dist = distribution - self._constant_parameters = constant_parameters if constant_parameters else {} - self._parameters = parameters if parameters else {} + self._constant_parameters = constant_parameters or {} + self._parameters = parameters or {} super().__init__(label=label, shape=shape, is_distribution=True) @override diff --git a/src/causalprog/utils/translator.py b/src/causalprog/utils/translator.py index 0e26eba..8e6e0bd 100644 --- a/src/causalprog/utils/translator.py +++ b/src/causalprog/utils/translator.py @@ -72,9 +72,7 @@ def __init__( """ # Assume backend name is identical to frontend name if not provided explicitly - self.backend_method = ( - backend_method if backend_method else self._frontend_method - ) + self.backend_method = backend_method or self._frontend_method # This should really be immutable after we fill defaults! self.corresponding_backend_arg = dict(front_args_to_back_args) diff --git a/tests/test_algorithms/test_evaluate.py b/tests/test_algorithms/test_evaluate.py index 0a1dcbc..6475c8f 100644 --- a/tests/test_algorithms/test_evaluate.py +++ b/tests/test_algorithms/test_evaluate.py @@ -3,7 +3,7 @@ import numpy as np import pytest -from causalprog.graph import DataNode, ComponentNode, DistributionNode +from causalprog.graph import ComponentNode, DataNode, DistributionNode @pytest.mark.parametrize( From bb1b0f3aea1412348fbb77c163a157598ce7d0f9 Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Thu, 25 Jun 2026 09:32:48 +0100 Subject: [PATCH 14/49] hasattr(..., "shape") instead of type check --- src/causalprog/graph/node/data.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/src/causalprog/graph/node/data.py b/src/causalprog/graph/node/data.py index dbe0d43..9b93df9 100644 --- a/src/causalprog/graph/node/data.py +++ b/src/causalprog/graph/node/data.py @@ -2,7 +2,6 @@ import jax import jax.numpy as jnp -import numpy as np import numpy.typing as npt from typing_extensions import override @@ -51,12 +50,8 @@ def evaluate( msg = f"Missing input for node: {self.label}." raise ValueError(msg) value = given_values[self.label] - if self.shape == (): - if not isinstance(value, float): - msg = f"Invalid value for note: {self.label}" - raise ValueError(msg) - elif not isinstance(value, np.ndarray) or self.shape != value.shape: - msg = f"Invalid value for nose: {self.label}" + if self.shape != (value.shape if hasattr(value, "shape") else ()): + msg = f"Invalid value for node: {self.label}" raise ValueError(msg) return value From 2c23e019bdc433fd61f8ffb956928ab7d92d046e Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Thu, 25 Jun 2026 09:34:29 +0100 Subject: [PATCH 15/49] let KeyError happen rather than asserting types --- src/causalprog/graph/node/component.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/causalprog/graph/node/component.py b/src/causalprog/graph/node/component.py index d9b8c67..42fba2c 100644 --- a/src/causalprog/graph/node/component.py +++ b/src/causalprog/graph/node/component.py @@ -56,11 +56,7 @@ def evaluate( self, **given_values: dict[str, float | npt.NDArray[float]], ) -> float | npt.NDArray[float]: - parent_value = given_values[self._parent_node_label] - if not isinstance(parent_value, np.ndarray): - msg = f"Invalid data in node: {self._parent_node_label}" - raise TypeError(msg) - return parent_value[*self._component] + return given_values[self._parent_node_label][*self._component] @override def copy(self) -> Node: From 2815ff876a2c6b3cde1e49dae232dd8c641bdba9 Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Thu, 25 Jun 2026 09:40:13 +0100 Subject: [PATCH 16/49] mypy? --- src/causalprog/graph/node/component.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/causalprog/graph/node/component.py b/src/causalprog/graph/node/component.py index 42fba2c..be7634e 100644 --- a/src/causalprog/graph/node/component.py +++ b/src/causalprog/graph/node/component.py @@ -56,7 +56,8 @@ def evaluate( self, **given_values: dict[str, float | npt.NDArray[float]], ) -> float | npt.NDArray[float]: - return given_values[self._parent_node_label][*self._component] + parent_value = given_values[self._parent_node_label] + return parent_value[*self._component] @override def copy(self) -> Node: From 42a28449648133cc50d0858944373669e0a64835 Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Thu, 25 Jun 2026 09:42:28 +0100 Subject: [PATCH 17/49] currect type --- src/causalprog/graph/node/base.py | 2 +- src/causalprog/graph/node/component.py | 5 ++--- src/causalprog/graph/node/data.py | 2 +- src/causalprog/graph/node/distribution.py | 2 +- 4 files changed, 5 insertions(+), 6 deletions(-) diff --git a/src/causalprog/graph/node/base.py b/src/causalprog/graph/node/base.py index 204fe6a..ffbfb46 100644 --- a/src/causalprog/graph/node/base.py +++ b/src/causalprog/graph/node/base.py @@ -130,7 +130,7 @@ def sample( @abstractmethod def evaluate( self, - **given_values: dict[str, float | npt.NDArray[float]], + **given_values: float | npt.NDArray[float], ) -> float | npt.NDArray[float]: """ Evaluate the node, given evaluations of its precursor nodes. diff --git a/src/causalprog/graph/node/component.py b/src/causalprog/graph/node/component.py index be7634e..607b474 100644 --- a/src/causalprog/graph/node/component.py +++ b/src/causalprog/graph/node/component.py @@ -54,10 +54,9 @@ def sample( @override def evaluate( self, - **given_values: dict[str, float | npt.NDArray[float]], + **given_values: float | npt.NDArray[float], ) -> float | npt.NDArray[float]: - parent_value = given_values[self._parent_node_label] - return parent_value[*self._component] + return given_values[self._parent_node_label][*self._component] @override def copy(self) -> Node: diff --git a/src/causalprog/graph/node/data.py b/src/causalprog/graph/node/data.py index 9b93df9..29a5c69 100644 --- a/src/causalprog/graph/node/data.py +++ b/src/causalprog/graph/node/data.py @@ -44,7 +44,7 @@ def sample( @override def evaluate( self, - **given_values: dict[str, float | npt.NDArray[float]], + **given_values: float | npt.NDArray[float], ) -> float | npt.NDArray[float]: if self.label not in given_values: msg = f"Missing input for node: {self.label}." diff --git a/src/causalprog/graph/node/distribution.py b/src/causalprog/graph/node/distribution.py index 9e9f2a3..811a801 100644 --- a/src/causalprog/graph/node/distribution.py +++ b/src/causalprog/graph/node/distribution.py @@ -73,7 +73,7 @@ def sample( @override def evaluate( self, - **given_values: dict[str, float | npt.NDArray[float]], + **given_values: float | npt.NDArray[float], ) -> float | npt.NDArray[float]: msg = "Cannot evaluate a DistributionNode" raise RuntimeError(msg) From 2e1eb1b64e5ee17e7572c67d3eeb103bf3633d64 Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Thu, 25 Jun 2026 09:44:01 +0100 Subject: [PATCH 18/49] type: ignore --- src/causalprog/graph/node/component.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/causalprog/graph/node/component.py b/src/causalprog/graph/node/component.py index 607b474..b401d64 100644 --- a/src/causalprog/graph/node/component.py +++ b/src/causalprog/graph/node/component.py @@ -4,7 +4,6 @@ import typing -import numpy as np from typing_extensions import override from .base import Node @@ -56,7 +55,8 @@ def evaluate( self, **given_values: float | npt.NDArray[float], ) -> float | npt.NDArray[float]: - return given_values[self._parent_node_label][*self._component] + parent_value = given_values[self._parent_node_label] + return parent_value[*self._component] # type: ignore @override def copy(self) -> Node: From dae7e6bdd77d08f1c4cef946c2aae8d025321338 Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Thu, 25 Jun 2026 09:46:56 +0100 Subject: [PATCH 19/49] PGH003 --- src/causalprog/graph/node/component.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/causalprog/graph/node/component.py b/src/causalprog/graph/node/component.py index b401d64..b774949 100644 --- a/src/causalprog/graph/node/component.py +++ b/src/causalprog/graph/node/component.py @@ -56,7 +56,7 @@ def evaluate( **given_values: float | npt.NDArray[float], ) -> float | npt.NDArray[float]: parent_value = given_values[self._parent_node_label] - return parent_value[*self._component] # type: ignore + return parent_value[*self._component] # type: ignore[index] @override def copy(self) -> Node: From 5717c16b119a530d22ac6998cf0ba0989b968805 Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Thu, 25 Jun 2026 09:50:07 +0100 Subject: [PATCH 20/49] typing --- src/causalprog/graph/node/constant.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/causalprog/graph/node/constant.py b/src/causalprog/graph/node/constant.py index 6856142..3b5bb52 100644 --- a/src/causalprog/graph/node/constant.py +++ b/src/causalprog/graph/node/constant.py @@ -44,7 +44,7 @@ def sample( @override def evaluate( self, - **given_values: dict[str, float | npt.NDArray[float]], + **given_values: float | npt.NDArray[float], ) -> float | npt.NDArray[float]: return self._value From 209b86e238fb837c4bd35dad32bf29deba5f7119 Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Thu, 25 Jun 2026 09:55:08 +0100 Subject: [PATCH 21/49] corrections --- src/causalprog/graph/node/constant.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/causalprog/graph/node/constant.py b/src/causalprog/graph/node/constant.py index 3b5bb52..f49a2d3 100644 --- a/src/causalprog/graph/node/constant.py +++ b/src/causalprog/graph/node/constant.py @@ -4,6 +4,7 @@ import typing +import jax.numpy as jnp from typing_extensions import override from .base import Node @@ -16,7 +17,7 @@ class ConstantNode(Node): """A node representing a constant.""" - def __init__(self, *, label: str, value: flat | npt.NDArray[float]) -> None: + def __init__(self, *, label: str, value: float | npt.NDArray[float]) -> None: """ Initialise. @@ -54,7 +55,7 @@ def copy(self) -> Node: @override def __repr__(self) -> str: - r = f"ConstantNode({self._value})" + return f"ConstantNode({self._value})" @override @property From 06d0f69cf7a55e94b8f8360ca473216198efaa7c Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Thu, 25 Jun 2026 10:37:59 +0100 Subject: [PATCH 22/49] add check that labels are valid variable names --- src/causalprog/_abc/labelled.py | 3 +++ tests/test__abc/test_labelled.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 32 insertions(+) create mode 100644 tests/test__abc/test_labelled.py diff --git a/src/causalprog/_abc/labelled.py b/src/causalprog/_abc/labelled.py index 92250bc..b62770e 100644 --- a/src/causalprog/_abc/labelled.py +++ b/src/causalprog/_abc/labelled.py @@ -20,4 +20,7 @@ def label(self) -> str: return self._label def __init__(self, *, label: str) -> None: + if not str.isidentifier(label): + msg = f"Label is not valid Python variable name: {label}" + raise ValueError(msg) self._label = str(label) diff --git a/tests/test__abc/test_labelled.py b/tests/test__abc/test_labelled.py new file mode 100644 index 0000000..10c469d --- /dev/null +++ b/tests/test__abc/test_labelled.py @@ -0,0 +1,29 @@ +import pytest +from causalprog._abc.labelled import Labelled + + +@pytest.mark.parametrize("label", [ + "1", + " ", + "a b", + "0a", + "a.b", + "a-b", + "a+b", + "a*b", + "a/b", +]) +def test_invalid_label(label, raises_context): + with raises_context(ValueError("Label is not valid Python variable name")): + Labelled(label=label) + + +@pytest.mark.parametrize("label", [ + "a", + "A", + "a0", + "a_b", + "_a", +]) +def test_valid_label(label, raises_context): + Labelled(label=label) From 7700ea5ba21a24edd2849dce3be412ae95fee16f Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Thu, 25 Jun 2026 10:38:31 +0100 Subject: [PATCH 23/49] ruff --- tests/test__abc/test_labelled.py | 43 +++++++++++++++++++------------- 1 file changed, 25 insertions(+), 18 deletions(-) diff --git a/tests/test__abc/test_labelled.py b/tests/test__abc/test_labelled.py index 10c469d..6d53372 100644 --- a/tests/test__abc/test_labelled.py +++ b/tests/test__abc/test_labelled.py @@ -1,29 +1,36 @@ import pytest + from causalprog._abc.labelled import Labelled -@pytest.mark.parametrize("label", [ - "1", - " ", - "a b", - "0a", - "a.b", - "a-b", - "a+b", - "a*b", - "a/b", -]) +@pytest.mark.parametrize( + "label", + [ + "1", + " ", + "a b", + "0a", + "a.b", + "a-b", + "a+b", + "a*b", + "a/b", + ], +) def test_invalid_label(label, raises_context): with raises_context(ValueError("Label is not valid Python variable name")): Labelled(label=label) -@pytest.mark.parametrize("label", [ - "a", - "A", - "a0", - "a_b", - "_a", -]) +@pytest.mark.parametrize( + "label", + [ + "a", + "A", + "a0", + "a_b", + "_a", + ], +) def test_valid_label(label, raises_context): Labelled(label=label) From 29f7679d8a2b12b0b850b334455d4449754881d0 Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Thu, 25 Jun 2026 10:50:28 +0100 Subject: [PATCH 24/49] update tests and labels --- src/causalprog/graph/node/base.py | 8 ++++---- tests/test_graph/test_component.py | 6 +++--- tests/test_graph/test_model.py | 2 +- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/causalprog/graph/node/base.py b/src/causalprog/graph/node/base.py index ee6a89a..e7d4f54 100644 --- a/src/causalprog/graph/node/base.py +++ b/src/causalprog/graph/node/base.py @@ -15,18 +15,18 @@ def _to_string(indices: int | slice | tuple[int | slice, ...]) -> str: """Convert getitem indices to a string.""" if isinstance(indices, tuple): - return ", ".join(_to_string(i) for i in indices) + return "_".join(_to_string(i) for i in indices) if isinstance(indices, int): return f"{indices}" if isinstance(indices, slice): s = "" if indices.start is not None: s += f"{indices.start}" - s += ":" + s += "__" if indices.stop is not None: s += f"{indices.stop}" if indices.step is not None: - s += f":{indices.step}" + s += f"__{indices.step}" return s e = f"Invalid indices: {indices}" raise TypeError(e) @@ -91,7 +91,7 @@ def __getitem__(self, indices: int | slice | tuple[int | slice, ...]) -> Node: self.label, indices, shape=shape, - label=f"{self.label}[{_to_string(indices)}]", + label=f"{self.label}_{_to_string(indices)}", ) @abstractmethod diff --git a/tests/test_graph/test_component.py b/tests/test_graph/test_component.py index 9a417b8..3c6b356 100644 --- a/tests/test_graph/test_component.py +++ b/tests/test_graph/test_component.py @@ -26,6 +26,6 @@ def test_component_node(param_values): graph.add_node(graph.get_node("X")[1, :, 2]) assert graph.get_node("X").shape == (3, 1, 4) - assert graph.get_node("X[0]").shape == (1, 4) - assert graph.get_node("X[1, 0, 2]").shape == () - assert graph.get_node("X[1, :, 2]").shape == (1,) + assert graph.get_node("X_0").shape == (1, 4) + assert graph.get_node("X_1_0_2").shape == () + assert graph.get_node("X_1____2").shape == (1,) diff --git a/tests/test_graph/test_model.py b/tests/test_graph/test_model.py index 7a18b0d..1f3e38d 100644 --- a/tests/test_graph/test_model.py +++ b/tests/test_graph/test_model.py @@ -84,7 +84,7 @@ def test_model_extension( parameters={"loc": "mean"}, constant_parameters={"scale": 1.0}, ) - one_normal_graph = Graph(label="One normal") + one_normal_graph = Graph(label="One_normal") one_normal_graph.add_edge(mean, x) def extended_model(*, cov2, **parameter_values): From ac099fea157563a5dcf6f642190a7c51c1287073 Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Thu, 25 Jun 2026 14:07:38 +0100 Subject: [PATCH 25/49] Add random variable nodes --- src/causalprog/graph/__init__.py | 10 +- src/causalprog/graph/node/__init__.py | 1 + src/causalprog/graph/node/random_variables.py | 135 ++++++++++++++++++ tests/test_graph/test_random_variables.py | 35 +++++ 4 files changed, 180 insertions(+), 1 deletion(-) create mode 100644 src/causalprog/graph/node/random_variables.py create mode 100644 tests/test_graph/test_random_variables.py diff --git a/src/causalprog/graph/__init__.py b/src/causalprog/graph/__init__.py index 0825556..02eb604 100644 --- a/src/causalprog/graph/__init__.py +++ b/src/causalprog/graph/__init__.py @@ -1,4 +1,12 @@ """Creation and storage of graphs.""" from .graph import Graph -from .node import ComponentNode, ConstantNode, DataNode, DistributionNode, Node +from .node import ( + ComponentNode, + ConstantNode, + ContinuousRandomVariableNode, + DataNode, + DiscreteRandomVariableNode, + DistributionNode, + Node, +) diff --git a/src/causalprog/graph/node/__init__.py b/src/causalprog/graph/node/__init__.py index 1698224..a6acf66 100644 --- a/src/causalprog/graph/node/__init__.py +++ b/src/causalprog/graph/node/__init__.py @@ -5,3 +5,4 @@ from .constant import ConstantNode from .data import DataNode from .distribution import DistributionNode +from .random_variables import ContinuousRandomVariableNode, DiscreteRandomVariableNode diff --git a/src/causalprog/graph/node/random_variables.py b/src/causalprog/graph/node/random_variables.py new file mode 100644 index 0000000..9b9d580 --- /dev/null +++ b/src/causalprog/graph/node/random_variables.py @@ -0,0 +1,135 @@ +"""Graph nodes representing random variables.""" + +import inspect +import typing +from abc import abstractmethod + +import jax +import numpy as np +import numpy.typing as npt +from typing_extensions import override + +from .base import Node + + +class RandomVariableNode(Node): + """A node containing a random variable (RV).""" + + def __init__( + self, + *, + shape: tuple[int, ...] = (), + label: str, + compute: typing.Callable | None = None, + ) -> None: + """ + Initialise. + + Args: + shape: The shape of the output of the RV + label: A unique label to identify the node + compute: A function to compute this node's value from given values of its parent nodes + + """ + super().__init__(label=label, shape=shape) + if compute is None: + self._parents = [] + else: + self._parents = list(inspect.signature(compute).parameters.keys()) + self._compute = compute + + @override + def sample( + self, + parameter_values: dict[str, float], + sampled_dependencies: dict[str, npt.ArrayLike], + samples: int, + *, + rng_key: jax.Array, + ) -> npt.ArrayLike: + raise NotImplementedError + + @override + def evaluate( + self, + **given_values: float | npt.NDArray[float], + ) -> float | npt.NDArray[float]: + if self.label in given_values: + value = given_values[self.label] + if not self.is_valid_value(value): + msg = f"Invalid value for {self.__class__.__name__}: {self.label} cannot be {value}" + raise ValueError(msg) + if self.shape != (value.shape if hasattr(value, "shape") else ()): + msg = f"Invalid value for node: {self.label}" + raise ValueError(msg) + return value + + if self._compute is None: + msg = f"Missing input for node: {self.label}." + raise ValueError(msg) + return self._compute(**{p: given_values[p] for p in self._parents}) + + @override + def copy(self) -> Node: + return DataNode(label=self.label) + + @override + @property + def parents(self) -> list[str]: + return self._parents + + @abstractmethod + def is_valid_value(self, value: float | npt.NDArray[float]): + """Check if a value is valid for this node.""" + + +class ContinuousRandomVariableNode(RandomVariableNode): + """A node containing a continuous random variable (RV).""" + + @override + def __repr__(self) -> str: + return f'ContinuousRandomVariableNode(label="{self.label}")' + + @override + def is_valid_value(self, value: float | npt.NDArray[float]) -> bool: + return True + + +class DiscreteRandomVariableNode(RandomVariableNode): + """A node containing a discrete random variable (RV).""" + + def __init__( + self, + *, + values: list[float] | list[npt.NDArray[float]], + shape: tuple[int, ...] = (), + label: str, + compute: typing.Callable | None = None, + ) -> None: + """ + Initialise. + + Args: + shape: The shape of the output of the RV + label: A unique label to identify the node + compute: A function to compute this node's value from given values of its parent nodes + + """ + super().__init__(label=label, shape=shape, compute=compute) + self._values = values + + @property + def possible_values(self) -> list[float] | list[npt.NDArray[float]]: + """The values that this RV can take.""" + return self._values + + @override + def __repr__(self) -> str: + return f'DiscreteRandomVariableNode(label="{self.label}")' + + @override + def is_valid_value(self, value: float | npt.NDArray[float]) -> bool: + for v in self._values: + if np.allclose(v, value): + return True + return False diff --git a/tests/test_graph/test_random_variables.py b/tests/test_graph/test_random_variables.py new file mode 100644 index 0000000..e2fabd3 --- /dev/null +++ b/tests/test_graph/test_random_variables.py @@ -0,0 +1,35 @@ +import numpy as np + +from causalprog.graph import ( + ContinuousRandomVariableNode, + DiscreteRandomVariableNode, + Graph, +) + + +def test_random_variable_node(): + node = ContinuousRandomVariableNode(label="X") + assert np.isclose(node.evaluate(X=2.0), 2.0) + + +def test_missing_input(raises_context): + node = ContinuousRandomVariableNode(label="X") + + with raises_context(ValueError("Missing input for node")): + node.evaluate() + + +def test_invalid_discrete_node_value(raises_context): + node = DiscreteRandomVariableNode(label="Y", values=[-0.5, 0.0, 0.5]) + with raises_context(ValueError("Invalid value for DiscreteRandomVariableNode")): + node.evaluate(Y=-1.0) + + +def test_sum_parents(): + graph = Graph(label="G") + + graph.add_node(ContinuousRandomVariableNode(label="X")) + graph.add_node(DiscreteRandomVariableNode(label="Y", values=[-0.5, 0.0, 0.5])) + graph.add_node(ContinuousRandomVariableNode(label="Z", compute=lambda X, Y: X + Y)) + + assert np.isclose(graph.get_node("Z").evaluate(X=3.0, Y=-0.5), 2.5) From 30c0216ce5c2a82ebd075fddf62b801b6b1aabeb Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Thu, 25 Jun 2026 14:30:48 +0100 Subject: [PATCH 26/49] ruff --- src/causalprog/algorithms/do.py | 4 +-- src/causalprog/graph/node/random_variables.py | 28 ++++++++++++++----- .../test_random_variables.py | 8 +++--- tests/test_node/test_vector_nodes.py | 2 +- 4 files changed, 28 insertions(+), 14 deletions(-) rename tests/{test_graph => test_node}/test_random_variables.py (73%) diff --git a/src/causalprog/algorithms/do.py b/src/causalprog/algorithms/do.py index 9b53b97..9c0d6f9 100644 --- a/src/causalprog/algorithms/do.py +++ b/src/causalprog/algorithms/do.py @@ -68,7 +68,7 @@ def do(graph: Graph, node: str, value: float, label: str | None = None) -> Graph """ if label is None: - label = f"{graph.label}|do({node}={value})" + label = f"{graph.label}_do_{node}__" + f"{value}".replace(".", "_") nodes = {n.label: deepcopy(n) for n in graph.nodes if n.label != node} @@ -92,7 +92,7 @@ def do(graph: Graph, node: str, value: float, label: str | None = None) -> Graph nodes[node] = ConstantNode(label=node, value=value) - g = Graph(label=f"{label}|do[{node}={value}]") + g = Graph(label=f"{label}_do_{node}__" + f"{value}".replace(".", "_")) for n in nodes.values(): g.add_node(n) diff --git a/src/causalprog/graph/node/random_variables.py b/src/causalprog/graph/node/random_variables.py index 9b9d580..79fd619 100644 --- a/src/causalprog/graph/node/random_variables.py +++ b/src/causalprog/graph/node/random_variables.py @@ -28,7 +28,7 @@ def __init__( Args: shape: The shape of the output of the RV label: A unique label to identify the node - compute: A function to compute this node's value from given values of its parent nodes + compute: A function to compute node's value from given values of parents """ super().__init__(label=label, shape=shape) @@ -57,7 +57,10 @@ def evaluate( if self.label in given_values: value = given_values[self.label] if not self.is_valid_value(value): - msg = f"Invalid value for {self.__class__.__name__}: {self.label} cannot be {value}" + msg = ( + f"Invalid value for {self.__class__.__name__}: " + f"{self.label} cannot be {value}" + ) raise ValueError(msg) if self.shape != (value.shape if hasattr(value, "shape") else ()): msg = f"Invalid value for node: {self.label}" @@ -69,10 +72,6 @@ def evaluate( raise ValueError(msg) return self._compute(**{p: given_values[p] for p in self._parents}) - @override - def copy(self) -> Node: - return DataNode(label=self.label) - @override @property def parents(self) -> list[str]: @@ -94,6 +93,12 @@ def __repr__(self) -> str: def is_valid_value(self, value: float | npt.NDArray[float]) -> bool: return True + @override + def copy(self) -> Node: + return ContinuousRandomVariableNode( + shape=self.shape, label=self.label, compute=self._compute + ) + class DiscreteRandomVariableNode(RandomVariableNode): """A node containing a discrete random variable (RV).""" @@ -112,7 +117,7 @@ def __init__( Args: shape: The shape of the output of the RV label: A unique label to identify the node - compute: A function to compute this node's value from given values of its parent nodes + compute: A function to compute node's value from given values of parents """ super().__init__(label=label, shape=shape, compute=compute) @@ -133,3 +138,12 @@ def is_valid_value(self, value: float | npt.NDArray[float]) -> bool: if np.allclose(v, value): return True return False + + @override + def copy(self) -> Node: + return DiscreteRandomVariableNode( + values=self._values, + shape=self.shape, + label=self.label, + compute=self._compute, + ) diff --git a/tests/test_graph/test_random_variables.py b/tests/test_node/test_random_variables.py similarity index 73% rename from tests/test_graph/test_random_variables.py rename to tests/test_node/test_random_variables.py index e2fabd3..0f87d96 100644 --- a/tests/test_graph/test_random_variables.py +++ b/tests/test_node/test_random_variables.py @@ -28,8 +28,8 @@ def test_invalid_discrete_node_value(raises_context): def test_sum_parents(): graph = Graph(label="G") - graph.add_node(ContinuousRandomVariableNode(label="X")) - graph.add_node(DiscreteRandomVariableNode(label="Y", values=[-0.5, 0.0, 0.5])) - graph.add_node(ContinuousRandomVariableNode(label="Z", compute=lambda X, Y: X + Y)) + graph.add_node(ContinuousRandomVariableNode(label="x")) + graph.add_node(DiscreteRandomVariableNode(label="y", values=[-0.5, 0.0, 0.5])) + graph.add_node(ContinuousRandomVariableNode(label="z", compute=lambda x, y: x + y)) - assert np.isclose(graph.get_node("Z").evaluate(X=3.0, Y=-0.5), 2.5) + assert np.isclose(graph.get_node("z").evaluate(x=3.0, y=-0.5), 2.5) diff --git a/tests/test_node/test_vector_nodes.py b/tests/test_node/test_vector_nodes.py index 39230ef..44e848e 100644 --- a/tests/test_node/test_vector_nodes.py +++ b/tests/test_node/test_vector_nodes.py @@ -45,6 +45,6 @@ def test_get_component_node(rng_key): assert s0.shape == (100, 5) assert node[:, 1].sample({}, {"X": s}, 100, rng_key=rng_key).shape == (100, 4) assert node[0, 1].sample({}, {"X": s}, 100, rng_key=rng_key).shape == (100,) - assert node[0][1].sample({}, {"X": s, "X[0]": s0}, 100, rng_key=rng_key).shape == ( + assert node[0][1].sample({}, {"X": s, "X_0": s0}, 100, rng_key=rng_key).shape == ( 100, ) From 13ad54756635e40d6d5ebd6085ac08fa5f8518e0 Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Thu, 25 Jun 2026 14:46:25 +0100 Subject: [PATCH 27/49] ruff --- src/causalprog/graph/node/random_variables.py | 7 ++----- tests/test__abc/test_labelled.py | 2 +- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/src/causalprog/graph/node/random_variables.py b/src/causalprog/graph/node/random_variables.py index 79fd619..81e8b06 100644 --- a/src/causalprog/graph/node/random_variables.py +++ b/src/causalprog/graph/node/random_variables.py @@ -78,7 +78,7 @@ def parents(self) -> list[str]: return self._parents @abstractmethod - def is_valid_value(self, value: float | npt.NDArray[float]): + def is_valid_value(self, value: float | npt.NDArray[float]) -> bool: """Check if a value is valid for this node.""" @@ -134,10 +134,7 @@ def __repr__(self) -> str: @override def is_valid_value(self, value: float | npt.NDArray[float]) -> bool: - for v in self._values: - if np.allclose(v, value): - return True - return False + return any(np.allclose(v, value) for v in self._values) @override def copy(self) -> Node: diff --git a/tests/test__abc/test_labelled.py b/tests/test__abc/test_labelled.py index 6d53372..3f17704 100644 --- a/tests/test__abc/test_labelled.py +++ b/tests/test__abc/test_labelled.py @@ -32,5 +32,5 @@ def test_invalid_label(label, raises_context): "_a", ], ) -def test_valid_label(label, raises_context): +def test_valid_label(label): Labelled(label=label) From 76f7fb1d738bca0222145c929e7b5019509e3dc0 Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Fri, 26 Jun 2026 08:37:57 +0100 Subject: [PATCH 28/49] add evaluate algorithm --- src/causalprog/algorithms/__init__.py | 1 + src/causalprog/algorithms/evaluate.py | 26 ++++++++++++++++++++ src/causalprog/graph/graph.py | 9 +++---- tests/test_algorithms/test_evaluate.py | 33 +++++++++++++++++++++++++- 4 files changed, 62 insertions(+), 7 deletions(-) create mode 100644 src/causalprog/algorithms/evaluate.py diff --git a/src/causalprog/algorithms/__init__.py b/src/causalprog/algorithms/__init__.py index da37e9a..d7b26db 100644 --- a/src/causalprog/algorithms/__init__.py +++ b/src/causalprog/algorithms/__init__.py @@ -1,4 +1,5 @@ """Algorithms.""" from .do import do +from .evaluate import evaluate from .moments import expectation, moment, standard_deviation diff --git a/src/causalprog/algorithms/evaluate.py b/src/causalprog/algorithms/evaluate.py new file mode 100644 index 0000000..d3ae780 --- /dev/null +++ b/src/causalprog/algorithms/evaluate.py @@ -0,0 +1,26 @@ +"""Algorithms for evaluating a graph node.""" + +import numpy.typing as npt + +from causalprog.graph import Graph + + +def evaluate( + graph: Graph, outcome_node_label: str, **values: float | npt.NDArray[float] +) -> float | npt.NDArray[float]: + """ + Evaluate a node. + + Args: + graph: The graph that the node is contained in. + outcome_node_label: The label of the node to evaluate. + values: Values taken by nodes whose value is given + + Returns: + The evaluation of the node + + """ + for node in graph.roots_down_to_outcome(outcome_node_label): + if node.label not in values: + values[node.label] = node.evaluate(**values) + return values[outcome_node_label] diff --git a/src/causalprog/graph/graph.py b/src/causalprog/graph/graph.py index 60c2f42..5b2e179 100644 --- a/src/causalprog/graph/graph.py +++ b/src/causalprog/graph/graph.py @@ -4,7 +4,7 @@ import numpy.typing as npt from causalprog._abc.labelled import Labelled -from causalprog.graph.node import ComponentNode, DistributionNode, Node +from causalprog.graph.node import DistributionNode, Node class Graph(Labelled): @@ -60,11 +60,8 @@ def add_node(self, node: Node) -> None: raise ValueError(msg) self._nodes_by_label[node.label] = node self._graph.add_node(node) - if isinstance(node, ComponentNode): - if len(node.parents) != 1: - msg = "ComponentNode should have exactly one parent." - raise ValueError(msg) - self.add_edge(node.parents[0], node.label) + for p in node.parents: + self.add_edge(p, node.label) def add_edge(self, start_node: Node | str, end_node: Node | str) -> None: """ diff --git a/tests/test_algorithms/test_evaluate.py b/tests/test_algorithms/test_evaluate.py index 6475c8f..bbcc394 100644 --- a/tests/test_algorithms/test_evaluate.py +++ b/tests/test_algorithms/test_evaluate.py @@ -3,7 +3,14 @@ import numpy as np import pytest -from causalprog.graph import ComponentNode, DataNode, DistributionNode +from causalprog.algorithms import evaluate +from causalprog.graph import ( + ComponentNode, + ContinuousRandomVariableNode, + DataNode, + DistributionNode, + Graph, +) @pytest.mark.parametrize( @@ -46,3 +53,27 @@ def test_evaluate_node_fail_on_missing_data( ): with raises_context(expected_error): node.evaluate(**kwargs_for_evaluate) + + +def test_evaluate_algorithm_three_node(): + graph = Graph(label="g") + graph.add_node(DataNode(label="a")) + graph.add_node(DataNode(label="b")) + graph.add_node( + ContinuousRandomVariableNode(label="x", compute=lambda a, b: a + 2.0 * b) + ) + + assert np.isclose(evaluate(graph, "x", a=2.0, b=1.5), 5.0) + + +def test_evaluate_algorithm_four_node(): + graph = Graph(label="g") + graph.add_node(DataNode(label="a")) + graph.add_node(DataNode(label="b")) + graph.add_node(ContinuousRandomVariableNode(label="c", compute=lambda a: a - 0.5)) + graph.add_node( + ContinuousRandomVariableNode(label="x", compute=lambda b, c: c + 2.0 * b) + ) + + assert np.isclose(evaluate(graph, "c", a=2.0, b=1.5), 1.5) + assert np.isclose(evaluate(graph, "x", a=2.0, b=1.5), 4.5) From ad4ccb771dadf145282465762f39f1806b7f6a5c Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Fri, 26 Jun 2026 09:19:43 +0100 Subject: [PATCH 29/49] start adding example graph --- src/causalprog/utils/graphs.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) create mode 100644 src/causalprog/utils/graphs.py diff --git a/src/causalprog/utils/graphs.py b/src/causalprog/utils/graphs.py new file mode 100644 index 0000000..bff95af --- /dev/null +++ b/src/causalprog/utils/graphs.py @@ -0,0 +1,27 @@ +from causalprog.graph import Graph, DataNode, DiscreteRandomVariableNode + + +def example_model( + label: str = "G", + l_len: int = 1, + z_len: int = 1, + k: int = 10, +) -> Graph: + """Create a graph representing the example model. + + Args: + label: The label of the graph. + l_len: The number of entries in the vector data node L. + z_len: The number of entries in the vector data node Z. + k: The maximum value that could be taken by the mixture indicator C + + Returns: + A graph + """ + graph = Graph(label=label) + + graph.add_node(DataNode(label="L", shape=(l_len, ))) + graph.add_node(DataNode(label="Z", shape=(z_len, ))) + graph.add_node(DiscreteRandomVariableNode(label="C", values=list(range(1, k + 1))) + + return graph From dc7b6432f7aa0f432691962b9908aa48ca786af7 Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Fri, 26 Jun 2026 10:01:05 +0100 Subject: [PATCH 30/49] add ricardo's graph (!) --- src/causalprog/utils/graphs.py | 65 ++++++++++++++++++++++++++++++---- 1 file changed, 59 insertions(+), 6 deletions(-) diff --git a/src/causalprog/utils/graphs.py b/src/causalprog/utils/graphs.py index bff95af..b8b29df 100644 --- a/src/causalprog/utils/graphs.py +++ b/src/causalprog/utils/graphs.py @@ -1,27 +1,80 @@ -from causalprog.graph import Graph, DataNode, DiscreteRandomVariableNode +"""Functions to create example graphs.""" + +import inspect +from collections.abc import Callable + +from causalprog.graph import ( + ContinuousRandomVariableNode, + DataNode, + DiscreteRandomVariableNode, + Graph, +) def example_model( + *, label: str = "G", l_len: int = 1, z_len: int = 1, k: int = 10, + compute_u_x: Callable, + compute_u_y: Callable, + compute_phi_x: Callable, + compute_x: Callable, + compute_y: Callable, ) -> Graph: - """Create a graph representing the example model. + """ + Create a graph representing the example model. Args: label: The label of the graph. l_len: The number of entries in the vector data node L. z_len: The number of entries in the vector data node Z. - k: The maximum value that could be taken by the mixture indicator C + k: The maximum value that could be taken by the mixture indicator C. + compute_u_x: Compute UX given the value of C. + compute_u_y: Compute UY given the value of C. + compute_phi_x: Compute PhiX given the value of L. + compute_x: Compute X given the values of Z, PhiX and UX. + compute_y: Compute Y given the values of X and UY. Returns: A graph + """ + for p in inspect.signature(compute_u_x).parameters: + if p != "C": + msg = f"Invalid input to UX: {p}" + raise ValueError(msg) + for p in inspect.signature(compute_u_y).parameters: + if p != "C": + msg = f"Invalid input to UX: {p}" + raise ValueError(msg) + for p in inspect.signature(compute_phi_x).parameters: + if p != "L": + msg = f"Invalid input to PhiX: {p}" + raise ValueError(msg) + for p in inspect.signature(compute_x).parameters: + if p not in ["Z", "PhiX", "UX"]: + msg = f"Invalid input to X: {p}" + raise ValueError(msg) + for p in inspect.signature(compute_y).parameters: + if p not in ["X", "UY"]: + msg = f"Invalid input to X: {p}" + raise ValueError(msg) + graph = Graph(label=label) - graph.add_node(DataNode(label="L", shape=(l_len, ))) - graph.add_node(DataNode(label="Z", shape=(z_len, ))) - graph.add_node(DiscreteRandomVariableNode(label="C", values=list(range(1, k + 1))) + graph.add_node(DataNode(label="L", shape=(l_len,))) + graph.add_node(DataNode(label="Z", shape=(z_len,))) + graph.add_node( + DiscreteRandomVariableNode( + label="C", values=[float(i) for i in range(1, k + 1)] + ) + ) + graph.add_node(ContinuousRandomVariableNode(label="UX", compute=compute_u_x)) + graph.add_node(ContinuousRandomVariableNode(label="UY", compute=compute_u_y)) + graph.add_node(ContinuousRandomVariableNode(label="PhiX", compute=compute_phi_x)) + graph.add_node(ContinuousRandomVariableNode(label="X", compute=compute_x)) + graph.add_node(ContinuousRandomVariableNode(label="Y", compute=compute_y)) return graph From ccb5f0d39fcb725dc8407aad8a6659c8dbe08283 Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Fri, 26 Jun 2026 10:01:16 +0100 Subject: [PATCH 31/49] add test --- tests/test_utils/test_graphs.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) create mode 100644 tests/test_utils/test_graphs.py diff --git a/tests/test_utils/test_graphs.py b/tests/test_utils/test_graphs.py new file mode 100644 index 0000000..f4773b1 --- /dev/null +++ b/tests/test_utils/test_graphs.py @@ -0,0 +1,30 @@ +import numpy as np + +from causalprog.algorithms import evaluate +from causalprog.utils.graphs import example_model + + +def test_example_model(): + graph = example_model( + compute_u_x=lambda C: C, + compute_u_y=lambda C: C + 1, + compute_phi_x=lambda L: L[0], + compute_x=lambda Z, PhiX, UX: Z[0] + UX - PhiX, + compute_y=lambda X, UY: X * UY, + ) + assert len(graph.nodes) == 8 + + data = { + "L": np.array([5.5]), + "Z": np.array([2.0]), + "C": 4.0, + } + + assert np.allclose(evaluate(graph, "L", **data), np.array([5.5])) + assert np.allclose(evaluate(graph, "Z", **data), np.array([2.0])) + assert np.allclose(evaluate(graph, "C", **data), 4.0) + assert np.allclose(evaluate(graph, "UX", **data), 4.0) + assert np.allclose(evaluate(graph, "UY", **data), 5.0) + assert np.allclose(evaluate(graph, "PhiX", **data), 5.5) + assert np.allclose(evaluate(graph, "X", **data), 0.5) + assert np.isclose(evaluate(graph, "Y", **data), 2.5) From 01017bc79c5e830197ae735b57f18a8cd54f769b Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Fri, 26 Jun 2026 10:37:24 +0100 Subject: [PATCH 32/49] ruff --- pyproject.toml | 3 +++ src/causalprog/causal_problem/handlers.py | 4 ++++ src/causalprog/graph/node/base.py | 4 ++-- tests/test_graph/test_ordering.py | 2 +- 4 files changed, 10 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 039cc1d..e68d9e8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -86,6 +86,9 @@ lint.per-file-ignores = {"__init__.py" = [ "INP001", # File is part of an implicit namespace package. "PLR0913", # Too many arguments in function definition "S101", # Use of `assert` detected + "N803", # Argument name should be lowercase + "PLR2004", # Magic value used in comparison + "PT028", # Test function parameters have default arguments ]} lint.select = ["ALL"] lint.flake8-unused-arguments.ignore-variadic-names = true diff --git a/src/causalprog/causal_problem/handlers.py b/src/causalprog/causal_problem/handlers.py index f9958ff..0b80d51 100644 --- a/src/causalprog/causal_problem/handlers.py +++ b/src/causalprog/causal_problem/handlers.py @@ -85,3 +85,7 @@ def __eq__(self, other: object) -> bool: and self.handler is other.handler and self.options == other.options ) + + def __hash__(self) -> int: + """Hash.""" + return hash((self.handler, self.options)) diff --git a/src/causalprog/graph/node/base.py b/src/causalprog/graph/node/base.py index e7d4f54..5719c43 100644 --- a/src/causalprog/graph/node/base.py +++ b/src/causalprog/graph/node/base.py @@ -66,6 +66,8 @@ def __init__( def __getitem__(self, indices: int | slice | tuple[int | slice, ...]) -> Node: """Get a component of this node.""" + from causalprog.graph import ComponentNode # noqa: PLC0415 + if isinstance(indices, int | slice): indices = (indices,) if not isinstance(indices, tuple): @@ -79,8 +81,6 @@ def __getitem__(self, indices: int | slice | tuple[int | slice, ...]) -> Node: e = "list index out of range" raise IndexError(e) - from causalprog.graph import ComponentNode - shape: tuple[int, ...] = () for i, s in zip(indices, self._shape, strict=False): if isinstance(i, slice): diff --git a/tests/test_graph/test_ordering.py b/tests/test_graph/test_ordering.py index 04407b1..494e632 100644 --- a/tests/test_graph/test_ordering.py +++ b/tests/test_graph/test_ordering.py @@ -30,7 +30,7 @@ def test_roots_down_to_outcome() -> None: graph.get_node("W"), ) nodes = graph.roots_down_to_outcome("Z") - assert len(nodes) == 5 # noqa: PLR2004 + assert len(nodes) == 5 for e in edges: if "W" not in e: assert nodes.index(graph.get_node(e[0])) < nodes.index(graph.get_node(e[1])) From 9fc538380ad02c4f9fc1e63c1c1073edb1f85951 Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Fri, 26 Jun 2026 15:11:22 +0100 Subject: [PATCH 33/49] values not keys for parents, update some tests --- src/causalprog/algorithms/do.py | 3 +-- src/causalprog/graph/graph.py | 6 +++--- src/causalprog/graph/node/distribution.py | 2 +- tests/test_algorithms/test_do.py | 6 +++--- 4 files changed, 8 insertions(+), 9 deletions(-) diff --git a/src/causalprog/algorithms/do.py b/src/causalprog/algorithms/do.py index 9c0d6f9..b21b0f0 100644 --- a/src/causalprog/algorithms/do.py +++ b/src/causalprog/algorithms/do.py @@ -90,9 +90,8 @@ def do(graph: Graph, node: str, value: float, label: str | None = None) -> Graph ) raise ValueError(msg) - nodes[node] = ConstantNode(label=node, value=value) - g = Graph(label=f"{label}_do_{node}__" + f"{value}".replace(".", "_")) + g.add_node(ConstantNode(label=node, value=value)) for n in nodes.values(): g.add_node(n) diff --git a/src/causalprog/graph/graph.py b/src/causalprog/graph/graph.py index 1767b5b..e956e2e 100644 --- a/src/causalprog/graph/graph.py +++ b/src/causalprog/graph/graph.py @@ -4,7 +4,7 @@ import numpy.typing as npt from causalprog._abc.labelled import Labelled -from causalprog.graph.node import DistributionNode, Node +from causalprog.graph.node import DataNode, DistributionNode, Node class Graph(Labelled): @@ -238,8 +238,8 @@ def model(self, **parameter_values: npt.ArrayLike) -> dict[str, npt.ArrayLike]: """ # Confirm that all `DataNode`s have been assigned a value. - for node in self.root_nodes: - if node.label not in parameter_values: + for node in self.nodes: + if isinstance(node, DataNode) and node.label not in parameter_values: msg = f"DataNode '{node.label}' not assigned" raise KeyError(msg) diff --git a/src/causalprog/graph/node/distribution.py b/src/causalprog/graph/node/distribution.py index 9c80a51..ddde0b6 100644 --- a/src/causalprog/graph/node/distribution.py +++ b/src/causalprog/graph/node/distribution.py @@ -102,7 +102,7 @@ def __repr__(self) -> str: @override @property def parents(self) -> list[str]: - return [*self._parameters.keys(), *self._constant_parameters.keys()] + return list(self._parameters.values()) def create_model_site(self, **dependent_nodes: jax.Array) -> npt.ArrayLike: """ diff --git a/tests/test_algorithms/test_do.py b/tests/test_algorithms/test_do.py index f47a1b5..a2b3ec8 100644 --- a/tests/test_algorithms/test_do.py +++ b/tests/test_algorithms/test_do.py @@ -10,10 +10,10 @@ def test_do(two_normal_graph): graph = two_normal_graph(5.0, 1.2, 0.8) graph2 = algorithms.do(graph, "UX", 4.0) - assert "loc" in graph.get_node("X").parents - assert "loc" in graph2.get_node("X").parents + assert "UX" in graph.get_node("X").parents + assert "UX" in graph2.get_node("X").parents - graph.get_node("UX") + assert not isinstance(graph.get_node("UX"), ConstantNode) assert isinstance(graph2.get_node("UX"), ConstantNode) From c7e5e8881f3f0441c85e34af6d6fc8d86d990044 Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Fri, 26 Jun 2026 15:13:05 +0100 Subject: [PATCH 34/49] sort toml --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 0b90495..0128b5f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -94,11 +94,11 @@ lint.per-file-ignores = {"__init__.py" = [ "ANN", "D", "INP001", # File is part of an implicit namespace package. - "PLR0913", # Too many arguments in function definition - "S101", # Use of `assert` detected "N803", # Argument name should be lowercase + "PLR0913", # Too many arguments in function definition "PLR2004", # Magic value used in comparison "PT028", # Test function parameters have default arguments + "S101", # Use of `assert` detected ]} lint.select = ["ALL"] lint.flake8-unused-arguments.ignore-variadic-names = true From 63d0b7c995dba886084c3a0ac1900ab1d8a9d878 Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Fri, 26 Jun 2026 15:15:57 +0100 Subject: [PATCH 35/49] ruff --- src/causalprog/graph/node/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/causalprog/graph/node/base.py b/src/causalprog/graph/node/base.py index 5719c43..9cec3a6 100644 --- a/src/causalprog/graph/node/base.py +++ b/src/causalprog/graph/node/base.py @@ -66,7 +66,7 @@ def __init__( def __getitem__(self, indices: int | slice | tuple[int | slice, ...]) -> Node: """Get a component of this node.""" - from causalprog.graph import ComponentNode # noqa: PLC0415 + from causalprog.graph import ComponentNode if isinstance(indices, int | slice): indices = (indices,) From 83a1b5161f958360433f3462dd08c635aad75d4a Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Mon, 29 Jun 2026 10:40:17 +0100 Subject: [PATCH 36/49] Update src/causalprog/algorithms/evaluate.py Co-authored-by: Will Graham <32364977+willGraham01@users.noreply.github.com> --- src/causalprog/algorithms/evaluate.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/causalprog/algorithms/evaluate.py b/src/causalprog/algorithms/evaluate.py index d3ae780..b2c2b3a 100644 --- a/src/causalprog/algorithms/evaluate.py +++ b/src/causalprog/algorithms/evaluate.py @@ -20,7 +20,10 @@ def evaluate( The evaluation of the node """ - for node in graph.roots_down_to_outcome(outcome_node_label): - if node.label not in values: + nodes_to_evaluate = [ + n for n in graph.roots_down_to_outcome(outcome_node_label) + if n not in values + ] + for node in nodes_to_evaluate: values[node.label] = node.evaluate(**values) return values[outcome_node_label] From 6b4ec2e86f6ad620d0d0c5657c2d9886fa3fd100 Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Mon, 29 Jun 2026 10:41:04 +0100 Subject: [PATCH 37/49] correct --- src/causalprog/algorithms/evaluate.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/causalprog/algorithms/evaluate.py b/src/causalprog/algorithms/evaluate.py index b2c2b3a..93edd7a 100644 --- a/src/causalprog/algorithms/evaluate.py +++ b/src/causalprog/algorithms/evaluate.py @@ -22,8 +22,8 @@ def evaluate( """ nodes_to_evaluate = [ n for n in graph.roots_down_to_outcome(outcome_node_label) - if n not in values + if n.label not in values ] for node in nodes_to_evaluate: - values[node.label] = node.evaluate(**values) + values[node.label] = node.evaluate(**values) return values[outcome_node_label] From 9d65fcc3a03c783d5ba4db0e3e8ee36d95375304 Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Mon, 29 Jun 2026 10:49:20 +0100 Subject: [PATCH 38/49] move graph to graph.special and rearrange tests --- src/causalprog/algorithms/evaluate.py | 3 +- src/causalprog/utils/graphs.py | 80 -------------------------- tests/test_algorithms/test_evaluate.py | 26 +++++++++ tests/test_graph/test_special.py | 24 ++++++++ tests/test_utils/test_graphs.py | 30 ---------- 5 files changed, 52 insertions(+), 111 deletions(-) delete mode 100644 src/causalprog/utils/graphs.py create mode 100644 tests/test_graph/test_special.py delete mode 100644 tests/test_utils/test_graphs.py diff --git a/src/causalprog/algorithms/evaluate.py b/src/causalprog/algorithms/evaluate.py index 93edd7a..b512e62 100644 --- a/src/causalprog/algorithms/evaluate.py +++ b/src/causalprog/algorithms/evaluate.py @@ -21,7 +21,8 @@ def evaluate( """ nodes_to_evaluate = [ - n for n in graph.roots_down_to_outcome(outcome_node_label) + n + for n in graph.roots_down_to_outcome(outcome_node_label) if n.label not in values ] for node in nodes_to_evaluate: diff --git a/src/causalprog/utils/graphs.py b/src/causalprog/utils/graphs.py deleted file mode 100644 index b8b29df..0000000 --- a/src/causalprog/utils/graphs.py +++ /dev/null @@ -1,80 +0,0 @@ -"""Functions to create example graphs.""" - -import inspect -from collections.abc import Callable - -from causalprog.graph import ( - ContinuousRandomVariableNode, - DataNode, - DiscreteRandomVariableNode, - Graph, -) - - -def example_model( - *, - label: str = "G", - l_len: int = 1, - z_len: int = 1, - k: int = 10, - compute_u_x: Callable, - compute_u_y: Callable, - compute_phi_x: Callable, - compute_x: Callable, - compute_y: Callable, -) -> Graph: - """ - Create a graph representing the example model. - - Args: - label: The label of the graph. - l_len: The number of entries in the vector data node L. - z_len: The number of entries in the vector data node Z. - k: The maximum value that could be taken by the mixture indicator C. - compute_u_x: Compute UX given the value of C. - compute_u_y: Compute UY given the value of C. - compute_phi_x: Compute PhiX given the value of L. - compute_x: Compute X given the values of Z, PhiX and UX. - compute_y: Compute Y given the values of X and UY. - - Returns: - A graph - - """ - for p in inspect.signature(compute_u_x).parameters: - if p != "C": - msg = f"Invalid input to UX: {p}" - raise ValueError(msg) - for p in inspect.signature(compute_u_y).parameters: - if p != "C": - msg = f"Invalid input to UX: {p}" - raise ValueError(msg) - for p in inspect.signature(compute_phi_x).parameters: - if p != "L": - msg = f"Invalid input to PhiX: {p}" - raise ValueError(msg) - for p in inspect.signature(compute_x).parameters: - if p not in ["Z", "PhiX", "UX"]: - msg = f"Invalid input to X: {p}" - raise ValueError(msg) - for p in inspect.signature(compute_y).parameters: - if p not in ["X", "UY"]: - msg = f"Invalid input to X: {p}" - raise ValueError(msg) - - graph = Graph(label=label) - - graph.add_node(DataNode(label="L", shape=(l_len,))) - graph.add_node(DataNode(label="Z", shape=(z_len,))) - graph.add_node( - DiscreteRandomVariableNode( - label="C", values=[float(i) for i in range(1, k + 1)] - ) - ) - graph.add_node(ContinuousRandomVariableNode(label="UX", compute=compute_u_x)) - graph.add_node(ContinuousRandomVariableNode(label="UY", compute=compute_u_y)) - graph.add_node(ContinuousRandomVariableNode(label="PhiX", compute=compute_phi_x)) - graph.add_node(ContinuousRandomVariableNode(label="X", compute=compute_x)) - graph.add_node(ContinuousRandomVariableNode(label="Y", compute=compute_y)) - - return graph diff --git a/tests/test_algorithms/test_evaluate.py b/tests/test_algorithms/test_evaluate.py index bbcc394..893ea4f 100644 --- a/tests/test_algorithms/test_evaluate.py +++ b/tests/test_algorithms/test_evaluate.py @@ -11,6 +11,7 @@ DistributionNode, Graph, ) +from causalprog.graph.special import example_model @pytest.mark.parametrize( @@ -77,3 +78,28 @@ def test_evaluate_algorithm_four_node(): assert np.isclose(evaluate(graph, "c", a=2.0, b=1.5), 1.5) assert np.isclose(evaluate(graph, "x", a=2.0, b=1.5), 4.5) + + +def test_example_model(): + graph = example_model( + compute_u_x=lambda C: C, + compute_u_y=lambda C: C + 1, + compute_phi_x=lambda L: L[0], + compute_x=lambda Z, PhiX, UX: Z[0] + UX - PhiX, + compute_y=lambda X, UY: X * UY, + ) + + data = { + "L": np.array([5.5]), + "Z": np.array([2.0]), + "C": 4.0, + } + + assert np.allclose(evaluate(graph, "L", **data), np.array([5.5])) + assert np.allclose(evaluate(graph, "Z", **data), np.array([2.0])) + assert np.allclose(evaluate(graph, "C", **data), 4.0) + assert np.allclose(evaluate(graph, "UX", **data), 4.0) + assert np.allclose(evaluate(graph, "UY", **data), 5.0) + assert np.allclose(evaluate(graph, "PhiX", **data), 5.5) + assert np.allclose(evaluate(graph, "X", **data), 0.5) + assert np.isclose(evaluate(graph, "Y", **data), 2.5) diff --git a/tests/test_graph/test_special.py b/tests/test_graph/test_special.py new file mode 100644 index 0000000..85a94c9 --- /dev/null +++ b/tests/test_graph/test_special.py @@ -0,0 +1,24 @@ +from causalprog.graph.special import example_model + + +def test_example_model(): + graph = example_model( + compute_u_x=lambda C: C, + compute_u_y=lambda C: C, + compute_phi_x=lambda L: L[0], + compute_x=lambda Z, PhiX, UX: Z + PhiX + UX, + compute_y=lambda X, UY: X + UY, + ) + assert len(graph.nodes) == 8 + assert len(graph.edges) == 8 + edges = {(e[0].label, e[1].label) for e in graph.edges} + assert edges == { + ("L", "PhiX"), + ("C", "UY"), + ("C", "UX"), + ("UX", "X"), + ("PhiX", "X"), + ("Z", "X"), + ("UY", "Y"), + ("X", "Y"), + } diff --git a/tests/test_utils/test_graphs.py b/tests/test_utils/test_graphs.py deleted file mode 100644 index f4773b1..0000000 --- a/tests/test_utils/test_graphs.py +++ /dev/null @@ -1,30 +0,0 @@ -import numpy as np - -from causalprog.algorithms import evaluate -from causalprog.utils.graphs import example_model - - -def test_example_model(): - graph = example_model( - compute_u_x=lambda C: C, - compute_u_y=lambda C: C + 1, - compute_phi_x=lambda L: L[0], - compute_x=lambda Z, PhiX, UX: Z[0] + UX - PhiX, - compute_y=lambda X, UY: X * UY, - ) - assert len(graph.nodes) == 8 - - data = { - "L": np.array([5.5]), - "Z": np.array([2.0]), - "C": 4.0, - } - - assert np.allclose(evaluate(graph, "L", **data), np.array([5.5])) - assert np.allclose(evaluate(graph, "Z", **data), np.array([2.0])) - assert np.allclose(evaluate(graph, "C", **data), 4.0) - assert np.allclose(evaluate(graph, "UX", **data), 4.0) - assert np.allclose(evaluate(graph, "UY", **data), 5.0) - assert np.allclose(evaluate(graph, "PhiX", **data), 5.5) - assert np.allclose(evaluate(graph, "X", **data), 0.5) - assert np.isclose(evaluate(graph, "Y", **data), 2.5) From b6395842e327e59bcd2fd2a9ca4b77da5f8ce684 Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Mon, 29 Jun 2026 10:49:42 +0100 Subject: [PATCH 39/49] add new file --- src/causalprog/graph/special.py | 80 +++++++++++++++++++++++++++++++++ 1 file changed, 80 insertions(+) create mode 100644 src/causalprog/graph/special.py diff --git a/src/causalprog/graph/special.py b/src/causalprog/graph/special.py new file mode 100644 index 0000000..b8b29df --- /dev/null +++ b/src/causalprog/graph/special.py @@ -0,0 +1,80 @@ +"""Functions to create example graphs.""" + +import inspect +from collections.abc import Callable + +from causalprog.graph import ( + ContinuousRandomVariableNode, + DataNode, + DiscreteRandomVariableNode, + Graph, +) + + +def example_model( + *, + label: str = "G", + l_len: int = 1, + z_len: int = 1, + k: int = 10, + compute_u_x: Callable, + compute_u_y: Callable, + compute_phi_x: Callable, + compute_x: Callable, + compute_y: Callable, +) -> Graph: + """ + Create a graph representing the example model. + + Args: + label: The label of the graph. + l_len: The number of entries in the vector data node L. + z_len: The number of entries in the vector data node Z. + k: The maximum value that could be taken by the mixture indicator C. + compute_u_x: Compute UX given the value of C. + compute_u_y: Compute UY given the value of C. + compute_phi_x: Compute PhiX given the value of L. + compute_x: Compute X given the values of Z, PhiX and UX. + compute_y: Compute Y given the values of X and UY. + + Returns: + A graph + + """ + for p in inspect.signature(compute_u_x).parameters: + if p != "C": + msg = f"Invalid input to UX: {p}" + raise ValueError(msg) + for p in inspect.signature(compute_u_y).parameters: + if p != "C": + msg = f"Invalid input to UX: {p}" + raise ValueError(msg) + for p in inspect.signature(compute_phi_x).parameters: + if p != "L": + msg = f"Invalid input to PhiX: {p}" + raise ValueError(msg) + for p in inspect.signature(compute_x).parameters: + if p not in ["Z", "PhiX", "UX"]: + msg = f"Invalid input to X: {p}" + raise ValueError(msg) + for p in inspect.signature(compute_y).parameters: + if p not in ["X", "UY"]: + msg = f"Invalid input to X: {p}" + raise ValueError(msg) + + graph = Graph(label=label) + + graph.add_node(DataNode(label="L", shape=(l_len,))) + graph.add_node(DataNode(label="Z", shape=(z_len,))) + graph.add_node( + DiscreteRandomVariableNode( + label="C", values=[float(i) for i in range(1, k + 1)] + ) + ) + graph.add_node(ContinuousRandomVariableNode(label="UX", compute=compute_u_x)) + graph.add_node(ContinuousRandomVariableNode(label="UY", compute=compute_u_y)) + graph.add_node(ContinuousRandomVariableNode(label="PhiX", compute=compute_phi_x)) + graph.add_node(ContinuousRandomVariableNode(label="X", compute=compute_x)) + graph.add_node(ContinuousRandomVariableNode(label="Y", compute=compute_y)) + + return graph From b06b6f5b921d2832c63b2a4eca2db894122d0bbb Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Mon, 29 Jun 2026 10:56:00 +0100 Subject: [PATCH 40/49] add evaluate_down_to variant --- src/causalprog/algorithms/evaluate.py | 34 ++++++++++++++++++++------- 1 file changed, 26 insertions(+), 8 deletions(-) diff --git a/src/causalprog/algorithms/evaluate.py b/src/causalprog/algorithms/evaluate.py index b512e62..50f1ad2 100644 --- a/src/causalprog/algorithms/evaluate.py +++ b/src/causalprog/algorithms/evaluate.py @@ -5,26 +5,44 @@ from causalprog.graph import Graph -def evaluate( +def evaluate_down_to( graph: Graph, outcome_node_label: str, **values: float | npt.NDArray[float] -) -> float | npt.NDArray[float]: +) -> dict[str, float | npt.NDArray[float]]: """ - Evaluate a node. + Evaluate all nodes down to a particular node. Args: graph: The graph that the node is contained in. - outcome_node_label: The label of the node to evaluate. + outcome_node_label: The label of the node to evaluate down to. values: Values taken by nodes whose value is given Returns: - The evaluation of the node - + A dictionary of the values of all the nodes that are ancestors of the input node """ + computed_values = {key: value for key, value in values.items()} nodes_to_evaluate = [ n for n in graph.roots_down_to_outcome(outcome_node_label) if n.label not in values ] for node in nodes_to_evaluate: - values[node.label] = node.evaluate(**values) - return values[outcome_node_label] + computed_values[node.label] = node.evaluate(**computed_values) + return computed_values + + +def evaluate( + graph: Graph, outcome_node_label: str, **values: float | npt.NDArray[float] +) -> float | npt.NDArray[float]: + """ + Evaluate a node. + + Args: + graph: The graph that the node is contained in. + outcome_node_label: The label of the node to evaluate. + values: Values taken by nodes whose value is given + + Returns: + The evaluation of the node + + """ + return evaluate_down_to(graph, outcome_node_label, **values)[outcome_node_label] From 89fd2b2da40f635b4d26f3c6c4d1b3423dcb14e0 Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Mon, 29 Jun 2026 10:57:27 +0100 Subject: [PATCH 41/49] Update src/causalprog/algorithms/__init__.py Co-authored-by: Will Graham <32364977+willGraham01@users.noreply.github.com> --- src/causalprog/algorithms/__init__.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/causalprog/algorithms/__init__.py b/src/causalprog/algorithms/__init__.py index d7b26db..bedcb50 100644 --- a/src/causalprog/algorithms/__init__.py +++ b/src/causalprog/algorithms/__init__.py @@ -3,3 +3,6 @@ from .do import do from .evaluate import evaluate from .moments import expectation, moment, standard_deviation + +__all__ = ("do", "evaluate", "expectation", "moment", "standard_deviation") + From 18432f736e2a28f4cbb23308baba1025cbb7d35c Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Mon, 29 Jun 2026 11:00:45 +0100 Subject: [PATCH 42/49] RUFF --- src/causalprog/algorithms/evaluate.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/causalprog/algorithms/evaluate.py b/src/causalprog/algorithms/evaluate.py index 50f1ad2..87c0cec 100644 --- a/src/causalprog/algorithms/evaluate.py +++ b/src/causalprog/algorithms/evaluate.py @@ -18,8 +18,9 @@ def evaluate_down_to( Returns: A dictionary of the values of all the nodes that are ancestors of the input node + """ - computed_values = {key: value for key, value in values.items()} + computed_values = dict(values) nodes_to_evaluate = [ n for n in graph.roots_down_to_outcome(outcome_node_label) From a512b2d25de591cbbfcaede72c149ae75cc307a7 Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Mon, 29 Jun 2026 11:10:02 +0100 Subject: [PATCH 43/49] Parametrize evaluate tests Co-authored-by: Will Graham <32364977+willGraham01@users.noreply.github.com> --- src/causalprog/algorithms/__init__.py | 1 - tests/test_algorithms/test_evaluate.py | 153 +++++++++++-------------- 2 files changed, 70 insertions(+), 84 deletions(-) diff --git a/src/causalprog/algorithms/__init__.py b/src/causalprog/algorithms/__init__.py index bedcb50..a46568c 100644 --- a/src/causalprog/algorithms/__init__.py +++ b/src/causalprog/algorithms/__init__.py @@ -5,4 +5,3 @@ from .moments import expectation, moment, standard_deviation __all__ = ("do", "evaluate", "expectation", "moment", "standard_deviation") - diff --git a/tests/test_algorithms/test_evaluate.py b/tests/test_algorithms/test_evaluate.py index 893ea4f..35f423a 100644 --- a/tests/test_algorithms/test_evaluate.py +++ b/tests/test_algorithms/test_evaluate.py @@ -1,105 +1,92 @@ """Tests for evaluate algorithms.""" -import numpy as np +import jax.numpy as jnp import pytest +from jax import Array from causalprog.algorithms import evaluate -from causalprog.graph import ( - ComponentNode, - ContinuousRandomVariableNode, - DataNode, - DistributionNode, - Graph, -) +from causalprog.graph import Graph from causalprog.graph.special import example_model +@pytest.fixture +def evaluate_test_graph() -> Graph: + return example_model( + z_len=2, + compute_u_x=lambda C: C, + compute_u_y=lambda C: C + 1, + compute_phi_x=lambda L: L[0], + compute_x=lambda Z, PhiX, UX: Z[0] + UX - PhiX, + compute_y=lambda X, UY: X * UY, + ) + + @pytest.mark.parametrize( - ("node", "kwargs_to_evaluate", "expected_result"), + ("outcome_node_label", "initial_values", "expected_result"), [ pytest.param( - DataNode(label="A"), {"A": 2.0}, 2.0, id="Evaluate DataNode itself" + "L", + {"L": jnp.array([5.5]), "Z": jnp.array([2.0, 0.0]), "C": 4.0}, + jnp.array([5.5]), + id="DataNode evaluation w/ excess information provided", + ), + pytest.param( + "Z", + {"Z": jnp.array([2.0, 0.0])}, + jnp.array([2.0, 0.0]), + id="DataNode evaluation", ), pytest.param( - ComponentNode("Parent", 1, label="Child"), - {"Parent": np.arange(4)}, + "C", + {"C": 4.0}, + 4.0, + id="DiscreteRVNode evaluation", + ), + pytest.param( + "UX", + {"C": 4.0}, + 4.0, + id="CtsRVNode evaluation", + ), + pytest.param( + "UX", + {"C": 4.0, "UX": 1.0}, 1.0, - id="Evaluate ComponentNode, given parent", + id="CtsRVNode evaluation, 'given that' overrides computed value", ), - ], -) -def test_evaluate_node(node, kwargs_to_evaluate, expected_result): - assert np.allclose(node.evaluate(**kwargs_to_evaluate), expected_result) - - -@pytest.mark.parametrize( - ("node", "kwargs_for_evaluate", "expected_error"), - [ pytest.param( - DataNode(label="A"), - {}, - ValueError("Missing input for node: A"), - id="DataNode missing input value", + "UY", + {"C": 4.0}, + 5.0, + id="CtsRVNode evaluation, with parents that need evaluating", ), pytest.param( - DistributionNode(distribution=None, label="A"), - {}, - RuntimeError("Cannot evaluate a DistributionNode"), - id="Attempt to evaluate DistributionNode", + "X", + {"L": jnp.array([5.5]), "Z": jnp.array([2.0, 0.0]), "C": 4.0}, + 0.5, + id="Multiple paths from different root nodes", + ), + pytest.param( + "X", + {"L": jnp.array([5.5]), "Z": jnp.array([2.0, 0.0]), "C": 4.0, "PhiX": 0.0}, + 6.0, + id="Multiple paths from different root nodes, with some given values", + ), + pytest.param( + "Y", + {"L": jnp.array([5.5]), "Z": jnp.array([2.0, 0.0]), "C": 4.0}, + 2.5, + id="Evaluating the 'outcome' node.", ), ], ) -def test_evaluate_node_fail_on_missing_data( - node, kwargs_for_evaluate, expected_error, raises_context -): - with raises_context(expected_error): - node.evaluate(**kwargs_for_evaluate) - - -def test_evaluate_algorithm_three_node(): - graph = Graph(label="g") - graph.add_node(DataNode(label="a")) - graph.add_node(DataNode(label="b")) - graph.add_node( - ContinuousRandomVariableNode(label="x", compute=lambda a, b: a + 2.0 * b) +def test_evaluate( + evaluate_test_graph: Graph, + outcome_node_label: str, + initial_values: dict[str, Array], + expected_result: Array, +) -> None: + computed_result = evaluate( + evaluate_test_graph, outcome_node_label, **initial_values ) - - assert np.isclose(evaluate(graph, "x", a=2.0, b=1.5), 5.0) - - -def test_evaluate_algorithm_four_node(): - graph = Graph(label="g") - graph.add_node(DataNode(label="a")) - graph.add_node(DataNode(label="b")) - graph.add_node(ContinuousRandomVariableNode(label="c", compute=lambda a: a - 0.5)) - graph.add_node( - ContinuousRandomVariableNode(label="x", compute=lambda b, c: c + 2.0 * b) - ) - - assert np.isclose(evaluate(graph, "c", a=2.0, b=1.5), 1.5) - assert np.isclose(evaluate(graph, "x", a=2.0, b=1.5), 4.5) - - -def test_example_model(): - graph = example_model( - compute_u_x=lambda C: C, - compute_u_y=lambda C: C + 1, - compute_phi_x=lambda L: L[0], - compute_x=lambda Z, PhiX, UX: Z[0] + UX - PhiX, - compute_y=lambda X, UY: X * UY, - ) - - data = { - "L": np.array([5.5]), - "Z": np.array([2.0]), - "C": 4.0, - } - - assert np.allclose(evaluate(graph, "L", **data), np.array([5.5])) - assert np.allclose(evaluate(graph, "Z", **data), np.array([2.0])) - assert np.allclose(evaluate(graph, "C", **data), 4.0) - assert np.allclose(evaluate(graph, "UX", **data), 4.0) - assert np.allclose(evaluate(graph, "UY", **data), 5.0) - assert np.allclose(evaluate(graph, "PhiX", **data), 5.5) - assert np.allclose(evaluate(graph, "X", **data), 0.5) - assert np.isclose(evaluate(graph, "Y", **data), 2.5) + assert jnp.allclose(expected_result, computed_result) From 1b92a35093dca66823a6701cf16ad7974d4d2bbe Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Mon, 29 Jun 2026 14:02:25 +0100 Subject: [PATCH 44/49] don't copy big arrays --- src/causalprog/algorithms/evaluate.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/causalprog/algorithms/evaluate.py b/src/causalprog/algorithms/evaluate.py index 87c0cec..0fc325d 100644 --- a/src/causalprog/algorithms/evaluate.py +++ b/src/causalprog/algorithms/evaluate.py @@ -20,14 +20,14 @@ def evaluate_down_to( A dictionary of the values of all the nodes that are ancestors of the input node """ - computed_values = dict(values) + computed_values = {} nodes_to_evaluate = [ n for n in graph.roots_down_to_outcome(outcome_node_label) if n.label not in values ] for node in nodes_to_evaluate: - computed_values[node.label] = node.evaluate(**computed_values) + computed_values[node.label] = node.evaluate(**values, **computed_values) return computed_values @@ -46,4 +46,7 @@ def evaluate( The evaluation of the node """ - return evaluate_down_to(graph, outcome_node_label, **values)[outcome_node_label] + return values.get( + outcome_node_label, + evaluate_down_to(graph, outcome_node_label, **values)[outcome_node_label], + ) From 20da68532afa3773e30267541bc89c9b816f8556 Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Mon, 29 Jun 2026 14:04:29 +0100 Subject: [PATCH 45/49] mypy --- src/causalprog/algorithms/evaluate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/causalprog/algorithms/evaluate.py b/src/causalprog/algorithms/evaluate.py index 0fc325d..49eec91 100644 --- a/src/causalprog/algorithms/evaluate.py +++ b/src/causalprog/algorithms/evaluate.py @@ -20,7 +20,7 @@ def evaluate_down_to( A dictionary of the values of all the nodes that are ancestors of the input node """ - computed_values = {} + computed_values: dict[str, float | npt.NDArray[float]] = {} nodes_to_evaluate = [ n for n in graph.roots_down_to_outcome(outcome_node_label) From c1ed8588a312a18a84425ecda2fc1137de258eba Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Mon, 29 Jun 2026 14:13:07 +0100 Subject: [PATCH 46/49] get doesn't work here as it evaluates arguments BEFORE checking if key in array --- src/causalprog/algorithms/evaluate.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/causalprog/algorithms/evaluate.py b/src/causalprog/algorithms/evaluate.py index 49eec91..2445745 100644 --- a/src/causalprog/algorithms/evaluate.py +++ b/src/causalprog/algorithms/evaluate.py @@ -46,7 +46,6 @@ def evaluate( The evaluation of the node """ - return values.get( - outcome_node_label, - evaluate_down_to(graph, outcome_node_label, **values)[outcome_node_label], - ) + if outcome_node_label in values: + return values[outcome_node_label] + return evaluate_down_to(graph, outcome_node_label, **values)[outcome_node_label] From c0c54f8698ec2ec4168b50c2be89c5135a24c78e Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Mon, 29 Jun 2026 14:21:42 +0100 Subject: [PATCH 47/49] Add test of evaluate_down_to Co-authored-by: Will Graham <32364977+willGraham01@users.noreply.github.com> --- src/causalprog/algorithms/__init__.py | 11 +++++-- tests/test_algorithms/test_evaluate.py | 41 ++++++++++++++++++-------- 2 files changed, 38 insertions(+), 14 deletions(-) diff --git a/src/causalprog/algorithms/__init__.py b/src/causalprog/algorithms/__init__.py index a46568c..5f81e7c 100644 --- a/src/causalprog/algorithms/__init__.py +++ b/src/causalprog/algorithms/__init__.py @@ -1,7 +1,14 @@ """Algorithms.""" from .do import do -from .evaluate import evaluate +from .evaluate import evaluate, evaluate_down_to from .moments import expectation, moment, standard_deviation -__all__ = ("do", "evaluate", "expectation", "moment", "standard_deviation") +__all__ = ( + "do", + "evaluate", + "expectation", + "moment", + "standard_deviation", + "evaluate_down_to", +) diff --git a/tests/test_algorithms/test_evaluate.py b/tests/test_algorithms/test_evaluate.py index 35f423a..e955b47 100644 --- a/tests/test_algorithms/test_evaluate.py +++ b/tests/test_algorithms/test_evaluate.py @@ -4,7 +4,7 @@ import pytest from jax import Array -from causalprog.algorithms import evaluate +from causalprog.algorithms import evaluate, evaluate_down_to from causalprog.graph import Graph from causalprog.graph.special import example_model @@ -27,55 +27,55 @@ def evaluate_test_graph() -> Graph: pytest.param( "L", {"L": jnp.array([5.5]), "Z": jnp.array([2.0, 0.0]), "C": 4.0}, - jnp.array([5.5]), + {}, id="DataNode evaluation w/ excess information provided", ), pytest.param( "Z", {"Z": jnp.array([2.0, 0.0])}, - jnp.array([2.0, 0.0]), + {}, id="DataNode evaluation", ), pytest.param( "C", {"C": 4.0}, - 4.0, + {}, id="DiscreteRVNode evaluation", ), pytest.param( "UX", {"C": 4.0}, - 4.0, + {"UX": 4.0}, id="CtsRVNode evaluation", ), pytest.param( "UX", {"C": 4.0, "UX": 1.0}, - 1.0, + {}, id="CtsRVNode evaluation, 'given that' overrides computed value", ), pytest.param( "UY", {"C": 4.0}, - 5.0, + {"UY": 5.0}, id="CtsRVNode evaluation, with parents that need evaluating", ), pytest.param( "X", {"L": jnp.array([5.5]), "Z": jnp.array([2.0, 0.0]), "C": 4.0}, - 0.5, + {"UX": 4.0, "PhiX": 5.5, "X": 0.5}, id="Multiple paths from different root nodes", ), pytest.param( "X", {"L": jnp.array([5.5]), "Z": jnp.array([2.0, 0.0]), "C": 4.0, "PhiX": 0.0}, - 6.0, + {"UX": 4.0, "X": 6.0}, id="Multiple paths from different root nodes, with some given values", ), pytest.param( "Y", {"L": jnp.array([5.5]), "Z": jnp.array([2.0, 0.0]), "C": 4.0}, - 2.5, + {"UX": 4.0, "UY": 5.0, "PhiX": 5.5, "X": 0.5, "Y": 2.5}, id="Evaluating the 'outcome' node.", ), ], @@ -86,7 +86,24 @@ def test_evaluate( initial_values: dict[str, Array], expected_result: Array, ) -> None: - computed_result = evaluate( + computed_result = evaluate_down_to( evaluate_test_graph, outcome_node_label, **initial_values ) - assert jnp.allclose(expected_result, computed_result) + + # Same number of entries + assert len(expected_result) == len(computed_result) + # Keys are correct + assert set(expected_result.keys()) == set(computed_result.keys()) + + # All entries match to acceptable precision for floats + for node_label, computed_value in computed_result.items(): + assert jnp.allclose(computed_value, expected_result[node_label]) + + # Just asking for one value did indeed extract the correct node value + computed_result_single = evaluate( + evaluate_test_graph, outcome_node_label, **initial_values + ) + if outcome_node_label in initial_values: + assert jnp.allclose(computed_result_single, initial_values[outcome_node_label]) + else: + assert jnp.allclose(computed_result_single, computed_result[outcome_node_label]) From 5cf59ecbf1c99cd65bb8c33e12d4c223e6063b13 Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Mon, 29 Jun 2026 14:22:43 +0100 Subject: [PATCH 48/49] RUFF --- src/causalprog/algorithms/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/causalprog/algorithms/__init__.py b/src/causalprog/algorithms/__init__.py index 5f81e7c..6bb0efb 100644 --- a/src/causalprog/algorithms/__init__.py +++ b/src/causalprog/algorithms/__init__.py @@ -7,8 +7,8 @@ __all__ = ( "do", "evaluate", + "evaluate_down_to", "expectation", "moment", "standard_deviation", - "evaluate_down_to", ) From 5f4a080754f171996c2d79f70f48b7a661acc151 Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Mon, 29 Jun 2026 14:34:58 +0100 Subject: [PATCH 49/49] type --- tests/test_algorithms/test_evaluate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_algorithms/test_evaluate.py b/tests/test_algorithms/test_evaluate.py index e955b47..dedc8c7 100644 --- a/tests/test_algorithms/test_evaluate.py +++ b/tests/test_algorithms/test_evaluate.py @@ -84,7 +84,7 @@ def test_evaluate( evaluate_test_graph: Graph, outcome_node_label: str, initial_values: dict[str, Array], - expected_result: Array, + expected_result: dict[str, Array], ) -> None: computed_result = evaluate_down_to( evaluate_test_graph, outcome_node_label, **initial_values