From 27201223d028229169adb9d17c1dcd31f68122e7 Mon Sep 17 00:00:00 2001 From: Michael Foster Date: Wed, 29 Apr 2026 11:55:17 +0100 Subject: [PATCH 01/13] Estimators now as entrypoints --- causal_testing/main.py | 16 ++++++++++------ pyproject.toml | 7 +++++++ 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/causal_testing/main.py b/causal_testing/main.py index 19a7f3d0..fc9c9079 100644 --- a/causal_testing/main.py +++ b/causal_testing/main.py @@ -7,6 +7,7 @@ from enum import Enum from pathlib import Path from typing import Any, Dict, List, Optional, Sequence, Union +from importlib.metadata import entry_points import numpy as np import pandas as pd @@ -267,17 +268,20 @@ def create_causal_test(self, test: dict, base_test: BaseTestCase) -> CausalTestC "Negative": Negative(), } - # Map estimator string to estimator class - estimator_map = { - "LinearRegressionEstimator": LinearRegressionEstimator, - "LogisticRegressionEstimator": LogisticRegressionEstimator, - } + estimator_map = {ff.name: ff for ff in entry_points(group="estimators")} if "estimator" not in test: raise ValueError("Test configuration must specify an estimator") + if test["estimator"] not in estimator_map: + print( + f"Unsupported estimator {estimator}. Supported: {sorted(estimators)}" + "If you have implemented a custom estimator, you will need to add this to your entrypoints via your " + "pyproject.toml file." + ) + # Get the estimator class - estimator_class = estimator_map.get(test["estimator"]) + estimator_class = estimator_map.get(test["estimator"]).load() if estimator_class is None: raise ValueError(f"Unknown estimator: {test['estimator']}") diff --git a/pyproject.toml b/pyproject.toml index d7da1465..d65c02ad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,6 +65,13 @@ Documentation = "https://causal-testing-framework.readthedocs.io/" Repository = "https://github.com/CITCOM-project/CausalTestingFramework" Issues = "https://github.com/CITCOM-project/CausalTestingFramework/issues" +[project.entry-points."estimators"] +LinearRegressionEstimator = "causal_testing.estimation.linear_regression_estimator:LinearRegressionEstimator" +LogisticRegressionEstimator = "causal_testing.estimation.linear_regression_estimator:LogisticRegressionEstimator" +CubicSplineEstimator = "causal_testing.estimation.linear_regression_estimator:CubicSplineEstimator" +InstrumentalVariableEstimator = "causal_testing.estimation.linear_regression_estimator:InstrumentalVariableEstimator" +IPCWEstimator = "causal_testing.estimation.linear_regression_estimator:IPCWEstimator" + [tool.setuptools.packages] find = {} From 3651a2f8e9f76dfa3029061b336decbe6fd4f59c Mon Sep 17 00:00:00 2001 From: Michael Foster Date: Wed, 29 Apr 2026 12:04:29 +0100 Subject: [PATCH 02/13] Causal effects now as entry points --- causal_testing/main.py | 21 ++++++++++----------- pyproject.toml | 7 +++++++ 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/causal_testing/main.py b/causal_testing/main.py index fc9c9079..45fa6bd3 100644 --- a/causal_testing/main.py +++ b/causal_testing/main.py @@ -260,22 +260,15 @@ def create_causal_test(self, test: dict, base_test: BaseTestCase) -> CausalTestC :return: CausalTestCase object :raises: ValueError if invalid estimator or configuration is provided """ - # Map effect string to effect class - effect_map = { - "NoEffect": NoEffect(), - "SomeEffect": SomeEffect(), - "Positive": Positive(), - "Negative": Negative(), - } - estimator_map = {ff.name: ff for ff in entry_points(group="estimators")} + effect_map = {ff.name: ff for ff in entry_points(group="causal_effects")} if "estimator" not in test: raise ValueError("Test configuration must specify an estimator") if test["estimator"] not in estimator_map: - print( - f"Unsupported estimator {estimator}. Supported: {sorted(estimators)}" + raise ValueError( + f"Unsupported estimator {estimator}. Supported: {sorted(estimator_map)}" "If you have implemented a custom estimator, you will need to add this to your entrypoints via your " "pyproject.toml file." ) @@ -322,7 +315,13 @@ def create_causal_test(self, test: dict, base_test: BaseTestCase) -> CausalTestC # Get effect type and create expected effect effect_type = test["expected_effect"][base_test.outcome_variable.name] - expected_effect = effect_map[effect_type] + if test["estimator"] not in estimator_map: + raise ValueError( + f"Unsupported causal effect {effect_type}. Supported: {sorted(effect_map)}" + "If you have implemented a custom causal effect, you will need to add this to your entrypoints via your " + "pyproject.toml file." + ) + expected_effect = effect_map[effect_type].load()() return CausalTestCase( base_test_case=base_test, diff --git a/pyproject.toml b/pyproject.toml index d65c02ad..a755c5c8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,6 +72,13 @@ CubicSplineEstimator = "causal_testing.estimation.linear_regression_estimator:Cu InstrumentalVariableEstimator = "causal_testing.estimation.linear_regression_estimator:InstrumentalVariableEstimator" IPCWEstimator = "causal_testing.estimation.linear_regression_estimator:IPCWEstimator" +[project.entry-points."causal_effects"] +NoEffect = "causal_testing.testing.causal_effect:NoEffect" +SomeEffect = "causal_testing.testing.causal_effect:SomeEffect" +Positive = "causal_testing.testing.causal_effect:Positive" +Negative = "causal_testing.testing.causal_effect:Negative" +ExactValue = "causal_testing.testing.causal_effect:ExactValue" + [tool.setuptools.packages] find = {} From 35d141956ddf6d63676aebec19c4638afedb8fbb Mon Sep 17 00:00:00 2001 From: Michael Foster Date: Wed, 29 Apr 2026 13:16:40 +0100 Subject: [PATCH 03/13] Added support for estimator and effect kwargs --- causal_testing/main.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/causal_testing/main.py b/causal_testing/main.py index 45fa6bd3..c6409fab 100644 --- a/causal_testing/main.py +++ b/causal_testing/main.py @@ -311,6 +311,7 @@ def create_causal_test(self, test: dict, base_test: BaseTestCase) -> CausalTestC formula=test.get("formula"), alpha=test.get("alpha", 0.05), query=combined_query, + **test.get("estimator_kwargs", {}), ) # Get effect type and create expected effect @@ -329,6 +330,7 @@ def create_causal_test(self, test: dict, base_test: BaseTestCase) -> CausalTestC estimate_type=test.get("estimate_type", "ate"), estimate_params=test.get("estimate_params"), estimator=estimator, + **test.get("effect_kwargs", {}), ) def run_tests_in_batches( From dde436c3a7c4917150584b34089a6b92819b5b02 Mon Sep 17 00:00:00 2001 From: Michael Foster Date: Wed, 29 Apr 2026 13:42:22 +0100 Subject: [PATCH 04/13] Pytests --- causal_testing/main.py | 12 ++++-------- tests/main_tests/test_main.py | 34 ++++++++++++++++++++++++++++++---- 2 files changed, 34 insertions(+), 12 deletions(-) diff --git a/causal_testing/main.py b/causal_testing/main.py index c6409fab..456ded52 100644 --- a/causal_testing/main.py +++ b/causal_testing/main.py @@ -268,16 +268,11 @@ def create_causal_test(self, test: dict, base_test: BaseTestCase) -> CausalTestC if test["estimator"] not in estimator_map: raise ValueError( - f"Unsupported estimator {estimator}. Supported: {sorted(estimator_map)}" + f"Unsupported estimator {test['estimator']}. Supported: {sorted(estimator_map)}. " "If you have implemented a custom estimator, you will need to add this to your entrypoints via your " "pyproject.toml file." ) - # Get the estimator class - estimator_class = estimator_map.get(test["estimator"]).load() - if estimator_class is None: - raise ValueError(f"Unknown estimator: {test['estimator']}") - # Handle combined queries (global and test-specific) test_query = test.get("query") combined_query = None @@ -298,6 +293,7 @@ def create_causal_test(self, test: dict, base_test: BaseTestCase) -> CausalTestC filtered_df = self.data.query(combined_query) if combined_query else self.data # Create the estimator with correct parameters + estimator_class = estimator_map.get(test["estimator"]).load() estimator = estimator_class( base_test_case=base_test, treatment_value=test.get("treatment_value"), @@ -316,9 +312,9 @@ def create_causal_test(self, test: dict, base_test: BaseTestCase) -> CausalTestC # Get effect type and create expected effect effect_type = test["expected_effect"][base_test.outcome_variable.name] - if test["estimator"] not in estimator_map: + if effect_type not in effect_map: raise ValueError( - f"Unsupported causal effect {effect_type}. Supported: {sorted(effect_map)}" + f"Unsupported causal effect {effect_type}. Supported: {sorted(effect_map)}. " "If you have implemented a custom causal effect, you will need to add this to your entrypoints via your " "pyproject.toml file." ) diff --git a/tests/main_tests/test_main.py b/tests/main_tests/test_main.py index 232cf0e7..7d7cfb1a 100644 --- a/tests/main_tests/test_main.py +++ b/tests/main_tests/test_main.py @@ -112,12 +112,37 @@ def test_create_base_test_case_missing_estimator(self): framework.create_causal_test({}, None) self.assertEqual("Test configuration must specify an estimator", str(e.exception)) - def test_create_base_test_case_invalid_estimator(self): + def test_create_test_case_invalid_estimator(self): framework = CausalTestingFramework(self.paths) framework.setup() with self.assertRaises(ValueError) as e: framework.create_causal_test({"estimator": "InvalidEstimator"}, None) - self.assertEqual("Unknown estimator: InvalidEstimator", str(e.exception)) + self.assertEqual( + f"Unsupported estimator InvalidEstimator. Supported: ['CubicSplineEstimator', 'IPCWEstimator', 'InstrumentalVariableEstimator', 'LinearRegressionEstimator', 'LogisticRegressionEstimator']. " + "If you have implemented a custom estimator, you will need to add this to your entrypoints via your " + "pyproject.toml file.", + str(e.exception), + ) + + def test_create_test_case_invalid_effect(self): + framework = CausalTestingFramework(self.paths) + framework.setup() + test = { + "name": "test1", + "treatment_variable": "test_input", + "estimator": "LinearRegressionEstimator", + "estimate_type": "coefficient", + "expected_effect": {"test_output": "InvalidEffect"}, + } + base_test_case = framework.create_base_test(test) + with self.assertRaises(ValueError) as e: + framework.create_causal_test(test, base_test_case) + self.assertEqual( + f"Unsupported causal effect InvalidEffect. Supported: ['ExactValue', 'Negative', 'NoEffect', 'Positive', 'SomeEffect']. " + "If you have implemented a custom causal effect, you will need to add this to your entrypoints via your " + "pyproject.toml file.", + str(e.exception), + ) def test_create_base_test_case_missing_outcome(self): framework = CausalTestingFramework(self.paths) @@ -166,7 +191,8 @@ def test_ctf(self): test_passed = ( test_case.expected_causal_effect.apply(framework_result) - if framework_result.effect_estimate is not None else False + if framework_result.effect_estimate is not None + else False ) self.assertEqual(result["passed"], test_passed) @@ -518,4 +544,4 @@ def test_parse_args_generation_non_default(self): def tearDown(self): if self.output_path.parent.exists(): - shutil.rmtree(self.output_path.parent) \ No newline at end of file + shutil.rmtree(self.output_path.parent) From 740f840cafe51f86749a2af4eaebe5c9c3577c41 Mon Sep 17 00:00:00 2001 From: Michael Foster Date: Wed, 29 Apr 2026 14:32:25 +0100 Subject: [PATCH 05/13] Fixed estimator kwarg formatting --- causal_testing/main.py | 34 ++++++------------- .../testing/metamorphic_relation.py | 10 +++--- .../causal_test_results.json | 6 ++-- pyproject.toml | 8 ++--- tests/main_tests/test_main.py | 30 ++++++++++++++++ tests/resources/data/tests.json | 16 +++++---- 6 files changed, 63 insertions(+), 41 deletions(-) diff --git a/causal_testing/main.py b/causal_testing/main.py index 456ded52..60ce2bb7 100644 --- a/causal_testing/main.py +++ b/causal_testing/main.py @@ -273,27 +273,17 @@ def create_causal_test(self, test: dict, base_test: BaseTestCase) -> CausalTestC "pyproject.toml file." ) - # Handle combined queries (global and test-specific) - test_query = test.get("query") - combined_query = None - - if self.query and test_query: - combined_query = f"({self.query}) and ({test_query})" - logger.info( - f"Combining global query '{self.query}' with test-specific query " - f"'{test_query}' for test '{test['name']}'" - ) - elif test_query: - combined_query = test_query - logger.info(f"Using test-specific query for '{test['name']}': {test_query}") - elif self.query: - combined_query = self.query - logger.info(f"Using global query for '{test['name']}': {self.query}") - - filtered_df = self.data.query(combined_query) if combined_query else self.data + # Handle global queries + # Test-specific queries are handled by the estimator as not all estimators support them + filtered_df = self.data + if self.query: + filtered_df = self.data.query(self.query) # Create the estimator with correct parameters estimator_class = estimator_map.get(test["estimator"]).load() + estimator_kwargs = test.get("estimator_kwargs", {}) + if "query" in test: + estimator_kwargs["query"] = test["query"] estimator = estimator_class( base_test_case=base_test, treatment_value=test.get("treatment_value"), @@ -303,11 +293,8 @@ def create_causal_test(self, test: dict, base_test: BaseTestCase) -> CausalTestC self.dag.identification(base_test, self.scenario.hidden_variables()), ), df=filtered_df, - effect_modifiers=None, - formula=test.get("formula"), alpha=test.get("alpha", 0.05), - query=combined_query, - **test.get("estimator_kwargs", {}), + **estimator_kwargs, ) # Get effect type and create expected effect @@ -318,7 +305,7 @@ def create_causal_test(self, test: dict, base_test: BaseTestCase) -> CausalTestC "If you have implemented a custom causal effect, you will need to add this to your entrypoints via your " "pyproject.toml file." ) - expected_effect = effect_map[effect_type].load()() + expected_effect = effect_map[effect_type].load()(**test.get("effect_kwargs", {})) return CausalTestCase( base_test_case=base_test, @@ -326,7 +313,6 @@ def create_causal_test(self, test: dict, base_test: BaseTestCase) -> CausalTestC estimate_type=test.get("estimate_type", "ate"), estimate_params=test.get("estimate_params"), estimator=estimator, - **test.get("effect_kwargs", {}), ) def run_tests_in_batches( diff --git a/causal_testing/testing/metamorphic_relation.py b/causal_testing/testing/metamorphic_relation.py index 7f1e2c8d..942de00a 100644 --- a/causal_testing/testing/metamorphic_relation.py +++ b/causal_testing/testing/metamorphic_relation.py @@ -55,12 +55,14 @@ def to_json_stub( "estimate_type": estimate_type, "effect": effect_type, "treatment_variable": self.base_test_case.treatment_variable, - "formula": ( - f"{self.base_test_case.outcome_variable} ~ " - f"{' + '.join([self.base_test_case.treatment_variable] + self.adjustment_vars)}" - ), "alpha": alpha, "skip": skip, + "estimator_kwargs": { + "formula": ( + f"{self.base_test_case.outcome_variable} ~ " + f"{' + '.join([self.base_test_case.treatment_variable] + self.adjustment_vars)}" + ), + }, } diff --git a/docs/source/tutorials/vaccinating_elderly/causal_test_results.json b/docs/source/tutorials/vaccinating_elderly/causal_test_results.json index 969dcf17..22ec8a8b 100644 --- a/docs/source/tutorials/vaccinating_elderly/causal_test_results.json +++ b/docs/source/tutorials/vaccinating_elderly/causal_test_results.json @@ -176,7 +176,7 @@ "cum_vaccinated": "NoEffect" }, "alpha": 0.05, - "formula": "cum_vaccinated ~ cum_vaccinations + vaccine", + "formula": "cum_vaccinated ~ cum_vaccinations+vaccine", "skip": false, "passed": false, "result": { @@ -206,7 +206,7 @@ "cum_infections": "NoEffect" }, "alpha": 0.05, - "formula": "cum_infections ~ cum_vaccinations + vaccine", + "formula": "cum_infections ~ cum_vaccinations+vaccine", "skip": false, "passed": true, "result": { @@ -236,7 +236,7 @@ "cum_infections": "NoEffect" }, "alpha": 0.05, - "formula": "cum_infections ~ cum_vaccinated + vaccine", + "formula": "cum_infections ~ cum_vaccinated+vaccine", "skip": false, "passed": true, "result": { diff --git a/pyproject.toml b/pyproject.toml index a755c5c8..b876ff83 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -67,10 +67,10 @@ Issues = "https://github.com/CITCOM-project/CausalTestingFramework/issues" [project.entry-points."estimators"] LinearRegressionEstimator = "causal_testing.estimation.linear_regression_estimator:LinearRegressionEstimator" -LogisticRegressionEstimator = "causal_testing.estimation.linear_regression_estimator:LogisticRegressionEstimator" -CubicSplineEstimator = "causal_testing.estimation.linear_regression_estimator:CubicSplineEstimator" -InstrumentalVariableEstimator = "causal_testing.estimation.linear_regression_estimator:InstrumentalVariableEstimator" -IPCWEstimator = "causal_testing.estimation.linear_regression_estimator:IPCWEstimator" +LogisticRegressionEstimator = "causal_testing.estimation.logistic_regression_estimator:LogisticRegressionEstimator" +CubicSplineEstimator = "causal_testing.estimation.cubic_spline_estimator:CubicSplineEstimator" +InstrumentalVariableEstimator = "causal_testing.estimation.instrumental_variable_estimator:InstrumentalVariableEstimator" +IPCWEstimator = "causal_testing.estimation.ipcw_estimator:IPCWEstimator" [project.entry-points."causal_effects"] NoEffect = "causal_testing.testing.causal_effect:NoEffect" diff --git a/tests/main_tests/test_main.py b/tests/main_tests/test_main.py index 7d7cfb1a..14d92f88 100644 --- a/tests/main_tests/test_main.py +++ b/tests/main_tests/test_main.py @@ -144,6 +144,36 @@ def test_create_test_case_invalid_effect(self): str(e.exception), ) + def test_create_test_case_effect_kwargs(self): + framework = CausalTestingFramework(self.paths) + framework.setup() + test = { + "name": "test1", + "treatment_variable": "test_input", + "estimator": "LinearRegressionEstimator", + "estimate_type": "coefficient", + "expected_effect": {"test_output": "ExactValue"}, + "effect_kwargs": {"value": 4}, + } + base_test_case = framework.create_base_test(test) + test_case = framework.create_causal_test(test, base_test_case) + self.assertEqual(test_case.expected_causal_effect.value, 4) + + def test_create_test_case_estimator_kwargs(self): + framework = CausalTestingFramework(self.paths) + framework.setup() + test = { + "name": "test1", + "treatment_variable": "test_input", + "estimator": "InstrumentalVariableEstimator", + "estimate_type": "coefficient", + "expected_effect": {"test_output": "SomeEffect"}, + "estimator_kwargs": {"instrument": "instrumental_variable"}, + } + base_test_case = framework.create_base_test(test) + test_case = framework.create_causal_test(test, base_test_case) + self.assertEqual(test_case.estimator.instrument, "instrumental_variable") + def test_create_base_test_case_missing_outcome(self): framework = CausalTestingFramework(self.paths) framework.setup() diff --git a/tests/resources/data/tests.json b/tests/resources/data/tests.json index 1a92cb2d..8354a71a 100644 --- a/tests/resources/data/tests.json +++ b/tests/resources/data/tests.json @@ -4,19 +4,23 @@ "treatment_variable": "test_input", "estimator": "LinearRegressionEstimator", "estimate_type": "coefficient", - "effect_modifiers": [], - "query": "test_input > 0", "expected_effect": {"test_output": "NoEffect"}, - "skip": false + "skip": false, + "estimator_kwargs": { + "query": "test_input > 0", + "effect_modifiers": [] + } }, { "name": "test2", "treatment_variable": "test_input", "estimator": "LinearRegressionEstimator", "estimate_type": "coefficient", - "effect_modifiers": [], - "query": "test_input <= 5", "expected_effect": {"test_output": "NoEffect"}, - "skip": true + "skip": true, + "estimator_kwargs": { + "effect_modifiers": [], + "query": "test_input <= 5" + } }] } From fb199ac0efc4ec8ffc959f56fe1f9699e626b12f Mon Sep 17 00:00:00 2001 From: Michael Foster Date: Wed, 29 Apr 2026 14:49:21 +0100 Subject: [PATCH 06/13] MetamorphicRelation estimator validation and testing --- .../testing/metamorphic_relation.py | 6 ++++++ .../test_metamorphic_relations.py | 19 +++++++++++++++---- 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/causal_testing/testing/metamorphic_relation.py b/causal_testing/testing/metamorphic_relation.py index 942de00a..10300136 100644 --- a/causal_testing/testing/metamorphic_relation.py +++ b/causal_testing/testing/metamorphic_relation.py @@ -49,6 +49,12 @@ def to_json_stub( :param estimator: The name of the estimator class to use when evaluating the test :param alpha: The significance level to use when calculating the confidence intervals """ + if estimator not in ["LinearRegressionEstimator", "LogisticRegressionEstimator"]: + raise ValueError( + f"Unsupported estimator {estimator}. " + "We only support autogeneration using LinearRegressionEstimator or LogisticRegressionEstimator." + "More advanced estimators require careful thought that cannot be easily automated." + ) return { "name": str(self), "estimator": estimator, diff --git a/tests/testing_tests/test_metamorphic_relations.py b/tests/testing_tests/test_metamorphic_relations.py index 709a891b..7c88a775 100644 --- a/tests/testing_tests/test_metamorphic_relations.py +++ b/tests/testing_tests/test_metamorphic_relations.py @@ -41,6 +41,17 @@ def setUp(self) -> None: def tearDown(self) -> None: shutil.rmtree(self.temp_dir_path) + def test_json_stub_invalid_estimator(self): + """Test if the ShouldCause MR passes all metamorphic tests where the DAG perfectly represents the program + and there is only a single input.""" + causal_dag = CausalDAG(self.dag_dot_path) + causal_dag.remove_nodes_from(["X2", "X3"]) + adj_set = list(causal_dag.direct_effect_adjustment_sets(["X1"], ["Z"])[0]) + should_not_cause_mr = ShouldNotCause(BaseTestCase("X1", "Z"), adj_set) + with self.assertRaises(ValueError) as e: + should_not_cause_mr.to_json_stub(estimator="InvalidEstimator") + self.assertTrue(e.exception.startswith("Unsupported estimator estimator InvalidEstimator.")) + def test_should_not_cause_json_stub(self): """Test if the ShouldCause MR passes all metamorphic tests where the DAG perfectly represents the program and there is only a single input.""" @@ -57,7 +68,7 @@ def test_should_not_cause_json_stub(self): "expected_effect": {"Z": "NoEffect"}, "treatment_variable": "X1", "name": "X1 _||_ Z", - "formula": "Z ~ X1", + "estimator_kwargs": {"formula": "Z ~ X1"}, "alpha": 0.05, "skip": False, }, @@ -81,7 +92,7 @@ def test_should_not_cause_logistic_json_stub(self): "expected_effect": {"Z": "NoEffect"}, "treatment_variable": "X1", "name": "X1 _||_ Z", - "formula": "Z ~ X1", + "estimator_kwargs": {"formula": "Z ~ X1"}, "alpha": 0.05, "skip": False, }, @@ -101,7 +112,7 @@ def test_should_cause_json_stub(self): "estimate_type": "coefficient", "estimator": "LinearRegressionEstimator", "expected_effect": {"Z": "SomeEffect"}, - "formula": "Z ~ X1", + "estimator_kwargs": {"formula": "Z ~ X1"}, "treatment_variable": "X1", "name": "X1 --> Z", "alpha": 0.05, @@ -128,7 +139,7 @@ def test_should_cause_logistic_json_stub(self): "estimate_type": "unit_odds_ratio", "estimator": "LogisticRegressionEstimator", "expected_effect": {"Z": "SomeEffect"}, - "formula": "Z ~ X1", + "estimator_kwargs": {"formula": "Z ~ X1"}, "treatment_variable": "X1", "name": "X1 --> Z", "alpha": 0.05, From 09be6766577e7b1e0d2a8b031f60111430933ff5 Mon Sep 17 00:00:00 2001 From: Michael Foster Date: Wed, 29 Apr 2026 14:56:00 +0100 Subject: [PATCH 07/13] Docs --- docs/source/modules/causal_testing.rst | 2 +- .../vaccinating_elderly_tutorial.ipynb | 41 ++++++++++--------- 2 files changed, 23 insertions(+), 20 deletions(-) diff --git a/docs/source/modules/causal_testing.rst b/docs/source/modules/causal_testing.rst index 18b180a9..3d4e4e6d 100644 --- a/docs/source/modules/causal_testing.rst +++ b/docs/source/modules/causal_testing.rst @@ -34,7 +34,7 @@ In the following sections, we describe the end-to-end process of ``causal testin In particular, suppose we're interested in how various precautions, such as hand-washing and mask-wearing, can prevent the spread of a virus within a classroom. 1. Modelling Scenario ----------------- +--------------------- For our modelling scenario, suppose we define the scenario with the following constraints: diff --git a/docs/source/tutorials/vaccinating_elderly/vaccinating_elderly_tutorial.ipynb b/docs/source/tutorials/vaccinating_elderly/vaccinating_elderly_tutorial.ipynb index 1188c222..d3bd8420 100644 --- a/docs/source/tutorials/vaccinating_elderly/vaccinating_elderly_tutorial.ipynb +++ b/docs/source/tutorials/vaccinating_elderly/vaccinating_elderly_tutorial.ipynb @@ -98,7 +98,9 @@ " \"expected_effect\": {\n", " \"cum_vaccinations\": \"NoEffect\"\n", " },\n", - " \"formula\": \"cum_vaccinations ~ max_doses\",\n", + " \"estimator_kwargs\": {\n", + " \"formula\": \"cum_vaccinations ~ max_doses\",\n", + " },\n", " \"alpha\": 0.05,\n", " \"skip\": false\n", " },\n", @@ -119,7 +121,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "id": "e80a3064-9f5f-446e-b443-e7dd1d54d6a4", "metadata": {}, "outputs": [ @@ -127,20 +129,19 @@ "name": "stdout", "output_type": "stream", "text": [ - "Namespace(command='test', dag_path='dag.dot', output='causal_test_results.json', ignore_cycles=False, data_paths=['simulated_data.csv'], test_config='causal_tests.json', verbose=False, query=None, adequacy=False, bootstrap_size=None, silent=False, batch_size=0)\n", - "2025-12-12 11:21:57 - causal_testing.main - INFO - Setting up Causal Testing Framework...\n", - "2025-12-12 11:21:57 - causal_testing.main - INFO - Loading DAG from dag.dot\n", - "2025-12-12 11:21:57 - causal_testing.main - INFO - DAG loaded with 5 nodes and 3 edges\n", - "2025-12-12 11:21:57 - causal_testing.main - INFO - Loading data from 1 source(s)\n", - "2025-12-12 11:21:57 - causal_testing.main - INFO - Initial data shape: (60, 16)\n", - "2025-12-12 11:21:57 - causal_testing.main - INFO - Setup completed successfully\n", - "2025-12-12 11:21:57 - causal_testing.main - INFO - Loading test configurations from causal_tests.json\n", - "2025-12-12 11:21:57 - root - INFO - Running tests in regular mode\n", - "2025-12-12 11:21:57 - causal_testing.main - INFO - Running causal tests...\n", - "100%|████████████████████████████████████████████| 9/9 [00:00<00:00, 358.51it/s]\n", - "2025-12-12 11:21:57 - causal_testing.main - INFO - Saving results to causal_test_results.json\n", - "2025-12-12 11:21:57 - causal_testing.main - INFO - Results saved successfully\n", - "2025-12-12 11:21:57 - root - INFO - Causal testing completed successfully.\n" + "2026-04-29 14:55:38 - causal_testing.main - INFO - Setting up Causal Testing Framework...\n", + "2026-04-29 14:55:38 - causal_testing.main - INFO - Loading DAG from dag.dot\n", + "2026-04-29 14:55:38 - causal_testing.main - INFO - DAG loaded with 5 nodes and 3 edges\n", + "2026-04-29 14:55:38 - causal_testing.main - INFO - Loading data from 1 source(s)\n", + "2026-04-29 14:55:38 - causal_testing.main - INFO - Initial data shape: (60, 16)\n", + "2026-04-29 14:55:38 - causal_testing.main - INFO - Setup completed successfully\n", + "2026-04-29 14:55:38 - causal_testing.main - INFO - Loading test configurations from causal_tests.json\n", + "2026-04-29 14:55:39 - root - INFO - Running tests in regular mode\n", + "2026-04-29 14:55:39 - causal_testing.main - INFO - Running causal tests...\n", + "100%|████████████████████████████████████████████| 9/9 [00:00<00:00, 741.98it/s]\n", + "2026-04-29 14:55:39 - causal_testing.main - INFO - Saving results to causal_test_results.json\n", + "2026-04-29 14:55:39 - causal_testing.main - INFO - Results saved successfully\n", + "2026-04-29 14:55:39 - root - INFO - Causal testing completed successfully.\n" ] } ], @@ -172,7 +173,9 @@ " \"expected_effect\": {\n", " \"cum_vaccinations\": \"NoEffect\"\n", " },\n", - " \"formula\": \"cum_vaccinations ~ max_doses\",\n", + " \"estimator_kwargs\": {\n", + " \"formula\": \"cum_vaccinations ~ max_doses\",\n", + " },\n", " \"alpha\": 0.05,\n", " \"skip\": false,\n", " \"passed\": false,\n", @@ -232,7 +235,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "Python 3 (CI)", "language": "python", "name": "python3" }, @@ -246,7 +249,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.12" + "version": "3.11.15" } }, "nbformat": 4, From 145819291538b9184f9ec11e6eb4f2b4c8aaf444 Mon Sep 17 00:00:00 2001 From: Michael Foster Date: Wed, 29 Apr 2026 15:04:05 +0100 Subject: [PATCH 08/13] Pylint imports --- causal_testing/main.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/causal_testing/main.py b/causal_testing/main.py index 60ce2bb7..9fa3fd55 100644 --- a/causal_testing/main.py +++ b/causal_testing/main.py @@ -13,13 +13,10 @@ import pandas as pd from tqdm import tqdm -from causal_testing.estimation.linear_regression_estimator import LinearRegressionEstimator -from causal_testing.estimation.logistic_regression_estimator import LogisticRegressionEstimator from causal_testing.specification.causal_dag import CausalDAG from causal_testing.specification.scenario import Scenario from causal_testing.specification.variable import Input, Output from causal_testing.testing.base_test_case import BaseTestCase -from causal_testing.testing.causal_effect import Negative, NoEffect, Positive, SomeEffect from causal_testing.testing.causal_test_adequacy import DataAdequacy from causal_testing.testing.causal_test_case import CausalTestCase from causal_testing.testing.causal_test_result import CausalTestResult @@ -302,8 +299,8 @@ def create_causal_test(self, test: dict, base_test: BaseTestCase) -> CausalTestC if effect_type not in effect_map: raise ValueError( f"Unsupported causal effect {effect_type}. Supported: {sorted(effect_map)}. " - "If you have implemented a custom causal effect, you will need to add this to your entrypoints via your " - "pyproject.toml file." + "If you have implemented a custom causal effect, you will need to add this to your entrypoints via " + "your pyproject.toml file." ) expected_effect = effect_map[effect_type].load()(**test.get("effect_kwargs", {})) From d204e2f5137968475ede8be3f0c8171b270680fe Mon Sep 17 00:00:00 2001 From: Michael Foster Date: Wed, 29 Apr 2026 15:25:41 +0100 Subject: [PATCH 09/13] Added a top level `causal-testing` entrypoint --- README.md | 4 ++-- causal_testing/testing/metamorphic_relation.py | 2 +- dafni/entrypoint.sh | 6 +++--- docs/source/installation.rst | 4 ++-- .../vaccinating_elderly/vaccinating_elderly_tutorial.ipynb | 4 ++-- examples/poisson-line-process/README.md | 2 +- pyproject.toml | 3 +++ 7 files changed, 14 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index 4805a6b4..96d5ae81 100644 --- a/README.md +++ b/README.md @@ -89,12 +89,12 @@ For more information on how to use the Causal Testing Framework, please refer to 2. If you do not already have causal test cases, you can convert your causal DAG to causal tests by running the following command. ``` -python -m causal_testing generate --dag-path $PATH_TO_DAG --output $PATH_TO_TESTS +causal-testing generate --dag-path $PATH_TO_DAG --output $PATH_TO_TESTS ``` 3. You can now execute your tests by running the following command. ``` -python -m causal_testing test --dag-path $PATH_TO_DAG --data-paths $PATH_TO_DATA --test-config $PATH_TO_TESTS --output $OUTPUT +causal-testing test --dag-path $PATH_TO_DAG --data-paths $PATH_TO_DATA --test-config $PATH_TO_TESTS --output $OUTPUT ``` The results will be saved for inspection in a JSON file located at `$OUTPUT`. In the future, we hope to add a visualisation tool to assist with this. diff --git a/causal_testing/testing/metamorphic_relation.py b/causal_testing/testing/metamorphic_relation.py index 10300136..1b6b5c84 100644 --- a/causal_testing/testing/metamorphic_relation.py +++ b/causal_testing/testing/metamorphic_relation.py @@ -279,7 +279,7 @@ def generate_causal_tests( logger.warning( "The skip parameter is hard-coded to False during test generation for better integration with the " - "causal testing component (python -m causal_testing test ...)" + "causal testing component (causal-testing test ...)" "Please carefully review the generated tests and decide which to skip." ) diff --git a/dafni/entrypoint.sh b/dafni/entrypoint.sh index 502dfea2..71dfde8e 100644 --- a/dafni/entrypoint.sh +++ b/dafni/entrypoint.sh @@ -74,7 +74,7 @@ if [ "$EXECUTION_MODE" = "generate" ]; then echo "Running causal_testing GENERATE..." echo "Will write causal tests to: $CAUSAL_TESTS_OUTPUT_PATH" - python -m causal_testing generate \ + causal-testing generate \ -D "$DAG_PATH" \ -o "$CAUSAL_TESTS_OUTPUT_PATH" \ -e "$ESTIMATOR" \ @@ -107,7 +107,7 @@ elif [ "$EXECUTION_MODE" = "test" ]; then # Build command with adequacy flags only when ADEQUACY is true if [ "$ADEQUACY" = "true" ]; then echo "DEBUG: Executing WITH adequacy flags" - python -m causal_testing test \ + causal-testing test \ -D "$DAG_PATH" \ -d $DATA_PATHS \ -t "$CAUSAL_TESTS_INPUT_PATH" \ @@ -120,7 +120,7 @@ elif [ "$EXECUTION_MODE" = "test" ]; then $([ "$BATCH_SIZE" != "0" ] && echo "--batch-size $BATCH_SIZE") else echo "DEBUG: Executing WITHOUT adequacy flags" - python -m causal_testing test \ + causal-testing test \ -D "$DAG_PATH" \ -d $DATA_PATHS \ -t "$CAUSAL_TESTS_INPUT_PATH" \ diff --git a/docs/source/installation.rst b/docs/source/installation.rst index 74e5ee6a..139398b5 100644 --- a/docs/source/installation.rst +++ b/docs/source/installation.rst @@ -95,11 +95,11 @@ Next Steps * Read about :doc:`modules/causal_specification` to understand causal specifications and :doc:`modules/causal_testing` for the end-to-end causal testing process. * Run the command for guidance on how to generate your causal tests directly from your input DAG:: - python -m causal_testing generate --help + causal-testing generate --help * and the command on guidance on how to execute your causal tests:: - python -m causal_testing test --help + causal-testing test --help Using the CTF on DAFNI diff --git a/docs/source/tutorials/vaccinating_elderly/vaccinating_elderly_tutorial.ipynb b/docs/source/tutorials/vaccinating_elderly/vaccinating_elderly_tutorial.ipynb index d3bd8420..717d8131 100644 --- a/docs/source/tutorials/vaccinating_elderly/vaccinating_elderly_tutorial.ipynb +++ b/docs/source/tutorials/vaccinating_elderly/vaccinating_elderly_tutorial.ipynb @@ -48,7 +48,7 @@ "**Note:** If you haven't created your own causal tests, it's possible to utilise the CTF to automatically generate tests based on your input DAG using the following command:\n", "\n", "```python\n", - "python -m causal_testing generate --dag_path dag.dot --output_path causal_tests.json\n", + "causal-testing generate --dag_path dag.dot --output_path causal_tests.json\n", "```" ] }, @@ -146,7 +146,7 @@ } ], "source": [ - "!python -m causal_testing test --data-paths simulated_data.csv --dag-path dag.dot --test-config causal_tests.json --output causal_test_results.json" + "!causal-testing test --data-paths simulated_data.csv --dag-path dag.dot --test-config causal_tests.json --output causal_test_results.json" ] }, { diff --git a/examples/poisson-line-process/README.md b/examples/poisson-line-process/README.md index e50ff26f..8c725bbe 100644 --- a/examples/poisson-line-process/README.md +++ b/examples/poisson-line-process/README.md @@ -14,5 +14,5 @@ This should print a series of causal test results and produce two CSV files. `in You should be able to run the main entrypoint by simply running the following command from within this directory: ``` -python -m causal_testing --dag_path dag.dot --data_paths data/random/data_random_1000.csv --test_config causal_tests.json --output results/test_results.json +causal-testing --dag_path dag.dot --data_paths data/random/data_random_1000.csv --test_config causal_tests.json --output results/test_results.json ``` diff --git a/pyproject.toml b/pyproject.toml index b876ff83..d02dffb7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,6 +65,9 @@ Documentation = "https://causal-testing-framework.readthedocs.io/" Repository = "https://github.com/CITCOM-project/CausalTestingFramework" Issues = "https://github.com/CITCOM-project/CausalTestingFramework/issues" +[project.scripts] +causal-testing = "causal_testing.__main__:main" + [project.entry-points."estimators"] LinearRegressionEstimator = "causal_testing.estimation.linear_regression_estimator:LinearRegressionEstimator" LogisticRegressionEstimator = "causal_testing.estimation.logistic_regression_estimator:LogisticRegressionEstimator" From 8c99927f2865b4655a6bb72d6f4811a39fd4c3e9 Mon Sep 17 00:00:00 2001 From: Michael Foster Date: Thu, 30 Apr 2026 10:58:17 +0100 Subject: [PATCH 10/13] Added custom estimation docs --- docs/source/modules/custom_estimators.rst | 62 +++++++++++++++++++ docs/source/modules/estimators.rst | 2 + .../example_pure_python.py | 40 ++++++------ 3 files changed, 86 insertions(+), 18 deletions(-) create mode 100644 docs/source/modules/custom_estimators.rst diff --git a/docs/source/modules/custom_estimators.rst b/docs/source/modules/custom_estimators.rst new file mode 100644 index 00000000..009e2fea --- /dev/null +++ b/docs/source/modules/custom_estimators.rst @@ -0,0 +1,62 @@ +Custom Estimators +================= + +If the supported :ref:`estimators` are not sufficient for your needs, you can implement your own custom estimator by extending the :code:`Estimator` class and implementing the abstract :code:`add_modelling_assumptions` method and the estimation method for the causal effect measure you wish to calculate. +For example, if you wished to estimate the ATE using the empirical mean of the recorded outcome under the control and treatment values, you would need to implement a method called :code:`estimate_ate`. +If you wished to estimate the risk ratio, you would need to call your method :code:`estimate_risk_ratio`. +The code for the :code:`EmpiricalMeanEstimator` is shown below. + +.. code-block:: python + + from causal_testing.estimation.abstract_estimator import Estimator + from scipy.stats import bootstrap + + class EmpiricalMeanEstimator(Estimator): + """ + Custom estimator class to estimate the causal effect based on the empirical mean. + """ + + def add_modelling_assumptions(self): + """ + Add modelling assumptions to the estimator. This is a list of strings which list the modelling assumptions that + must hold if the resulting causal inference is to be considered valid. + """ + self.modelling_assumptions += "The data must contain runs with the exact configuration of interest." + + def estimate_ate(self) -> EffectEstimate: + """Estimate the outcomes under control and treatment. + :return: The empirical average treatment effect. + """ + treatment_variable = self.base_test_case.treatment_variable.name + outcome_variable = self.base_test_case.outcome_variable.name + + control_results = self.df.where(self.df[treatment_variable] == self.control_value)[outcome_variable].dropna() + treatment_results = self.df.where(self.df[treatment_variable] == self.treatment_value)[ + outcome_variable + ].dropna() + + def risk_ratio(sample1, sample2): + return sample1.mean() - sample2.mean() + + bootstraps = bootstrap((treatment_results, control_results), risk_ratio, confidence_level=self.alpha) + return EffectEstimate( + type="risk_ratio", + value=risk_ratio(treatment_results, control_results), + ci_low=bootstraps.confidence_interval.low, + ci_high=bootstraps.confidence_interval.high, + ) + +Once you have implemented your estimator, you will need to register it as an extra entry point in your :code:`pyproject.toml` file so that the Causal Testing Framework can find it. +For example, if you had defined your :code:`EmpiricalMeanEstimator` class in a module called :code:`empirical_mean_estimator` in a folder called :code:`custom_estimators`, you would register it as follows. + +.. code-block:: ini + + [project.entry-points."estimators"] + CustomFlakefighter = "custom_estimators.empirical_mean_estimator:EmpiricalMeanEstimator" + +Of course, for this to work, your module needs to be discoverable on your python path. +That is, you should be able to execute :code:`from custom_estimators.empirical_mean_estimator import EmpiricalMeanEstimator` successfully from within the current working directory. + +You can also add your custom estimator to causal test cases specified in JSON. +To do so, you can simply set the :code:`estimator` property to the name of your estimator class and the :code:`estimate_type` property to the name of your causal effect measure. +In the above :code:`EmpiricalMeanEstimator` example, :code:`estimator` would be set to :code:`"EmpiricalMeanEstimator"` and :code:`estimate_type` would be set to :code:`"ate"`. diff --git a/docs/source/modules/estimators.rst b/docs/source/modules/estimators.rst index fc072113..22bd5955 100644 --- a/docs/source/modules/estimators.rst +++ b/docs/source/modules/estimators.rst @@ -1,3 +1,5 @@ +.. _estimators: + Estimators Overview =================== diff --git a/examples/poisson-line-process/example_pure_python.py b/examples/poisson-line-process/example_pure_python.py index 4bbaa878..544ae8c3 100644 --- a/examples/poisson-line-process/example_pure_python.py +++ b/examples/poisson-line-process/example_pure_python.py @@ -2,6 +2,7 @@ import logging import pandas as pd +from scipy.stats import bootstrap from causal_testing.specification.causal_dag import CausalDAG from causal_testing.specification.scenario import Scenario @@ -19,6 +20,10 @@ class EmpiricalMeanEstimator(Estimator): + """ + Custom estimator class to estimate the causal effect based on the empirical mean. + """ + def add_modelling_assumptions(self): """ Add modelling assumptions to the estimator. This is a list of strings which list the modelling assumptions that @@ -26,29 +31,28 @@ def add_modelling_assumptions(self): """ self.modelling_assumptions += "The data must contain runs with the exact configuration of interest." - def estimate_ate(self) -> EffectEstimate: + def estimate_risk_ratio(self) -> EffectEstimate: """Estimate the outcomes under control and treatment. :return: The empirical average treatment effect. """ - control_results = self.df.where(self.df[self.base_test_case.treatment_variable.name] == self.control_value)[ - self.base_test_case.outcome_variable.name - ].dropna() - treatment_results = self.df.where(self.df[self.base_test_case.treatment_variable.name] == self.treatment_value)[ - self.base_test_case.outcome_variable.name - ].dropna() - return EffectEstimate("ate", treatment_results.mean() - control_results.mean()) + treatment_variable = self.base_test_case.treatment_variable.name + outcome_variable = self.base_test_case.outcome_variable.name - def estimate_risk_ratio(self) -> float: - """Estimate the outcomes under control and treatment. - :return: The empirical average treatment effect. - """ - control_results = self.df.where(self.df[self.base_test_case.treatment_variable.name] == self.control_value)[ - self.base_test_case.outcome_variable.name - ].dropna() - treatment_results = self.df.where(self.df[self.base_test_case.treatment_variable.name] == self.treatment_value)[ - self.base_test_case.outcome_variable.name + control_results = self.df.where(self.df[treatment_variable] == self.control_value)[outcome_variable].dropna() + treatment_results = self.df.where(self.df[treatment_variable] == self.treatment_value)[ + outcome_variable ].dropna() - return EffectEstimate("risk_ratio", treatment_results.mean() / control_results.mean()) + + def risk_ratio(sample1, sample2): + return sample1.mean() / sample2.mean() + + bootstraps = bootstrap((treatment_results, control_results), risk_ratio, confidence_level=self.alpha) + return EffectEstimate( + type="risk_ratio", + value=risk_ratio(treatment_results, control_results), + ci_low=bootstraps.confidence_interval.low, + ci_high=bootstraps.confidence_interval.high, + ) # 1. Read in the Causal DAG From 161c96494b367ed6ca96688cba44ad798cb6f39c Mon Sep 17 00:00:00 2001 From: Michael Foster Date: Thu, 30 Apr 2026 11:02:15 +0100 Subject: [PATCH 11/13] Added custom estimation docs --- docs/source/index.rst | 1 + docs/source/modules/custom_estimators.rst | 62 +++++++++++++++++++ docs/source/modules/estimators.rst | 2 + .../example_pure_python.py | 40 ++++++------ 4 files changed, 87 insertions(+), 18 deletions(-) create mode 100644 docs/source/modules/custom_estimators.rst diff --git a/docs/source/index.rst b/docs/source/index.rst index 530c4042..bd7a8b51 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -41,6 +41,7 @@ If you have any questions about our framework, you can also reach us by `email < /modules/causal_specification /modules/estimators + /modules/custom_estimators /modules/causal_testing .. toctree:: diff --git a/docs/source/modules/custom_estimators.rst b/docs/source/modules/custom_estimators.rst new file mode 100644 index 00000000..009e2fea --- /dev/null +++ b/docs/source/modules/custom_estimators.rst @@ -0,0 +1,62 @@ +Custom Estimators +================= + +If the supported :ref:`estimators` are not sufficient for your needs, you can implement your own custom estimator by extending the :code:`Estimator` class and implementing the abstract :code:`add_modelling_assumptions` method and the estimation method for the causal effect measure you wish to calculate. +For example, if you wished to estimate the ATE using the empirical mean of the recorded outcome under the control and treatment values, you would need to implement a method called :code:`estimate_ate`. +If you wished to estimate the risk ratio, you would need to call your method :code:`estimate_risk_ratio`. +The code for the :code:`EmpiricalMeanEstimator` is shown below. + +.. code-block:: python + + from causal_testing.estimation.abstract_estimator import Estimator + from scipy.stats import bootstrap + + class EmpiricalMeanEstimator(Estimator): + """ + Custom estimator class to estimate the causal effect based on the empirical mean. + """ + + def add_modelling_assumptions(self): + """ + Add modelling assumptions to the estimator. This is a list of strings which list the modelling assumptions that + must hold if the resulting causal inference is to be considered valid. + """ + self.modelling_assumptions += "The data must contain runs with the exact configuration of interest." + + def estimate_ate(self) -> EffectEstimate: + """Estimate the outcomes under control and treatment. + :return: The empirical average treatment effect. + """ + treatment_variable = self.base_test_case.treatment_variable.name + outcome_variable = self.base_test_case.outcome_variable.name + + control_results = self.df.where(self.df[treatment_variable] == self.control_value)[outcome_variable].dropna() + treatment_results = self.df.where(self.df[treatment_variable] == self.treatment_value)[ + outcome_variable + ].dropna() + + def risk_ratio(sample1, sample2): + return sample1.mean() - sample2.mean() + + bootstraps = bootstrap((treatment_results, control_results), risk_ratio, confidence_level=self.alpha) + return EffectEstimate( + type="risk_ratio", + value=risk_ratio(treatment_results, control_results), + ci_low=bootstraps.confidence_interval.low, + ci_high=bootstraps.confidence_interval.high, + ) + +Once you have implemented your estimator, you will need to register it as an extra entry point in your :code:`pyproject.toml` file so that the Causal Testing Framework can find it. +For example, if you had defined your :code:`EmpiricalMeanEstimator` class in a module called :code:`empirical_mean_estimator` in a folder called :code:`custom_estimators`, you would register it as follows. + +.. code-block:: ini + + [project.entry-points."estimators"] + CustomFlakefighter = "custom_estimators.empirical_mean_estimator:EmpiricalMeanEstimator" + +Of course, for this to work, your module needs to be discoverable on your python path. +That is, you should be able to execute :code:`from custom_estimators.empirical_mean_estimator import EmpiricalMeanEstimator` successfully from within the current working directory. + +You can also add your custom estimator to causal test cases specified in JSON. +To do so, you can simply set the :code:`estimator` property to the name of your estimator class and the :code:`estimate_type` property to the name of your causal effect measure. +In the above :code:`EmpiricalMeanEstimator` example, :code:`estimator` would be set to :code:`"EmpiricalMeanEstimator"` and :code:`estimate_type` would be set to :code:`"ate"`. diff --git a/docs/source/modules/estimators.rst b/docs/source/modules/estimators.rst index fc072113..22bd5955 100644 --- a/docs/source/modules/estimators.rst +++ b/docs/source/modules/estimators.rst @@ -1,3 +1,5 @@ +.. _estimators: + Estimators Overview =================== diff --git a/examples/poisson-line-process/example_pure_python.py b/examples/poisson-line-process/example_pure_python.py index 4bbaa878..544ae8c3 100644 --- a/examples/poisson-line-process/example_pure_python.py +++ b/examples/poisson-line-process/example_pure_python.py @@ -2,6 +2,7 @@ import logging import pandas as pd +from scipy.stats import bootstrap from causal_testing.specification.causal_dag import CausalDAG from causal_testing.specification.scenario import Scenario @@ -19,6 +20,10 @@ class EmpiricalMeanEstimator(Estimator): + """ + Custom estimator class to estimate the causal effect based on the empirical mean. + """ + def add_modelling_assumptions(self): """ Add modelling assumptions to the estimator. This is a list of strings which list the modelling assumptions that @@ -26,29 +31,28 @@ def add_modelling_assumptions(self): """ self.modelling_assumptions += "The data must contain runs with the exact configuration of interest." - def estimate_ate(self) -> EffectEstimate: + def estimate_risk_ratio(self) -> EffectEstimate: """Estimate the outcomes under control and treatment. :return: The empirical average treatment effect. """ - control_results = self.df.where(self.df[self.base_test_case.treatment_variable.name] == self.control_value)[ - self.base_test_case.outcome_variable.name - ].dropna() - treatment_results = self.df.where(self.df[self.base_test_case.treatment_variable.name] == self.treatment_value)[ - self.base_test_case.outcome_variable.name - ].dropna() - return EffectEstimate("ate", treatment_results.mean() - control_results.mean()) + treatment_variable = self.base_test_case.treatment_variable.name + outcome_variable = self.base_test_case.outcome_variable.name - def estimate_risk_ratio(self) -> float: - """Estimate the outcomes under control and treatment. - :return: The empirical average treatment effect. - """ - control_results = self.df.where(self.df[self.base_test_case.treatment_variable.name] == self.control_value)[ - self.base_test_case.outcome_variable.name - ].dropna() - treatment_results = self.df.where(self.df[self.base_test_case.treatment_variable.name] == self.treatment_value)[ - self.base_test_case.outcome_variable.name + control_results = self.df.where(self.df[treatment_variable] == self.control_value)[outcome_variable].dropna() + treatment_results = self.df.where(self.df[treatment_variable] == self.treatment_value)[ + outcome_variable ].dropna() - return EffectEstimate("risk_ratio", treatment_results.mean() / control_results.mean()) + + def risk_ratio(sample1, sample2): + return sample1.mean() / sample2.mean() + + bootstraps = bootstrap((treatment_results, control_results), risk_ratio, confidence_level=self.alpha) + return EffectEstimate( + type="risk_ratio", + value=risk_ratio(treatment_results, control_results), + ci_low=bootstraps.confidence_interval.low, + ci_high=bootstraps.confidence_interval.high, + ) # 1. Read in the Causal DAG From b9e1d042a3d4e0ee01b4181e90d5d637acf367e1 Mon Sep 17 00:00:00 2001 From: Michael Foster Date: Thu, 30 Apr 2026 11:07:11 +0100 Subject: [PATCH 12/13] Fixed title level misalighment error --- .../poisson_line_process_tutorial.ipynb | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/docs/source/tutorials/poisson_line_process/poisson_line_process_tutorial.ipynb b/docs/source/tutorials/poisson_line_process/poisson_line_process_tutorial.ipynb index 56d2ffc7..353622a7 100644 --- a/docs/source/tutorials/poisson_line_process/poisson_line_process_tutorial.ipynb +++ b/docs/source/tutorials/poisson_line_process/poisson_line_process_tutorial.ipynb @@ -5,7 +5,7 @@ "id": "5adf7cdc-fd96-47a4-a194-f1f060a4c0c5", "metadata": {}, "source": [ - "## Overview" + "# Statistical Metamorphic Testing using the API" ] }, { @@ -26,14 +26,6 @@ "Before diving into the details, a good first step is to define your file paths, including your input configurations:" ] }, - { - "cell_type": "markdown", - "id": "56965fba-b90b-4233-a819-bb747ecd9d81", - "metadata": {}, - "source": [ - "# Statistical Metamorphic Testing using the API" - ] - }, { "cell_type": "code", "execution_count": 1, @@ -841,7 +833,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.14" + "version": "3.11.15" } }, "nbformat": 4, From 99fb9abea839e9ee1cd310dad2d74d4860c589a8 Mon Sep 17 00:00:00 2001 From: Michael Foster Date: Fri, 1 May 2026 13:54:31 +0100 Subject: [PATCH 13/13] Added endpoint reinstall note --- docs/source/modules/custom_estimators.rst | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/docs/source/modules/custom_estimators.rst b/docs/source/modules/custom_estimators.rst index 009e2fea..54d470fa 100644 --- a/docs/source/modules/custom_estimators.rst +++ b/docs/source/modules/custom_estimators.rst @@ -46,8 +46,11 @@ The code for the :code:`EmpiricalMeanEstimator` is shown below. ci_high=bootstraps.confidence_interval.high, ) -Once you have implemented your estimator, you will need to register it as an extra entry point in your :code:`pyproject.toml` file so that the Causal Testing Framework can find it. +Once you have implemented your estimator, you will need to register it as an extra entry point in your project's :code:`pyproject.toml` file so that the Causal Testing Framework can find it. For example, if you had defined your :code:`EmpiricalMeanEstimator` class in a module called :code:`empirical_mean_estimator` in a folder called :code:`custom_estimators`, you would register it as follows. +You will also need to reinstall your project, e.g. with :code:`pip install -e .` each time you add a new estimator to your :code:`pyproject.toml`. +You do not need to reinstall each time you edit your project for source code edits. + .. code-block:: ini