-
Notifications
You must be signed in to change notification settings - Fork 0
Add evaluate algorithm #145
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
52 commits
Select commit
Hold shift + click to select a range
a8f2751
rename ParameterNode to DataNode
mscroggs df6bf52
add evaluate method
mscroggs 1d0738d
mypy?
mscroggs cbb10ea
ruff
mscroggs a956b83
typeerror
mscroggs b7b3332
rufff
mscroggs abbff6b
simplify some graph code
mscroggs 43149c5
Update src/causalprog/graph/node/base.py
mscroggs 6e68264
Update src/causalprog/graph/node/base.py
mscroggs 8f64aa4
Update tests/test_algorithms/test_evaluate.py
mscroggs 752005a
Apply suggestion from @mscroggs
mscroggs 7f1db29
fix tests
mscroggs 5ed79b5
ruff
mscroggs bb1b0f3
hasattr(..., "shape") instead of type check
mscroggs 2c23e01
let KeyError happen rather than asserting types
mscroggs 2815ff8
mypy?
mscroggs 42a2844
currect type
mscroggs 2e1eb1b
type: ignore
mscroggs dae7e6b
PGH003
mscroggs 16c04df
Merge branch 'mscroggs/data-node' into mscroggs/rvnodes
mscroggs 5717c16
typing
mscroggs 209b86e
corrections
mscroggs 06d0f69
add check that labels are valid variable names
mscroggs 7700ea5
ruff
mscroggs 29f7679
update tests and labels
mscroggs ac099fe
Add random variable nodes
mscroggs 30c0216
ruff
mscroggs 13ad547
ruff
mscroggs 76f7fb1
add evaluate algorithm
mscroggs ad4ccb7
start adding example graph
mscroggs dc7b643
add ricardo's graph (!)
mscroggs ccb5f0d
add test
mscroggs 01017bc
ruff
mscroggs f19ae94
Merge branch 'main' into mscroggs/evaluate
mscroggs 9fc5383
values not keys for parents, update some tests
mscroggs c7e5e88
sort toml
mscroggs 63d0b7c
ruff
mscroggs 83a1b51
Update src/causalprog/algorithms/evaluate.py
mscroggs 6b4ec2e
correct
mscroggs 9d65fcc
move graph to graph.special and rearrange tests
mscroggs b639584
add new file
mscroggs b06b6f5
add evaluate_down_to variant
mscroggs 89fd2b2
Update src/causalprog/algorithms/__init__.py
mscroggs 18432f7
RUFF
mscroggs 6b7ba9a
Merge branch 'mscroggs/evaluate' of github.com:UCL/causalprog into ms…
mscroggs a512b2d
Parametrize evaluate tests
mscroggs 1b92a35
don't copy big arrays
mscroggs 20da685
mypy
mscroggs c1ed858
get doesn't work here as it evaluates arguments BEFORE checking if ke…
mscroggs c0c54f8
Add test of evaluate_down_to
mscroggs 5cf59ec
RUFF
mscroggs 5f4a080
type
mscroggs File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,14 @@ | ||
| """Algorithms.""" | ||
|
|
||
| from .do import do | ||
| from .evaluate import evaluate, evaluate_down_to | ||
| from .moments import expectation, moment, standard_deviation | ||
|
|
||
| __all__ = ( | ||
| "do", | ||
| "evaluate", | ||
| "evaluate_down_to", | ||
| "expectation", | ||
| "moment", | ||
| "standard_deviation", | ||
| ) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,51 @@ | ||
| """Algorithms for evaluating a graph node.""" | ||
|
|
||
| import numpy.typing as npt | ||
|
|
||
| from causalprog.graph import Graph | ||
|
|
||
|
|
||
| def evaluate_down_to( | ||
| graph: Graph, outcome_node_label: str, **values: float | npt.NDArray[float] | ||
| ) -> dict[str, float | npt.NDArray[float]]: | ||
| """ | ||
| Evaluate all nodes down to a particular node. | ||
|
|
||
| Args: | ||
| graph: The graph that the node is contained in. | ||
| outcome_node_label: The label of the node to evaluate down to. | ||
| values: Values taken by nodes whose value is given | ||
|
|
||
| Returns: | ||
| A dictionary of the values of all the nodes that are ancestors of the input node | ||
|
|
||
| """ | ||
| computed_values: dict[str, float | npt.NDArray[float]] = {} | ||
| nodes_to_evaluate = [ | ||
| n | ||
| for n in graph.roots_down_to_outcome(outcome_node_label) | ||
| if n.label not in values | ||
| ] | ||
| for node in nodes_to_evaluate: | ||
| computed_values[node.label] = node.evaluate(**values, **computed_values) | ||
| return computed_values | ||
|
|
||
|
|
||
| def evaluate( | ||
| graph: Graph, outcome_node_label: str, **values: float | npt.NDArray[float] | ||
|
mscroggs marked this conversation as resolved.
|
||
| ) -> float | npt.NDArray[float]: | ||
| """ | ||
| Evaluate a node. | ||
|
|
||
| Args: | ||
| graph: The graph that the node is contained in. | ||
| outcome_node_label: The label of the node to evaluate. | ||
| values: Values taken by nodes whose value is given | ||
|
|
||
| Returns: | ||
| The evaluation of the node | ||
|
|
||
| """ | ||
| if outcome_node_label in values: | ||
| return values[outcome_node_label] | ||
| return evaluate_down_to(graph, outcome_node_label, **values)[outcome_node_label] | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,80 @@ | ||
| """Functions to create example graphs.""" | ||
|
|
||
| import inspect | ||
| from collections.abc import Callable | ||
|
|
||
| from causalprog.graph import ( | ||
| ContinuousRandomVariableNode, | ||
| DataNode, | ||
| DiscreteRandomVariableNode, | ||
| Graph, | ||
| ) | ||
|
|
||
|
|
||
| def example_model( | ||
|
mscroggs marked this conversation as resolved.
|
||
| *, | ||
| label: str = "G", | ||
| l_len: int = 1, | ||
| z_len: int = 1, | ||
| k: int = 10, | ||
| compute_u_x: Callable, | ||
| compute_u_y: Callable, | ||
| compute_phi_x: Callable, | ||
| compute_x: Callable, | ||
| compute_y: Callable, | ||
|
willGraham01 marked this conversation as resolved.
|
||
| ) -> Graph: | ||
| """ | ||
| Create a graph representing the example model. | ||
|
|
||
| Args: | ||
| label: The label of the graph. | ||
| l_len: The number of entries in the vector data node L. | ||
| z_len: The number of entries in the vector data node Z. | ||
| k: The maximum value that could be taken by the mixture indicator C. | ||
| compute_u_x: Compute UX given the value of C. | ||
| compute_u_y: Compute UY given the value of C. | ||
| compute_phi_x: Compute PhiX given the value of L. | ||
| compute_x: Compute X given the values of Z, PhiX and UX. | ||
| compute_y: Compute Y given the values of X and UY. | ||
|
|
||
| Returns: | ||
| A graph | ||
|
|
||
| """ | ||
| for p in inspect.signature(compute_u_x).parameters: | ||
| if p != "C": | ||
| msg = f"Invalid input to UX: {p}" | ||
| raise ValueError(msg) | ||
| for p in inspect.signature(compute_u_y).parameters: | ||
| if p != "C": | ||
| msg = f"Invalid input to UX: {p}" | ||
| raise ValueError(msg) | ||
| for p in inspect.signature(compute_phi_x).parameters: | ||
| if p != "L": | ||
| msg = f"Invalid input to PhiX: {p}" | ||
| raise ValueError(msg) | ||
| for p in inspect.signature(compute_x).parameters: | ||
| if p not in ["Z", "PhiX", "UX"]: | ||
| msg = f"Invalid input to X: {p}" | ||
| raise ValueError(msg) | ||
| for p in inspect.signature(compute_y).parameters: | ||
| if p not in ["X", "UY"]: | ||
| msg = f"Invalid input to X: {p}" | ||
| raise ValueError(msg) | ||
|
mscroggs marked this conversation as resolved.
|
||
|
|
||
| graph = Graph(label=label) | ||
|
|
||
| graph.add_node(DataNode(label="L", shape=(l_len,))) | ||
| graph.add_node(DataNode(label="Z", shape=(z_len,))) | ||
| graph.add_node( | ||
| DiscreteRandomVariableNode( | ||
| label="C", values=[float(i) for i in range(1, k + 1)] | ||
| ) | ||
| ) | ||
| graph.add_node(ContinuousRandomVariableNode(label="UX", compute=compute_u_x)) | ||
| graph.add_node(ContinuousRandomVariableNode(label="UY", compute=compute_u_y)) | ||
| graph.add_node(ContinuousRandomVariableNode(label="PhiX", compute=compute_phi_x)) | ||
| graph.add_node(ContinuousRandomVariableNode(label="X", compute=compute_x)) | ||
| graph.add_node(ContinuousRandomVariableNode(label="Y", compute=compute_y)) | ||
|
|
||
| return graph | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,48 +1,109 @@ | ||
| """Tests for evaluate algorithms.""" | ||
|
|
||
| import numpy as np | ||
| import jax.numpy as jnp | ||
| import pytest | ||
| from jax import Array | ||
|
|
||
| from causalprog.graph import ComponentNode, DataNode, DistributionNode | ||
| from causalprog.algorithms import evaluate, evaluate_down_to | ||
| from causalprog.graph import Graph | ||
| from causalprog.graph.special import example_model | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def evaluate_test_graph() -> Graph: | ||
| return example_model( | ||
| z_len=2, | ||
| compute_u_x=lambda C: C, | ||
| compute_u_y=lambda C: C + 1, | ||
| compute_phi_x=lambda L: L[0], | ||
| compute_x=lambda Z, PhiX, UX: Z[0] + UX - PhiX, | ||
| compute_y=lambda X, UY: X * UY, | ||
| ) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| ("node", "kwargs_to_evaluate", "expected_result"), | ||
| ("outcome_node_label", "initial_values", "expected_result"), | ||
| [ | ||
| pytest.param( | ||
| DataNode(label="A"), {"A": 2.0}, 2.0, id="Evaluate DataNode itself" | ||
| "L", | ||
| {"L": jnp.array([5.5]), "Z": jnp.array([2.0, 0.0]), "C": 4.0}, | ||
| {}, | ||
| id="DataNode evaluation w/ excess information provided", | ||
| ), | ||
| pytest.param( | ||
| ComponentNode("Parent", 1, label="Child"), | ||
| {"Parent": np.arange(4)}, | ||
| 1.0, | ||
| id="Evaluate ComponentNode, given parent", | ||
| "Z", | ||
| {"Z": jnp.array([2.0, 0.0])}, | ||
| {}, | ||
| id="DataNode evaluation", | ||
| ), | ||
| ], | ||
| ) | ||
| def test_evaluate_node(node, kwargs_to_evaluate, expected_result): | ||
| assert np.allclose(node.evaluate(**kwargs_to_evaluate), expected_result) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| ("node", "kwargs_for_evaluate", "expected_error"), | ||
| [ | ||
| pytest.param( | ||
| DataNode(label="A"), | ||
| "C", | ||
| {"C": 4.0}, | ||
| {}, | ||
| ValueError("Missing input for node: A"), | ||
| id="DataNode missing input value", | ||
| id="DiscreteRVNode evaluation", | ||
| ), | ||
| pytest.param( | ||
| "UX", | ||
| {"C": 4.0}, | ||
| {"UX": 4.0}, | ||
| id="CtsRVNode evaluation", | ||
| ), | ||
| pytest.param( | ||
| DistributionNode(distribution=None, label="A"), | ||
| "UX", | ||
| {"C": 4.0, "UX": 1.0}, | ||
| {}, | ||
| RuntimeError("Cannot evaluate a DistributionNode"), | ||
| id="Attempt to evaluate DistributionNode", | ||
| id="CtsRVNode evaluation, 'given that' overrides computed value", | ||
| ), | ||
| pytest.param( | ||
| "UY", | ||
| {"C": 4.0}, | ||
| {"UY": 5.0}, | ||
| id="CtsRVNode evaluation, with parents that need evaluating", | ||
| ), | ||
| pytest.param( | ||
| "X", | ||
| {"L": jnp.array([5.5]), "Z": jnp.array([2.0, 0.0]), "C": 4.0}, | ||
| {"UX": 4.0, "PhiX": 5.5, "X": 0.5}, | ||
| id="Multiple paths from different root nodes", | ||
| ), | ||
| pytest.param( | ||
| "X", | ||
| {"L": jnp.array([5.5]), "Z": jnp.array([2.0, 0.0]), "C": 4.0, "PhiX": 0.0}, | ||
| {"UX": 4.0, "X": 6.0}, | ||
| id="Multiple paths from different root nodes, with some given values", | ||
| ), | ||
| pytest.param( | ||
| "Y", | ||
| {"L": jnp.array([5.5]), "Z": jnp.array([2.0, 0.0]), "C": 4.0}, | ||
| {"UX": 4.0, "UY": 5.0, "PhiX": 5.5, "X": 0.5, "Y": 2.5}, | ||
| id="Evaluating the 'outcome' node.", | ||
| ), | ||
| ], | ||
| ) | ||
| def test_evaluate_node_fail_on_missing_data( | ||
| node, kwargs_for_evaluate, expected_error, raises_context | ||
| ): | ||
| with raises_context(expected_error): | ||
| node.evaluate(**kwargs_for_evaluate) | ||
| def test_evaluate( | ||
| evaluate_test_graph: Graph, | ||
| outcome_node_label: str, | ||
| initial_values: dict[str, Array], | ||
| expected_result: dict[str, Array], | ||
| ) -> None: | ||
| computed_result = evaluate_down_to( | ||
| evaluate_test_graph, outcome_node_label, **initial_values | ||
| ) | ||
|
|
||
| # Same number of entries | ||
| assert len(expected_result) == len(computed_result) | ||
| # Keys are correct | ||
| assert set(expected_result.keys()) == set(computed_result.keys()) | ||
|
|
||
| # All entries match to acceptable precision for floats | ||
| for node_label, computed_value in computed_result.items(): | ||
| assert jnp.allclose(computed_value, expected_result[node_label]) | ||
|
|
||
| # Just asking for one value did indeed extract the correct node value | ||
| computed_result_single = evaluate( | ||
| evaluate_test_graph, outcome_node_label, **initial_values | ||
| ) | ||
| if outcome_node_label in initial_values: | ||
| assert jnp.allclose(computed_result_single, initial_values[outcome_node_label]) | ||
| else: | ||
| assert jnp.allclose(computed_result_single, computed_result[outcome_node_label]) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.