diff --git a/src/causalprog/algorithms/evaluate.py b/src/causalprog/algorithms/evaluate.py index 7972cc0..72c7925 100644 --- a/src/causalprog/algorithms/evaluate.py +++ b/src/causalprog/algorithms/evaluate.py @@ -20,6 +20,12 @@ def evaluate_down_to( A dictionary of the values of all the nodes that are ancestors of the input node """ + for label, value in values.items(): + graph.get_node(label).assert_is_valid_value(value) + + if outcome_node_label in values: + return {outcome_node_label: values[outcome_node_label]} + computed_values: dict[str, float | npt.NDArray[float]] = {} nodes_to_evaluate = [ n @@ -46,6 +52,4 @@ def evaluate( The evaluation of the node """ - if outcome_node_label in values: - return values[outcome_node_label] return evaluate_down_to(graph, outcome_node_label, values)[outcome_node_label] diff --git a/src/causalprog/graph/node/base.py b/src/causalprog/graph/node/base.py index 6a218e9..1e00647 100644 --- a/src/causalprog/graph/node/base.py +++ b/src/causalprog/graph/node/base.py @@ -168,3 +168,19 @@ def parents(self) -> list[str]: List of labels of parent nodes """ + + def is_valid_value(self, _value: float | npt.NDArray[float]) -> bool: + """Check if a value is valid for this node.""" + return True + + def assert_is_valid_value(self, value: float | npt.NDArray[float]) -> None: + """Check if a value is valid for this node.""" + if not self.is_valid_value(value): + msg = ( + f"Invalid value for {self.__class__.__name__}: " + f"{self.label} cannot be {value}" + ) + raise ValueError(msg) + if self.shape != (value.shape if hasattr(value, "shape") else ()): + msg = f"Invalid value for node: {self.label}" + raise ValueError(msg) diff --git a/src/causalprog/graph/node/random_variables.py b/src/causalprog/graph/node/random_variables.py index be80f12..97669c9 100644 --- a/src/causalprog/graph/node/random_variables.py +++ b/src/causalprog/graph/node/random_variables.py @@ -1,7 +1,6 @@ """Graph nodes representing random variables.""" import typing -from abc import abstractmethod import jax import numpy as np @@ -69,22 +68,6 @@ def evaluate( def parents(self) -> list[str]: return self._parents - @abstractmethod - def is_valid_value(self, value: float | npt.NDArray[float]) -> bool: - """Check if a value is valid for this node.""" - - def assert_is_valid_value(self, value: float | npt.NDArray[float]) -> None: - """Check if a value is valid for this node.""" - if not self.is_valid_value(value): - msg = ( - f"Invalid value for {self.__class__.__name__}: " - f"{self.label} cannot be {value}" - ) - raise ValueError(msg) - if self.shape != (value.shape if hasattr(value, "shape") else ()): - msg = f"Invalid value for node: {self.label}" - raise ValueError(msg) - class ContinuousRandomVariableNode(RandomVariableNode): """A node containing a continuous random variable (RV).""" @@ -93,10 +76,6 @@ class ContinuousRandomVariableNode(RandomVariableNode): def __repr__(self) -> str: return f'ContinuousRandomVariableNode(label="{self.label}")' - @override - def is_valid_value(self, value: float | npt.NDArray[float]) -> bool: - return True - @override def copy(self) -> Node: return ContinuousRandomVariableNode( diff --git a/tests/test_algorithms/test_evaluate.py b/tests/test_algorithms/test_evaluate.py index 770897b..bf9d144 100644 --- a/tests/test_algorithms/test_evaluate.py +++ b/tests/test_algorithms/test_evaluate.py @@ -26,20 +26,20 @@ def evaluate_test_graph() -> Graph: [ pytest.param( "l", - {"l": jnp.array([5.5]), "x": jnp.array([2.0, 0.0]), "c": 4.0}, - {}, + {"l": jnp.array([5.5]), "x": 2.0, "c": 4.0}, + {"l": jnp.array([5.5])}, id="DataNode evaluation w/ excess information provided", ), pytest.param( "z", {"z": jnp.array([2.0, 0.0])}, - {}, + {"z": jnp.array([2.0, 0.0])}, id="DataNode evaluation", ), pytest.param( "c", {"c": 4.0}, - {}, + {"c": 4.0}, id="DiscreteRVNode evaluation", ), pytest.param( @@ -51,7 +51,7 @@ def evaluate_test_graph() -> Graph: pytest.param( "u_x", {"c": 4.0, "u_x": 1.0}, - {}, + {"u_x": 1.0}, id="CtsRVNode evaluation, 'given that' overrides computed value", ), pytest.param( @@ -103,7 +103,32 @@ def test_evaluate( computed_result_single = evaluate( evaluate_test_graph, outcome_node_label, initial_values ) - if outcome_node_label in initial_values: - assert jnp.allclose(computed_result_single, initial_values[outcome_node_label]) - else: - assert jnp.allclose(computed_result_single, computed_result[outcome_node_label]) + assert jnp.allclose(computed_result_single, computed_result[outcome_node_label]) + + +@pytest.mark.parametrize( + ("outcome_node_label", "initial_values", "expected_error"), + [ + pytest.param( + "c", + {"l": jnp.array([5.5]), "z": jnp.array([2.0, 0.0]), "c": 4.5}, + ValueError("Invalid value for "), + id="Invalid value for discrete RV node", + ), + pytest.param( + "phi_x", + {"z": jnp.array([2.0, 0.0]), "x": 4.0}, + ValueError("Missing input for node"), + id="Missing value for a parent", + ), + ], +) +def test_evaluate_error( + evaluate_test_graph: Graph, + outcome_node_label: str, + initial_values: dict[str, Array], + expected_error: BaseException, + raises_context, +) -> None: + with raises_context(expected_error): + evaluate(evaluate_test_graph, outcome_node_label, initial_values)