From 5274228f04aa0d0dcccacae8d853affb46e7f6e5 Mon Sep 17 00:00:00 2001 From: Sofia Faltenbacher Date: Sun, 9 Feb 2025 13:53:05 +0100 Subject: [PATCH 1/3] feat(NonColliderTest): throw orientation conflict if it would lead to a cycle --- .../constraint/orientation_rules/pc.py | 8 +- causy/graph.py | 23 +++++- tests/test_effect_estimation.py | 2 +- tests/test_graph.py | 81 +++++++++++++++++++ 4 files changed, 108 insertions(+), 6 deletions(-) diff --git a/causy/causal_discovery/constraint/orientation_rules/pc.py b/causy/causal_discovery/constraint/orientation_rules/pc.py index aa0e1c1..c3e7598 100644 --- a/causy/causal_discovery/constraint/orientation_rules/pc.py +++ b/causy/causal_discovery/constraint/orientation_rules/pc.py @@ -3,6 +3,7 @@ from typing import Tuple, List, Optional, Generic import itertools +from causy.edge_types import DirectedEdge from causy.generators import AllCombinationsGenerator from causy.interfaces import ( BaseGraphInterface, @@ -254,12 +255,15 @@ def process( if graph.only_directed_edge_exists(x, z) and graph.undirected_edge_exists( z, y ): + # orientation conflict if rule would lead to a new collider or a cycle for node in graph.nodes: if graph.only_directed_edge_exists( graph.nodes[node], y ) and not graph.edge_exists(graph.nodes[node], z): breakflag = True break + if graph.directed_path_exists(y, x): + breakflag = True if breakflag is True: return TestResult( u=y, @@ -280,13 +284,15 @@ def process( if graph.only_directed_edge_exists(y, z) and graph.undirected_edge_exists( z, x ): + # orientation conflict if rule would lead to a new collider or a cycle for node in graph.nodes: if graph.only_directed_edge_exists( graph.nodes[node], x ) and not graph.edge_exists(graph.nodes[node], z): breakflag = True break - + if graph.directed_path_exists(x, y): + breakflag = True if breakflag is True: return TestResult( u=x, diff --git a/causy/graph.py b/causy/graph.py index 9e22e05..7c1ad29 100644 --- a/causy/graph.py +++ b/causy/graph.py @@ -254,11 +254,12 @@ def retrieve_edge_history( return [i for i in self.edge_history[(u, v)] if i.action == action] - def directed_path_exists(self, u: Union[Node, str], v: Union[Node, str]) -> bool: + def directed_path_exists(self, u: Union[Node, str], v: Union[Node, str], visited: Set[str] = None) -> bool: """ Check if a directed path from u to v exists :param u: node u :param v: node v + :param visited: set of visited nodes to avoid infinite recursion :return: True if a directed path exists, False otherwise """ @@ -267,11 +268,25 @@ def directed_path_exists(self, u: Union[Node, str], v: Union[Node, str]) -> bool if isinstance(v, Node): v = v.id - if self.directed_edge_exists(u, v): + if visited is None: + visited = set() + + # If already visited, return False to avoid infinite loop + if u in visited: + return False + + # Mark the node as visited + visited.add(u) + + # Direct edge check + if self.edge_of_type_exists(u, v, DirectedEdge()): return True - for w in self.edges[u]: - if self.directed_path_exists(self.nodes[w], v): + + # Recursive DFS through neighbors + for w in self.edges.get(u, []): # Use .get() to avoid KeyError if u is not in self.edges + if self.directed_path_exists(w, v, visited): return True + return False def edge_of_type_exists( diff --git a/tests/test_effect_estimation.py b/tests/test_effect_estimation.py index 77c99fb..3798a32 100644 --- a/tests/test_effect_estimation.py +++ b/tests/test_effect_estimation.py @@ -205,7 +205,7 @@ def test_direct_effect_estimation_partially_directed(self): ) tst = PC() - sample_size = 100000 + sample_size = 10000 test_data, graph = model.generate(sample_size) tst.create_graph_from_data(test_data) tst.create_all_possible_edges() diff --git a/tests/test_graph.py b/tests/test_graph.py index 63b55f8..8e2fcc6 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -678,3 +678,84 @@ def test_are_nodes_d_separated_cpdag_three_nodes_fully_connected_undirected_fals graph.add_edge(node1, node3, {"test": "test"}) self.assertFalse(graph.are_nodes_d_separated_cpdag(node1, node3, [])) self.assertFalse(graph.are_nodes_d_separated_cpdag(node1, node3, [node2])) + + def test_directed_path_exists(self): + graph = GraphManager() + node1 = graph.add_node("test1", [1, 2, 3]) + node2 = graph.add_node("test2", [1, 2, 3]) + graph.add_directed_edge(node1, node2, {"test": "test"}) + self.assertTrue(graph.directed_path_exists(node1, node2)) + self.assertFalse(graph.directed_path_exists(node2, node1)) + + def test_directed_path_exists_cycle(self): + graph = GraphManager() + node1 = graph.add_node("test1", [1, 2, 3]) + node2 = graph.add_node("test2", [1, 2, 3]) + node3 = graph.add_node("test3", [1, 2, 3]) + graph.add_directed_edge(node1, node2, {"test": "test"}) + graph.add_directed_edge(node2, node3, {"test": "test"}) + graph.add_directed_edge(node3, node1, {"test": "test"}) + self.assertTrue(graph.directed_path_exists(node1, node2)) + self.assertTrue(graph.directed_path_exists(node2, node1)) + + def test_directed_path_mediated_path(self): + graph = GraphManager() + node1 = graph.add_node("test1", [1, 2, 3]) + node2 = graph.add_node("test2", [1, 2, 3]) + node3 = graph.add_node("test3", [1, 2, 3]) + node4 = graph.add_node("test4", [1, 2, 3]) + graph.add_directed_edge(node1, node2, {"test": "test"}) + graph.add_directed_edge(node2, node3, {"test": "test"}) + graph.add_directed_edge(node3, node4, {"test": "test"}) + self.assertTrue(graph.directed_path_exists(node1, node4)) + self.assertFalse(graph.directed_path_exists(node4, node1)) + self.assertTrue(graph.directed_path_exists(node1, node3)) + self.assertFalse(graph.directed_path_exists(node3, node1)) + self.assertTrue(graph.directed_path_exists(node2, node4)) + self.assertFalse(graph.directed_path_exists(node4, node2)) + self.assertTrue(graph.directed_path_exists(node2, node3)) + self.assertFalse(graph.directed_path_exists(node3, node2)) + + def test_directed_path_mediated_path_several_mediated_paths(self): + graph = GraphManager() + node1 = graph.add_node("test1", [1, 2, 3]) + node2 = graph.add_node("test2", [1, 2, 3]) + node3 = graph.add_node("test3", [1, 2, 3]) + node4 = graph.add_node("test4", [1, 2, 3]) + node5 = graph.add_node("test5", [1, 2, 3]) + graph.add_directed_edge(node1, node2, {"test": "test"}) + graph.add_directed_edge(node2, node3, {"test": "test"}) + graph.add_directed_edge(node3, node4, {"test": "test"}) + graph.add_directed_edge(node4, node5, {"test": "test"}) + graph.add_directed_edge(node1, node5, {"test": "test"}) + graph.add_directed_edge(node1, node4, {"test": "test"}) + graph.add_directed_edge(node2, node5, {"test": "test"}) + + self.assertTrue(graph.directed_path_exists(node1, node4)) + self.assertFalse(graph.directed_path_exists(node4, node1)) + self.assertTrue(graph.directed_path_exists(node1, node3)) + self.assertFalse(graph.directed_path_exists(node3, node1)) + self.assertTrue(graph.directed_path_exists(node2, node4)) + self.assertFalse(graph.directed_path_exists(node4, node2)) + self.assertTrue(graph.directed_path_exists(node2, node3)) + self.assertFalse(graph.directed_path_exists(node3, node2)) + self.assertTrue(graph.directed_path_exists(node1, node4)) + self.assertFalse(graph.directed_path_exists(node4, node1)) + + def test_directed_path_mediated_path_undirected_edges(self): + graph = GraphManager() + node1 = graph.add_node("test1", [1, 2, 3]) + node2 = graph.add_node("test2", [1, 2, 3]) + node3 = graph.add_node("test3", [1, 2, 3]) + node4 = graph.add_node("test4", [1, 2, 3]) + graph.add_edge(node1, node2, {"test": "test"}) + graph.add_edge(node2, node3, {"test": "test"}) + graph.add_edge(node3, node4, {"test": "test"}) + self.assertFalse(graph.directed_path_exists(node1, node4)) + self.assertFalse(graph.directed_path_exists(node4, node1)) + self.assertFalse(graph.directed_path_exists(node1, node3)) + self.assertFalse(graph.directed_path_exists(node3, node1)) + self.assertFalse(graph.directed_path_exists(node2, node4)) + self.assertFalse(graph.directed_path_exists(node4, node2)) + self.assertFalse(graph.directed_path_exists(node2, node3)) + self.assertFalse(graph.directed_path_exists(node3, node2)) From d70de26a3446940c8760209d4e518cb850ed5fdc Mon Sep 17 00:00:00 2001 From: Sofia Faltenbacher Date: Sun, 9 Feb 2025 14:40:17 +0100 Subject: [PATCH 2/3] refactor(graph): directed_path_exists --- causy/graph.py | 2 +- tests/test_effect_estimation.py | 49 ++++-------------- tests/test_graph.py | 88 +++++++++++++++++++++++++++++++-- tests/test_orientation_tests.py | 72 +++++++++++++++++++++++++++ 4 files changed, 168 insertions(+), 43 deletions(-) diff --git a/causy/graph.py b/causy/graph.py index 7c1ad29..80701d7 100644 --- a/causy/graph.py +++ b/causy/graph.py @@ -284,7 +284,7 @@ def directed_path_exists(self, u: Union[Node, str], v: Union[Node, str], visited # Recursive DFS through neighbors for w in self.edges.get(u, []): # Use .get() to avoid KeyError if u is not in self.edges - if self.directed_path_exists(w, v, visited): + if self.edge_of_type_exists(u, w, DirectedEdge()) and self.directed_path_exists(w, v, visited): return True return False diff --git a/tests/test_effect_estimation.py b/tests/test_effect_estimation.py index 3798a32..d2c3580 100644 --- a/tests/test_effect_estimation.py +++ b/tests/test_effect_estimation.py @@ -2,7 +2,7 @@ PC_ORIENTATION_RULES, PC_EDGE_TYPES, PC_GRAPH_UI_EXTENSION, - PC_DEFAULT_THRESHOLD, + PC_DEFAULT_THRESHOLD, PC, ) from causy.causal_discovery.constraint.independence_tests.common import ( CorrelationCoefficientTest, @@ -85,47 +85,18 @@ def test_direct_effect_estimation_trivial_case(self): ) def test_direct_effect_estimation_basic_example(self): - PC = graph_model_factory( - Algorithm( - pipeline_steps=[ - CalculatePearsonCorrelations( - display_name="Calculate Pearson Correlations" - ), - CorrelationCoefficientTest( - threshold=VariableReference(name="threshold"), - display_name="Correlation Coefficient Test", - ), - PartialCorrelationTest( - threshold=VariableReference(name="threshold"), - display_name="Partial Correlation Test", - ), - ExtendedPartialCorrelationTestMatrix( - threshold=VariableReference(name="threshold"), - display_name="Extended Partial Correlation Test Matrix", - ), - *PC_ORIENTATION_RULES, - ComputeDirectEffectsInDAGsMultivariateRegression( - display_name="Compute Direct Effects" - ), - ], - edge_types=PC_EDGE_TYPES, - extensions=[PC_GRAPH_UI_EXTENSION], - name="PC", - variables=[FloatVariable(name="threshold", value=PC_DEFAULT_THRESHOLD)], - ) - ) model = IIDSampleGenerator( edges=[ - SampleEdge(NodeReference("X"), NodeReference("Z"), 5), - SampleEdge(NodeReference("Y"), NodeReference("Z"), 6), - SampleEdge(NodeReference("Z"), NodeReference("V"), 3), - SampleEdge(NodeReference("Z"), NodeReference("W"), 4), + SampleEdge(NodeReference("X"), NodeReference("Z"), 1), + SampleEdge(NodeReference("Y"), NodeReference("Z"), 1), + SampleEdge(NodeReference("Z"), NodeReference("V"), 1), + SampleEdge(NodeReference("Z"), NodeReference("W"), 1), ], ) tst = PC() - sample_size = 1000000 + sample_size = 10000 test_data, graph = model.generate(sample_size) tst.create_graph_from_data(test_data) tst.create_all_possible_edges() @@ -137,14 +108,14 @@ def test_direct_effect_estimation_basic_example(self): tst.graph.edge_value(tst.graph.nodes["X"], tst.graph.nodes["Z"])[ "direct_effect" ], - 5.0, + 1.0, 0, ) self.assertAlmostEqual( tst.graph.edge_value(tst.graph.nodes["Y"], tst.graph.nodes["Z"])[ "direct_effect" ], - 6.0, + 1.0, 0, ) @@ -152,7 +123,7 @@ def test_direct_effect_estimation_basic_example(self): tst.graph.edge_value(tst.graph.nodes["Z"], tst.graph.nodes["V"])[ "direct_effect" ], - 3.0, + 1.0, 0, ) @@ -160,7 +131,7 @@ def test_direct_effect_estimation_basic_example(self): tst.graph.edge_value(tst.graph.nodes["Z"], tst.graph.nodes["W"])[ "direct_effect" ], - 4.0, + 1.0, 0, ) diff --git a/tests/test_graph.py b/tests/test_graph.py index 8e2fcc6..82daea4 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -697,8 +697,12 @@ def test_directed_path_exists_cycle(self): graph.add_directed_edge(node3, node1, {"test": "test"}) self.assertTrue(graph.directed_path_exists(node1, node2)) self.assertTrue(graph.directed_path_exists(node2, node1)) + self.assertTrue(graph.directed_path_exists(node1, node3)) + self.assertTrue(graph.directed_path_exists(node3, node1)) + self.assertTrue(graph.directed_path_exists(node2, node3)) + self.assertTrue(graph.directed_path_exists(node3, node2)) - def test_directed_path_mediated_path(self): + def test_directed_path_exists_mediated_path(self): graph = GraphManager() node1 = graph.add_node("test1", [1, 2, 3]) node2 = graph.add_node("test2", [1, 2, 3]) @@ -716,7 +720,7 @@ def test_directed_path_mediated_path(self): self.assertTrue(graph.directed_path_exists(node2, node3)) self.assertFalse(graph.directed_path_exists(node3, node2)) - def test_directed_path_mediated_path_several_mediated_paths(self): + def test_directed_path_exists_mediated_path_several_mediated_paths(self): graph = GraphManager() node1 = graph.add_node("test1", [1, 2, 3]) node2 = graph.add_node("test2", [1, 2, 3]) @@ -742,7 +746,7 @@ def test_directed_path_mediated_path_several_mediated_paths(self): self.assertTrue(graph.directed_path_exists(node1, node4)) self.assertFalse(graph.directed_path_exists(node4, node1)) - def test_directed_path_mediated_path_undirected_edges(self): + def test_directed_path_exists_mediated_path_undirected_edges(self): graph = GraphManager() node1 = graph.add_node("test1", [1, 2, 3]) node2 = graph.add_node("test2", [1, 2, 3]) @@ -759,3 +763,81 @@ def test_directed_path_mediated_path_undirected_edges(self): self.assertFalse(graph.directed_path_exists(node4, node2)) self.assertFalse(graph.directed_path_exists(node2, node3)) self.assertFalse(graph.directed_path_exists(node3, node2)) + + def test_directed_path_exists_mediated_path_undirected_and_directed_edges(self): + graph = GraphManager() + node1 = graph.add_node("test1", [1, 2, 3]) + node2 = graph.add_node("test2", [1, 2, 3]) + node3 = graph.add_node("test3", [1, 2, 3]) + graph.add_directed_edge(node1, node2, {"test": "test"}) + graph.add_edge(node2, node3, {"test": "test"}) + + self.assertFalse(graph.directed_path_exists(node1, node3)) + + def test_directed_path_exists_mediated_path_undirected_and_directed_edges_2(self): + graph = GraphManager() + node1 = graph.add_node("test1", [1, 2, 3]) + node2 = graph.add_node("test2", [1, 2, 3]) + node3 = graph.add_node("test3", [1, 2, 3]) + graph.add_edge(node1, node2, {"test": "test"}) + graph.add_directed_edge(node2, node3, {"test": "test"}) + + self.assertFalse(graph.directed_path_exists(node1, node3)) + + def test_directed_path_exists_mediated_path_undirected_and_directed_edges_3(self): + graph = GraphManager() + node1 = graph.add_node("test1", [1, 2, 3]) + node2 = graph.add_node("test2", [1, 2, 3]) + node3 = graph.add_node("test3", [1, 2, 3]) + node4 = graph.add_node("test4", [1, 2, 3]) + node5 = graph.add_node("test5", [1, 2, 3]) + graph.add_directed_edge(node1, node2, {"test": "test"}) + graph.add_edge(node2, node3, {"test": "test"}) + graph.add_directed_edge(node3, node4, {"test": "test"}) + graph.add_directed_edge(node4, node5, {"test": "test"}) + + self.assertFalse(graph.directed_path_exists(node1, node5)) + self.assertFalse(graph.directed_path_exists(node1, node4)) + self.assertFalse(graph.directed_path_exists(node1, node3)) + self.assertTrue(graph.directed_path_exists(node1, node2)) + self.assertTrue(graph.directed_path_exists(node3, node5)) + + def test_directed_path_exists_mediated_path_undirected_and_directed_edges_4(self): + graph = GraphManager() + node1 = graph.add_node("test1", [1, 2, 3]) + node2 = graph.add_node("test2", [1, 2, 3]) + node3 = graph.add_node("test3", [1, 2, 3]) + node4 = graph.add_node("test4", [1, 2, 3]) + node5 = graph.add_node("test5", [1, 2, 3]) + graph.add_directed_edge(node1, node2, {"test": "test"}) + graph.add_directed_edge(node2, node3, {"test": "test"}) + graph.add_edge(node3, node4, {"test": "test"}) + graph.add_directed_edge(node4, node5, {"test": "test"}) + + self.assertFalse(graph.directed_path_exists(node1, node5)) + self.assertFalse(graph.directed_path_exists(node1, node4)) + self.assertFalse(graph.directed_path_exists(node2, node5)) + self.assertTrue(graph.directed_path_exists(node1, node2)) + self.assertTrue(graph.directed_path_exists(node1, node3)) + + def test_directed_path_exists_mediated_paths_undirected_and_directed_edges(self): + graph = GraphManager() + X = graph.add_node("X", [1, 2, 3]) + Y = graph.add_node("Y", [1, 2, 3]) + Z = graph.add_node("Z", [1, 2, 3]) + W = graph.add_node("W", [1, 2, 3]) + V = graph.add_node("V", [1, 2, 3]) + + graph.add_directed_edge(X, Y, {"test": "test"}) + graph.add_edge(Y, Z, {"test": "test"}) + graph.add_edge(Z, W, {"test": "test"}) + graph.add_directed_edge(W, X, {"test": "test"}) + graph.add_directed_edge(V, X, {"test": "test"}) + + self.assertTrue(graph.directed_path_exists(V, X)) + self.assertTrue(graph.directed_path_exists(V, Y)) + self.assertFalse(graph.directed_path_exists(V, Z)) + self.assertFalse(graph.directed_path_exists(Z, X)) + + + diff --git a/tests/test_orientation_tests.py b/tests/test_orientation_tests.py index f65a277..5207e13 100644 --- a/tests/test_orientation_tests.py +++ b/tests/test_orientation_tests.py @@ -788,3 +788,75 @@ def test_further_orient_quadruple_test(self): self.assertFalse(model.graph.undirected_edge_exists(x, z)) self.assertFalse(model.graph.only_directed_edge_exists(z, x)) self.assertTrue(model.graph.only_directed_edge_exists(x, z)) + + def test_avoid_cycles_four_nodes(self): + pipeline = [Loop( + pipeline_steps=[ + NonColliderTest(), + FurtherOrientTripleTest(), + OrientQuadrupleTest(), + FurtherOrientQuadrupleTest(), + ], + exit_condition=ExitOnNoActions(), + ),] + model = graph_model_factory( + Algorithm( + pipeline_steps=pipeline, + edge_types=[DirectedEdge(), UndirectedEdge()], + name="TestNonColliderAvoidCycles", + ) + )() + model.graph = GraphManager() + x = model.graph.add_node("X", []) + y = model.graph.add_node("Y", []) + z = model.graph.add_node("Z", []) + w = model.graph.add_node("W", []) + model.graph.add_edge(x, y, {}) + model.graph.add_edge(y, z, {}) + model.graph.add_directed_edge(z, x, {}) + model.graph.add_directed_edge(w, x, {}) + model.execute_pipeline_steps() + # sanity check + self.assertTrue(model.graph.edge_of_type_exists(z, x, DirectedEdge())) + self.assertTrue(model.graph.edge_of_type_exists(w, x, DirectedEdge())) + + self.assertTrue(model.graph.edge_of_type_exists(x, y, DirectedEdge())) + self.assertTrue(model.graph.edge_of_type_exists(z, y, DirectedEdge())) + + + def test_avoid_cycles_five_nodes(self): + pipeline = [Loop( + pipeline_steps=[ + NonColliderTest(), + FurtherOrientTripleTest(), + OrientQuadrupleTest(), + FurtherOrientQuadrupleTest(), + ], + exit_condition=ExitOnNoActions(), + ),] + model = graph_model_factory( + Algorithm( + pipeline_steps=pipeline, + edge_types=[DirectedEdge(), UndirectedEdge()], + name="TestNonColliderAvoidCycles", + ) + )() + model.graph = GraphManager() + x = model.graph.add_node("X", []) + y = model.graph.add_node("Y", []) + z = model.graph.add_node("Z", []) + w = model.graph.add_node("W", []) + v = model.graph.add_node("V", []) + model.graph.add_edge(x, y, {}) + model.graph.add_edge(y, z, {}) + model.graph.add_edge(z, w, {}) + model.graph.add_directed_edge(w, x, {}) + model.graph.add_directed_edge(v, x, {}) + model.execute_pipeline_steps() + # sanity check + self.assertTrue(model.graph.edge_of_type_exists(w, x, DirectedEdge())) + self.assertTrue(model.graph.edge_of_type_exists(v, x, DirectedEdge())) + + self.assertTrue(model.graph.edge_of_type_exists(x, y, DirectedEdge())) + self.assertTrue(model.graph.edge_of_type_exists(y, z, DirectedEdge())) + self.assertFalse(model.graph.edge_of_type_exists(z, w, DirectedEdge())) \ No newline at end of file From cb581b18f33d3221e5764a0e0ee89f00c8d6a3c7 Mon Sep 17 00:00:00 2001 From: Sofia Faltenbacher Date: Fri, 28 Mar 2025 23:01:28 +0100 Subject: [PATCH 3/3] wip --- tests/liliths_wrapper_demo.py | 54 ++++++++++++++++++++ tests/test_graph.py | 14 ++++++ tests/test_pc_e2e.py | 95 ++++++++++++++++++++++++++++++++++- 3 files changed, 162 insertions(+), 1 deletion(-) create mode 100644 tests/liliths_wrapper_demo.py diff --git a/tests/liliths_wrapper_demo.py b/tests/liliths_wrapper_demo.py new file mode 100644 index 0000000..cb22f2a --- /dev/null +++ b/tests/liliths_wrapper_demo.py @@ -0,0 +1,54 @@ +from typing import Generic, List, Optional + +from causy.causal_discovery.constraint.orientation_rules.pc import FurtherOrientQuadrupleTest, ColliderTest +from causy.interfaces import PipelineStepInterface, PipelineStepInterfaceType, BaseGraphInterface, TestResultInterface, \ + logger + + +def actual_fn(a,b): + return a + b + + +def wrapper_fn(a,b): + result = actual_fn(a, b) + return result * -1 + + + +class CustomOrientationRule(PipelineStepInterface[PipelineStepInterfaceType], + Generic[PipelineStepInterfaceType]): + + _inner_step_cls = ColliderTest + + def __int__(self, *args, **kwargs): + self.inner_step = self._inner_step_cls(*args, **kwargs) + super().__init__(*args, **kwargs) + + def process(self, nodes: List[str], graph: BaseGraphInterface, + unapplied_actions: Optional[List[TestResultInterface]] = None) -> Optional[TestResultInterface]: + result = self.inner_step.process(nodes, graph, unapplied_actions=unapplied_actions) + + # get all unshielded triples + + + for proposed_action in result.all_proposed_actions: + if "separatedBy" in proposed_action.data: + return result + + def __call__( + self, + nodes: List[str], + graph: BaseGraphInterface, + unapplied_actions: Optional[List[TestResultInterface]] = None, + ) -> Optional[TestResultInterface]: + if self.needs_unapplied_actions and unapplied_actions is None: + logger.warn( + f"Pipeline step {self.name} needs unapplied actions but none were provided" + ) + elif self.needs_unapplied_actions and unapplied_actions is not None: + return self.process(nodes, graph, unapplied_actions) + + return self.process(nodes, graph) + + + diff --git a/tests/test_graph.py b/tests/test_graph.py index 82daea4..143ae7d 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -666,6 +666,20 @@ def test_are_nodes_d_separated_cpdag_four_nodes_with_colliders(self): self.assertFalse(graph.are_nodes_d_separated_cpdag(node1, node3, [node2])) self.assertTrue(graph.are_nodes_d_separated_cpdag(node1, node3, [])) + def test_are_nodes_d_separated_cpdag_condition_on_child_of_collider(self): + graph = GraphManager() + node1 = graph.add_node("test1", [1, 2, 3]) + node2 = graph.add_node("test2", [1, 2, 3]) + node3 = graph.add_node("test3", [1, 2, 3]) + node4 = graph.add_node("test4", [1, 2, 3]) + graph.add_directed_edge(node1, node2, {"test": "test"}) + graph.add_directed_edge(node3, node2, {"test": "test"}) + graph.add_directed_edge(node2, node4, {"test": "test"}) + self.assertTrue(graph.are_nodes_d_separated_cpdag(node1, node3, [])) + self.assertFalse(graph.are_nodes_d_separated_cpdag(node1, node3, [node2])) + self.assertFalse(graph.are_nodes_d_separated_cpdag(node1, node3, [node4])) + self.assertFalse(graph.are_nodes_d_separated_cpdag(node1, node3, [node2, node4])) + def test_are_nodes_d_separated_cpdag_three_nodes_fully_connected_undirected_false( self, ): diff --git a/tests/test_pc_e2e.py b/tests/test_pc_e2e.py index 83acf52..fadb57c 100644 --- a/tests/test_pc_e2e.py +++ b/tests/test_pc_e2e.py @@ -80,7 +80,100 @@ def test_pc_e2e_auto_mpg(self): variables=[FloatVariable(name="threshold", value=0.05)], ) ) - pc = PC_LOCAL() + pc = PC() + pc.create_graph_from_data(auto_mpg_data_set) + pc.create_all_possible_edges() + pc.execute_pipeline_steps() + + # skeleton + self.assertEqual(pc.graph.edge_exists("mpg", "weight"), True) + self.assertEqual(pc.graph.edge_exists("mpg", "horsepower"), True) + self.assertEqual(pc.graph.edge_exists("weight", "displacement"), True) + self.assertEqual(pc.graph.edge_exists("weight", "horsepower"), True) + self.assertEqual(pc.graph.edge_exists("displacement", "cylinders"), True) + self.assertEqual(pc.graph.edge_exists("displacement", "acceleration"), True) + self.assertEqual(pc.graph.edge_exists("displacement", "horsepower"), True) + self.assertEqual(pc.graph.edge_exists("horsepower", "acceleration"), True) + + # assert all other edges are not present + self.assertEqual(pc.graph.edge_exists("mpg", "displacement"), False) + self.assertEqual(pc.graph.edge_exists("mpg", "cylinders"), False) + self.assertEqual(pc.graph.edge_exists("mpg", "acceleration"), False) + self.assertEqual(pc.graph.edge_exists("weight", "cylinders"), False) + self.assertEqual(pc.graph.edge_exists("weight", "acceleration"), False) + self.assertEqual(pc.graph.edge_exists("acceleration", "cylinders"), False) + self.assertEqual(pc.graph.edge_exists("horsepower", "cylinders"), False) + + # directions + self.assertEqual( + pc.graph.edge_of_type_exists("mpg", "weight", UndirectedEdge()), True + ) + self.assertEqual( + pc.graph.edge_of_type_exists("weight", "horsepower", DirectedEdge()), True + ) + self.assertEqual( + pc.graph.edge_of_type_exists("weight", "displacement", DirectedEdge()), True + ) + self.assertEqual( + pc.graph.edge_of_type_exists("mpg", "horsepower", DirectedEdge()), True + ) + self.assertEqual( + pc.graph.edge_of_type_exists("acceleration", "horsepower", DirectedEdge()), + True, + ) + self.assertEqual( + pc.graph.edge_of_type_exists( + "acceleration", "displacement", DirectedEdge() + ), + True, + ) + self.assertEqual( + pc.graph.edge_of_type_exists("displacement", "cylinders", DirectedEdge()), + True, + ) + # due to order-dependency, two cases are possible + self.assertEqual( + pc.graph.edge_of_type_exists("horsepower", "displacement", DirectedEdge()) or pc.graph.edge_of_type_exists("displacement", "horsepower", DirectedEdge()), + True, + ) + + def test_pc_classic_e2e_auto_mpg(self): + script_dir = os.path.dirname(os.path.abspath(__file__)) + folder_auto_mpg = os.path.join(script_dir, "fixtures/auto_mpg/") + with open(f"{folder_auto_mpg}auto_mpg.json", "r") as f: + auto_mpg_data_set = json.load(f) + PC_CLASSIC_LOCAL = graph_model_factory( + Algorithm( + pipeline_steps=[ + CalculatePearsonCorrelations( + display_name="Calculate Pearson Correlations" + ), + CorrelationCoefficientTest( + threshold=VariableReference(name="threshold"), + display_name="Correlation Coefficient Test", + ), + ExtendedPartialCorrelationTestMatrix( + threshold=VariableReference(name="threshold"), + display_name="Extended Partial Correlation Test Matrix", + generator=PairsWithNeighboursGenerator( + comparison_settings=ComparisonSettings( + min=3, max=AS_MANY_AS_FIELDS + ), + shuffle_combinations=False, + ), + ), + *PC_ORIENTATION_RULES, + ComputeDirectEffectsInDAGsMultivariateRegression( + display_name="Compute Direct Effects in DAGs Multivariate Regression" + ), + ], + edge_types=PC_EDGE_TYPES, + extensions=[PC_GRAPH_UI_EXTENSION], + name="PC", + variables=[FloatVariable(name="threshold", value=0.05)], + ) + ) + pc = PCClassic() pc.create_graph_from_data(auto_mpg_data_set) pc.create_all_possible_edges() pc.execute_pipeline_steps()