Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion causy/causal_discovery/constraint/orientation_rules/pc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Tuple, List, Optional, Generic
import itertools

from causy.edge_types import DirectedEdge
from causy.generators import AllCombinationsGenerator
from causy.interfaces import (
BaseGraphInterface,
Expand Down Expand Up @@ -254,12 +255,15 @@ def process(
if graph.only_directed_edge_exists(x, z) and graph.undirected_edge_exists(
z, y
):
# orientation conflict if rule would lead to a new collider or a cycle
for node in graph.nodes:
if graph.only_directed_edge_exists(
graph.nodes[node], y
) and not graph.edge_exists(graph.nodes[node], z):
breakflag = True
break
if graph.directed_path_exists(y, x):
breakflag = True
if breakflag is True:
return TestResult(
u=y,
Expand All @@ -280,13 +284,15 @@ def process(
if graph.only_directed_edge_exists(y, z) and graph.undirected_edge_exists(
z, x
):
# orientation conflict if rule would lead to a new collider or a cycle
for node in graph.nodes:
if graph.only_directed_edge_exists(
graph.nodes[node], x
) and not graph.edge_exists(graph.nodes[node], z):
breakflag = True
break

if graph.directed_path_exists(x, y):
breakflag = True
if breakflag is True:
return TestResult(
u=x,
Expand Down
23 changes: 19 additions & 4 deletions causy/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,11 +254,12 @@ def retrieve_edge_history(

return [i for i in self.edge_history[(u, v)] if i.action == action]

def directed_path_exists(self, u: Union[Node, str], v: Union[Node, str]) -> bool:
def directed_path_exists(self, u: Union[Node, str], v: Union[Node, str], visited: Set[str] = None) -> bool:
"""
Check if a directed path from u to v exists
:param u: node u
:param v: node v
:param visited: set of visited nodes to avoid infinite recursion
:return: True if a directed path exists, False otherwise
"""

Expand All @@ -267,11 +268,25 @@ def directed_path_exists(self, u: Union[Node, str], v: Union[Node, str]) -> bool
if isinstance(v, Node):
v = v.id

if self.directed_edge_exists(u, v):
if visited is None:
visited = set()

# If already visited, return False to avoid infinite loop
if u in visited:
return False

# Mark the node as visited
visited.add(u)

# Direct edge check
if self.edge_of_type_exists(u, v, DirectedEdge()):
return True
for w in self.edges[u]:
if self.directed_path_exists(self.nodes[w], v):

# Recursive DFS through neighbors
for w in self.edges.get(u, []): # Use .get() to avoid KeyError if u is not in self.edges
if self.edge_of_type_exists(u, w, DirectedEdge()) and self.directed_path_exists(w, v, visited):
return True

return False

def edge_of_type_exists(
Expand Down
54 changes: 54 additions & 0 deletions tests/liliths_wrapper_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from typing import Generic, List, Optional

from causy.causal_discovery.constraint.orientation_rules.pc import FurtherOrientQuadrupleTest, ColliderTest
from causy.interfaces import PipelineStepInterface, PipelineStepInterfaceType, BaseGraphInterface, TestResultInterface, \
logger


def actual_fn(a,b):
return a + b


def wrapper_fn(a,b):
result = actual_fn(a, b)
return result * -1



class CustomOrientationRule(PipelineStepInterface[PipelineStepInterfaceType],
Generic[PipelineStepInterfaceType]):

_inner_step_cls = ColliderTest

def __int__(self, *args, **kwargs):
self.inner_step = self._inner_step_cls(*args, **kwargs)
super().__init__(*args, **kwargs)

def process(self, nodes: List[str], graph: BaseGraphInterface,
unapplied_actions: Optional[List[TestResultInterface]] = None) -> Optional[TestResultInterface]:
result = self.inner_step.process(nodes, graph, unapplied_actions=unapplied_actions)

# get all unshielded triples


for proposed_action in result.all_proposed_actions:
if "separatedBy" in proposed_action.data:
return result

def __call__(
self,
nodes: List[str],
graph: BaseGraphInterface,
unapplied_actions: Optional[List[TestResultInterface]] = None,
) -> Optional[TestResultInterface]:
if self.needs_unapplied_actions and unapplied_actions is None:
logger.warn(
f"Pipeline step {self.name} needs unapplied actions but none were provided"
)
elif self.needs_unapplied_actions and unapplied_actions is not None:
return self.process(nodes, graph, unapplied_actions)

return self.process(nodes, graph)



51 changes: 11 additions & 40 deletions tests/test_effect_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
PC_ORIENTATION_RULES,
PC_EDGE_TYPES,
PC_GRAPH_UI_EXTENSION,
PC_DEFAULT_THRESHOLD,
PC_DEFAULT_THRESHOLD, PC,
)
from causy.causal_discovery.constraint.independence_tests.common import (
CorrelationCoefficientTest,
Expand Down Expand Up @@ -85,47 +85,18 @@ def test_direct_effect_estimation_trivial_case(self):
)

def test_direct_effect_estimation_basic_example(self):
PC = graph_model_factory(
Algorithm(
pipeline_steps=[
CalculatePearsonCorrelations(
display_name="Calculate Pearson Correlations"
),
CorrelationCoefficientTest(
threshold=VariableReference(name="threshold"),
display_name="Correlation Coefficient Test",
),
PartialCorrelationTest(
threshold=VariableReference(name="threshold"),
display_name="Partial Correlation Test",
),
ExtendedPartialCorrelationTestMatrix(
threshold=VariableReference(name="threshold"),
display_name="Extended Partial Correlation Test Matrix",
),
*PC_ORIENTATION_RULES,
ComputeDirectEffectsInDAGsMultivariateRegression(
display_name="Compute Direct Effects"
),
],
edge_types=PC_EDGE_TYPES,
extensions=[PC_GRAPH_UI_EXTENSION],
name="PC",
variables=[FloatVariable(name="threshold", value=PC_DEFAULT_THRESHOLD)],
)
)

model = IIDSampleGenerator(
edges=[
SampleEdge(NodeReference("X"), NodeReference("Z"), 5),
SampleEdge(NodeReference("Y"), NodeReference("Z"), 6),
SampleEdge(NodeReference("Z"), NodeReference("V"), 3),
SampleEdge(NodeReference("Z"), NodeReference("W"), 4),
SampleEdge(NodeReference("X"), NodeReference("Z"), 1),
SampleEdge(NodeReference("Y"), NodeReference("Z"), 1),
SampleEdge(NodeReference("Z"), NodeReference("V"), 1),
SampleEdge(NodeReference("Z"), NodeReference("W"), 1),
],
)

tst = PC()
sample_size = 1000000
sample_size = 10000
test_data, graph = model.generate(sample_size)
tst.create_graph_from_data(test_data)
tst.create_all_possible_edges()
Expand All @@ -137,30 +108,30 @@ def test_direct_effect_estimation_basic_example(self):
tst.graph.edge_value(tst.graph.nodes["X"], tst.graph.nodes["Z"])[
"direct_effect"
],
5.0,
1.0,
0,
)
self.assertAlmostEqual(
tst.graph.edge_value(tst.graph.nodes["Y"], tst.graph.nodes["Z"])[
"direct_effect"
],
6.0,
1.0,
0,
)

self.assertAlmostEqual(
tst.graph.edge_value(tst.graph.nodes["Z"], tst.graph.nodes["V"])[
"direct_effect"
],
3.0,
1.0,
0,
)

self.assertAlmostEqual(
tst.graph.edge_value(tst.graph.nodes["Z"], tst.graph.nodes["W"])[
"direct_effect"
],
4.0,
1.0,
0,
)

Expand Down Expand Up @@ -205,7 +176,7 @@ def test_direct_effect_estimation_partially_directed(self):
)

tst = PC()
sample_size = 100000
sample_size = 10000
test_data, graph = model.generate(sample_size)
tst.create_graph_from_data(test_data)
tst.create_all_possible_edges()
Expand Down
Loading
Loading