diff --git a/causy/causal_discovery/constraint/algorithms/pc.py b/causy/causal_discovery/constraint/algorithms/pc.py index f454484..023a99a 100644 --- a/causy/causal_discovery/constraint/algorithms/pc.py +++ b/causy/causal_discovery/constraint/algorithms/pc.py @@ -1,3 +1,7 @@ +from causy.causal_discovery.constraint.independence_tests.conditional_independence_calculations import ( + PearsonStudentsTTest, + FishersZTest, +) from causy.causal_effect_estimation.multivariate_regression import ( ComputeDirectEffectsMultivariateRegression, ) @@ -19,6 +23,7 @@ ) from causy.common_pipeline_steps.calculation import ( CalculatePearsonCorrelations, + CalculateEdgeCorrelations, ) from causy.interfaces import AS_MANY_AS_FIELDS from causy.models import ComparisonSettings, Algorithm @@ -33,6 +38,8 @@ FloatVariable, VariableReference, IntegerVariable, + CausyObjectParameter, + CausyObjectVariable, ) PC_DEFAULT_THRESHOLD = 0.005 @@ -63,18 +70,32 @@ PC = graph_model_factory( Algorithm( pipeline_steps=[ - CalculatePearsonCorrelations(display_name="Calculate Pearson Correlations"), + CalculateEdgeCorrelations( + display_name="Calculate Edge Correlations", + conditional_independence_test=VariableReference( + name="conditional_independence_test" + ), + ), CorrelationCoefficientTest( threshold=VariableReference(name="threshold"), display_name="Correlation Coefficient Test", + conditional_independence_test=VariableReference( + name="conditional_independence_test" + ), ), PartialCorrelationTest( threshold=VariableReference(name="threshold"), display_name="Partial Correlation Test", + conditional_independence_test=VariableReference( + name="conditional_independence_test" + ), ), ExtendedPartialCorrelationTestMatrix( threshold=VariableReference(name="threshold"), display_name="Extended Partial Correlation Test Matrix", + conditional_independence_test=VariableReference( + name="conditional_independence_test" + ), ), *PC_ORIENTATION_RULES, ComputeDirectEffectsMultivariateRegression( @@ -84,7 +105,12 @@ edge_types=PC_EDGE_TYPES, extensions=[PC_GRAPH_UI_EXTENSION], name="PC", - variables=[FloatVariable(name="threshold", value=PC_DEFAULT_THRESHOLD)], + variables=[ + FloatVariable(name="threshold", value=PC_DEFAULT_THRESHOLD), + CausyObjectVariable( + name="conditional_independence_test", value=PearsonStudentsTTest() + ), + ], ) ) diff --git a/causy/causal_discovery/constraint/independence_tests/common.py b/causy/causal_discovery/constraint/independence_tests/common.py index 55101ee..df2dd9e 100644 --- a/causy/causal_discovery/constraint/independence_tests/common.py +++ b/causy/causal_discovery/constraint/independence_tests/common.py @@ -4,6 +4,11 @@ import torch +from causy.causal_discovery.constraint.independence_tests.conditional_independence_calculations import ( + FishersZTest, + PearsonStudentsTTest, + ConditionalIndependenceTestInterface, +) from causy.generators import AllCombinationsGenerator, PairsWithNeighboursGenerator from causy.math_utils import get_t_and_critical_t from causy.interfaces import ( @@ -15,7 +20,7 @@ PipelineStepInterfaceType, ) from causy.models import ComparisonSettings, TestResultAction, TestResult -from causy.variables import IntegerParameter, BoolParameter +from causy.variables import IntegerParameter, BoolParameter, CausyObjectParameter logger = logging.getLogger(__name__) @@ -28,6 +33,7 @@ class CorrelationCoefficientTest( ) chunk_size_parallel_processing: IntegerParameter = 1 parallel: BoolParameter = False + conditional_independence_test: CausyObjectParameter = PearsonStudentsTTest() def process( self, nodes: List[str], graph: BaseGraphInterface @@ -40,14 +46,9 @@ def process( x = graph.nodes[nodes[0]] y = graph.nodes[nodes[1]] - # make t test for independency of u and v - sample_size = len(x.values) - nb_of_control_vars = 0 - corr = graph.edge_value(x, y)["correlation"] - t, critical_t = get_t_and_critical_t( - sample_size, nb_of_control_vars, corr, self.threshold - ) - if abs(t) < critical_t: + if self.conditional_independence_test.test( + graph, x.name, y.name, [], self.threshold + ): logger.debug(f"Nodes {x.name} and {y.name} are uncorrelated") return TestResult( u=x, @@ -65,6 +66,7 @@ class PartialCorrelationTest( ) chunk_size_parallel_processing: IntegerParameter = 1 parallel: BoolParameter = False + conditional_independence_test: CausyObjectParameter = PearsonStudentsTTest() def process( self, nodes: Tuple[str], graph: BaseGraphInterface @@ -92,30 +94,9 @@ def process( if not graph.edge_exists(x, y) or (y, x) in already_deleted_edges: continue - try: - cor_xy = graph.edge_value(x, y)["correlation"] - cor_xz = graph.edge_value(x, z)["correlation"] - cor_yz = graph.edge_value(y, z)["correlation"] - except (KeyError, TypeError): - return - - numerator = cor_xy - cor_xz * cor_yz - denominator = ((1 - cor_xz**2) * (1 - cor_yz**2)) ** 0.5 - - # Avoid division by zero - if denominator == 0: - return - - par_corr = numerator / denominator - - # make t test for independency of u and v given z - sample_size = len(x.values) - nb_of_control_vars = len(nodes) - 2 - t, critical_t = get_t_and_critical_t( - sample_size, nb_of_control_vars, par_corr, self.threshold - ) - - if abs(t) < critical_t: + if self.conditional_independence_test.test( + graph, x.name, y.name, [z.name], self.threshold + ): logger.debug( f"Nodes {x.name} and {y.name} are uncorrelated given {z.name}" ) @@ -142,6 +123,7 @@ class ExtendedPartialCorrelationTestMatrix( ) chunk_size_parallel_processing: IntegerParameter = 1000 parallel: BoolParameter = False + conditional_independence_test: CausyObjectParameter = PearsonStudentsTTest() def process( self, nodes: List[str], graph: BaseGraphInterface @@ -170,48 +152,10 @@ def process( if not set(nodes[2:]).issubset(set([on for on in list(other_neighbours)])): return - cov_matrix = torch.cov( - torch.stack([graph.nodes[node].values for node in nodes]) - ) - # check if the covariance matrix is ill-conditioned - if torch.det(cov_matrix) == 0: - logger.warning( - "The covariance matrix is ill-conditioned. The precision matrix is not reliable." - ) - return - - inverse_cov_matrix = torch.inverse(cov_matrix) - - n = inverse_cov_matrix.size(0) - diagonal = torch.diag(inverse_cov_matrix) - diagonal_matrix = torch.zeros((n, n), dtype=torch.float64) - for i in range(n): - diagonal_matrix[i, i] = diagonal[i] - helper = torch.mm(torch.sqrt(diagonal_matrix), inverse_cov_matrix) - precision_matrix = torch.mm(helper, torch.sqrt(diagonal_matrix)) - - sample_size = len(graph.nodes[nodes[0]].values) - nb_of_control_vars = len(nodes) - 2 - - # prevent math domain error - try: - t, critical_t = get_t_and_critical_t( - sample_size, - nb_of_control_vars, - ( - (-1 * precision_matrix[0][1]) - / torch.sqrt(precision_matrix[0][0] * precision_matrix[1][1]) - ).item(), - self.threshold, - ) - except ValueError: - logger.warning( - "Math domain error. The covariance matrix is ill-conditioned. The precision matrix is not reliable." - ) - return - - if abs(t) < critical_t: + if self.conditional_independence_test.test( + graph, nodes[0], nodes[1], nodes[2:], self.threshold + ): logger.debug( f"Nodes {graph.nodes[nodes[0]].name} and {graph.nodes[nodes[1]].name} are uncorrelated given nodes {','.join([graph.nodes[on].name for on in other_neighbours])}" ) diff --git a/causy/causal_discovery/constraint/independence_tests/conditional_independence_calculations.py b/causy/causal_discovery/constraint/independence_tests/conditional_independence_calculations.py new file mode 100644 index 0000000..ec1523a --- /dev/null +++ b/causy/causal_discovery/constraint/independence_tests/conditional_independence_calculations.py @@ -0,0 +1,196 @@ +from abc import ABC +from typing import List, TypeVar, Generic + +import torch +from pydantic import BaseModel, computed_field + +from causy.graph import Graph, Node, logger +from causy.graph_utils import serialize_module_name +from causy.math_utils import get_t_and_critical_t + + +def invert_matrix(matrix: torch.Tensor) -> torch.Tensor: + if torch.det(matrix) == 0: + logger.warning( + "The covariance matrix is ill-conditioned. The precision matrix is not reliable." + ) + return torch.linalg.pinv(matrix) + else: + return torch.inverse(matrix) + + +ConditionalIndependenceTestInterfaceType = TypeVar( + "ConditionalIndependenceTestInterfaceType" +) + + +class ConditionalIndependenceTestInterface( + ABC, BaseModel, Generic[ConditionalIndependenceTestInterfaceType] +): + @computed_field + @property + def name(self) -> str: + return serialize_module_name(self) + + @staticmethod + def calculate_correlation(x: Node, y: Node, z: List[Node]) -> torch.Tensor: + raise NotImplementedError + + @staticmethod + def test(graph: Graph, x: str, y: str, z: List[str], threshold: float) -> bool: + raise NotImplementedError + + +class PearsonStudentsTTest( + ConditionalIndependenceTestInterface[ConditionalIndependenceTestInterfaceType], + Generic[ConditionalIndependenceTestInterfaceType], +): + @staticmethod + def calculate_correlation(x: Node, y: Node, z: List[Node]) -> torch.Tensor: + """ + Calculate the correlation between two nodes x and y given a list of control variables z. + It returns a tensor with the t-value and the critical t-value. + It uses the Pearson's t-test for the correlation coefficient. + :param x: + :param y: + :param z: + :param threshold: + :return: + """ + + if len(z) == 0: + cov_xy = torch.mean( + (x.values - x.values.mean()) * (y.values - y.values.mean()) + ) + std_x = x.values.std(unbiased=False) + std_y = y.values.std(unbiased=False) + pearson_correlation = cov_xy / (std_x * std_y) + + correlation = pearson_correlation.item() + + # Clamp the correlation to -1 and 1 to avoid numerical errors + if correlation < -1: + correlation = -1 + elif correlation > 1: + correlation = 1 + + return torch.tensor(correlation) + + cov_matrix = torch.cov( + torch.stack([x.values, y.values, *[zi.values for zi in z]]) + ) + # check if the covariance matrix is ill-conditioned + inverse_cov_matrix = invert_matrix(cov_matrix) + + n = inverse_cov_matrix.size(0) + diagonal = torch.diag(inverse_cov_matrix) + diagonal_matrix = torch.zeros((n, n), dtype=torch.float64) + for i in range(n): + diagonal_matrix[i, i] = diagonal[i] + + helper = torch.mm(torch.sqrt(diagonal_matrix), inverse_cov_matrix) + precision_matrix = torch.mm(helper, torch.sqrt(diagonal_matrix)) + + return (-1 * precision_matrix[0][1]) / torch.sqrt( + precision_matrix[0][0] * precision_matrix[1][1] + ) + + @staticmethod + def test(graph: Graph, x: str, y: str, z: List[str], threshold: float) -> bool: + """ + :param graph: + :param x: + :param y: + :param z: + :return: + """ + x = graph.nodes[x] + y = graph.nodes[y] + z = [graph.nodes[zi] for zi in z] + + res = None + + if len(z) == 0: + edge = graph.edge_value(x, y) + if edge is not None and "correlation" in edge: + res = torch.tensor(edge["correlation"]) + + if res is None: + res = PearsonStudentsTTest.calculate_correlation(x, y, z) + + sample_size = len(x.values) + nb_of_control_vars = len(z) + + # prevent math domain error + try: + t, critical_t = get_t_and_critical_t( + sample_size, nb_of_control_vars, res.item(), threshold + ) + except ValueError: + logger.warning( + "Math domain error. The covariance matrix is ill-conditioned. The precision matrix is not reliable." + ) + return None + + return abs(t) < critical_t + + +class FishersZTest( + ConditionalIndependenceTestInterface[ConditionalIndependenceTestInterfaceType], + Generic[ConditionalIndependenceTestInterfaceType], +): + @staticmethod + def calculate_correlation(x: Node, y: Node, z: List[Node]) -> torch.Tensor: + if len(z) == 0: + r = torch.corrcoef(torch.stack([x.values, y.values]))[0, 1] + else: + sub_corr = torch.corrcoef( + torch.stack([x.values, y.values, *[zi.values for zi in z]]) + ) + r = invert_matrix(sub_corr) + r = -1 * r[0, 1] / torch.sqrt(abs(r[0, 0] * r[1, 1])) + + cut_at = torch.tensor(0.99999) + r = torch.min(cut_at, torch.max(-1 * cut_at, r)) # make r between -1 and 1 + + res = torch.sqrt( + torch.tensor(len(x.values)) - torch.tensor(len(z)) - 3 + ) * torch.atanh(r) + p = 2 * (1 - torch.distributions.Normal(0, 1).cdf(res)) + + return p + + @staticmethod + def test(graph: Graph, x: str, y: str, z: List[str], threshold: float) -> bool: + """ + :param graph: + :param x: + :param y: + :param z: + :return: + """ + x = graph.nodes[x] + y = graph.nodes[y] + z = [graph.nodes[zi] for zi in z] + + p = FishersZTest.calculate_correlation(x, y, z) + p = p.item() + return p < threshold + + +class ChiSquareTest( + ConditionalIndependenceTestInterface[ConditionalIndependenceTestInterfaceType], + Generic[ConditionalIndependenceTestInterfaceType], +): + @staticmethod + def test(graph: Graph, x: str, y: str, z: List[str], threshold: float) -> bool: + return True + + +class G2Test( + ConditionalIndependenceTestInterface[ConditionalIndependenceTestInterfaceType], + Generic[ConditionalIndependenceTestInterfaceType], +): + @staticmethod + def test(graph: Graph, x: str, y: str, z: List[str], threshold: float) -> bool: + return True diff --git a/causy/common_pipeline_steps/calculation.py b/causy/common_pipeline_steps/calculation.py index 60e0ef5..ba23ac5 100644 --- a/causy/common_pipeline_steps/calculation.py +++ b/causy/common_pipeline_steps/calculation.py @@ -2,6 +2,9 @@ import torch +from causy.causal_discovery.constraint.independence_tests.conditional_independence_calculations import ( + PearsonStudentsTTest, +) from causy.generators import AllCombinationsGenerator from causy.interfaces import ( PipelineStepInterface, @@ -10,9 +13,10 @@ PipelineStepInterfaceType, ) from causy.models import ComparisonSettings, TestResultAction, TestResult +from causy.variables import CausyObjectParameter -class CalculatePearsonCorrelations( +class CalculateEdgeCorrelations( PipelineStepInterface[PipelineStepInterfaceType], Generic[PipelineStepInterfaceType] ): generator: Optional[GeneratorInterface] = AllCombinationsGenerator( @@ -20,34 +24,49 @@ class CalculatePearsonCorrelations( ) chunk_size_parallel_processing: int = 1 parallel: bool = False + conditional_independence_test: CausyObjectParameter = PearsonStudentsTTest() def process(self, nodes: Tuple[str], graph: BaseGraphInterface) -> TestResult: """ - Calculate the correlation between each pair of nodes and store it to the respective edge. + Test if u and v are independent and delete edge in graph if they are. :param nodes: list of nodes :return: A TestResult with the action to take """ x = graph.nodes[nodes[0]] y = graph.nodes[nodes[1]] edge_value = graph.edge_value(graph.nodes[nodes[0]], graph.nodes[nodes[1]]) + correlation = self.conditional_independence_test.calculate_correlation(x, y, []) + edge_value["correlation"] = correlation.item() + return TestResult( + u=x, + v=y, + action=TestResultAction.UPDATE_EDGE, + data=edge_value, + ) - x_val = x.values - y_val = y.values - cov_xy = torch.mean((x_val - x_val.mean()) * (y_val - y_val.mean())) - std_x = x_val.std(unbiased=False) - std_y = y_val.std(unbiased=False) - pearson_correlation = cov_xy / (std_x * std_y) +class CalculatePearsonCorrelations( + PipelineStepInterface[PipelineStepInterfaceType], Generic[PipelineStepInterfaceType] +): + generator: Optional[GeneratorInterface] = AllCombinationsGenerator( + comparison_settings=ComparisonSettings(min=2, max=2) + ) + chunk_size_parallel_processing: int = 1 + parallel: bool = False - correlation = pearson_correlation.item() + def process(self, nodes: Tuple[str], graph: BaseGraphInterface) -> TestResult: + """ + Calculate the correlation between each pair of nodes and store it to the respective edge. + :param nodes: list of nodes + :return: A TestResult with the action to take + """ + x = graph.nodes[nodes[0]] + y = graph.nodes[nodes[1]] + edge_value = graph.edge_value(graph.nodes[nodes[0]], graph.nodes[nodes[1]]) - # Clamp the correlation to -1 and 1 to avoid numerical errors - if correlation < -1: - correlation = -1 - elif correlation > 1: - correlation = 1 + correlation = PearsonStudentsTTest.calculate_correlation(x, y, []) - edge_value["correlation"] = correlation + edge_value["correlation"] = correlation.item() return TestResult( u=x, diff --git a/causy/graph_model.py b/causy/graph_model.py index b371274..94cea2c 100644 --- a/causy/graph_model.py +++ b/causy/graph_model.py @@ -502,6 +502,9 @@ def graph_model_factory( :return: the graph model """ original_algorithm = deepcopy(algorithm) + if variables is not None and len(variables) == 0: + variables = None + if variables is None and algorithm.variables is not None: variables = resolve_variables(algorithm.variables, {}) elif variables is None: diff --git a/causy/graph_utils.py b/causy/graph_utils.py index b49c6ca..61ff258 100644 --- a/causy/graph_utils.py +++ b/causy/graph_utils.py @@ -3,6 +3,8 @@ import json from typing import List, Tuple, Dict +from pydantic import BaseModel + from causy.variables import deserialize_variable_references @@ -52,12 +54,21 @@ def retrieve_edges(graph) -> List[Tuple[str, str]]: return edges +class BaseModelEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, BaseModel): + return obj.dict() + # Let the base class default method raise the TypeError + return super().default(obj) + + def hash_dictionary(dct: Dict): """ Hash a dictionary using SHA256 (e.g. for caching) :param dct: :return: """ + return hashlib.sha256( json.dumps( dct, @@ -65,5 +76,6 @@ def hash_dictionary(dct: Dict): sort_keys=True, indent=None, separators=(",", ":"), + cls=BaseModelEncoder, ).encode() ).hexdigest() diff --git a/causy/interfaces.py b/causy/interfaces.py index 2d81c4f..a33d1ef 100644 --- a/causy/interfaces.py +++ b/causy/interfaces.py @@ -17,6 +17,7 @@ IntegerParameter, BoolParameter, FloatParameter, + CausyObjectParameter, ) logger = logging.getLogger(__name__) @@ -25,9 +26,15 @@ AS_MANY_AS_FIELDS = 0 -MetadataBaseType = Union[str, int, float, bool] +MetadataBaseType = Union[str, int, float, bool, BaseModel] MetadataType = Union[ - str, int, float, bool, List[MetadataBaseType], Dict[str, MetadataBaseType] + str, + int, + float, + bool, + List[MetadataBaseType], + Dict[str, MetadataBaseType], + BaseModel, ] @@ -350,6 +357,7 @@ class PipelineStepInterface(ABC, BaseModel, Generic[PipelineStepInterfaceType]): display_name: Optional[StringParameter] = None needs_unapplied_actions: Optional[BoolParameter] = False + conditional_independence_test: Optional[CausyObjectParameter] = None def __init__( self, @@ -379,6 +387,9 @@ def __init__( if threshold: self.threshold = threshold + for key, value in kwargs.items(): + setattr(self, key, value) + @computed_field @property def name(self) -> str: diff --git a/causy/serialization.py b/causy/serialization.py index b9f91ba..87e095f 100644 --- a/causy/serialization.py +++ b/causy/serialization.py @@ -83,7 +83,7 @@ def load_algorithm_by_reference(reference_type: str, algorithm: str): st_function = getattr(module, ref_) if not st_function: raise ValueError(f"Algorithm {algorithm} not found") - return st_function()._original_algorithm + return copy.deepcopy(st_function()._original_algorithm) class CausyJSONEncoder(JSONEncoder): diff --git a/causy/variables.py b/causy/variables.py index bf59625..32dfa18 100644 --- a/causy/variables.py +++ b/causy/variables.py @@ -8,7 +8,7 @@ VariableInterfaceType = TypeVar("VariableInterfaceType") -VariableType = Union[str, int, float, bool] +VariableType = Union[str, int, float, bool, BaseModel] class VariableTypes(enum.Enum): @@ -16,6 +16,7 @@ class VariableTypes(enum.Enum): Integer = "integer" Float = "float" Bool = "bool" + CausyObject = "causy_object" # this is a reference to any object that is a subclass of BaseModel used to reference pipeline/custom objects class BaseVariable(BaseModel, Generic[VariableInterfaceType]): @@ -29,8 +30,8 @@ def __init__(self, **data): self.validate_value(self.value) name: str - value: Union[str, int, float, bool] - choices: Optional[List[Union[str, int, float, bool]]] = None + value: Union[str, int, float, bool, BaseModel] + choices: Optional[List[Union[str, int, float, bool, BaseModel]]] = None def is_valid(self): return self.is_valid_value(self.value) @@ -124,6 +125,21 @@ class BoolVariable(BaseVariable[VariableInterfaceType], Generic[VariableInterfac _PYTHON_TYPE: Optional[type] = bool +class CausyObjectVariable( + BaseVariable[VariableInterfaceType], Generic[VariableInterfaceType] +): + """ + Represents a single causy object variable. + causy object is a reference to any object that is a subclass of BaseModel used to reference pipeline/custom objects + """ + + value: BaseModel + name: str + + _TYPE: str = VariableTypes.CausyObject.value + _PYTHON_TYPE: Optional[type] = BaseModel + + class VariableReference(BaseModel, Generic[VariableInterfaceType]): """ Represents a reference to a variable. @@ -142,13 +158,21 @@ def type(self) -> str: VariableTypes.Integer.value: IntegerVariable, VariableTypes.Float.value: FloatVariable, VariableTypes.Bool.value: BoolVariable, + VariableTypes.CausyObject.value: CausyObjectVariable, } BoolParameter = Union[bool, VariableReference] IntegerParameter = Union[int, VariableReference] FloatParameter = Union[float, VariableReference] StringParameter = Union[str, VariableReference] -CausyParameter = Union[BoolParameter, IntegerParameter, FloatParameter, StringParameter] +CausyObjectParameter = Union[BaseModel, VariableReference] +CausyParameter = Union[ + BoolParameter, + IntegerParameter, + FloatParameter, + StringParameter, + CausyObjectParameter, +] def validate_variable_values(algorithm, variable_values: Dict[str, VariableType]): @@ -256,7 +280,10 @@ def deserialize_variable_references(element: object) -> object: if isinstance(value, dict) and "type" in value and value["type"] == "reference": setattr(element, attribute, VariableReference(name=value["name"])) - if hasattr(value, "__dict__"): + if isinstance(element, enum.Enum): + setattr(element, attribute, value) + + if not isinstance(element, enum.Enum) and hasattr(value, "__dict__"): setattr(element, attribute, deserialize_variable_references(value)) if hasattr(element, "pipeline_steps"):