From b49420ba3b3cbfbbcec2917ae24870d42548af66 Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Mon, 29 Jun 2026 14:45:12 +0100 Subject: [PATCH 1/5] Add check that values input into evaluate are valid --- src/causalprog/algorithms/evaluate.py | 4 ++++ src/causalprog/graph/node/base.py | 16 ++++++++++++++ src/causalprog/graph/node/component.py | 4 ++++ src/causalprog/graph/node/constant.py | 4 ++++ src/causalprog/graph/node/data.py | 4 ++++ src/causalprog/graph/node/distribution.py | 4 ++++ src/causalprog/graph/node/random_variables.py | 17 -------------- tests/test_algorithms/test_evaluate.py | 22 +++++++++++++++++++ 8 files changed, 58 insertions(+), 17 deletions(-) diff --git a/src/causalprog/algorithms/evaluate.py b/src/causalprog/algorithms/evaluate.py index 2445745..c398a53 100644 --- a/src/causalprog/algorithms/evaluate.py +++ b/src/causalprog/algorithms/evaluate.py @@ -20,6 +20,8 @@ def evaluate_down_to( A dictionary of the values of all the nodes that are ancestors of the input node """ + for label, value in values.items(): + graph.get_node(label).assert_is_valid_value(value) computed_values: dict[str, float | npt.NDArray[float]] = {} nodes_to_evaluate = [ n @@ -46,6 +48,8 @@ def evaluate( The evaluation of the node """ + for label, value in values.items(): + graph.get_node(label).assert_is_valid_value(value) if outcome_node_label in values: return values[outcome_node_label] return evaluate_down_to(graph, outcome_node_label, **values)[outcome_node_label] diff --git a/src/causalprog/graph/node/base.py b/src/causalprog/graph/node/base.py index 9cec3a6..d8a1157 100644 --- a/src/causalprog/graph/node/base.py +++ b/src/causalprog/graph/node/base.py @@ -168,3 +168,19 @@ def parents(self) -> list[str]: List of labels of parent nodes """ + + @abstractmethod + def is_valid_value(self, value: float | npt.NDArray[float]) -> bool: + """Check if a value is valid for this node.""" + + def assert_is_valid_value(self, value: float | npt.NDArray[float]) -> None: + """Check if a value is valid for this node.""" + if not self.is_valid_value(value): + msg = ( + f"Invalid value for {self.__class__.__name__}: " + f"{self.label} cannot be {value}" + ) + raise ValueError(msg) + if self.shape != (value.shape if hasattr(value, "shape") else ()): + msg = f"Invalid value for node: {self.label}" + raise ValueError(msg) diff --git a/src/causalprog/graph/node/component.py b/src/causalprog/graph/node/component.py index 8e27d53..2f51eb0 100644 --- a/src/causalprog/graph/node/component.py +++ b/src/causalprog/graph/node/component.py @@ -82,3 +82,7 @@ def __repr__(self) -> str: @property def parents(self) -> list[str]: return [self._parent_node_label] + + @override + def is_valid_value(self, value: float | npt.NDArray[float]) -> bool: + return True diff --git a/src/causalprog/graph/node/constant.py b/src/causalprog/graph/node/constant.py index f49a2d3..daee513 100644 --- a/src/causalprog/graph/node/constant.py +++ b/src/causalprog/graph/node/constant.py @@ -61,3 +61,7 @@ def __repr__(self) -> str: @property def parents(self) -> list[str]: return [] + + @override + def is_valid_value(self, value: float | npt.NDArray[float]) -> bool: + return True diff --git a/src/causalprog/graph/node/data.py b/src/causalprog/graph/node/data.py index a319025..6b87720 100644 --- a/src/causalprog/graph/node/data.py +++ b/src/causalprog/graph/node/data.py @@ -67,3 +67,7 @@ def __repr__(self) -> str: @property def parents(self) -> list[str]: return [] + + @override + def is_valid_value(self, value: float | npt.NDArray[float]) -> bool: + return True diff --git a/src/causalprog/graph/node/distribution.py b/src/causalprog/graph/node/distribution.py index ddde0b6..1147a1d 100644 --- a/src/causalprog/graph/node/distribution.py +++ b/src/causalprog/graph/node/distribution.py @@ -126,3 +126,7 @@ def create_model_site(self, **dependent_nodes: jax.Array) -> npt.ArrayLike: **self._constant_parameters, ), ) + + @override + def is_valid_value(self, value: float | npt.NDArray[float]) -> bool: + return True diff --git a/src/causalprog/graph/node/random_variables.py b/src/causalprog/graph/node/random_variables.py index 9f5dd38..bcf9584 100644 --- a/src/causalprog/graph/node/random_variables.py +++ b/src/causalprog/graph/node/random_variables.py @@ -2,7 +2,6 @@ import inspect import typing -from abc import abstractmethod import jax import numpy as np @@ -69,22 +68,6 @@ def evaluate( def parents(self) -> list[str]: return self._parents - @abstractmethod - def is_valid_value(self, value: float | npt.NDArray[float]) -> bool: - """Check if a value is valid for this node.""" - - def assert_is_valid_value(self, value: float | npt.NDArray[float]) -> None: - """Check if a value is valid for this node.""" - if not self.is_valid_value(value): - msg = ( - f"Invalid value for {self.__class__.__name__}: " - f"{self.label} cannot be {value}" - ) - raise ValueError(msg) - if self.shape != (value.shape if hasattr(value, "shape") else ()): - msg = f"Invalid value for node: {self.label}" - raise ValueError(msg) - class ContinuousRandomVariableNode(RandomVariableNode): """A node containing a continuous random variable (RV).""" diff --git a/tests/test_algorithms/test_evaluate.py b/tests/test_algorithms/test_evaluate.py index dedc8c7..5d376f3 100644 --- a/tests/test_algorithms/test_evaluate.py +++ b/tests/test_algorithms/test_evaluate.py @@ -107,3 +107,25 @@ def test_evaluate( assert jnp.allclose(computed_result_single, initial_values[outcome_node_label]) else: assert jnp.allclose(computed_result_single, computed_result[outcome_node_label]) + + +@pytest.mark.parametrize( + ("outcome_node_label", "initial_values", "expected_error"), + [ + pytest.param( + "C", + {"L": jnp.array([5.5]), "Z": jnp.array([2.0, 0.0]), "C": 4.5}, + ValueError("Invalid value for "), + id="Invalid value for C", + ) + ], +) +def test_evaluate_error( + evaluate_test_graph: Graph, + outcome_node_label: str, + initial_values: dict[str, Array], + expected_error: BaseException, + raises_context, +) -> None: + with raises_context(expected_error): + evaluate(evaluate_test_graph, outcome_node_label, **initial_values) From 6149e0b96439927ac1e5454576ebbb74c3c1d0d3 Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Mon, 29 Jun 2026 15:25:51 +0100 Subject: [PATCH 2/5] add test that evaluate fails on missing input --- tests/test_algorithms/test_evaluate.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tests/test_algorithms/test_evaluate.py b/tests/test_algorithms/test_evaluate.py index 5d376f3..4b20905 100644 --- a/tests/test_algorithms/test_evaluate.py +++ b/tests/test_algorithms/test_evaluate.py @@ -117,7 +117,13 @@ def test_evaluate( {"L": jnp.array([5.5]), "Z": jnp.array([2.0, 0.0]), "C": 4.5}, ValueError("Invalid value for "), id="Invalid value for C", - ) + ), + pytest.param( + "PhiX", + {"Z": jnp.array([2.0, 0.0]), "C": 4.0}, + ValueError("Missing input for node"), + id="Missing value for parent of PhiX", + ), ], ) def test_evaluate_error( From f5736c6065e0cc392e8a80190a006ba3adc79974 Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Tue, 30 Jun 2026 08:25:17 +0100 Subject: [PATCH 3/5] update tests following merge --- tests/test_algorithms/test_evaluate.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/test_algorithms/test_evaluate.py b/tests/test_algorithms/test_evaluate.py index ecd9acc..f4831f8 100644 --- a/tests/test_algorithms/test_evaluate.py +++ b/tests/test_algorithms/test_evaluate.py @@ -26,7 +26,7 @@ def evaluate_test_graph() -> Graph: [ pytest.param( "l", - {"l": jnp.array([5.5]), "x": jnp.array([2.0, 0.0]), "c": 4.0}, + {"l": jnp.array([5.5]), "x": 2.0, "c": 4.0}, {}, id="DataNode evaluation w/ excess information provided", ), @@ -113,16 +113,16 @@ def test_evaluate( ("outcome_node_label", "initial_values", "expected_error"), [ pytest.param( - "C", - {"L": jnp.array([5.5]), "Z": jnp.array([2.0, 0.0]), "C": 4.5}, + "c", + {"l": jnp.array([5.5]), "z": jnp.array([2.0, 0.0]), "c": 4.5}, ValueError("Invalid value for "), - id="Invalid value for C", + id="Invalid value for discrete RV node", ), pytest.param( - "PhiX", - {"Z": jnp.array([2.0, 0.0]), "C": 4.0}, + "phi_x", + {"z": jnp.array([2.0, 0.0]), "x": 4.0}, ValueError("Missing input for node"), - id="Missing value for parent of PhiX", + id="Missing value for a parent", ), ], ) @@ -134,4 +134,4 @@ def test_evaluate_error( raises_context, ) -> None: with raises_context(expected_error): - evaluate(evaluate_test_graph, outcome_node_label, **initial_values) + evaluate(evaluate_test_graph, outcome_node_label, initial_values) From 7693525f5fa841be03d6f77513e74b57caf0c11b Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Wed, 1 Jul 2026 08:43:53 +0100 Subject: [PATCH 4/5] return True by default --- src/causalprog/graph/node/base.py | 4 ++-- src/causalprog/graph/node/component.py | 4 ---- src/causalprog/graph/node/constant.py | 4 ---- src/causalprog/graph/node/data.py | 4 ---- src/causalprog/graph/node/distribution.py | 4 ---- src/causalprog/graph/node/random_variables.py | 4 ---- 6 files changed, 2 insertions(+), 22 deletions(-) diff --git a/src/causalprog/graph/node/base.py b/src/causalprog/graph/node/base.py index 42ca210..1e00647 100644 --- a/src/causalprog/graph/node/base.py +++ b/src/causalprog/graph/node/base.py @@ -169,9 +169,9 @@ def parents(self) -> list[str]: """ - @abstractmethod - def is_valid_value(self, value: float | npt.NDArray[float]) -> bool: + def is_valid_value(self, _value: float | npt.NDArray[float]) -> bool: """Check if a value is valid for this node.""" + return True def assert_is_valid_value(self, value: float | npt.NDArray[float]) -> None: """Check if a value is valid for this node.""" diff --git a/src/causalprog/graph/node/component.py b/src/causalprog/graph/node/component.py index f781973..5cd7290 100644 --- a/src/causalprog/graph/node/component.py +++ b/src/causalprog/graph/node/component.py @@ -82,7 +82,3 @@ def __repr__(self) -> str: @property def parents(self) -> list[str]: return [self._parent_node_label] - - @override - def is_valid_value(self, value: float | npt.NDArray[float]) -> bool: - return True diff --git a/src/causalprog/graph/node/constant.py b/src/causalprog/graph/node/constant.py index 1994d03..4e742a5 100644 --- a/src/causalprog/graph/node/constant.py +++ b/src/causalprog/graph/node/constant.py @@ -61,7 +61,3 @@ def __repr__(self) -> str: @property def parents(self) -> list[str]: return [] - - @override - def is_valid_value(self, value: float | npt.NDArray[float]) -> bool: - return True diff --git a/src/causalprog/graph/node/data.py b/src/causalprog/graph/node/data.py index f474cd7..97d185a 100644 --- a/src/causalprog/graph/node/data.py +++ b/src/causalprog/graph/node/data.py @@ -67,7 +67,3 @@ def __repr__(self) -> str: @property def parents(self) -> list[str]: return [] - - @override - def is_valid_value(self, value: float | npt.NDArray[float]) -> bool: - return True diff --git a/src/causalprog/graph/node/distribution.py b/src/causalprog/graph/node/distribution.py index 8d9997a..403db3d 100644 --- a/src/causalprog/graph/node/distribution.py +++ b/src/causalprog/graph/node/distribution.py @@ -126,7 +126,3 @@ def create_model_site(self, **dependent_nodes: jax.Array) -> npt.ArrayLike: **self._constant_parameters, ), ) - - @override - def is_valid_value(self, value: float | npt.NDArray[float]) -> bool: - return True diff --git a/src/causalprog/graph/node/random_variables.py b/src/causalprog/graph/node/random_variables.py index 13c9d91..97669c9 100644 --- a/src/causalprog/graph/node/random_variables.py +++ b/src/causalprog/graph/node/random_variables.py @@ -76,10 +76,6 @@ class ContinuousRandomVariableNode(RandomVariableNode): def __repr__(self) -> str: return f'ContinuousRandomVariableNode(label="{self.label}")' - @override - def is_valid_value(self, value: float | npt.NDArray[float]) -> bool: - return True - @override def copy(self) -> Node: return ContinuousRandomVariableNode( From b4332601a00864cb62a2fd6c7b09be7066e4583b Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Wed, 1 Jul 2026 08:46:41 +0100 Subject: [PATCH 5/5] remove double check in evaluate Co-authored-by: Will Graham <32364977+willGraham01@users.noreply.github.com> --- src/causalprog/algorithms/evaluate.py | 8 ++++---- tests/test_algorithms/test_evaluate.py | 13 +++++-------- 2 files changed, 9 insertions(+), 12 deletions(-) diff --git a/src/causalprog/algorithms/evaluate.py b/src/causalprog/algorithms/evaluate.py index 063dcdc..72c7925 100644 --- a/src/causalprog/algorithms/evaluate.py +++ b/src/causalprog/algorithms/evaluate.py @@ -22,6 +22,10 @@ def evaluate_down_to( """ for label, value in values.items(): graph.get_node(label).assert_is_valid_value(value) + + if outcome_node_label in values: + return {outcome_node_label: values[outcome_node_label]} + computed_values: dict[str, float | npt.NDArray[float]] = {} nodes_to_evaluate = [ n @@ -48,8 +52,4 @@ def evaluate( The evaluation of the node """ - for label, value in values.items(): - graph.get_node(label).assert_is_valid_value(value) - if outcome_node_label in values: - return values[outcome_node_label] return evaluate_down_to(graph, outcome_node_label, values)[outcome_node_label] diff --git a/tests/test_algorithms/test_evaluate.py b/tests/test_algorithms/test_evaluate.py index f4831f8..bf9d144 100644 --- a/tests/test_algorithms/test_evaluate.py +++ b/tests/test_algorithms/test_evaluate.py @@ -27,19 +27,19 @@ def evaluate_test_graph() -> Graph: pytest.param( "l", {"l": jnp.array([5.5]), "x": 2.0, "c": 4.0}, - {}, + {"l": jnp.array([5.5])}, id="DataNode evaluation w/ excess information provided", ), pytest.param( "z", {"z": jnp.array([2.0, 0.0])}, - {}, + {"z": jnp.array([2.0, 0.0])}, id="DataNode evaluation", ), pytest.param( "c", {"c": 4.0}, - {}, + {"c": 4.0}, id="DiscreteRVNode evaluation", ), pytest.param( @@ -51,7 +51,7 @@ def evaluate_test_graph() -> Graph: pytest.param( "u_x", {"c": 4.0, "u_x": 1.0}, - {}, + {"u_x": 1.0}, id="CtsRVNode evaluation, 'given that' overrides computed value", ), pytest.param( @@ -103,10 +103,7 @@ def test_evaluate( computed_result_single = evaluate( evaluate_test_graph, outcome_node_label, initial_values ) - if outcome_node_label in initial_values: - assert jnp.allclose(computed_result_single, initial_values[outcome_node_label]) - else: - assert jnp.allclose(computed_result_single, computed_result[outcome_node_label]) + assert jnp.allclose(computed_result_single, computed_result[outcome_node_label]) @pytest.mark.parametrize(