Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 3 additions & 17 deletions causal_testing/testing/causal_test_adequacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
from itertools import combinations

import pandas as pd
from lifelines.exceptions import ConvergenceError
from numpy.linalg import LinAlgError

from causal_testing.specification.causal_dag import CausalDAG
from causal_testing.testing.causal_test_case import CausalTestCase
Expand Down Expand Up @@ -104,21 +102,9 @@ def measure_adequacy(self):
estimator.df = estimator.df[estimator.df[self.group_by].isin(ids)]
else:
estimator.df = estimator.df.sample(len(estimator.df), replace=True, random_state=i)
try:
result = self.test_case.execute_test(estimator)
outcomes.append(self.test_case.expected_causal_effect.apply(result))
results.append(result.effect_estimate.to_df())
except LinAlgError:
logger.warning("Adequacy LinAlgError")
continue
except ConvergenceError:
logger.warning("Adequacy ConvergenceError")
continue
except ValueError as e:
logger.warning(f"Adequacy ValueError: {e}")
continue
# outcomes = [self.test_case.expected_causal_effect.apply(c) for c in results]
# results = pd.concat([c.effect_estimate.to_df() for c in results])
result = self.test_case.execute_test(estimator)
outcomes.append(self.test_case.expected_causal_effect.apply(result))
results.append(result.effect_estimate.to_df())
results = pd.concat(results)
results["var"] = results.index
results["passed"] = outcomes
Expand Down
51 changes: 34 additions & 17 deletions causal_testing/testing/metamorphic_relation.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,38 +33,62 @@ def __eq__(self, other):
same_adjustment_set = set(self.adjustment_vars) == set(other.adjustment_vars)
return same_type and same_treatment and same_outcome and same_effect and same_adjustment_set


class ShouldCause(MetamorphicRelation):
"""Class representing a should cause metamorphic relation."""

def to_json_stub(
self,
skip: bool = False,
estimate_type: str = "coefficient",
effect_type: str = "direct",
estimator: str = "LinearRegressionEstimator",
alpha: float = 0.05,
) -> dict:
"""
Convert to a JSON frontend stub string for user customisation.
:param skip: Whether to skip the test (default False).
:param effect_type: The type of causal effect to consider (total or direct)
:param estimate_type: The estimate type to use when evaluating tests
:param estimator: The name of the estimator class to use when evaluating the test
:param alpha: The significance level to use when calculating the confidence intervals
"""
return {
"name": str(self),
"estimator": estimator,
"estimate_type": estimate_type,
"effect": effect_type,
"treatment_variable": self.base_test_case.treatment_variable,
"expected_effect": {self.base_test_case.outcome_variable: "SomeEffect"},
"formula": (
f"{self.base_test_case.outcome_variable} ~ "
f"{' + '.join([self.base_test_case.treatment_variable] + self.adjustment_vars)}"
),
"alpha": alpha,
"skip": skip,
}


class ShouldCause(MetamorphicRelation):
"""Class representing a should cause metamorphic relation."""

def to_json_stub(
self,
skip: bool = False,
estimate_type: str = "coefficient",
effect_type: str = "direct",
estimator: str = "LinearRegressionEstimator",
alpha: float = 0.05,
) -> dict:
"""
Convert to a JSON frontend stub string for user customisation.
:param skip: Whether to skip the test (default False).
:param effect_type: The type of causal effect to consider (total or direct)
:param estimate_type: The estimate type to use when evaluating tests
:param estimator: The name of the estimator class to use when evaluating the test
:param alpha: The significance level to use when calculating the confidence intervals
"""
return super().to_json_stub(
skip=skip, estimate_type=estimate_type, effect_type=effect_type, estimator=estimator, alpha=alpha
) | {
"expected_effect": {self.base_test_case.outcome_variable: "SomeEffect"},
}

def __str__(self):
formatted_str = f"{self.base_test_case.treatment_variable} --> {self.base_test_case.outcome_variable}"
if self.adjustment_vars:
Expand All @@ -81,27 +105,20 @@ def to_json_stub(
estimate_type: str = "coefficient",
effect_type: str = "direct",
estimator: str = "LinearRegressionEstimator",
alpha: float = 0.05,
) -> dict:
"""
Convert to a JSON frontend stub string for user customisation.
:param skip: Whether to skip the test (default False).
:param effect_type: The type of causal effect to consider (total or direct)
:param estimate_type: The estimate type to use when evaluating tests
:param estimator: The name of the estimator class to use when evaluating the test
:param alpha: The significance level to use when calculating the confidence intervals
"""
return {
"name": str(self),
"estimator": estimator,
"estimate_type": estimate_type,
"effect": effect_type,
"treatment_variable": self.base_test_case.treatment_variable,
return super().to_json_stub(
skip=skip, estimate_type=estimate_type, effect_type=effect_type, estimator=estimator, alpha=alpha
) | {
"expected_effect": {self.base_test_case.outcome_variable: "NoEffect"},
"formula": (
f"{self.base_test_case.outcome_variable} ~ "
f"{' + '.join([self.base_test_case.treatment_variable] + self.adjustment_vars)}"
),
"alpha": 0.05,
"skip": skip,
}

def __str__(self):
Expand Down
10 changes: 7 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ dependencies = [
"lifelines~=0.30.0",
"networkx>=3.4,<3.5",
"numpy>=1.26.0,<=2.2.0",
"pandas>=2.1",
"pandas>=2.1,<3",
"scikit_learn~=1.4",
"scipy>=1.12.0,<=1.16.2",
"statsmodels~=0.14",
Expand Down Expand Up @@ -100,8 +100,12 @@ skip_missing_interpreters = false # fail if devs don’t have all required Pytho
description = "Run pytest under {base_python}"
extras = ["dev","test"]
deps = ["pytest"]
commands = [["pytest"]]

commands = [
[
"pytest",
"{posargs:tests}",
],
]
# Automatically test for type-checking (TODO: enable type checking in env_list in the future)
[tool.tox.env.type]
description = "Run type checks with mypy on the codebase"
Expand Down
2 changes: 2 additions & 0 deletions tests/testing_tests/test_metamorphic_relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def test_should_cause_json_stub(self):
"formula": "Z ~ X1",
"treatment_variable": "X1",
"name": "X1 --> Z",
"alpha": 0.05,
"skip": False,
},
)
Expand All @@ -130,6 +131,7 @@ def test_should_cause_logistic_json_stub(self):
"formula": "Z ~ X1",
"treatment_variable": "X1",
"name": "X1 --> Z",
"alpha": 0.05,
"skip": False,
},
)
Expand Down