Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/causalprog/algorithms/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@ def evaluate_down_to(
A dictionary of the values of all the nodes that are ancestors of the input node

"""
for label, value in values.items():
graph.get_node(label).assert_is_valid_value(value)

if outcome_node_label in values:
return {outcome_node_label: values[outcome_node_label]}

computed_values: dict[str, float | npt.NDArray[float]] = {}
nodes_to_evaluate = [
n
Expand All @@ -46,6 +52,4 @@ def evaluate(
The evaluation of the node

"""
if outcome_node_label in values:
return values[outcome_node_label]
return evaluate_down_to(graph, outcome_node_label, values)[outcome_node_label]
16 changes: 16 additions & 0 deletions src/causalprog/graph/node/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,19 @@ def parents(self) -> list[str]:
List of labels of parent nodes

"""

def is_valid_value(self, _value: float | npt.NDArray[float]) -> bool:
"""Check if a value is valid for this node."""
return True

def assert_is_valid_value(self, value: float | npt.NDArray[float]) -> None:
"""Check if a value is valid for this node."""
if not self.is_valid_value(value):
msg = (
f"Invalid value for {self.__class__.__name__}: "
f"{self.label} cannot be {value}"
)
raise ValueError(msg)
if self.shape != (value.shape if hasattr(value, "shape") else ()):
msg = f"Invalid value for node: {self.label}"
raise ValueError(msg)
21 changes: 0 additions & 21 deletions src/causalprog/graph/node/random_variables.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Graph nodes representing random variables."""

import typing
from abc import abstractmethod

import jax
import numpy as np
Expand Down Expand Up @@ -69,22 +68,6 @@ def evaluate(
def parents(self) -> list[str]:
return self._parents

@abstractmethod
def is_valid_value(self, value: float | npt.NDArray[float]) -> bool:
"""Check if a value is valid for this node."""

def assert_is_valid_value(self, value: float | npt.NDArray[float]) -> None:
"""Check if a value is valid for this node."""
if not self.is_valid_value(value):
msg = (
f"Invalid value for {self.__class__.__name__}: "
f"{self.label} cannot be {value}"
)
raise ValueError(msg)
if self.shape != (value.shape if hasattr(value, "shape") else ()):
msg = f"Invalid value for node: {self.label}"
raise ValueError(msg)


class ContinuousRandomVariableNode(RandomVariableNode):
"""A node containing a continuous random variable (RV)."""
Expand All @@ -93,10 +76,6 @@ class ContinuousRandomVariableNode(RandomVariableNode):
def __repr__(self) -> str:
return f'ContinuousRandomVariableNode(label="{self.label}")'

@override
def is_valid_value(self, value: float | npt.NDArray[float]) -> bool:
return True

@override
def copy(self) -> Node:
return ContinuousRandomVariableNode(
Expand Down
43 changes: 34 additions & 9 deletions tests/test_algorithms/test_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,20 @@ def evaluate_test_graph() -> Graph:
[
pytest.param(
"l",
{"l": jnp.array([5.5]), "x": jnp.array([2.0, 0.0]), "c": 4.0},
{},
{"l": jnp.array([5.5]), "x": 2.0, "c": 4.0},
{"l": jnp.array([5.5])},
id="DataNode evaluation w/ excess information provided",
),
pytest.param(
"z",
{"z": jnp.array([2.0, 0.0])},
{},
{"z": jnp.array([2.0, 0.0])},
id="DataNode evaluation",
),
pytest.param(
"c",
{"c": 4.0},
{},
{"c": 4.0},
id="DiscreteRVNode evaluation",
),
pytest.param(
Expand All @@ -51,7 +51,7 @@ def evaluate_test_graph() -> Graph:
pytest.param(
"u_x",
{"c": 4.0, "u_x": 1.0},
{},
{"u_x": 1.0},
id="CtsRVNode evaluation, 'given that' overrides computed value",
),
pytest.param(
Expand Down Expand Up @@ -103,7 +103,32 @@ def test_evaluate(
computed_result_single = evaluate(
evaluate_test_graph, outcome_node_label, initial_values
)
if outcome_node_label in initial_values:
assert jnp.allclose(computed_result_single, initial_values[outcome_node_label])
else:
assert jnp.allclose(computed_result_single, computed_result[outcome_node_label])
assert jnp.allclose(computed_result_single, computed_result[outcome_node_label])


@pytest.mark.parametrize(
("outcome_node_label", "initial_values", "expected_error"),
[
pytest.param(
"c",
{"l": jnp.array([5.5]), "z": jnp.array([2.0, 0.0]), "c": 4.5},
ValueError("Invalid value for "),
id="Invalid value for discrete RV node",
),
pytest.param(
"phi_x",
{"z": jnp.array([2.0, 0.0]), "x": 4.0},
ValueError("Missing input for node"),
id="Missing value for a parent",
),
],
)
def test_evaluate_error(
evaluate_test_graph: Graph,
outcome_node_label: str,
initial_values: dict[str, Array],
expected_error: BaseException,
raises_context,
) -> None:
with raises_context(expected_error):
evaluate(evaluate_test_graph, outcome_node_label, initial_values)
Loading