diff --git a/pyproject.toml b/pyproject.toml index aa82ff5..0128b5f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -94,7 +94,10 @@ lint.per-file-ignores = {"__init__.py" = [ "ANN", "D", "INP001", # File is part of an implicit namespace package. + "N803", # Argument name should be lowercase "PLR0913", # Too many arguments in function definition + "PLR2004", # Magic value used in comparison + "PT028", # Test function parameters have default arguments "S101", # Use of `assert` detected ]} lint.select = ["ALL"] diff --git a/src/causalprog/algorithms/__init__.py b/src/causalprog/algorithms/__init__.py index da37e9a..6bb0efb 100644 --- a/src/causalprog/algorithms/__init__.py +++ b/src/causalprog/algorithms/__init__.py @@ -1,4 +1,14 @@ """Algorithms.""" from .do import do +from .evaluate import evaluate, evaluate_down_to from .moments import expectation, moment, standard_deviation + +__all__ = ( + "do", + "evaluate", + "evaluate_down_to", + "expectation", + "moment", + "standard_deviation", +) diff --git a/src/causalprog/algorithms/do.py b/src/causalprog/algorithms/do.py index 9c0d6f9..b21b0f0 100644 --- a/src/causalprog/algorithms/do.py +++ b/src/causalprog/algorithms/do.py @@ -90,9 +90,8 @@ def do(graph: Graph, node: str, value: float, label: str | None = None) -> Graph ) raise ValueError(msg) - nodes[node] = ConstantNode(label=node, value=value) - g = Graph(label=f"{label}_do_{node}__" + f"{value}".replace(".", "_")) + g.add_node(ConstantNode(label=node, value=value)) for n in nodes.values(): g.add_node(n) diff --git a/src/causalprog/algorithms/evaluate.py b/src/causalprog/algorithms/evaluate.py new file mode 100644 index 0000000..2445745 --- /dev/null +++ b/src/causalprog/algorithms/evaluate.py @@ -0,0 +1,51 @@ +"""Algorithms for evaluating a graph node.""" + +import numpy.typing as npt + +from causalprog.graph import Graph + + +def evaluate_down_to( + graph: Graph, outcome_node_label: str, **values: float | npt.NDArray[float] +) -> dict[str, float | npt.NDArray[float]]: + """ + Evaluate all nodes down to a particular node. + + Args: + graph: The graph that the node is contained in. + outcome_node_label: The label of the node to evaluate down to. + values: Values taken by nodes whose value is given + + Returns: + A dictionary of the values of all the nodes that are ancestors of the input node + + """ + computed_values: dict[str, float | npt.NDArray[float]] = {} + nodes_to_evaluate = [ + n + for n in graph.roots_down_to_outcome(outcome_node_label) + if n.label not in values + ] + for node in nodes_to_evaluate: + computed_values[node.label] = node.evaluate(**values, **computed_values) + return computed_values + + +def evaluate( + graph: Graph, outcome_node_label: str, **values: float | npt.NDArray[float] +) -> float | npt.NDArray[float]: + """ + Evaluate a node. + + Args: + graph: The graph that the node is contained in. + outcome_node_label: The label of the node to evaluate. + values: Values taken by nodes whose value is given + + Returns: + The evaluation of the node + + """ + if outcome_node_label in values: + return values[outcome_node_label] + return evaluate_down_to(graph, outcome_node_label, **values)[outcome_node_label] diff --git a/src/causalprog/causal_problem/handlers.py b/src/causalprog/causal_problem/handlers.py index f9958ff..0b80d51 100644 --- a/src/causalprog/causal_problem/handlers.py +++ b/src/causalprog/causal_problem/handlers.py @@ -85,3 +85,7 @@ def __eq__(self, other: object) -> bool: and self.handler is other.handler and self.options == other.options ) + + def __hash__(self) -> int: + """Hash.""" + return hash((self.handler, self.options)) diff --git a/src/causalprog/graph/graph.py b/src/causalprog/graph/graph.py index 1295e27..e956e2e 100644 --- a/src/causalprog/graph/graph.py +++ b/src/causalprog/graph/graph.py @@ -4,7 +4,7 @@ import numpy.typing as npt from causalprog._abc.labelled import Labelled -from causalprog.graph.node import ComponentNode, DistributionNode, Node +from causalprog.graph.node import DataNode, DistributionNode, Node class Graph(Labelled): @@ -60,11 +60,8 @@ def add_node(self, node: Node) -> None: raise ValueError(msg) self._nodes_by_label[node.label] = node self._graph.add_node(node) - if isinstance(node, ComponentNode): - if len(node.parents) != 1: - msg = "ComponentNode should have exactly one parent." - raise ValueError(msg) - self.add_edge(node.parents[0], node.label) + for p in node.parents: + self.add_edge(p, node.label) def add_edge(self, start_node: Node | str, end_node: Node | str) -> None: """ @@ -241,8 +238,8 @@ def model(self, **parameter_values: npt.ArrayLike) -> dict[str, npt.ArrayLike]: """ # Confirm that all `DataNode`s have been assigned a value. - for node in self.root_nodes: - if node.label not in parameter_values: + for node in self.nodes: + if isinstance(node, DataNode) and node.label not in parameter_values: msg = f"DataNode '{node.label}' not assigned" raise KeyError(msg) diff --git a/src/causalprog/graph/node/base.py b/src/causalprog/graph/node/base.py index e7d4f54..9cec3a6 100644 --- a/src/causalprog/graph/node/base.py +++ b/src/causalprog/graph/node/base.py @@ -66,6 +66,8 @@ def __init__( def __getitem__(self, indices: int | slice | tuple[int | slice, ...]) -> Node: """Get a component of this node.""" + from causalprog.graph import ComponentNode + if isinstance(indices, int | slice): indices = (indices,) if not isinstance(indices, tuple): @@ -79,8 +81,6 @@ def __getitem__(self, indices: int | slice | tuple[int | slice, ...]) -> Node: e = "list index out of range" raise IndexError(e) - from causalprog.graph import ComponentNode - shape: tuple[int, ...] = () for i, s in zip(indices, self._shape, strict=False): if isinstance(i, slice): diff --git a/src/causalprog/graph/node/distribution.py b/src/causalprog/graph/node/distribution.py index 9c80a51..ddde0b6 100644 --- a/src/causalprog/graph/node/distribution.py +++ b/src/causalprog/graph/node/distribution.py @@ -102,7 +102,7 @@ def __repr__(self) -> str: @override @property def parents(self) -> list[str]: - return [*self._parameters.keys(), *self._constant_parameters.keys()] + return list(self._parameters.values()) def create_model_site(self, **dependent_nodes: jax.Array) -> npt.ArrayLike: """ diff --git a/src/causalprog/graph/special.py b/src/causalprog/graph/special.py new file mode 100644 index 0000000..b8b29df --- /dev/null +++ b/src/causalprog/graph/special.py @@ -0,0 +1,80 @@ +"""Functions to create example graphs.""" + +import inspect +from collections.abc import Callable + +from causalprog.graph import ( + ContinuousRandomVariableNode, + DataNode, + DiscreteRandomVariableNode, + Graph, +) + + +def example_model( + *, + label: str = "G", + l_len: int = 1, + z_len: int = 1, + k: int = 10, + compute_u_x: Callable, + compute_u_y: Callable, + compute_phi_x: Callable, + compute_x: Callable, + compute_y: Callable, +) -> Graph: + """ + Create a graph representing the example model. + + Args: + label: The label of the graph. + l_len: The number of entries in the vector data node L. + z_len: The number of entries in the vector data node Z. + k: The maximum value that could be taken by the mixture indicator C. + compute_u_x: Compute UX given the value of C. + compute_u_y: Compute UY given the value of C. + compute_phi_x: Compute PhiX given the value of L. + compute_x: Compute X given the values of Z, PhiX and UX. + compute_y: Compute Y given the values of X and UY. + + Returns: + A graph + + """ + for p in inspect.signature(compute_u_x).parameters: + if p != "C": + msg = f"Invalid input to UX: {p}" + raise ValueError(msg) + for p in inspect.signature(compute_u_y).parameters: + if p != "C": + msg = f"Invalid input to UX: {p}" + raise ValueError(msg) + for p in inspect.signature(compute_phi_x).parameters: + if p != "L": + msg = f"Invalid input to PhiX: {p}" + raise ValueError(msg) + for p in inspect.signature(compute_x).parameters: + if p not in ["Z", "PhiX", "UX"]: + msg = f"Invalid input to X: {p}" + raise ValueError(msg) + for p in inspect.signature(compute_y).parameters: + if p not in ["X", "UY"]: + msg = f"Invalid input to X: {p}" + raise ValueError(msg) + + graph = Graph(label=label) + + graph.add_node(DataNode(label="L", shape=(l_len,))) + graph.add_node(DataNode(label="Z", shape=(z_len,))) + graph.add_node( + DiscreteRandomVariableNode( + label="C", values=[float(i) for i in range(1, k + 1)] + ) + ) + graph.add_node(ContinuousRandomVariableNode(label="UX", compute=compute_u_x)) + graph.add_node(ContinuousRandomVariableNode(label="UY", compute=compute_u_y)) + graph.add_node(ContinuousRandomVariableNode(label="PhiX", compute=compute_phi_x)) + graph.add_node(ContinuousRandomVariableNode(label="X", compute=compute_x)) + graph.add_node(ContinuousRandomVariableNode(label="Y", compute=compute_y)) + + return graph diff --git a/tests/test_algorithms/test_do.py b/tests/test_algorithms/test_do.py index f47a1b5..a2b3ec8 100644 --- a/tests/test_algorithms/test_do.py +++ b/tests/test_algorithms/test_do.py @@ -10,10 +10,10 @@ def test_do(two_normal_graph): graph = two_normal_graph(5.0, 1.2, 0.8) graph2 = algorithms.do(graph, "UX", 4.0) - assert "loc" in graph.get_node("X").parents - assert "loc" in graph2.get_node("X").parents + assert "UX" in graph.get_node("X").parents + assert "UX" in graph2.get_node("X").parents - graph.get_node("UX") + assert not isinstance(graph.get_node("UX"), ConstantNode) assert isinstance(graph2.get_node("UX"), ConstantNode) diff --git a/tests/test_algorithms/test_evaluate.py b/tests/test_algorithms/test_evaluate.py index 6475c8f..dedc8c7 100644 --- a/tests/test_algorithms/test_evaluate.py +++ b/tests/test_algorithms/test_evaluate.py @@ -1,48 +1,109 @@ """Tests for evaluate algorithms.""" -import numpy as np +import jax.numpy as jnp import pytest +from jax import Array -from causalprog.graph import ComponentNode, DataNode, DistributionNode +from causalprog.algorithms import evaluate, evaluate_down_to +from causalprog.graph import Graph +from causalprog.graph.special import example_model + + +@pytest.fixture +def evaluate_test_graph() -> Graph: + return example_model( + z_len=2, + compute_u_x=lambda C: C, + compute_u_y=lambda C: C + 1, + compute_phi_x=lambda L: L[0], + compute_x=lambda Z, PhiX, UX: Z[0] + UX - PhiX, + compute_y=lambda X, UY: X * UY, + ) @pytest.mark.parametrize( - ("node", "kwargs_to_evaluate", "expected_result"), + ("outcome_node_label", "initial_values", "expected_result"), [ pytest.param( - DataNode(label="A"), {"A": 2.0}, 2.0, id="Evaluate DataNode itself" + "L", + {"L": jnp.array([5.5]), "Z": jnp.array([2.0, 0.0]), "C": 4.0}, + {}, + id="DataNode evaluation w/ excess information provided", ), pytest.param( - ComponentNode("Parent", 1, label="Child"), - {"Parent": np.arange(4)}, - 1.0, - id="Evaluate ComponentNode, given parent", + "Z", + {"Z": jnp.array([2.0, 0.0])}, + {}, + id="DataNode evaluation", ), - ], -) -def test_evaluate_node(node, kwargs_to_evaluate, expected_result): - assert np.allclose(node.evaluate(**kwargs_to_evaluate), expected_result) - - -@pytest.mark.parametrize( - ("node", "kwargs_for_evaluate", "expected_error"), - [ pytest.param( - DataNode(label="A"), + "C", + {"C": 4.0}, {}, - ValueError("Missing input for node: A"), - id="DataNode missing input value", + id="DiscreteRVNode evaluation", + ), + pytest.param( + "UX", + {"C": 4.0}, + {"UX": 4.0}, + id="CtsRVNode evaluation", ), pytest.param( - DistributionNode(distribution=None, label="A"), + "UX", + {"C": 4.0, "UX": 1.0}, {}, - RuntimeError("Cannot evaluate a DistributionNode"), - id="Attempt to evaluate DistributionNode", + id="CtsRVNode evaluation, 'given that' overrides computed value", + ), + pytest.param( + "UY", + {"C": 4.0}, + {"UY": 5.0}, + id="CtsRVNode evaluation, with parents that need evaluating", + ), + pytest.param( + "X", + {"L": jnp.array([5.5]), "Z": jnp.array([2.0, 0.0]), "C": 4.0}, + {"UX": 4.0, "PhiX": 5.5, "X": 0.5}, + id="Multiple paths from different root nodes", + ), + pytest.param( + "X", + {"L": jnp.array([5.5]), "Z": jnp.array([2.0, 0.0]), "C": 4.0, "PhiX": 0.0}, + {"UX": 4.0, "X": 6.0}, + id="Multiple paths from different root nodes, with some given values", + ), + pytest.param( + "Y", + {"L": jnp.array([5.5]), "Z": jnp.array([2.0, 0.0]), "C": 4.0}, + {"UX": 4.0, "UY": 5.0, "PhiX": 5.5, "X": 0.5, "Y": 2.5}, + id="Evaluating the 'outcome' node.", ), ], ) -def test_evaluate_node_fail_on_missing_data( - node, kwargs_for_evaluate, expected_error, raises_context -): - with raises_context(expected_error): - node.evaluate(**kwargs_for_evaluate) +def test_evaluate( + evaluate_test_graph: Graph, + outcome_node_label: str, + initial_values: dict[str, Array], + expected_result: dict[str, Array], +) -> None: + computed_result = evaluate_down_to( + evaluate_test_graph, outcome_node_label, **initial_values + ) + + # Same number of entries + assert len(expected_result) == len(computed_result) + # Keys are correct + assert set(expected_result.keys()) == set(computed_result.keys()) + + # All entries match to acceptable precision for floats + for node_label, computed_value in computed_result.items(): + assert jnp.allclose(computed_value, expected_result[node_label]) + + # Just asking for one value did indeed extract the correct node value + computed_result_single = evaluate( + evaluate_test_graph, outcome_node_label, **initial_values + ) + if outcome_node_label in initial_values: + assert jnp.allclose(computed_result_single, initial_values[outcome_node_label]) + else: + assert jnp.allclose(computed_result_single, computed_result[outcome_node_label]) diff --git a/tests/test_graph/test_ordering.py b/tests/test_graph/test_ordering.py index 605e0b8..ae092c8 100644 --- a/tests/test_graph/test_ordering.py +++ b/tests/test_graph/test_ordering.py @@ -30,7 +30,7 @@ def test_roots_down_to_outcome() -> None: graph.get_node("W"), ) nodes = graph.roots_down_to_outcome("Z") - assert len(nodes) == 5 # noqa: PLR2004 + assert len(nodes) == 5 for e in edges: if "W" not in e: assert nodes.index(graph.get_node(e[0])) < nodes.index(graph.get_node(e[1])) diff --git a/tests/test_graph/test_special.py b/tests/test_graph/test_special.py new file mode 100644 index 0000000..85a94c9 --- /dev/null +++ b/tests/test_graph/test_special.py @@ -0,0 +1,24 @@ +from causalprog.graph.special import example_model + + +def test_example_model(): + graph = example_model( + compute_u_x=lambda C: C, + compute_u_y=lambda C: C, + compute_phi_x=lambda L: L[0], + compute_x=lambda Z, PhiX, UX: Z + PhiX + UX, + compute_y=lambda X, UY: X + UY, + ) + assert len(graph.nodes) == 8 + assert len(graph.edges) == 8 + edges = {(e[0].label, e[1].label) for e in graph.edges} + assert edges == { + ("L", "PhiX"), + ("C", "UY"), + ("C", "UX"), + ("UX", "X"), + ("PhiX", "X"), + ("Z", "X"), + ("UY", "Y"), + ("X", "Y"), + }