From 5af282fe61d3d2c2c4b9fdfe743163542fd621fc Mon Sep 17 00:00:00 2001 From: Gregor Lichtner Date: Mon, 10 Mar 2025 09:18:12 +0100 Subject: [PATCH 01/16] fix: Decimal value and float --- execution_engine/converter/criterion.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/execution_engine/converter/criterion.py b/execution_engine/converter/criterion.py index b70b7b98..a88c9c84 100644 --- a/execution_engine/converter/criterion.py +++ b/execution_engine/converter/criterion.py @@ -99,10 +99,14 @@ def parse_value( eps = float(value_numeric) / 1e5 match value.comparator: case "<=" | "<": - value_max = value_numeric - (eps if value.comparator == "<" else 0) + value_max = float(value_numeric) - ( + eps if value.comparator == "<" else 0 + ) value_numeric = None case ">=" | ">": - value_min = value_numeric + (eps if value.comparator == "<" else 0) + value_min = float(value_numeric) + ( + eps if value.comparator == "<" else 0 + ) value_numeric = None case _: raise ValueError(f'Unknown quantity operator: "{value.comparator}"') From 35a582162f04e924fe0420e0af6025c41c466055 Mon Sep 17 00:00:00 2001 From: Gregor Lichtner Date: Mon, 10 Mar 2025 09:18:46 +0100 Subject: [PATCH 02/16] fix: ValueScalar as float --- execution_engine/util/value/value.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/execution_engine/util/value/value.py b/execution_engine/util/value/value.py index bc6c8f93..72c892c7 100644 --- a/execution_engine/util/value/value.py +++ b/execution_engine/util/value/value.py @@ -314,9 +314,6 @@ class ValueScalar(ValueNumeric[float, None]): unit: None = None - _validate_value = field_validator("value", mode="before")(check_int) - _validate_value_min = field_validator("value_min", mode="before")(check_int) - _validate_value_max = field_validator("value_max", mode="before")(check_int) _validate_no_unit = field_validator("unit", mode="before")(check_unit_none) def supports_units(self) -> bool: From 233b745db5d8ebc4062599b4200e0a63d6824bd4 Mon Sep 17 00:00:00 2001 From: Gregor Lichtner Date: Mon, 10 Mar 2025 09:19:05 +0100 Subject: [PATCH 03/16] feat: add ICD10CM --- execution_engine/omop/vocabulary.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/execution_engine/omop/vocabulary.py b/execution_engine/omop/vocabulary.py index 8c212ee1..d55156f9 100644 --- a/execution_engine/omop/vocabulary.py +++ b/execution_engine/omop/vocabulary.py @@ -110,6 +110,15 @@ class ICD10GM(AbstractStandardVocabulary): omop_vocab_name = "ICD10GM" +class ICD10CM(AbstractStandardVocabulary): + """ + ICD10 Clinical Modification + """ + + system_uri = "http://hl7.org/fhir/sid/icd-10-cm" + omop_vocab_name = "ICD10CM" + + class UCUM(AbstractStandardVocabulary): """ UCUM vocabulary. @@ -232,6 +241,7 @@ def init(self) -> None: self.register(UCUM) self.register(ATCDE) self.register(ICD10GM) + self.register(ICD10CM) self.register(CODEXCELIDA) def register(self, vocabulary: Type[AbstractVocabulary]) -> None: From b3079ee2e2abc3745c7ef25328bc8ad5fbf998e5 Mon Sep 17 00:00:00 2001 From: Gregor Lichtner Date: Mon, 10 Mar 2025 09:19:35 +0100 Subject: [PATCH 04/16] fix: too many clients sql error --- execution_engine/converter/parser/fhir_parser_v1.py | 3 +++ execution_engine/omop/sqlclient.py | 6 +++++- execution_engine/task/task.py | 1 + 3 files changed, 9 insertions(+), 1 deletion(-) diff --git a/execution_engine/converter/parser/fhir_parser_v1.py b/execution_engine/converter/parser/fhir_parser_v1.py index 40b21d7d..bd4bd277 100644 --- a/execution_engine/converter/parser/fhir_parser_v1.py +++ b/execution_engine/converter/parser/fhir_parser_v1.py @@ -79,6 +79,9 @@ def parse_time_from_event( new_combo = converter.to_temporal_combination(combo) + if not isinstance(new_combo, CriterionCombination): + raise ValueError(f"Expected CriterionCombination, got {type(new_combo)}") + return new_combo def parse_characteristics(self, ev: EvidenceVariable) -> CriterionCombination: diff --git a/execution_engine/omop/sqlclient.py b/execution_engine/omop/sqlclient.py index eace72ae..b5488029 100644 --- a/execution_engine/omop/sqlclient.py +++ b/execution_engine/omop/sqlclient.py @@ -4,7 +4,7 @@ import pandas as pd import sqlalchemy -from sqlalchemy import and_, bindparam, event, func, select, text +from sqlalchemy import NullPool, and_, bindparam, event, func, select, text from sqlalchemy.engine.interfaces import DBAPIConnection from sqlalchemy.pool import ConnectionPoolEntry from sqlalchemy.sql import Insert, Select @@ -83,6 +83,7 @@ def __init__( result_schema: str, timezone: str = "Europe/Berlin", disable_triggers: bool = False, + null_pool: bool = False, ) -> None: """Initialize the OMOP SQL client.""" @@ -97,6 +98,9 @@ def __init__( connection_string, connect_args={"options": "-csearch_path={}".format(self._data_schema)}, future=True, + poolclass=( + NullPool if null_pool else None + ), # <--- ensures no persistent pool ) if disable_triggers: diff --git a/execution_engine/task/task.py b/execution_engine/task/task.py index 0968fbda..36493593 100644 --- a/execution_engine/task/task.py +++ b/execution_engine/task/task.py @@ -27,6 +27,7 @@ def get_engine() -> OMOPSQLClient: return OMOPSQLClient( **get_config().omop.model_dump(by_alias=True), timezone=get_config().timezone, + null_pool=True, ) From 0b26e6a602a73768ef2789fba694ee8c9906f119 Mon Sep 17 00:00:00 2001 From: Gregor Lichtner Date: Mon, 10 Mar 2025 17:33:32 +0100 Subject: [PATCH 05/16] refactor: remove criterion combination --- .flake8 | 2 +- execution_engine/converter/action/abstract.py | 15 +- .../converter/action/body_positioning.py | 15 +- .../converter/action/drug_administration.py | 50 +- .../converter/action/procedure.py | 8 +- .../converter/action/ventilator_management.py | 6 +- .../converter/characteristic/abstract.py | 20 +- .../characteristic/codeable_concept.py | 14 +- .../converter/characteristic/value.py | 13 +- execution_engine/converter/criterion.py | 26 +- .../converter/goal/assessment_scale.py | 12 +- .../converter/goal/laboratory_value.py | 12 +- .../converter/goal/ventilator_management.py | 18 +- execution_engine/converter/parser/base.py | 9 +- .../converter/parser/fhir_parser_v1.py | 144 ++--- .../converter/parser/fhir_parser_v2.py | 32 +- .../converter/recommendation_factory.py | 24 +- .../converter/time_from_event/abstract.py | 37 +- execution_engine/execution_engine.py | 21 +- execution_engine/execution_graph/graph.py | 291 +++-------- execution_engine/omop/cohort/__init__.py | 2 +- .../cohort/population_intervention_pair.py | 332 ++---------- .../omop/cohort/recommendation.py | 143 +++-- execution_engine/omop/criterion/abstract.py | 5 +- .../omop/criterion/combination/combination.py | 296 ----------- .../omop/criterion/combination/logical.py | 292 ----------- .../omop/criterion/combination/temporal.py | 490 ------------------ execution_engine/omop/criterion/factory.py | 29 +- execution_engine/task/creator.py | 3 +- execution_engine/task/runner.py | 7 +- execution_engine/task/task.py | 8 +- .../util/{cohort_logic.py => logic.py} | 177 +++++-- .../converter/action/test_assessment.py | 2 +- .../action/test_drug_administration.py | 4 +- .../converter/test_converter.py | 7 +- .../omop/criterion/test_criterion.py | 10 +- .../util/test_cohort_logic.py | 2 +- 37 files changed, 606 insertions(+), 1972 deletions(-) delete mode 100644 execution_engine/omop/criterion/combination/combination.py delete mode 100644 execution_engine/omop/criterion/combination/logical.py rename execution_engine/util/{cohort_logic.py => logic.py} (82%) diff --git a/.flake8 b/.flake8 index 0314f154..21e3b235 100644 --- a/.flake8 +++ b/.flake8 @@ -1,5 +1,5 @@ [flake8] -ignore = E203, E266, E501, W503, F403, F401, E402, C901, F405 +ignore = E203, E266, E501, W503, F403, F401, E402, C901, F405, E731 max-line-length = 88 max-complexity = 18 select = B,C,E,F,W,T4,B9 diff --git a/execution_engine/converter/action/abstract.py b/execution_engine/converter/action/abstract.py index b9466b8b..eaf4ae19 100644 --- a/execution_engine/converter/action/abstract.py +++ b/execution_engine/converter/action/abstract.py @@ -8,11 +8,8 @@ from execution_engine.fhir.recommendation import RecommendationPlan from execution_engine.fhir.util import get_coding from execution_engine.omop.criterion.abstract import Criterion -from execution_engine.omop.criterion.combination.logical import ( - LogicalCriterionCombination, -) from execution_engine.omop.vocabulary import AbstractVocabulary -from execution_engine.util import AbstractPrivateMethods +from execution_engine.util import AbstractPrivateMethods, logic from execution_engine.util.types import Timing from execution_engine.util.value.time import ValueCount, ValueDuration, ValuePeriod @@ -144,16 +141,16 @@ def process_timing(cls, timing: FHIRTiming) -> Timing: ) @abstractmethod - def _to_criterion(self) -> Criterion | LogicalCriterionCombination | None: + def _to_expression(self) -> logic.BaseExpr | None: """Converts this action to a Criterion.""" raise NotImplementedError() @final - def to_positive_criterion(self) -> Criterion | LogicalCriterionCombination: + def to_positive_expression(self) -> logic.BaseExpr: """ Converts this action to a criterion. """ - action = self._to_criterion() + action = self._to_expression() if action is None: assert ( @@ -161,10 +158,10 @@ def to_positive_criterion(self) -> Criterion | LogicalCriterionCombination: ), "Action without explicit criterion must have at least one goal" if self.goals: - criteria = [goal.to_criterion() for goal in self.goals] + criteria = [goal.to_expression() for goal in self.goals] if action is not None: criteria.append(action) - return LogicalCriterionCombination.And(*criteria) + return logic.And(*criteria) else: return action # type: ignore diff --git a/execution_engine/converter/action/body_positioning.py b/execution_engine/converter/action/body_positioning.py index e029b1ea..f0c5ae46 100644 --- a/execution_engine/converter/action/body_positioning.py +++ b/execution_engine/converter/action/body_positioning.py @@ -4,12 +4,9 @@ from execution_engine.converter.criterion import parse_code from execution_engine.fhir.recommendation import RecommendationPlan from execution_engine.omop.concepts import Concept -from execution_engine.omop.criterion.abstract import Criterion -from execution_engine.omop.criterion.combination.logical import ( - LogicalCriterionCombination, -) from execution_engine.omop.criterion.procedure_occurrence import ProcedureOccurrence from execution_engine.omop.vocabulary import SNOMEDCT +from execution_engine.util import logic from execution_engine.util.types import Timing @@ -48,10 +45,12 @@ def from_fhir(cls, action_def: RecommendationPlan.Action) -> Self: return cls(exclude=exclude, code=code, timing=timing) - def _to_criterion(self) -> Criterion | LogicalCriterionCombination | None: + def _to_expression(self) -> logic.Symbol: """Converts this characteristic to a Criterion.""" - return ProcedureOccurrence( - concept=self._code, - timing=self._timing, + return logic.Symbol( + criterion=ProcedureOccurrence( + concept=self._code, + timing=self._timing, + ) ) diff --git a/execution_engine/converter/action/drug_administration.py b/execution_engine/converter/action/drug_administration.py index b25011aa..c556c528 100644 --- a/execution_engine/converter/action/drug_administration.py +++ b/execution_engine/converter/action/drug_administration.py @@ -15,12 +15,6 @@ ) from execution_engine.fhir.recommendation import RecommendationPlan from execution_engine.omop.concepts import Concept -from execution_engine.omop.criterion.abstract import Criterion -from execution_engine.omop.criterion.combination.combination import CriterionCombination -from execution_engine.omop.criterion.combination.logical import ( - LogicalCriterionCombination, - NonCommutativeLogicalCriterionCombination, -) from execution_engine.omop.criterion.drug_exposure import DrugExposure from execution_engine.omop.criterion.point_in_time import PointInTimeCriterion from execution_engine.omop.vocabulary import ( @@ -28,6 +22,7 @@ VocabularyFactory, standard_vocabulary, ) +from execution_engine.util import logic from execution_engine.util.types import Dosage from execution_engine.util.value import Value, ValueNumber @@ -278,26 +273,30 @@ def filter_same_unit(cls, df: pd.DataFrame, unit: Concept) -> pd.DataFrame: return df_filtered - def _to_criterion(self) -> Criterion | LogicalCriterionCombination | None: + def _to_expression(self) -> logic.BaseExpr: """ Returns a criterion that represents this action. """ - drug_actions: list[Criterion | LogicalCriterionCombination] = [] + drug_actions: list[logic.BaseExpr] = [] if not self._dosages: # no dosages, just return the drug exposure - return DrugExposure( - ingredient_concept=self._ingredient_concept, - dose=None, - route=None, + return logic.Symbol( + criterion=DrugExposure( + ingredient_concept=self._ingredient_concept, + dose=None, + route=None, + ) ) for dosage in self._dosages: - drug_action = DrugExposure( - ingredient_concept=self._ingredient_concept, - dose=dosage["dose"], - route=dosage.get("route", None), + drug_action = logic.Symbol( + criterion=DrugExposure( + ingredient_concept=self._ingredient_concept, + dose=dosage["dose"], + route=dosage.get("route", None), + ) ) extensions = dosage.get("extensions", None) @@ -316,28 +315,29 @@ def _to_criterion(self) -> Criterion | LogicalCriterionCombination | None: f"Extension type {extension['type']} not supported yet" ) - ext_criterion = PointInTimeCriterion( - concept=extension["code"], - value=extension["value"], + ext_criterion = logic.Symbol( + criterion=PointInTimeCriterion( + concept=extension["code"], + value=extension["value"], + ) ) # A Conditional Filter returns `right` iff left is POSITIVE, otherwise it returns NEGATIVE # rational: "conditional" extensions are some conditions for dosage, such as body weight ranges. # Thus, the actual drug administration (drug_action, "right") must only be fulfilled if the # condition (ext_criterion, "left") is fulfilled. Thus, we here add this conditional filter. - comb = NonCommutativeLogicalCriterionCombination.ConditionalFilter( + comb = logic.ConditionalFilter( left=ext_criterion, right=drug_action, ) drug_actions.append(comb) - result: Criterion | CriterionCombination + result: logic.BaseExpr + if len(drug_actions) == 1: result = drug_actions[0] else: - result = LogicalCriterionCombination( - operator=LogicalCriterionCombination.Operator("OR"), - ) - result.add_all(drug_actions) + result = logic.Or(*drug_actions) + return result diff --git a/execution_engine/converter/action/procedure.py b/execution_engine/converter/action/procedure.py index 7fb827d5..c8176f13 100644 --- a/execution_engine/converter/action/procedure.py +++ b/execution_engine/converter/action/procedure.py @@ -3,13 +3,11 @@ from execution_engine.fhir.recommendation import RecommendationPlan from execution_engine.omop.concepts import Concept from execution_engine.omop.criterion.abstract import Criterion -from execution_engine.omop.criterion.combination.logical import ( - LogicalCriterionCombination, -) from execution_engine.omop.criterion.measurement import Measurement from execution_engine.omop.criterion.observation import Observation from execution_engine.omop.criterion.procedure_occurrence import ProcedureOccurrence from execution_engine.omop.vocabulary import SNOMEDCT +from execution_engine.util import logic from execution_engine.util.types import Timing @@ -57,7 +55,7 @@ def from_fhir(cls, action_def: RecommendationPlan.Action) -> AbstractAction: return cls(exclude=exclude, code=code, timing=timing) - def _to_criterion(self) -> Criterion | LogicalCriterionCombination | None: + def _to_expression(self) -> logic.Symbol: """Converts this characteristic to a Criterion.""" criterion: Criterion @@ -89,4 +87,4 @@ def _to_criterion(self) -> Criterion | LogicalCriterionCombination | None: f"Concept domain {self._code.domain_id} is not supported for {self.__class__.__name__}]" ) - return criterion + return logic.Symbol(criterion=criterion) diff --git a/execution_engine/converter/action/ventilator_management.py b/execution_engine/converter/action/ventilator_management.py index cebc7051..1fa6585d 100644 --- a/execution_engine/converter/action/ventilator_management.py +++ b/execution_engine/converter/action/ventilator_management.py @@ -2,10 +2,6 @@ from execution_engine.converter.action.abstract import AbstractAction from execution_engine.fhir.recommendation import RecommendationPlan -from execution_engine.omop.criterion.abstract import Criterion -from execution_engine.omop.criterion.combination.logical import ( - LogicalCriterionCombination, -) from execution_engine.omop.vocabulary import SNOMEDCT @@ -29,6 +25,6 @@ def from_fhir(cls, action_def: RecommendationPlan.Action) -> Self: exclude=False, ) # fixme: no way to exclude goals (e.g. "do not ventilate") - def _to_criterion(self) -> Criterion | LogicalCriterionCombination | None: + def _to_expression(self) -> None: """Converts this characteristic to a Criterion.""" return None diff --git a/execution_engine/converter/characteristic/abstract.py b/execution_engine/converter/characteristic/abstract.py index 899522ce..176d9302 100644 --- a/execution_engine/converter/characteristic/abstract.py +++ b/execution_engine/converter/characteristic/abstract.py @@ -8,6 +8,7 @@ from execution_engine.omop.concepts import Concept from execution_engine.omop.criterion.abstract import Criterion from execution_engine.omop.vocabulary import standard_vocabulary +from execution_engine.util import logic as logic from execution_engine.util.value import Value @@ -30,7 +31,7 @@ class AbstractCharacteristic(CriterionConverter, ABC): Subclasses must define the following methods: - valid: returns True if the supplied characteristic falls within the scope of the subclass - from_fhir: creates a new instance of the subclass from a FHIR EvidenceVariable.characteristic element - - to_criterion(): converts the characteristic to a Criterion + - to_expression(): converts the characteristic to a Criterion """ _criterion_class: Type[Criterion] @@ -82,6 +83,13 @@ def valid( """Checks if the given FHIR EvidenceVariable is a valid characteristic.""" raise NotImplementedError() + @abstractmethod + def to_positive_expression(self) -> logic.BaseExpr: + """ + Converts this characteristic to a positive expression (i.e. neglecting the exlude flag). + """ + raise NotImplementedError() + @staticmethod def get_standard_concept(cc: Coding) -> Concept: """ @@ -95,13 +103,3 @@ def get_concept(cc: Coding, standard: bool = True) -> Concept: Get the OMOP Standard Vocabulary standard concept for the given code in the given vocabulary. """ return standard_vocabulary.get_concept(cc.system, cc.code, standard=standard) - - @abstractmethod - def to_positive_criterion(self) -> Criterion: - """ - Converts this characteristic to a "Positive" Criterion. - - Positive criterion means that a possible excluded flag is disregarded. Instead, the exclusion - is later introduced (in the to_criterion() method) via a LogicalCriterionCombination.Not). - """ - raise NotImplementedError() diff --git a/execution_engine/converter/characteristic/codeable_concept.py b/execution_engine/converter/characteristic/codeable_concept.py index 1dac1b58..57d6bef4 100644 --- a/execution_engine/converter/characteristic/codeable_concept.py +++ b/execution_engine/converter/characteristic/codeable_concept.py @@ -6,9 +6,9 @@ from execution_engine.converter.characteristic.abstract import AbstractCharacteristic from execution_engine.fhir.util import get_coding from execution_engine.omop.concepts import Concept -from execution_engine.omop.criterion.abstract import Criterion from execution_engine.omop.criterion.concept import ConceptCriterion from execution_engine.omop.vocabulary import AbstractVocabulary +from execution_engine.util import logic class AbstractCodeableConceptCharacteristic(AbstractCharacteristic): @@ -78,10 +78,12 @@ def from_fhir( return c - def to_positive_criterion(self) -> Criterion: + def to_positive_expression(self) -> logic.BaseExpr: """Converts this characteristic to a Criterion.""" - return self._criterion_class( - concept=self.value, - value=None, - static=self._concept_value_static, + return logic.Symbol( + criterion=self._criterion_class( + concept=self.value, + value=None, + static=self._concept_value_static, + ) ) diff --git a/execution_engine/converter/characteristic/value.py b/execution_engine/converter/characteristic/value.py index 59e8b121..4c84cf58 100644 --- a/execution_engine/converter/characteristic/value.py +++ b/execution_engine/converter/characteristic/value.py @@ -6,6 +6,7 @@ from execution_engine.converter.characteristic.abstract import AbstractCharacteristic from execution_engine.converter.criterion import parse_code, parse_value from execution_engine.omop.criterion.concept import ConceptCriterion +from execution_engine.util import logic class AbstractValueCharacteristic(AbstractCharacteristic, ABC): @@ -32,10 +33,12 @@ def from_fhir( return c - def to_positive_criterion(self) -> ConceptCriterion: + def to_positive_expression(self) -> logic.Symbol: """Converts this characteristic to a Criterion.""" - return self._criterion_class( - concept=self.type, - value=self.value, - static=self._concept_value_static, + return logic.Symbol( + criterion=self._criterion_class( + concept=self.type, + value=self.value, + static=self._concept_value_static, + ) ) diff --git a/execution_engine/converter/criterion.py b/execution_engine/converter/criterion.py index a88c9c84..9cb43e57 100644 --- a/execution_engine/converter/criterion.py +++ b/execution_engine/converter/criterion.py @@ -1,6 +1,6 @@ import json from abc import ABC, abstractmethod -from typing import Tuple, Type +from typing import Tuple, Type, final from fhir.resources.codeableconcept import CodeableConcept from fhir.resources.element import Element @@ -11,11 +11,8 @@ from execution_engine.fhir.util import get_coding from execution_engine.omop.concepts import Concept -from execution_engine.omop.criterion.abstract import Criterion -from execution_engine.omop.criterion.combination.logical import ( - LogicalCriterionCombination, -) from execution_engine.omop.vocabulary import standard_vocabulary +from execution_engine.util import logic as logic from execution_engine.util.value import Value, ValueConcept, ValueNumber from execution_engine.util.value.value import ValueScalar @@ -186,22 +183,23 @@ def valid(cls, fhir_definition: Element) -> bool: raise NotImplementedError() @abstractmethod - def to_positive_criterion(self) -> Criterion | LogicalCriterionCombination: + def to_positive_expression(self) -> logic.BaseExpr: """Converts this characteristic to a Criterion or a combination of criteria but no negation.""" raise NotImplementedError() - def to_criterion(self) -> Criterion | LogicalCriterionCombination: + @final + def to_expression(self) -> logic.BaseExpr: """ - Converts this characteristic to a Criterion or a - combination of criteria. The result may be a "negative" - criterion, that is the result of to_positive_criterion wrapped - in a LogicalCriterionCombination with operator NOT. + Converts this characteristic to an expression. The result may be a "negative" + criterion, that is the result of to_positive_expression wrapped + in a logic.Not. """ - positive_criterion = self.to_positive_criterion() + positive_expression = self.to_positive_expression() + if self._exclude: - return LogicalCriterionCombination.Not(positive_criterion) + return logic.Not(positive_expression) else: - return positive_criterion + return positive_expression class CriterionConverterFactory: diff --git a/execution_engine/converter/goal/assessment_scale.py b/execution_engine/converter/goal/assessment_scale.py index dd8bfb92..ab23c898 100644 --- a/execution_engine/converter/goal/assessment_scale.py +++ b/execution_engine/converter/goal/assessment_scale.py @@ -4,9 +4,9 @@ from execution_engine.converter.criterion import parse_code_value from execution_engine.converter.goal.abstract import Goal from execution_engine.omop.concepts import Concept -from execution_engine.omop.criterion.abstract import Criterion from execution_engine.omop.criterion.measurement import Measurement from execution_engine.omop.vocabulary import SNOMEDCT +from execution_engine.util import logic from execution_engine.util.value import Value @@ -47,11 +47,13 @@ def from_fhir(cls, goal: PlanDefinitionGoal) -> "AssessmentScaleGoal": return cls(code.concept_name, exclude=False, code=code, value=value) - def to_positive_criterion(self) -> Criterion: + def to_positive_expression(self) -> logic.Symbol: """ Converts the goal to a criterion. """ - return Measurement( - concept=self._code, - value=self._value, + return logic.Symbol( + criterion=Measurement( + concept=self._code, + value=self._value, + ) ) diff --git a/execution_engine/converter/goal/laboratory_value.py b/execution_engine/converter/goal/laboratory_value.py index a1fd8119..f2e1bcc3 100644 --- a/execution_engine/converter/goal/laboratory_value.py +++ b/execution_engine/converter/goal/laboratory_value.py @@ -4,9 +4,9 @@ from execution_engine.converter.criterion import parse_code_value from execution_engine.converter.goal.abstract import Goal from execution_engine.omop.concepts import Concept -from execution_engine.omop.criterion.abstract import Criterion from execution_engine.omop.criterion.measurement import Measurement from execution_engine.omop.vocabulary import SNOMEDCT +from execution_engine.util import logic from execution_engine.util.value import Value @@ -46,11 +46,13 @@ def from_fhir(cls, goal: PlanDefinitionGoal) -> "LaboratoryValueGoal": return cls(exclude=False, code=code, value=value) - def to_positive_criterion(self) -> Criterion: + def to_positive_expression(self) -> logic.Symbol: """ Converts the goal to a criterion. """ - return Measurement( - concept=self._code, - value=self._value, + return logic.Symbol( + Measurement( + concept=self._code, + value=self._value, + ) ) diff --git a/execution_engine/converter/goal/ventilator_management.py b/execution_engine/converter/goal/ventilator_management.py index 25b08e11..e7a24001 100644 --- a/execution_engine/converter/goal/ventilator_management.py +++ b/execution_engine/converter/goal/ventilator_management.py @@ -6,12 +6,12 @@ from execution_engine.converter.criterion import parse_code, parse_value from execution_engine.converter.goal.abstract import Goal from execution_engine.omop.concepts import Concept -from execution_engine.omop.criterion.abstract import Criterion from execution_engine.omop.criterion.custom.tidal_volume import ( TidalVolumePerIdealBodyWeight, ) from execution_engine.omop.criterion.measurement import Measurement from execution_engine.omop.vocabulary import CODEXCELIDA, SNOMEDCT +from execution_engine.util import logic from execution_engine.util.value import Value CUSTOM_GOALS: dict[Concept, Type] = { @@ -56,18 +56,22 @@ def from_fhir(cls, goal: PlanDefinitionGoal) -> "VentilatorManagementGoal": return cls(exclude=False, code=code, value=value) - def to_positive_criterion(self) -> Criterion: + def to_positive_expression(self) -> logic.Symbol: """ Converts the goal to a criterion. """ if self._code in CUSTOM_GOALS: cls = CUSTOM_GOALS[self._code] - return cls( + return logic.Symbol( + cls( + concept=self._code, + value=self._value, + ) + ) + + return logic.Symbol( + Measurement( concept=self._code, value=self._value, ) - - return Measurement( - concept=self._code, - value=self._value, ) diff --git a/execution_engine/converter/parser/base.py b/execution_engine/converter/parser/base.py index d8ac4a7e..8f75aad1 100644 --- a/execution_engine/converter/parser/base.py +++ b/execution_engine/converter/parser/base.py @@ -5,7 +5,7 @@ from execution_engine import fhir from execution_engine.converter.criterion import CriterionConverterFactory from execution_engine.converter.temporal import TemporalIndicatorConverterFactory -from execution_engine.omop.criterion.combination.combination import CriterionCombination +from execution_engine.util import logic as logic class FhirRecommendationParserInterface(ABC): @@ -24,10 +24,9 @@ def __init__( self.time_from_event_converters = time_from_event_converters @abstractmethod - def parse_characteristics(self, ev: EvidenceVariable) -> CriterionCombination: + def parse_characteristics(self, ev: EvidenceVariable) -> logic.BooleanFunction: """ - Parses the EvidenceVariable characteristics and returns either a single Criterion - or a LogicalCriterionCombination + Parses the EvidenceVariable characteristics and returns a BooleanFunction """ raise NotImplementedError() @@ -36,7 +35,7 @@ def parse_actions( self, actions_def: list[fhir.RecommendationPlan.Action], rec_plan: fhir.RecommendationPlan, - ) -> CriterionCombination: + ) -> logic.BooleanFunction: """ Parses the actions of a Recommendation (PlanDefinition) and returns a list of Action objects and the corresponding action selection behavior. diff --git a/execution_engine/converter/parser/fhir_parser_v1.py b/execution_engine/converter/parser/fhir_parser_v1.py index bd4bd277..43a145f8 100644 --- a/execution_engine/converter/parser/fhir_parser_v1.py +++ b/execution_engine/converter/parser/fhir_parser_v1.py @@ -1,4 +1,4 @@ -from typing import Union, cast +from typing import Callable, Type, cast from fhir.resources.evidencevariable import ( EvidenceVariable, @@ -12,33 +12,37 @@ from execution_engine.converter.characteristic.abstract import AbstractCharacteristic from execution_engine.converter.goal.abstract import Goal from execution_engine.converter.parser.base import FhirRecommendationParserInterface -from execution_engine.omop.criterion.abstract import AbstractCriterion, Criterion -from execution_engine.omop.criterion.combination.combination import CriterionCombination -from execution_engine.omop.criterion.combination.logical import ( - LogicalCriterionCombination, -) +from execution_engine.util import logic as logic -def characteristic_code_to_criterion_combination_operator( +def characteristic_code_to_expression( code: str, threshold: int | None = None -) -> LogicalCriterionCombination.Operator: +) -> Type[logic.BooleanFunction] | Callable: """ Convert a characteristic combination code (from FHIR) to a criterion combination operator (for OMOP). """ - mapping = { - "all-of": LogicalCriterionCombination.Operator.AND, - "any-of": LogicalCriterionCombination.Operator.OR, - "at-least": LogicalCriterionCombination.Operator.AT_LEAST, - "at-most": LogicalCriterionCombination.Operator.AT_MOST, - "exactly": LogicalCriterionCombination.Operator.EXACTLY, - "all-or-none": LogicalCriterionCombination.Operator.ALL_OR_NONE, + simple_ops = { + "all-of": logic.And, + "any-of": logic.Or, + "all-or-none": logic.AllOrNone, + } + + count_ops = { + "at-least": logic.MinCount, + "at-most": logic.MaxCount, + "exactly": logic.ExactCount, + "all-or-none": logic.AllOrNone, } - if code not in mapping: - raise NotImplementedError(f"Unknown combination code: {code}") - return LogicalCriterionCombination.Operator( - operator=mapping[code], threshold=threshold - ) + if code in simple_ops: + return simple_ops[code] + + if code in count_ops: + if threshold is None: + raise ValueError(f"Threshold must be set for operator {code}") + return lambda *args, category: count_ops[code](*args, threshold=threshold) + + raise NotImplementedError(f'Combination "{code}" not implemented') class FhirRecommendationParserV1(FhirRecommendationParserInterface): @@ -55,17 +59,17 @@ class FhirRecommendationParserV1(FhirRecommendationParserInterface): def parse_time_from_event( self, tfes: list[EvidenceVariableCharacteristicTimeFromEvent], - combo: CriterionCombination, - ) -> CriterionCombination: + combo: logic.BooleanFunction, + ) -> logic.BooleanFunction: """ - Parses the timeFromEvent elements and updates the CriterionCombination. + Parses the timeFromEvent elements and updates the logic.BooleanFunction. Root element timeFromEvent specifies the time within which the population is valid. Non-root elements act as filters for the criteria they are attached to, meaning only criteria within the timeFromEvent timeframe will be observed to determine if the patient is part of the population. Args: tfes (list[EvidenceVariableCharacteristicTimeFromEvent]): List of timeFromEvent elements. - combo (CriterionCombination): The criterion combination to update. + combo (logic.BooleanFunction): The criterion combination to update. Returns: TemporalIndicatorCombination: Updated criterion combination. @@ -79,31 +83,31 @@ def parse_time_from_event( new_combo = converter.to_temporal_combination(combo) - if not isinstance(new_combo, CriterionCombination): - raise ValueError(f"Expected CriterionCombination, got {type(new_combo)}") + if not isinstance(new_combo, logic.BooleanFunction): + raise ValueError(f"Expected BooleanFunction, got {type(new_combo)}") return new_combo - def parse_characteristics(self, ev: EvidenceVariable) -> CriterionCombination: + def parse_characteristics(self, ev: EvidenceVariable) -> logic.BooleanFunction: """ - Parses the EvidenceVariable characteristics and returns either a single Criterion - or a CriterionCombination. + Parses the EvidenceVariable characteristics and returns either a BooleanFunction. Root element timeFromEvent specifies the time within which the population is valid. - Non-root elements act as filters for the criteria they are attached to, meaning only criteria within the timeFromEvent timeframe will be observed to determine if the patient is part of the population. + Non-root elements act as filters for the criteria they are attached to, meaning only criteria within the + timeFromEvent timeframe will be observed to determine if the patient is part of the population. Args: ev (EvidenceVariable): The evidence variable to parse. Returns: - CriterionCombination: The parsed criterion combination. + BooleanFunction: The parsed criterion combination. """ def build_criterion( characteristic: EvidenceVariableCharacteristic, is_root: bool - ) -> Union[Criterion, CriterionCombination]: + ) -> logic.BaseExpr: """ - Recursively build Criterion or CriterionCombination from a single + Recursively build Symbol or BooleanFunction from a single EvidenceVariableCharacteristic. Args: @@ -111,26 +115,27 @@ def build_criterion( is_root (bool): Indicates if the characteristic is the root element. Returns: - Union[Criterion, CriterionCombination]: The built criterion or criterion combination. + Union[Symbol, BooleanFunction]: The built criterion or criterion combination. """ - combo: CriterionCombination + combo: logic.BooleanFunction # If this characteristic is itself a combination if characteristic.definitionByCombination is not None: - operator = characteristic_code_to_criterion_combination_operator( + expr = characteristic_code_to_expression( characteristic.definitionByCombination.code, threshold=None, # or parse an actual threshold if needed ) - combo = LogicalCriterionCombination( - operator=operator, - ) + + children = [] for sub_char in characteristic.definitionByCombination.characteristic: - combo.add(build_criterion(sub_char, is_root=False)) + children.append(build_criterion(sub_char, is_root=False)) + + combo = expr(*children) if characteristic.exclude: - combo = LogicalCriterionCombination.Not( + combo = logic.Not( combo, ) @@ -144,10 +149,10 @@ def build_criterion( # Else it's a single characteristic converter = self.characteristics_converters.get(characteristic) converter = cast(AbstractCharacteristic, converter) - crit = converter.to_criterion() + crit = converter.to_expression() - if not isinstance(crit, AbstractCriterion): - raise ValueError(f"Expected AbstractCriterion, got {type(crit)}") + if not isinstance(crit, logic.BaseExpr): + raise ValueError(f"Expected BaseExpr, got {type(crit)}") return crit @@ -158,17 +163,16 @@ def build_criterion( and ev.characteristic[0].definitionByCombination is not None ): combo = build_criterion(ev.characteristic[0], is_root=True) - assert isinstance(combo, CriterionCombination) + assert isinstance(combo, logic.BooleanFunction) return combo # Otherwise, gather them under an ALL_OF (AND) combination by default: - combo = LogicalCriterionCombination( - operator=LogicalCriterionCombination.Operator( - LogicalCriterionCombination.Operator.AND - ), - ) + children = [] + for c in ev.characteristic: - combo.add(build_criterion(c, is_root=True)) + children.append(build_criterion(c, is_root=True)) + + combo = logic.And(*children) return combo @@ -176,7 +180,7 @@ def parse_actions( self, actions_def: list[fhir.RecommendationPlan.Action], rec_plan: fhir.RecommendationPlan, - ) -> CriterionCombination: + ) -> logic.BooleanFunction: """ Parses the actions of a Recommendation (PlanDefinition) and returns a list of Action objects and the corresponding action selection behavior. @@ -185,9 +189,9 @@ def parse_actions( def action_to_combination( actions_def: list[fhir.RecommendationPlan.Action], parent: fhir.RecommendationPlan | fhir.RecommendationPlan.Action, - ) -> CriterionCombination: + ) -> logic.BooleanFunction: # loop through PlanDefinition.action elements and find the corresponding Action object (by action.code) - actions: list[Criterion | CriterionCombination] = [] + actions: list[logic.BaseExpr] = [] for action_def in actions_def: # check if is combination of actions @@ -208,23 +212,25 @@ def action_to_combination( goal = cast(Goal, goal) action_conv.goals.append(goal) - actions.append(action_conv.to_criterion()) + actions.append(action_conv.to_expression()) - action_combination = self.parse_action_combination_method(parent.fhir()) + action_combination_expr = self.parse_action_combination_method( + parent.fhir() + ) for action_criterion in actions: - if not isinstance(action_criterion, (CriterionCombination, Criterion)): + if not isinstance( + action_criterion, (logic.Symbol, logic.BooleanFunction) + ): raise ValueError(f"Invalid action type: {type(action_criterion)}") - action_combination.add(action_criterion) - - return action_combination + return action_combination_expr(*actions) return action_to_combination(actions_def, rec_plan) def parse_action_combination_method( self, action_parent: PlanDefinition | PlanDefinitionAction - ) -> CriterionCombination: + ) -> Type[logic.BooleanFunction] | Callable: """ Get the correct action combination based on the action selection behavior. """ @@ -237,22 +243,22 @@ def parse_action_combination_method( selection_behavior = selection_behaviors[0] + expr: Type[logic.BooleanFunction] | Callable + match selection_behavior: case "any" | "one-or-more": - operator = LogicalCriterionCombination.Operator("OR") + expr = logic.Or case "all": - operator = LogicalCriterionCombination.Operator("AND") + expr = logic.And case "all-or-none": - operator = LogicalCriterionCombination.Operator("ALL_OR_NONE") + expr = logic.AllOrNone case "exactly-one": - operator = LogicalCriterionCombination.Operator("EXACTLY", threshold=1) + expr = lambda *args: logic.ExactCount(*args, threshold=1) case "at-most-one": - operator = LogicalCriterionCombination.Operator("AT_MOST", threshold=1) + expr = lambda *args: logic.MaxCount(*args, threshold=1) case _: raise NotImplementedError( f"Selection behavior {selection_behavior} not implemented." ) - return LogicalCriterionCombination( - operator=operator, - ) + return expr diff --git a/execution_engine/converter/parser/fhir_parser_v2.py b/execution_engine/converter/parser/fhir_parser_v2.py index 83d9c41c..86eb0b61 100644 --- a/execution_engine/converter/parser/fhir_parser_v2.py +++ b/execution_engine/converter/parser/fhir_parser_v2.py @@ -1,3 +1,5 @@ +from typing import Callable, Type + from fhir.resources.codeableconcept import CodeableConcept from fhir.resources.plandefinition import PlanDefinition, PlanDefinitionAction @@ -5,9 +7,7 @@ from execution_engine.converter.criterion import get_extension_by_url from execution_engine.converter.parser.fhir_parser_v1 import FhirRecommendationParserV1 from execution_engine.fhir.util import get_coding -from execution_engine.omop.criterion.combination.logical import ( - LogicalCriterionCombination, -) +from execution_engine.util import logic as logic class FhirRecommendationParserV2(FhirRecommendationParserV1): @@ -20,7 +20,7 @@ class FhirRecommendationParserV2(FhirRecommendationParserV1): def parse_action_combination_method( self, action_parent: PlanDefinition | PlanDefinitionAction - ) -> LogicalCriterionCombination: + ) -> Type[logic.BooleanFunction] | Callable: """ Parses the action combination method from an extension to a PlanDefinition or PlanDefinitionAction. """ @@ -41,28 +41,22 @@ def parse_action_combination_method( except ValueError: threshold = None + expr: Type[logic.BooleanFunction] | Callable + match method_code: case "all": - operator = LogicalCriterionCombination.Operator("AND") + expr = logic.And case "any": - operator = LogicalCriterionCombination.Operator("OR") + expr = logic.Or case "at-most": - operator = LogicalCriterionCombination.Operator( - "AT_MOST", threshold=threshold - ) + expr = lambda *args: logic.MaxCount(*args, threshold=threshold) case "exactly": - operator = LogicalCriterionCombination.Operator( - "EXACTLY", threshold=threshold - ) + expr = lambda *args: logic.ExactCount(*args, threshold=threshold) case "at-least": - operator = LogicalCriterionCombination.Operator( - "AT_LEAST", threshold=threshold - ) + expr = lambda *args: logic.MinCount(*args, threshold=threshold) case "one-or-more": - operator = LogicalCriterionCombination.Operator("AT_LEAST", threshold=1) + expr = lambda *args: logic.MinCount(*args, threshold=1) case _: raise ValueError(f"Invalid action combination method: {method_code}") - return LogicalCriterionCombination( - operator=operator, - ) + return expr diff --git a/execution_engine/converter/recommendation_factory.py b/execution_engine/converter/recommendation_factory.py index 15f95372..e80a9ce1 100644 --- a/execution_engine/converter/recommendation_factory.py +++ b/execution_engine/converter/recommendation_factory.py @@ -3,8 +3,11 @@ from execution_engine.converter.parser.factory import FhirRecommendationParserFactory from execution_engine.fhir.client import FHIRClient from execution_engine.omop import cohort -from execution_engine.omop.cohort import PopulationInterventionPair +from execution_engine.omop.cohort.population_intervention_pair import ( + PopulationInterventionPairExpr, +) from execution_engine.omop.criterion.visit_occurrence import PatientsActiveDuringPeriod +from execution_engine.util import logic as logic class FhirToRecommendationFactory: @@ -57,29 +60,30 @@ def parse_recommendation_from_url( fhir_connector=fhir_client, ) - pi_pairs: list[PopulationInterventionPair] = [] + pi_pairs: list[PopulationInterventionPairExpr] = [] base_criterion = PatientsActiveDuringPeriod() for rec_plan in rec.plans(): - pi_pair = PopulationInterventionPair( - name=rec_plan.name, - url=rec_plan.url, - base_criterion=base_criterion, - ) # parse population and create criteria population_criteria = parser.parse_characteristics(rec_plan.population) - pi_pair.set_population(population_criteria) # parse intervention and create criteria actions = parser.parse_actions(rec_plan.actions, rec_plan) - pi_pair.add_intervention(actions) + + pi_pair = PopulationInterventionPairExpr( + population_expr=population_criteria, + intervention_expr=actions, + name=rec_plan.name, + url=rec_plan.url, + base_criterion=base_criterion, + ) pi_pairs.append(pi_pair) recommendation = cohort.Recommendation( - pi_pairs, + expr=logic.Or(*pi_pairs), base_criterion=base_criterion, url=rec.url, name=rec.name, diff --git a/execution_engine/converter/time_from_event/abstract.py b/execution_engine/converter/time_from_event/abstract.py index 566f4f7c..5163b48f 100644 --- a/execution_engine/converter/time_from_event/abstract.py +++ b/execution_engine/converter/time_from_event/abstract.py @@ -7,39 +7,32 @@ from execution_engine.converter.criterion import parse_value from execution_engine.fhir.util import get_coding from execution_engine.omop.criterion.abstract import Criterion -from execution_engine.omop.criterion.combination.combination import CriterionCombination -from execution_engine.omop.criterion.combination.logical import ( - LogicalCriterionCombination, -) -from execution_engine.omop.criterion.combination.temporal import ( - TemporalIndicatorCombination, -) from execution_engine.omop.vocabulary import AbstractVocabulary +from execution_engine.util import logic from execution_engine.util.value import ValueNumeric def _wrap_criteria_with_factory( - combo: CriterionCombination, - factory: Callable[[Criterion | CriterionCombination], TemporalIndicatorCombination], -) -> CriterionCombination: + combo: logic.BooleanFunction, + factory: Callable[[logic.BaseExpr], logic.TemporalCount], +) -> logic.BooleanFunction: """ Recursively wraps all Criterion instances within a combination using the factory. """ - # Create a new combination of the same type with the same operator - new_combo = combo.__class__(operator=combo.operator) - + children = [] # Loop through all elements - for element in combo: - if isinstance(element, LogicalCriterionCombination): + for element in combo.args: + if isinstance(element, logic.BooleanFunction): # Recursively wrap nested combinations - new_combo.add(_wrap_criteria_with_factory(element, factory)) + children.append(_wrap_criteria_with_factory(element, factory)) elif isinstance(element, Criterion): # Wrap individual criteria with the factory - new_combo.add(factory(element)) + children.append(factory(element)) else: raise ValueError(f"Unexpected element type: {type(element)}") - return new_combo + # Create a new combination of the same type with the same operator + return combo.__class__(children) class TemporalIndicator(ABC): @@ -62,9 +55,7 @@ def valid(cls, fhir: Element) -> bool: raise NotImplementedError("must be implemented by class") @abstractmethod - def to_temporal_combination( - self, combo: Criterion | CriterionCombination - ) -> CriterionCombination: + def to_temporal_combination(self, combo: logic.BaseExpr) -> logic.TemporalCount: """ Wraps Criterion/CriterionCombinaion with a TemporalIndicatorCombination """ @@ -131,9 +122,7 @@ def valid(cls, fhir: Element) -> bool: return cls._event_vocabulary.is_system(cc.system) and cc.code == cls._event_code @abstractmethod - def to_temporal_combination( - self, combo: Criterion | CriterionCombination - ) -> CriterionCombination: + def to_temporal_combination(self, combo: logic.BaseExpr) -> logic.TemporalCount: """ Wraps Criterion/CriterionCombinaion with a TemporalIndicatorCombination """ diff --git a/execution_engine/execution_engine.py b/execution_engine/execution_engine.py index 5af0d781..919c6dac 100644 --- a/execution_engine/execution_engine.py +++ b/execution_engine/execution_engine.py @@ -14,7 +14,7 @@ FhirToRecommendationFactory, ) from execution_engine.omop import cohort -from execution_engine.omop.cohort import PopulationInterventionPair +from execution_engine.omop.cohort import PopulationInterventionPairExpr from execution_engine.omop.criterion.abstract import Criterion from execution_engine.omop.db.celida import tables as result_db from execution_engine.omop.serializable import Serializable @@ -208,8 +208,8 @@ def load_recommendation_from_database( pi_pair, rec_db.recommendation_id ) - for criterion in pi_pair.flatten(): - self.register_criterion(criterion) + for criterion in recommendation.flatten(): + self.register_criterion(criterion) # All objects in the deserialized object graph must have an id. assert recommendation.id is not None @@ -219,8 +219,8 @@ def load_recommendation_from_database( for pi_pair in recommendation.population_intervention_pairs(): assert pi_pair.id is not None - for criterion in pi_pair.flatten(): - assert criterion.id is not None + for criterion in recommendation.flatten(): + assert criterion.id is not None return recommendation @@ -293,8 +293,9 @@ def register_recommendation(self, recommendation: cohort.Recommendation) -> None self.register_population_intervention_pair( pi_pair, recommendation_id=recommendation.id ) - for criterion in pi_pair.flatten(): - self.register_criterion(criterion) + + for criterion in recommendation.flatten(): + self.register_criterion(criterion) assert recommendation.id is not None assert recommendation.base_criterion is not None @@ -303,8 +304,8 @@ def register_recommendation(self, recommendation: cohort.Recommendation) -> None for pi_pair in recommendation.population_intervention_pairs(): assert pi_pair.id is not None - for criterion in pi_pair.flatten(): - assert criterion.id is not None + for criterion in recommendation.flatten(): + assert criterion.id is not None # Update the recommendation in the database with the final # JSON representation and execution graph (now that @@ -329,7 +330,7 @@ def register_recommendation(self, recommendation: cohort.Recommendation) -> None con.execute(update_query) def register_population_intervention_pair( - self, pi_pair: PopulationInterventionPair, recommendation_id: int + self, pi_pair: PopulationInterventionPairExpr, recommendation_id: int ) -> None: """ Registers the Population/Intervention Pair in the result database. diff --git a/execution_engine/execution_graph/graph.py b/execution_engine/execution_graph/graph.py index c012a767..53893d46 100644 --- a/execution_engine/execution_graph/graph.py +++ b/execution_engine/execution_graph/graph.py @@ -1,20 +1,11 @@ -from typing import Any, Callable, Type +import copy +from typing import Any import networkx as nx -import execution_engine.util.cohort_logic as logic +import execution_engine.util.logic as logic from execution_engine.constants import CohortCategory from execution_engine.omop.criterion.abstract import Criterion -from execution_engine.omop.criterion.combination.combination import CriterionCombination -from execution_engine.omop.criterion.combination.logical import ( - LogicalCriterionCombination, - NonCommutativeLogicalCriterionCombination, -) -from execution_engine.omop.criterion.combination.temporal import ( - FixedWindowTemporalIndicatorCombination, - PersonalWindowTemporalIndicatorCombination, - TemporalIndicatorCombination, -) class ExecutionGraph(nx.DiGraph): @@ -28,30 +19,6 @@ def __add__(self, other: "ExecutionGraph") -> "ExecutionGraph": """ return nx.compose(self, other) - @classmethod - def from_criterion_combination( - cls, - population: CriterionCombination, - intervention: CriterionCombination, - base_criterion: Criterion, - ) -> "ExecutionGraph": - """ - Create a graph from a population and intervention criterion combination. - """ - from execution_engine.omop.cohort import PopulationInterventionPair - - p = cls.combination_to_expression(population, CohortCategory.POPULATION) - i = cls.combination_to_expression(intervention, CohortCategory.INTERVENTION) - i_filtered = PopulationInterventionPair.filter_symbols(i, p) - - pi = logic.LeftDependentToggle( - p, - i_filtered, - category=CohortCategory.POPULATION_INTERVENTION, - ) - - return cls.from_expression(pi, base_criterion) - def is_sink_of_category( self, expr: logic.Expr, graph: "ExecutionGraph", category: CohortCategory ) -> bool: @@ -74,36 +41,43 @@ def is_sink(self, expr: logic.Expr) -> bool: return self.out_degree(expr) == 0 @classmethod - def from_expression( - cls, expr: logic.Expr, base_criterion: Criterion - ) -> "ExecutionGraph": + def filter_symbols(cls, node: logic.Expr, filter_: logic.Expr) -> logic.Expr: """ - Create a graph from a cohort query expression. + Filter (=AND-combine) all symbols by the applied filter function + + Used to filter all intervention criteria (symbols) by the population output in order to exclude + all intervention events outside the population intervals, which may otherwise interfere with corrected + determination of temporal combination, i.e. the presence of an intervention event during some time window. """ - def expression_to_graph( - expr: logic.Expr, - graph: ExecutionGraph, - parent: logic.Expr | None = None, - ) -> ExecutionGraph: - graph.add_node(expr, category=expr.category, store_result=False) + if isinstance(node, logic.Symbol): + return logic.LeftDependentToggle(left=filter_, right=node) - if expr.is_Atom: - graph.nodes[expr]["store_result"] = True - graph.add_edge(base_node, expr) + if hasattr(node, "args") and isinstance(node.args, tuple): + converted_args = [cls.filter_symbols(a, filter_) for a in node.args] - if parent is not None: - graph.add_edge(expr, parent) + if any(a is not b for a, b in zip(node.args, converted_args)): + node.args = tuple(converted_args) - for child in expr.args: - expression_to_graph(child, graph, expr) + return node - return graph + @classmethod + def from_expression( + cls, expr: logic.Expr, base_criterion: Criterion, category: CohortCategory + ) -> "ExecutionGraph": + """ + Create a graph from a cohort query expression. + """ + + from execution_engine.omop.cohort import PopulationInterventionPairExpr + + # we might make changes to the expression (e.g. filtering), so we must preserve + # the original expression from the caller + expr = copy.deepcopy(expr) graph = cls() base_node = logic.Symbol( criterion=base_criterion, - category=CohortCategory.BASE, ) graph.add_node( base_node, @@ -111,7 +85,54 @@ def expression_to_graph( store_result=True, ) - expression_to_graph(expr, graph=graph) + def traverse( + expr: logic.Expr, + parent: logic.Expr | None = None, + category: CohortCategory = category, + ) -> None: + if isinstance(expr, PopulationInterventionPairExpr): + # special case for PopulationInterventionPairExpr: + # we need explicitly set the category of the population and intervention nodes + + p, i = expr.left, expr.right + + # filter all intervention criteria by the output of the population - this is performed to filter out + # intervention events that outside of the population intervals (i.e. the time windows during which + # patients are part of the population) as otherwise events outside of the population time may be picked up + # by Temporal criteria that determine the presence of some event or condition during a specific time window. + + # the following command changes expr, i.e. we must not add expr before this command to the graph + + i = cls.filter_symbols(i, filter_=p) + + traverse(i, parent=expr, category=CohortCategory.INTERVENTION) + traverse(p, parent=expr, category=CohortCategory.POPULATION) + + graph.add_node(expr, category=category, store_result=False) + + if parent is not None: + graph.add_edge(expr, parent) + + subgraph = graph.subgraph(nx.ancestors(graph, expr) | {expr}) + + subgraph.set_sink_nodes_store(bind_params=dict(pi_pair_id=expr._id)) + + # children are already traversed + return + + graph.add_node(expr, category=category, store_result=False) + + if parent is not None: + graph.add_edge(expr, parent) + + if expr.is_Atom: + graph.nodes[expr]["store_result"] = True + graph.add_edge(base_node, expr) + else: + for child in expr.args: + traverse(child, parent=expr, category=category) + + traverse(expr, category=category) return graph @@ -149,7 +170,7 @@ def combine_from(cls, *graphs: "ExecutionGraph") -> "ExecutionGraph": return combined_graph - def sink_node(self, category: CohortCategory | None = None) -> logic.Expr: + def sink_nodes(self, category: CohortCategory | None = None) -> list[logic.Expr]: """ Get the sink node of the graph. @@ -165,12 +186,7 @@ def sink_node(self, category: CohortCategory | None = None) -> logic.Expr: if self.is_sink_of_category(node, self, category) ] - if len(sink_nodes) != 1: - raise ValueError( - f"There must be exactly one sink node, but there are {len(sink_nodes)}" - ) - - return sink_nodes[0] + return sink_nodes def to_cytoscape_dict(self) -> dict: """ @@ -357,7 +373,7 @@ def set_predecessors_store( if hops_remaining <= 0: return for predecessor in graph.predecessors(expr): - if predecessor.category == expr.category: + if self.nodes[predecessor]["category"] == self.nodes[expr]["category"]: set_predecessors_store(predecessor, graph, hops_remaining - 1) if desired_category is not None: @@ -382,153 +398,6 @@ def set_predecessors_store( for sink_node in sink_nodes_of_desired_category: set_predecessors_store(sink_node, self, hops) - @staticmethod - def combination_to_expression( - comb: CriterionCombination, category: CohortCategory - ) -> logic.Expr: - """ - Convert the CriterionCombination into an expression of And, Not, Or objects (and possibly more operators). - - :param comb: The criterion combination. - :param category: The CohortCategory of the expression. - :return: The expression. - """ - - def conjunction_from_combination( - comb: CriterionCombination, - ) -> Type[logic.BooleanFunction] | Callable: - """ - Convert the criterion's operator into a logical conjunction (Not or And or Or) - """ - if isinstance(comb, LogicalCriterionCombination): - if comb.is_root(): - # This is a hack to make the root node a non-simplifiable And node - otherwise, using the - # logic.And, the root node would be simplified to the criterion if there is only one criterion. - # The problem is that we need a non-criterion sink node of the intervention and population in order - # to store the results to the database without the criterion_id (as the result of the whole - # intervention or population of this population/intervention pair). - - if ( - comb.operator.operator - != LogicalCriterionCombination.Operator.AND - ): - raise AssertionError( - f"Invalid operator {comb.operator} for root node. Expected AND." - ) - return logic.NonSimplifiableAnd - - # Handle non-commutative combinations. - if isinstance(comb, NonCommutativeLogicalCriterionCombination): - return logic.ConditionalFilter - - op = comb.operator.operator - - # Mapping of simple logical operators. - simple_ops = { - LogicalCriterionCombination.Operator.NOT: logic.Not, - LogicalCriterionCombination.Operator.AND: logic.And, - LogicalCriterionCombination.Operator.OR: logic.Or, - LogicalCriterionCombination.Operator.ALL_OR_NONE: logic.AllOrNone, - } - if op in simple_ops: - return simple_ops[op] - - # Mapping of count-based operators. - count_ops = { - LogicalCriterionCombination.Operator.AT_LEAST: logic.MinCount, - LogicalCriterionCombination.Operator.AT_MOST: logic.MaxCount, - LogicalCriterionCombination.Operator.EXACTLY: logic.ExactCount, - LogicalCriterionCombination.Operator.CAPPED_AT_LEAST: logic.CappedMinCount, - } - if op in count_ops: - if comb.operator.threshold is None: - raise ValueError( - f"Threshold must be set for operator {comb.operator.operator}" - ) - return lambda *args, category: count_ops[op]( - *args, threshold=comb.operator.threshold, category=category - ) - - raise NotImplementedError(f'Logical combination operator "{comb.operator}" not implemented') - - ################################################################################### - elif isinstance(comb, TemporalIndicatorCombination): - - interval_criterion: logic.BaseExpr | None = None - start_time = None - end_time = None - interval_type = None - - if isinstance(comb, PersonalWindowTemporalIndicatorCombination): - - if isinstance(comb.interval_criterion, CriterionCombination): - interval_criterion = _traverse(comb.interval_criterion) - elif isinstance(comb.interval_criterion, Criterion): - interval_criterion = logic.Symbol( - comb.interval_criterion, category=category - ) - else: - raise ValueError( - f"Invalid interval criterion type: {type(comb.interval_criterion)}" - ) - - elif isinstance(comb, FixedWindowTemporalIndicatorCombination): - start_time = comb.start_time - end_time = comb.end_time - interval_type = comb.interval_type - - # Ensure a threshold is set. - if comb.operator.threshold is None: - raise ValueError( - f"Threshold must be set for operator {comb.operator.operator}" - ) - - # Map the operator to the corresponding logic function. - op_map = { - TemporalIndicatorCombination.Operator.AT_LEAST: logic.TemporalMinCount, - TemporalIndicatorCombination.Operator.AT_MOST: logic.TemporalMaxCount, - TemporalIndicatorCombination.Operator.EXACTLY: logic.TemporalExactCount, - } - op_func = op_map.get(comb.operator.operator, None) - if op_func is None: - raise NotImplementedError( - f'Temporal combination operator "{str(comb.operator)}" not implemented' - ) - - return lambda *args, category: op_func( - *args, - threshold=comb.operator.threshold, - category=category, - start_time=start_time, - end_time=end_time, - interval_type=interval_type, - interval_criterion=interval_criterion, - ) - - else: - raise ValueError(f"Invalid combination type: {type(comb)}") - - def _traverse(comb: CriterionCombination) -> logic.Expr: - """ - Traverse the criterion combination and creates a collection of logical conjunctions from it. - """ - conjunction = conjunction_from_combination(comb) - components: list[logic.Expr | logic.Symbol] = [] - - for entry in comb: - if isinstance(entry, CriterionCombination): - components.append(_traverse(entry)) - elif isinstance(entry, Criterion): - components.append(logic.Symbol(entry, category=category)) - else: - raise ValueError(f"Invalid entry type: {type(entry)}") - - return conjunction(*components, category=category) - - expression = _traverse(comb) - - return expression - def __eq__(self, other: Any) -> bool: """ Check if two graphs are equal. diff --git a/execution_engine/omop/cohort/__init__.py b/execution_engine/omop/cohort/__init__.py index 81706fdd..bc7fcb5a 100644 --- a/execution_engine/omop/cohort/__init__.py +++ b/execution_engine/omop/cohort/__init__.py @@ -1,2 +1,2 @@ -from .population_intervention_pair import PopulationInterventionPair +from .population_intervention_pair import PopulationInterventionPairExpr from .recommendation import Recommendation diff --git a/execution_engine/omop/cohort/population_intervention_pair.py b/execution_engine/omop/cohort/population_intervention_pair.py index cab375f0..bdcbb535 100644 --- a/execution_engine/omop/cohort/population_intervention_pair.py +++ b/execution_engine/omop/cohort/population_intervention_pair.py @@ -1,32 +1,15 @@ -from typing import Any, Dict, cast +from typing import cast -from sqlalchemy.sql import ( - Alias, - CompoundSelect, - Insert, - Join, - Select, - Subquery, - TableClause, -) -from sqlalchemy.sql.elements import BinaryExpression, BooleanClauseList -from sqlalchemy.sql.selectable import CTE - -import execution_engine.util.cohort_logic as logic -from execution_engine.constants import CohortCategory -from execution_engine.execution_graph import ExecutionGraph +import execution_engine.util.logic as logic from execution_engine.omop.criterion.abstract import Criterion -from execution_engine.omop.criterion.combination.combination import CriterionCombination -from execution_engine.omop.criterion.combination.logical import ( - LogicalCriterionCombination, -) -from execution_engine.omop.criterion.factory import criterion_factory from execution_engine.omop.serializable import Serializable -from execution_engine.util.sql import SelectInto -class PopulationInterventionPair(Serializable): +class PopulationInterventionPairExpr(logic.LeftDependentToggle, Serializable): """ + A logical expression that ties together a population expression and an intervention expression, + plus any extra info like name, url, base_criterion, etc. + A population/intervention pair in OMOP as a collection of separate criteria. A population/intervention pair represents an individual recommendation plan (i.e. one part of a single recommendation), @@ -39,293 +22,78 @@ class PopulationInterventionPair(Serializable): """ _name: str - _population: CriterionCombination - _intervention: CriterionCombination - - def __init__( - self, + _url: str + _base_criterion: Criterion + + def __new__( + cls, + population_expr: logic.BaseExpr, + intervention_expr: logic.BaseExpr, + *, name: str, url: str, base_criterion: Criterion, - population: CriterionCombination | None = None, - intervention: CriterionCombination | None = None, - ) -> None: - self._name = name - self._url = url - self._base_criterion = base_criterion - - self.set_criteria(CohortCategory.POPULATION, population) - self.set_criteria(CohortCategory.INTERVENTION, intervention) - - def __repr__(self) -> str: + **kwargs: dict, + ) -> "PopulationInterventionPairExpr": """ - Get the string representation of the population/intervention pair. + Create a new PopulationInterventionPairExpr object. """ - return ( - f"{self.__class__.__name__}(\n" - f" name={repr(self._name)},\n" - f" url={repr(self._url)},\n" - f" base_criterion={repr(self._base_criterion)},\n" - f" population={self._population._repr_pretty(level=1).strip()},\n" - f" intervention={self._intervention._repr_pretty(level=1).strip()}\n" - f")" - ) - - def set_criteria( - self, category: CohortCategory, criteria: CriterionCombination | None - ) -> None: - """ - Set the criteria (either population or intervention) of the population/intervention pair. - - :param category: The category of the criteria. - :param criteria: The criteria. - """ - - root_combination = LogicalCriterionCombination( - operator=LogicalCriterionCombination.Operator("AND"), - root_combination=True, + self = cast( + PopulationInterventionPairExpr, + super().__new__( + cls, left=population_expr, right=intervention_expr, **kwargs + ), ) - if criteria is not None: - root_combination.add(criteria) + self._name = name + self._url = url + self._base_criterion = base_criterion - if category == CohortCategory.POPULATION: - self._population = root_combination - elif category == CohortCategory.INTERVENTION: - self._intervention = root_combination - else: - raise ValueError(f"Invalid category {category}") + return self @property def name(self) -> str: """ - Get the name of the population/intervention pair. + The name of the population/intervention pair. """ return self._name @property def url(self) -> str: """ - Get the url of the population/intervention pair. + The URL of the population/intervention pair. """ return self._url - @classmethod - def filter_symbols(cls, node: logic.Expr, filter_: logic.Expr) -> logic.Expr: - """ - Filter (=AND-combine) all symbols by the applied filter function - - Used to filter all intervention criteria (symbols) by the population output in order to exclude - all intervention events outside the population intervals, which may otherwise interfere with corrected - determination of temporal combination, i.e. the presence of an intervention event during some time window. - """ - - if isinstance(node, logic.Symbol): - return logic.LeftDependentToggle( - left=filter_, right=node, category=CohortCategory.INTERVENTION - ) - - if hasattr(node, "args") and isinstance(node.args, tuple): - converted_args = [cls.filter_symbols(a, filter_) for a in node.args] - - if any(a is not b for a, b in zip(node.args, converted_args)): - node.args = tuple(converted_args) - - return node - - def execution_graph(self) -> ExecutionGraph: - """ - Get the execution graph for the population/intervention pair. - """ - - p = ExecutionGraph.combination_to_expression( - self._population, category=CohortCategory.POPULATION - ) - i = ExecutionGraph.combination_to_expression( - self._intervention, category=CohortCategory.INTERVENTION - ) - - # filter all intervention criteria by the output of the population - this is performed to filter out - # intervention events that outside of the population intervals (i.e. the time windows during which - # patients are part of the population) as otherwise events outside of the population time may be picked up - # by Temporal criteria that determine the presence of some event or condition during a specific time window. - i = self.filter_symbols(i, filter_=p) - - pi = logic.LeftDependentToggle( - p, - i, - category=CohortCategory.POPULATION_INTERVENTION, - ) - pi_graph = ExecutionGraph.from_expression(pi, self._base_criterion) - - if self._id is None: - raise ValueError("Population/intervention pair ID not set") - - # todo: should we supply self instead of self._id? - pi_graph.set_sink_nodes_store(bind_params=dict(pi_pair_id=self._id)) - - return pi_graph - - def set_population(self, combination: CriterionCombination) -> None: - """ - Set the population criteria. - """ - combination.set_root() - self._population = combination - - def add_population(self, criterion: Criterion | CriterionCombination) -> None: - """ - Add a criterion to the population of the population/intervention pair. - """ - self._population.add(criterion) - - def add_intervention(self, criterion: Criterion | CriterionCombination) -> None: - """ - Add a criterion to the intervention of the population/intervention pair. - """ - self._intervention.add(criterion) - - def criteria(self) -> CriterionCombination: + @property + def base_criterion(self) -> Criterion: """ - Get the criteria of the population/intervention pair. + The base criterion of the population/intervention pair. """ - """return self._criteria""" - raise NotImplementedError() + return self._base_criterion - def flatten(self) -> list[Criterion]: + def __repr__(self) -> str: """ - Retrieve all criteria in a flat list (i.e. no nested criterion combinations). - - Includes the base criterion, population and intervention. - - :return: A list of individual criteria. + Return a string representation of the object. """ - - def _traverse(comb: CriterionCombination) -> list[Criterion]: - criteria = [] - for element in comb: - if isinstance(element, CriterionCombination): - criteria += _traverse(element) - else: - criteria.append(element) - return criteria - return ( - [self._base_criterion] - + _traverse(self._population) - + _traverse(self._intervention) + f"{self.__class__.__name__}(" + f"name={self.name!r}, " + f"url={self.url!r}, " + f"base_criterion={self.base_criterion!r}, " + f"left={self.left!r}, right={self.right!r}, " ) - @staticmethod - def _assert_base_table_in_select( - sql: CompoundSelect | Select | SelectInto, base_table_out: str - ) -> None: - """ - Assert that the base table is used in the select statement. - - Joining the base table ensures that always just a subset of patients is selected, - not all. - """ - if isinstance(sql, SelectInto) or isinstance(sql, Insert): - sql = sql.select - - def _base_table_in_select(sql_select: Join | Select | Alias) -> bool: - """ - Check if the base table is used in the select statement. - """ - if isinstance(sql_select, Join): - return _base_table_in_select(sql_select.left) or _base_table_in_select( - sql_select.right - ) - elif isinstance(sql_select, Select): - return any( - _base_table_in_select(f) for f in sql_select.get_final_froms() - ) or any(_base_table_in_select(w) for w in sql_select.whereclause) - elif isinstance(sql_select, Alias): - return sql_select.original.name == base_table_out - elif isinstance(sql_select, TableClause): - return sql_select.name == base_table_out - elif isinstance(sql_select, CTE): - return _base_table_in_select(sql_select.original) - elif isinstance(sql_select, BooleanClauseList): - return any( - w.right.element.froms[0].name == base_table_out for w in sql_select - ) - elif isinstance(sql_select, BinaryExpression): - return ( - sql_select.right.element.get_final_froms()[0].name == base_table_out - ) - elif isinstance(sql_select, Subquery): - if isinstance(sql_select.original, CompoundSelect): - return all( - _base_table_in_select(s) for s in sql_select.original.selects - ) - else: - return any( - _base_table_in_select(f) - for f in sql_select.original.get_final_froms() - ) - else: - raise NotImplementedError(f"Unknown type {type(sql_select)}") - - if isinstance(sql, CompoundSelect): - assert all( - _base_table_in_select(s) for s in sql.selects - ), "Base table not used in all selects of compound select" - elif isinstance(sql, Select): - assert _base_table_in_select( - sql - ), f"Base table {base_table_out} not found in select" - else: - raise NotImplementedError(f"Unknown type {type(sql)}") - - def dict(self) -> dict[str, Any]: - """ - Get a dictionary representation of the population/intervention pair. - """ - base_criterion = self._base_criterion - population = self._population - intervention = self._intervention - return { - "name": self.name, - "url": self.url, - "base_criterion": { - "class_name": base_criterion.__class__.__name__, - "data": base_criterion.dict(), - }, - "population": { - "class_name": population.__class__.__name__, - "data": population.dict(), - }, - "intervention": { - "class_name": intervention.__class__.__name__, - "data": intervention.dict(), - }, - } - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "PopulationInterventionPair": + def get_instance_variables(self, immutable: bool = False) -> dict | tuple: """ - Create a population/intervention pair from a dictionary. + Get the instance variables of the object. """ + # Include the base class instance variables plus the ones we added: + base_vars = cast(dict, super().get_instance_variables(immutable=False)) + base_vars["name"] = self.name + base_vars["url"] = self.url + base_vars["base_criterion"] = self.base_criterion - base_criterion = criterion_factory(**data["base_criterion"]) - assert isinstance( - base_criterion, Criterion - ), "Base criterion must be a Criterion" - population = cast(CriterionCombination, criterion_factory(**data["population"])) - intervention = cast( - CriterionCombination, criterion_factory(**data["intervention"]) - ) - object = cls( - name=data["name"], - url=data["url"], - base_criterion=base_criterion, - ) - # The constructor initializes the population and intervention - # slots in a particular way, but we want to use whatever we - # have deserialized instead. This is a bit inefficient because - # we discard the values that were assigned to the two slots in - # the constructor. - object._population = population - object._intervention = intervention - return object + # Convert to tuple if needed + if immutable: + return tuple(sorted(base_vars.items())) + return base_vars diff --git a/execution_engine/omop/cohort/recommendation.py b/execution_engine/omop/cohort/recommendation.py index 4f4b88fc..a11ac7a5 100644 --- a/execution_engine/omop/cohort/recommendation.py +++ b/execution_engine/omop/cohort/recommendation.py @@ -1,6 +1,5 @@ -import itertools import re -from typing import Any, Dict, Iterator, Self +from typing import Any, Dict, Iterator, Self, cast import networkx as nx from sqlalchemy import ( @@ -15,17 +14,13 @@ select, ) -import execution_engine.util.cohort_logic as logic +import execution_engine.util.logic as logic from execution_engine.constants import CohortCategory from execution_engine.execution_graph import ExecutionGraph -from execution_engine.omop import cohort - -# ) -from execution_engine.omop.criterion.abstract import Criterion -from execution_engine.omop.criterion.combination.combination import CriterionCombination -from execution_engine.omop.criterion.combination.logical import ( - LogicalCriterionCombination, +from execution_engine.omop.cohort.population_intervention_pair import ( + PopulationInterventionPairExpr, ) +from execution_engine.omop.criterion.abstract import Criterion from execution_engine.omop.criterion.factory import criterion_factory from execution_engine.omop.db.celida.tables import ResultInterval from execution_engine.omop.serializable import Serializable @@ -44,7 +39,7 @@ class Recommendation(Serializable): def __init__( self, - pi_pairs: list[cohort.PopulationInterventionPair], + expr: logic.BooleanFunction, base_criterion: Criterion, name: str, title: str, @@ -53,7 +48,7 @@ def __init__( description: str, package_version: str | None = None, ) -> None: - self._pi_pairs: list[cohort.PopulationInterventionPair] = pi_pairs + self._expr: logic.BooleanFunction = expr self._base_criterion: Criterion = base_criterion self._name: str = name self._title: str = title @@ -67,7 +62,7 @@ def __repr__(self) -> str: Get the string representation of the recommendation. """ pi_repr = "\n".join( - [(" " + line) for line in repr(self._pi_pairs).split("\n")] + [(" " + line) for line in repr(self._expr).split("\n")] ).strip() pi_repr = ( pi_repr[0] + "\n " + pi_repr[1:-2] + pi_repr[-2] + "\n " + pi_repr[-1] @@ -111,7 +106,7 @@ def version(self) -> str: """ Get the version of the recommendation. """ - return self._version # + return self._version @property def package_version(self) -> str | None: @@ -142,68 +137,82 @@ def execution_graph(self) -> ExecutionGraph: execution maps of the individual population/intervention pairs of the recommendation. """ - p_nodes = [] - pi_nodes = [] - pi_graphs = [] + # p_sink_nodes = [] + # pi_sink_nodes = [] + + common_graph = ExecutionGraph.from_expression( + self._expr, + base_criterion=self._base_criterion, + category=CohortCategory.POPULATION_INTERVENTION, + ) - for pi_pair in self._pi_pairs: - pi_graph = pi_pair.execution_graph() + # for pi_pair in self.population_intervention_pairs(): + # subgraph: ExecutionGraph = cast(ExecutionGraph, common_graph.subgraph(nx.ancestors(common_graph, pi_pair) | {pi_pair})) + # p_sink_nodes.append(subgraph.sink_node(CohortCategory.POPULATION)) + # pi_sink_nodes.append(subgraph.sink_node(CohortCategory.POPULATION_INTERVENTION)) - p_nodes.append(pi_graph.sink_node(CohortCategory.POPULATION)) - pi_nodes.append(pi_graph.sink_node(CohortCategory.POPULATION_INTERVENTION)) - pi_graphs.append(pi_graph) + p_sink_nodes = common_graph.sink_nodes(CohortCategory.POPULATION) p_combination_node = logic.NoDataPreservingOr( - *p_nodes, category=CohortCategory.POPULATION + *common_graph.sink_nodes(CohortCategory.POPULATION) ) - pi_combination_node = logic.NoDataPreservingAnd( - *pi_nodes, category=CohortCategory.POPULATION_INTERVENTION - ) - - common_graph = nx.compose_all(pi_graphs) + # pi_combination_node = logic.NoDataPreservingAnd( + # *pi_sink_nodes, + # ) common_graph.add_node( p_combination_node, store_result=True, category=CohortCategory.POPULATION ) - common_graph.add_node( - pi_combination_node, - store_result=True, - category=CohortCategory.POPULATION_INTERVENTION, - ) + # common_graph.add_node( + # pi_combination_node, + # store_result=True, + # category=CohortCategory.POPULATION_INTERVENTION, + # ) - common_graph.add_edges_from((src, p_combination_node) for src in p_nodes) - common_graph.add_edges_from((src, pi_combination_node) for src in pi_nodes) + common_graph.add_edges_from((src, p_combination_node) for src in p_sink_nodes) + # common_graph.add_edges_from((src, pi_combination_node) for src in pi_sink_nodes) - return common_graph + import json - def criteria(self) -> CriterionCombination: - """ - Get the criteria of the recommendation. - """ - criteria = LogicalCriterionCombination( - operator=LogicalCriterionCombination.Operator("OR"), - root_combination=True, - ) + with open("/home/glichtner/cyto.json", "w") as f: + json.dump({"elements": common_graph.to_cytoscape_dict()}, f, indent=4) - for pi_pair in self._pi_pairs: - criteria.add(pi_pair.criteria()) + if not nx.is_directed_acyclic_graph(common_graph): + raise ValueError("The recommendation execution graph is not a DAG.") - return criteria + return common_graph def flatten(self) -> list[Criterion]: """ Retrieve all criteria in a flat list """ - return list(itertools.chain(*[pi_pair.flatten() for pi_pair in self._pi_pairs])) - def population_intervention_pairs( - self, - ) -> Iterator[cohort.PopulationInterventionPair]: + def traverse(expr: logic.Expr) -> list[Criterion]: + if expr.is_Atom: + assert isinstance(expr, logic.Symbol), f"Expected Symbol, got {expr}" + return [expr.criterion] + + gathered = [] + for sub_expr in expr.args: + gathered.extend(traverse(sub_expr)) + return gathered + + return [self._base_criterion] + traverse(self._expr) + + def population_intervention_pairs(self) -> Iterator[PopulationInterventionPairExpr]: """ - Iterate over the population/intervention pairs. + Iterate over all PopulationInterventionPairExpr in the expression tree. """ - yield from self._pi_pairs + + def traverse(expr: logic.Expr) -> Iterator[PopulationInterventionPairExpr]: + if isinstance(expr, PopulationInterventionPairExpr): + yield expr + else: + for sub_expr in expr.args: + yield from traverse(sub_expr) + + yield from traverse(self._expr) def __str__(self) -> str: """ @@ -211,18 +220,6 @@ def __str__(self) -> str: """ return f"Recommendation(name='{self._name}', description='{self.description}')" - def __len__(self) -> int: - """ - Get the number of population/intervention pairs. - """ - return len(self._pi_pairs) - - def __getitem__(self, index: int) -> cohort.PopulationInterventionPair: - """ - Get the population/intervention pair at the given index. - """ - return self._pi_pairs[index] - @staticmethod def to_table(name: str) -> Table: """ @@ -276,10 +273,11 @@ def reset_state(self) -> None: """ self._id = None - for pi_pair in self._pi_pairs: + for pi_pair in self.population_intervention_pairs(): pi_pair._id = None - for criterion in pi_pair.flatten(): - criterion._id = None + + for criterion in self.flatten(): + criterion._id = None def dict(self) -> dict: """ @@ -287,7 +285,7 @@ def dict(self) -> dict: """ base_criterion = self._base_criterion return { - "population_intervention_pairs": [c.dict() for c in self._pi_pairs], + "expr": self._expr.dict(), "base_criterion": { "class_name": base_criterion.__class__.__name__, "data": base_criterion.dict(), @@ -311,10 +309,9 @@ def from_dict(cls, data: Dict[str, Any]) -> Self: ), "Base criterion must be a Criterion" return cls( - pi_pairs=[ - cohort.PopulationInterventionPair.from_dict(c) - for c in data["population_intervention_pairs"] - ], + expr=cast( + logic.BooleanFunction, logic.BooleanFunction.from_dict(data["expr"]) + ), base_criterion=base_criterion, name=data["recommendation_name"], title=data["recommendation_title"], diff --git a/execution_engine/omop/criterion/abstract.py b/execution_engine/omop/criterion/abstract.py index 16627385..70c16f0e 100644 --- a/execution_engine/omop/criterion/abstract.py +++ b/execution_engine/omop/criterion/abstract.py @@ -26,6 +26,7 @@ VisitOccurrence, ) from execution_engine.omop.serializable import Serializable +from execution_engine.util import logic from execution_engine.util.interval import IntervalType from execution_engine.util.sql import SelectInto, select_into from execution_engine.util.types import PersonIntervals, TimeRange @@ -89,7 +90,9 @@ def create_conditional_interval_column(condition: ColumnElement) -> ColumnElemen ) -class AbstractCriterion(Serializable, ABC, metaclass=SignatureReprABCMeta): +class AbstractCriterion( + logic.Symbol, Serializable, ABC, metaclass=SignatureReprABCMeta +): """ Abstract base class for Criterion and CriterionCombination. """ diff --git a/execution_engine/omop/criterion/combination/combination.py b/execution_engine/omop/criterion/combination/combination.py deleted file mode 100644 index 8ef990cc..00000000 --- a/execution_engine/omop/criterion/combination/combination.py +++ /dev/null @@ -1,296 +0,0 @@ -from abc import ABCMeta -from typing import Any, Dict, Iterator, Sequence, Union, cast - -from execution_engine.omop.criterion.abstract import AbstractCriterion, Criterion - -__all__ = ["CriterionCombination"] - - -def snake_to_camel(s: str) -> str: - return "".join(word.capitalize() for word in s.lower().split("_")) - - -class CriterionCombination(AbstractCriterion, metaclass=ABCMeta): - """ - Base class for a combination of criteria (temporal or logical). - """ - - class Operator: - """ - Operators for criterion combinations. - """ - - def __init__(self, operator: str, threshold: int | None = None): - self.operator = operator - self.threshold = threshold - - def __str__(self) -> str: - """ - Get the string representation of the operator. - """ - if self.operator in ["AT_LEAST", "AT_MOST", "EXACTLY"]: - return f"{self.operator}(threshold={self.threshold})" - else: - return f"{self.operator}" - - def __repr__(self) -> str: - """ - Get the string representation of the operator. - """ - if self.operator in ["AT_LEAST", "AT_MOST", "EXACTLY"]: - return f'{self.__class__.__qualname__}(operator="{self.operator}", threshold={self.threshold})' - else: - return f'{self.__class__.__qualname__}(operator="{self.operator}")' - - def __eq__(self, other: object) -> bool: - """ - Check if the operator is equal to another operator. - """ - if not isinstance(other, CriterionCombination.Operator): - return NotImplemented - return self.operator == other.operator and self.threshold == other.threshold - - def __init__( - self, - operator: Operator, - criteria: Sequence[Union[Criterion, "CriterionCombination"]] | None = None, - root_combination: bool = False, - ): - """ - Initialize the criterion combination. - """ - super().__init__() - self._operator = operator - - self._criteria: list[Union[Criterion, CriterionCombination]] - self._root = root_combination - - if criteria is None: - self._criteria = [] - else: - self._criteria = cast( - list[Union[Criterion, "CriterionCombination"]], criteria - ) - - def add(self, criterion: Union[Criterion, "CriterionCombination"]) -> None: - """ - Add a criterion to the combination. - """ - self._criteria.append(criterion) - - def add_all( - self, criteria: Sequence[Union[Criterion, "CriterionCombination"]] - ) -> None: - """ - Add multiple criteria to the combination. - """ - self._criteria.extend(criteria) - - def __str__(self) -> str: - """ - Get the name of the criterion combination. - """ - return f"{self.__class__.__name__}({self.operator})" - - @property - def operator(self) -> "CriterionCombination.Operator": - """ - Get the operator of the criterion combination (i.e. the type of combination, e.g. AND, OR, AT_LEAST, etc.). - """ - return self._operator - - def set_root(self, value: bool = True) -> None: - """ - Sets whether this criterion combination is at the root of a tree of criteria / combinations. - """ - self._root = value - - def is_root(self) -> bool: - """ - Returns whether this criterion combination is at the root of a tree of criteria / combinations. - """ - return self._root - - def __iter__(self) -> Iterator[Union[Criterion, "CriterionCombination"]]: - """ - Iterate over the criteria in the combination. - """ - for criterion in self._criteria: - yield criterion - - def __len__(self) -> int: - """ - Get the number of criteria in the combination. - """ - return len(self._criteria) - - def __getitem__(self, index: int) -> Union[Criterion, "CriterionCombination"]: - """ - Get the criterion at the specified index. - """ - return self._criteria[index] - - def description(self) -> str: - """ - Description of this combination. - """ - return str(self) - - def dict(self) -> dict[str, Any]: - """ - Get the dictionary representation of the criterion combination. - """ - return { - "operator": self._operator.operator, - "threshold": self._operator.threshold, - "criteria": [ - { - "class_name": criterion.__class__.__name__, - "data": criterion.dict(), - } - for criterion in self._criteria - ], - "root": self._root, - } - - def __invert__(self) -> AbstractCriterion: - """ - Invert the criterion combination. - """ - # Would be cycle if imported at top-level. - from execution_engine.omop.criterion.combination.logical import ( - LogicalCriterionCombination, - ) - - if ( - isinstance(self, LogicalCriterionCombination) - and self.operator.operator == LogicalCriterionCombination.Operator.NOT - ): - return self._criteria[0] - else: - copy = self.__class__( - operator=self._operator, - criteria=self._criteria, - ) - return LogicalCriterionCombination.Not(copy) - - def invert(self) -> AbstractCriterion: - """ - Invert the criterion combination. - """ - return ~self - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "CriterionCombination": - """ - Create a criterion combination from a dictionary. - """ - - from execution_engine.omop.criterion.factory import ( - criterion_factory, # needs to be here to avoid circular import - ) - - operator = cls.Operator(data["operator"], data["threshold"]) - - combination = cls( - operator=operator, - root_combination=data["root"], - ) - - for criterion in data["criteria"]: - combination.add(criterion_factory(**criterion)) - - return combination - - def _build_repr( - self, - children: Sequence[tuple[str | None, Any]], - params: list[tuple[str, Any]], - level: int = 0, - children_param_name: str = "criteria", - children_is_sequence: bool = True, - ) -> str: - """ - Builds a multi-line string for this criterion combination, - properly indenting each level but avoiding double-indenting if - a child already handles indentation in its own _repr_pretty. - """ - indent = " " * (2 * level) - child_indent = " " * (2 * (level + 1)) - - op = snake_to_camel(self.operator.operator) - - lines: list[str] = [] - kw_lines: list[str] = [] - criteria_lines: list[str] = [] - - method_defined = hasattr(self, op) and callable(getattr(self, op)) - - if self._root: - params.append(("root_combination", self._root)) - - # if there is a specific method for this operator, use it - if method_defined: - lines.append(f"{indent}{self.__class__.__name__}.{op}(") - if self.operator.threshold is not None: - params.append(("threshold", self.operator.threshold)) - else: - lines.append(f"{indent}{self.__class__.__name__}(") - params.append(("operator", self.operator)) - - # Each child on its own line - for key, child in children: - if hasattr(child, "_repr_pretty"): - # The child already handles its own indentation and multi-line formatting. - child_repr = child._repr_pretty(level + 1) - # We'll put a comma on the last line that the child produces. - child_lines = child_repr.split("\n") - child_lines[-1] += "," # add trailing comma to the last line - if key is None: - criteria_lines.extend(child_lines) - else: - # If you have a key like "left=" or "right=", prepend that to the first line - # or handle it similarly. One approach is: - child_lines[0] = f"{child_indent}{key}={child_lines[0].lstrip()}" - criteria_lines.extend(child_lines) - else: - # Fallback to normal repr, which we indent at this level - child_repr = repr(child) - if key is None: - criteria_lines.append(child_indent + child_repr + ",") - else: - criteria_lines.append(f"{child_indent}{key}={child_repr},") - - for key, value in params: - kw_lines.append(f"{child_indent}{key}={repr(value)},") - - if method_defined: - lines.extend(criteria_lines) - lines.extend(kw_lines) - elif children_is_sequence: - lines.extend(kw_lines) - lines.append(f"{child_indent}{children_param_name}=[") - lines.extend(criteria_lines) - lines.append(f"{child_indent}],") - else: - assert len(criteria_lines) <= 1 - lines.extend(kw_lines) - if criteria_lines: - lines.append( - f"{child_indent}{children_param_name}={criteria_lines[0].strip()}" - ) - - lines.append(f"{indent})") - - return "\n".join(lines) - - def _repr_pretty(self, level: int = 0) -> str: - children = [(None, c) for c in self._criteria] - - return self._build_repr(children, params=[], level=level) - - def __repr__(self) -> str: - """ - Get the string representation of the criterion combination. - """ - return self._repr_pretty(0) diff --git a/execution_engine/omop/criterion/combination/logical.py b/execution_engine/omop/criterion/combination/logical.py deleted file mode 100644 index bb624a66..00000000 --- a/execution_engine/omop/criterion/combination/logical.py +++ /dev/null @@ -1,292 +0,0 @@ -from typing import Any, Dict, Iterator, Union - -from execution_engine.omop.criterion.abstract import Criterion -from execution_engine.omop.criterion.combination.combination import CriterionCombination - - -class LogicalCriterionCombination(CriterionCombination): - """ - A combination of criteria. - """ - - class Operator(CriterionCombination.Operator): - """Operators for criterion combinations.""" - - NOT = "NOT" - AND = "AND" - OR = "OR" - AT_LEAST = "AT_LEAST" - CAPPED_AT_LEAST = "CAPPED_AT_LEAST" - AT_MOST = "AT_MOST" - EXACTLY = "EXACTLY" - ALL_OR_NONE = "ALL_OR_NONE" - - def __init__(self, operator: str, threshold: int | None = None): - assert operator in [ - "NOT", - "AND", - "OR", - "AT_LEAST", - "CAPPED_AT_LEAST", - "AT_MOST", - "EXACTLY", - "ALL_OR_NONE", - ], f"Invalid operator {operator}" - - self.operator = operator - if operator in ["AT_LEAST", "CAPPED_AT_LEAST", "AT_MOST", "EXACTLY"]: - assert ( - threshold is not None - ), f"Threshold must be set for operator {operator}" - self.threshold = threshold - - @classmethod - def Not( - cls, - criterion: Union[Criterion, "CriterionCombination"], - ) -> "LogicalCriterionCombination": - """ - Create a NOT "combination" of a single criterion. - """ - return cls( - operator=cls.Operator(cls.Operator.NOT), - criteria=[criterion], - ) - - @classmethod - def And( - cls, - *criteria: Union[Criterion, "CriterionCombination"], - ) -> "LogicalCriterionCombination": - """ - Create an AND combination of criteria. - """ - return cls( - operator=cls.Operator(cls.Operator.AND), - criteria=criteria, - ) - - @classmethod - def Or( - cls, - *criteria: Union[Criterion, "CriterionCombination"], - ) -> "LogicalCriterionCombination": - """ - Create an OR combination of criteria. - """ - return cls( - operator=cls.Operator(cls.Operator.OR), - criteria=criteria, - ) - - @classmethod - def AtLeast( - cls, - *criteria: Union[Criterion, "CriterionCombination"], - threshold: int, - ) -> "LogicalCriterionCombination": - """ - Create an AT_LEAST combination of criteria. - """ - return cls( - operator=cls.Operator(cls.Operator.AT_LEAST, threshold=threshold), - criteria=criteria, - ) - - @classmethod - def CappedAtLeast( - cls, - *criteria: Union[Criterion, "CriterionCombination"], - threshold: int, - ) -> "LogicalCriterionCombination": - """ - Create an CAPPED_AT_LEAST combination of criteria. - """ - return cls( - operator=cls.Operator(cls.Operator.CAPPED_AT_LEAST, threshold=threshold), - criteria=criteria, - ) - - @classmethod - def AtMost( - cls, - *criteria: Union[Criterion, "CriterionCombination"], - threshold: int, - ) -> "LogicalCriterionCombination": - """ - Create an AT_MOST combination of criteria. - """ - return cls( - operator=cls.Operator(cls.Operator.AT_MOST, threshold=threshold), - criteria=criteria, - ) - - @classmethod - def Exactly( - cls, - *criteria: Union[Criterion, "LogicalCriterionCombination"], - threshold: int, - ) -> "LogicalCriterionCombination": - """ - Create an EXACTLY combination of criteria. - """ - return cls( - operator=cls.Operator(cls.Operator.EXACTLY, threshold=threshold), - criteria=criteria, - ) - - @classmethod - def AllOrNone( - cls, - *criteria: Union[Criterion, "LogicalCriterionCombination"], - ) -> "LogicalCriterionCombination": - """ - Create an ALL_OR_NONE combination of criteria. - """ - return cls( - operator=cls.Operator(cls.Operator.ALL_OR_NONE), - criteria=criteria, - ) - - -class NonCommutativeLogicalCriterionCombination(LogicalCriterionCombination): - """ - A combination of criteria that is not commutative. - """ - - _left: Union[Criterion, CriterionCombination] - _right: Union[Criterion, CriterionCombination] - - class Operator(CriterionCombination.Operator): - """Operators for criterion combinations.""" - - CONDITIONAL_FILTER = "CONDITIONAL_FILTER" - - def __init__(self, operator: str, threshold: None = None): - assert operator in [ - "CONDITIONAL_FILTER", - ], f"Invalid operator {operator}" - assert threshold is None - self.operator = operator - self.threshold = threshold - - def __init__( - self, - operator: "NonCommutativeLogicalCriterionCombination.Operator", - left: Union[Criterion, CriterionCombination] | None = None, - right: Union[Criterion, CriterionCombination] | None = None, - root_combination: bool = False, - ): - """ - Initialize the criterion combination. - """ - super().__init__(operator=operator) - - self._criteria = [] - if left is not None: - self._left = left - if right is not None: - self._right = right - - self._root = root_combination - - @property - def left(self) -> Union[Criterion, CriterionCombination]: - """ - Get the left criterion. - """ - return self._left - - @property - def right(self) -> Union[Criterion, CriterionCombination]: - """ - Get the right criterion. - """ - return self._right - - def __str__(self) -> str: - """ - Get the string representation of the criterion combination. - """ - return f"{self.operator}({', '.join(str(c) for c in self._criteria)})" - - def __eq__(self, other: object) -> bool: - """ - Check if the criterion combination is equal to another criterion combination. - """ - if not isinstance(other, NonCommutativeLogicalCriterionCombination): - return NotImplemented - return ( - self.operator == other.operator - and self._left == other._left - and self._right == other._right - ) - - def __iter__(self) -> Iterator[Union[Criterion, "CriterionCombination"]]: - """ - Iterate over the criteria in the combination. - """ - yield self._left - yield self._right - - def dict(self) -> dict: - """ - Get the dictionary representation of the criterion combination. - """ - left = self._left - right = self._right - return { - "operator": self._operator.operator, - "left": {"class_name": left.__class__.__name__, "data": left.dict()}, - "right": {"class_name": right.__class__.__name__, "data": right.dict()}, - } - - def _repr_pretty(self, level: int = 0) -> str: - children = [ - ("left", self._left), - ("right", self._right), - ] - return self._build_repr(children, params=[], level=level) - - @classmethod - def from_dict( - cls, data: Dict[str, Any] - ) -> "NonCommutativeLogicalCriterionCombination": - """ - Create a criterion combination from a dictionary. - """ - - from execution_engine.omop.criterion.factory import ( - criterion_factory, # needs to be here to avoid circular import - ) - - return cls( - operator=cls.Operator(data["operator"]), - left=criterion_factory(**data["left"]), - right=criterion_factory(**data["right"]), - ) - - @classmethod - def ConditionalFilter( - cls, - left: Union[Criterion, "CriterionCombination"], - right: Union[Criterion, "CriterionCombination"], - ) -> "LogicalCriterionCombination": - """ - Create a CONDITIONAL_FILTER combination of criteria. - - A conditional filter returns `right` iff `left` is POSITIVE, otherwise NEGATIVE. - - | left | right | Result | - |----------|----------|----------| - | NEGATIVE | * | NEGATIVE | - | NO_DATA | * | NEGATIVE | - | POSITIVE | POSITIVE | POSITIVE | - | POSITIVE | NEGATIVE | NEGATIVE | - | POSITIVE | NO_DATA | NO_DATA | - """ - return cls( - operator=cls.Operator(cls.Operator.CONDITIONAL_FILTER), - left=left, - right=right, - ) diff --git a/execution_engine/omop/criterion/combination/temporal.py b/execution_engine/omop/criterion/combination/temporal.py index ca9f6930..9988ba1a 100644 --- a/execution_engine/omop/criterion/combination/temporal.py +++ b/execution_engine/omop/criterion/combination/temporal.py @@ -1,10 +1,4 @@ -from abc import ABC -from datetime import time from enum import StrEnum -from typing import Any, Dict, Iterator, Union - -from execution_engine.omop.criterion.abstract import Criterion -from execution_engine.omop.criterion.combination.combination import CriterionCombination class TimeIntervalType(StrEnum): @@ -25,487 +19,3 @@ def __repr__(self) -> str: Get the string representation of the time interval type. """ return f"{self.__class__.__name__}.{self.name}" - - -class TemporalIndicatorCombination(CriterionCombination, ABC): - """ - TemporalIndicatorCombination is an abstract base class for constructing temporal indicator - combinations used to evaluate patient data over time. It encapsulates the common logic such as - operator definitions (e.g., AT_LEAST, AT_MOST, EXACTLY) and the management of one or more criteria. - This class serves as the foundation for more specialized implementations that define how the time - windows for evaluation are determined. - """ - - class Operator(CriterionCombination.Operator): - """Operators for criterion combinations.""" - - AT_LEAST = "AT_LEAST" - AT_MOST = "AT_MOST" - EXACTLY = "EXACTLY" - - def __init__(self, operator: str, threshold: int | None = None): - assert operator in [ - "AT_LEAST", - "AT_MOST", - "EXACTLY", - ], f"Invalid operator {operator}" - - self.operator = operator - if operator in ["AT_LEAST", "AT_MOST", "EXACTLY"]: - assert ( - threshold is not None - ), f"Threshold must be set for operator {operator}" - self.threshold = threshold - - def __init__( - self, - operator: Operator, - criterion: Union[Criterion, "CriterionCombination"] | None = None, - ): - if criterion is not None: - if not isinstance(criterion, (Criterion, CriterionCombination)): - raise ValueError( - f"Invalid criterion - expected Criterion or CriterionCombination, got {type(criterion)}" - ) - criteria = [criterion] - else: - criteria = None - - super().__init__(operator=operator, criteria=criteria) - - def _repr_pretty(self, level: int = 0) -> str: - children = [(None, c) for c in self._criteria] - - return self._build_repr( - children, - params=[], - level=level, - children_param_name="criterion", - children_is_sequence=False, - ) - - -class FixedWindowTemporalIndicatorCombination(TemporalIndicatorCombination): - """ - FixedWindowTemporalIndicatorCombination implements a temporal indicator combination that relies on - fixed time window specifications. It supports two mutually exclusive methods for defining these windows: - either via a pre-defined TimeIntervalType (e.g., morning, afternoon, or night shifts) or through explicit - start_time and end_time values. This class is intended for scenarios where the same evaluation window - applies uniformly across all patients, and it enforces validation to ensure only one method of window - specification is used. - """ - - interval_type: TimeIntervalType | None = None - start_time: time | None = None - end_time: time | None = None - - def __init__( - self, - operator: TemporalIndicatorCombination.Operator, - criterion: Union[Criterion, "CriterionCombination"] | None = None, - interval_type: TimeIntervalType | None = None, - start_time: time | None = None, - end_time: time | None = None, - ): - super().__init__(operator=operator, criterion=criterion) - - if interval_type: - if start_time is not None or end_time is not None: - raise ValueError( - "start_time/end_time cannot be used together with interval_type" - ) - # Validate the interval_type if needed - self.interval_type = interval_type - self.start_time = None - self.end_time = None - else: - # Must have start_time and end_time - if start_time is None or end_time is None: - raise ValueError( - "Either interval_type OR both start_time & end_time must be provided" - ) - if start_time >= end_time: - raise ValueError("start_time must be less than end_time") - - self.interval_type = interval_type - self.start_time = start_time - self.end_time = end_time - - def __str__(self) -> str: - """ - Get the string representation of the criterion combination. - """ - if self.interval_type: - return f"{super().__str__()} for {self.interval_type.value}" - elif self.start_time and self.end_time: - return f"{super().__str__()} from {self.start_time.strftime('%H:%M:%S')} to {self.end_time.strftime('%H:%M:%S')}" - else: - return super().__str__() - - def _repr_pretty(self, level: int = 0) -> str: - children = [(None, c) for c in self._criteria] - params = [ - ("interval_type", self.interval_type), - ("start_time", self.start_time), - ("end_time", self.end_time), - ] - return self._build_repr( - children, - params=params, - level=level, - children_param_name="criterion", - children_is_sequence=False, - ) - - def dict(self) -> Dict: - """ - Get the dictionary representation of the criterion combination. - """ - - d = super().dict() - d["start_time"] = self.start_time.isoformat() if self.start_time else None - d["end_time"] = self.end_time.isoformat() if self.end_time else None - d["interval_type"] = self.interval_type - - return d - - @classmethod - def from_dict( - cls, data: Dict[str, Any] - ) -> "FixedWindowTemporalIndicatorCombination": - """ - Create a criterion combination from a dictionary. - """ - - from execution_engine.omop.criterion.factory import ( - criterion_factory, # needs to be here to avoid circular import - ) - - operator = cls.Operator(data["operator"], data["threshold"]) - - combination = cls( - operator=operator, - interval_type=data["interval_type"], - # start_time and end_time is in isoformat ! - start_time=( - time.fromisoformat(data["start_time"]) if data["start_time"] else None - ), - end_time=time.fromisoformat(data["end_time"]) if data["end_time"] else None, - ) - - for criterion in data["criteria"]: - combination.add(criterion_factory(**criterion)) - - return combination - - @classmethod - def Presence( - cls, - criterion: Union[Criterion, "CriterionCombination"], - interval_type: TimeIntervalType | None = None, - start_time: time | None = None, - end_time: time | None = None, - ) -> "FixedWindowTemporalIndicatorCombination": - """ - Create an AT_LEAST combination of criteria. - """ - return cls( - operator=cls.Operator(cls.Operator.AT_LEAST, threshold=1), - criterion=criterion, - interval_type=interval_type, - start_time=start_time, - end_time=end_time, - ) - - @classmethod - def MinCount( - cls, - criterion: Union[Criterion, "CriterionCombination"], - threshold: int, - interval_type: TimeIntervalType | None = None, - start_time: time | None = None, - end_time: time | None = None, - ) -> "FixedWindowTemporalIndicatorCombination": - """ - Create an AT_LEAST combination of criteria. - """ - return cls( - operator=cls.Operator(cls.Operator.AT_LEAST, threshold=threshold), - criterion=criterion, - interval_type=interval_type, - start_time=start_time, - end_time=end_time, - ) - - @classmethod - def MaxCount( - cls, - criterion: Union[Criterion, "CriterionCombination"], - threshold: int, - interval_type: TimeIntervalType | None = None, - start_time: time | None = None, - end_time: time | None = None, - interval_criterion: Criterion | CriterionCombination | None = None, - ) -> "FixedWindowTemporalIndicatorCombination": - """ - Create an AT_MOST combination of criteria. - """ - return cls( - operator=cls.Operator(cls.Operator.AT_MOST, threshold=threshold), - criterion=criterion, - interval_type=interval_type, - start_time=start_time, - end_time=end_time, - ) - - @classmethod - def ExactCount( - cls, - criterion: Union[Criterion, "CriterionCombination"], - threshold: int, - interval_type: TimeIntervalType | None = None, - start_time: time | None = None, - end_time: time | None = None, - ) -> "FixedWindowTemporalIndicatorCombination": - """ - Create an EXACTLY combination of criteria. - """ - return cls( - operator=cls.Operator(cls.Operator.EXACTLY, threshold=threshold), - criterion=criterion, - interval_type=interval_type, - start_time=start_time, - end_time=end_time, - ) - - @classmethod - def MorningShift( - cls, - criterion: Union[Criterion, "CriterionCombination"], - ) -> "FixedWindowTemporalIndicatorCombination": - """ - Create a MorningShift combination of criteria. - """ - return cls.Presence(criterion, TimeIntervalType.MORNING_SHIFT) - - @classmethod - def AfternoonShift( - cls, - criterion: Union[Criterion, "CriterionCombination"], - ) -> "FixedWindowTemporalIndicatorCombination": - """ - Create a AfternoonShift combination of criteria. - """ - return cls.Presence(criterion, TimeIntervalType.AFTERNOON_SHIFT) - - @classmethod - def NightShift( - cls, - criterion: Union[Criterion, "CriterionCombination"], - ) -> "FixedWindowTemporalIndicatorCombination": - """ - Create a NightShift combination of criteria. - """ - return cls.Presence(criterion, TimeIntervalType.NIGHT_SHIFT) - - @classmethod - def NightShiftBeforeMidnight( - cls, - criterion: Union[Criterion, "CriterionCombination"], - ) -> "FixedWindowTemporalIndicatorCombination": - """ - Create a NightShiftBeforeMidnight combination of criteria. - """ - return cls.Presence(criterion, TimeIntervalType.NIGHT_SHIFT_BEFORE_MIDNIGHT) - - @classmethod - def NightShiftAfterMidnight( - cls, - criterion: Union[Criterion, "CriterionCombination"], - ) -> "FixedWindowTemporalIndicatorCombination": - """ - Create a NightShiftAfterMidnight combination of criteria. - """ - return cls.Presence(criterion, TimeIntervalType.NIGHT_SHIFT_AFTER_MIDNIGHT) - - @classmethod - def Day( - cls, - criterion: Union[Criterion, "CriterionCombination"], - ) -> "FixedWindowTemporalIndicatorCombination": - """ - Create a Day combination of criteria. - """ - return cls.Presence(criterion, TimeIntervalType.DAY) - - @classmethod - def AnyTime( - cls, - criterion: Union[Criterion, "CriterionCombination"], - ) -> "FixedWindowTemporalIndicatorCombination": - """ - Create a AnyTime combination of criteria. - """ - return cls.Presence(criterion, TimeIntervalType.ANY_TIME) - - -class PersonalWindowTemporalIndicatorCombination(TemporalIndicatorCombination): - """ - PersonalWindowTemporalIndicatorCombination implements a temporal indicator combination based on - person-specific time windows. Instead of using fixed start/end times or a global TimeIntervalType, this - class leverages an interval_criterion to dynamically generate evaluation windows tailored to each - patient. This design is ideal for situations where the timing of events (such as post-operative periods) - varies between patients, enabling more personalized temporal assessments. - """ - - _interval_criterion: Criterion | CriterionCombination - - def __init__( - self, - operator: TemporalIndicatorCombination.Operator, - criterion: Union[Criterion, "CriterionCombination"] | None, - interval_criterion: Criterion | CriterionCombination, - ): - super().__init__(operator=operator, criterion=criterion) - - if not isinstance(interval_criterion, (Criterion, CriterionCombination)): - raise ValueError( - f"Invalid criterion - expected Criterion or CriterionCombination, got {type(interval_criterion)}" - ) - - self._interval_criterion = interval_criterion - - def __iter__(self) -> Iterator[Union[Criterion, "CriterionCombination"]]: - """ - Iterate over the criteria in the combination - first criteria, then interval criterion if present. - """ - yield from super().__iter__() - yield self._interval_criterion - - def __str__(self) -> str: - """ - Get the string representation of the criterion combination. - """ - base_str = super().__str__() - return f"{base_str} [Personal Windows via: {self.interval_criterion}]" - - def _repr_pretty(self, level: int = 0) -> str: - children = [(None, c) for c in self._criteria] - params = [ - ("interval_criterion", self.interval_criterion), - ] - return self._build_repr( - children, - params=params, - level=level, - children_param_name="criterion", - children_is_sequence=False, - ) - - @property - def interval_criterion(self) -> Criterion | CriterionCombination: - """ - Get the interval criterion. - """ - return self._interval_criterion - - def dict(self) -> Dict: - """ - Get the dictionary representation of the criterion combination. - """ - - d = super().dict() - d["interval_criterion"] = { - "class_name": self.interval_criterion.__class__.__name__, - "data": self.interval_criterion.dict(), - } - - return d - - @classmethod - def from_dict( - cls, data: Dict[str, Any] - ) -> "PersonalWindowTemporalIndicatorCombination": - """ - Create a criterion combination from a dictionary. - """ - - from execution_engine.omop.criterion.factory import ( - criterion_factory, # needs to be here to avoid circular import - ) - - operator = cls.Operator(data["operator"], data["threshold"]) - - combination = cls( - operator=operator, - criterion=None, - interval_criterion=criterion_factory(**data["interval_criterion"]), - ) - - for criterion in data["criteria"]: - combination.add(criterion_factory(**criterion)) - - return combination - - @classmethod - def Presence( - cls, - criterion: Union[Criterion, "CriterionCombination"], - interval_criterion: Criterion | CriterionCombination, - ) -> "PersonalWindowTemporalIndicatorCombination": - """ - Create an AT_LEAST combination of criteria. - """ - return cls( - operator=cls.Operator(cls.Operator.AT_LEAST, threshold=1), - criterion=criterion, - interval_criterion=interval_criterion, - ) - - @classmethod - def MinCount( - cls, - criterion: Union[Criterion, "CriterionCombination"], - threshold: int, - interval_criterion: Criterion | CriterionCombination, - ) -> "TemporalIndicatorCombination": - """ - Create an AT_LEAST combination of criteria. - """ - return cls( - operator=cls.Operator(cls.Operator.AT_LEAST, threshold=threshold), - criterion=criterion, - interval_criterion=interval_criterion, - ) - - @classmethod - def MaxCount( - cls, - criterion: Union[Criterion, "CriterionCombination"], - threshold: int, - interval_criterion: Criterion | CriterionCombination, - ) -> "TemporalIndicatorCombination": - """ - Create an AT_MOST combination of criteria. - """ - return cls( - operator=cls.Operator(cls.Operator.AT_MOST, threshold=threshold), - criterion=criterion, - interval_criterion=interval_criterion, - ) - - @classmethod - def ExactCount( - cls, - criterion: Union[Criterion, "CriterionCombination"], - threshold: int, - interval_criterion: Criterion | CriterionCombination, - ) -> "TemporalIndicatorCombination": - """ - Create an EXACTLY combination of criteria. - """ - return cls( - operator=cls.Operator(cls.Operator.EXACTLY, threshold=threshold), - criterion=criterion, - interval_criterion=interval_criterion, - ) diff --git a/execution_engine/omop/criterion/factory.py b/execution_engine/omop/criterion/factory.py index 706e7669..f6f6bbfd 100644 --- a/execution_engine/omop/criterion/factory.py +++ b/execution_engine/omop/criterion/factory.py @@ -1,16 +1,5 @@ from typing import Type -from execution_engine.omop.criterion.abstract import Criterion -from execution_engine.omop.criterion.combination.combination import CriterionCombination -from execution_engine.omop.criterion.combination.logical import ( - LogicalCriterionCombination, - NonCommutativeLogicalCriterionCombination, -) -from execution_engine.omop.criterion.combination.temporal import ( - PersonalWindowTemporalIndicatorCombination, - TemporalIndicatorCombination, -) -from execution_engine.omop.criterion.concept import ConceptCriterion from execution_engine.omop.criterion.condition_occurrence import ConditionOccurrence from execution_engine.omop.criterion.custom import TidalVolumePerIdealBodyWeight from execution_engine.omop.criterion.drug_exposure import DrugExposure @@ -27,11 +16,13 @@ __all__ = ["criterion_factory", "register_criterion_class"] -class_map: dict[str, Type[Criterion] | Type[CriterionCombination]] = { - "ConceptCriterion": ConceptCriterion, - "LogicalCriterionCombination": LogicalCriterionCombination, - "TemporalCriterionCombination": TemporalIndicatorCombination, - "NonCommutativeLogicalCriterionCombination": NonCommutativeLogicalCriterionCombination, +from execution_engine.util import logic + +class_map: dict[str, Type[logic.BaseExpr]] = { + # "ConceptCriterion": ConceptCriterion, # is an Abstract class + "LogicalCriterionCombination": logic.BooleanFunction, + "TemporalCriterionCombination": logic.BooleanFunction, + "NonCommutativeLogicalCriterionCombination": logic.BooleanFunction, "ConditionOccurrence": ConditionOccurrence, "DrugExposure": DrugExposure, "Measurement": Measurement, @@ -43,13 +34,13 @@ "TidalVolumePerIdealBodyWeight": TidalVolumePerIdealBodyWeight, "VisitDetail": VisitDetail, "PointInTimeCriterion": PointInTimeCriterion, - "PersonalWindowTemporalIndicatorCombination": PersonalWindowTemporalIndicatorCombination, + "PersonalWindowTemporalIndicatorCombination": logic.BooleanFunction, } def register_criterion_class( class_name: str, - criterion_class: Type[Criterion] | Type[CriterionCombination], + criterion_class: Type[logic.BaseExpr], ) -> None: """ Register a criterion class. @@ -60,7 +51,7 @@ def register_criterion_class( class_map[class_name] = criterion_class -def criterion_factory(class_name: str, data: dict) -> Criterion | CriterionCombination: +def criterion_factory(class_name: str, data: dict) -> logic.BaseExpr: """ Create a criterion from a dictionary representation. diff --git a/execution_engine/task/creator.py b/execution_engine/task/creator.py index 1da50cda..772fe314 100644 --- a/execution_engine/task/creator.py +++ b/execution_engine/task/creator.py @@ -2,7 +2,7 @@ import networkx as nx -import execution_engine.util.cohort_logic as logic +import execution_engine.util.logic as logic from execution_engine.task.task import Task @@ -26,6 +26,7 @@ def node_to_task(expr: logic.Expr, attr: dict) -> Task: criterion = cast(logic.Symbol, expr).criterion if expr.is_Atom else None store_result = attr.get("store_result", False) bind_params = attr.get("bind_params", {}) + bind_params["category"] = attr["category"] task = Task( expr=expr, diff --git a/execution_engine/task/runner.py b/execution_engine/task/runner.py index 9379fb9a..29510951 100644 --- a/execution_engine/task/runner.py +++ b/execution_engine/task/runner.py @@ -191,7 +191,6 @@ def run(self, bind_params: dict) -> None: try: while len(self.completed_tasks) < len(self.tasks): self.enqueue_ready_tasks() - logging.info(f"{len(self.completed_tasks)}/{len(self.tasks)} tasks") if self.queue.empty() and not any( task.status == TaskStatus.RUNNING for task in self.tasks @@ -317,7 +316,6 @@ def task_executor_worker() -> None: break self.enqueue_ready_tasks() - logging.info(f"{len(self.completed_tasks)}/{len(self.tasks)} tasks") if self.completed_tasks == self.enqueued_tasks and len( self.completed_tasks @@ -331,7 +329,12 @@ def task_executor_worker() -> None: with self.lock: # Update the set of completed tasks + n_completed = len(self.completed_tasks) self.completed_tasks = set(self.shared_results.keys()) + if len(self.completed_tasks) > n_completed: + logging.info( + f"Completed {len(self.completed_tasks)} of {len(self.tasks)} tasks" + ) except Exception as e: logging.error(f"An error occurred: {e}") diff --git a/execution_engine/task/task.py b/execution_engine/task/task.py index 36493593..24d12980 100644 --- a/execution_engine/task/task.py +++ b/execution_engine/task/task.py @@ -6,7 +6,7 @@ from sqlalchemy.exc import DBAPIError, IntegrityError, ProgrammingError, SQLAlchemyError -import execution_engine.util.cohort_logic as logic +import execution_engine.util.logic as logic from execution_engine.constants import CohortCategory from execution_engine.omop.criterion.abstract import Criterion from execution_engine.omop.criterion.combination.temporal import TimeIntervalType @@ -72,7 +72,7 @@ def category(self) -> CohortCategory: """ Returns the category of the task. """ - return self.expr.category + return self.bind_params["category"] def get_base_task(self) -> "Task": """ @@ -655,6 +655,6 @@ def __repr__(self) -> str: Returns a string representation of the Task object. """ if self.expr.is_Atom: - return f"Task(criterion={self.expr}, category={self.expr.category})" + return f"Task(criterion={self.expr}, category={self.category})" else: - return f"Task({self.expr}), category={self.expr.category})" + return f"Task({self.expr}), category={self.category})" diff --git a/execution_engine/util/cohort_logic.py b/execution_engine/util/logic.py similarity index 82% rename from execution_engine/util/cohort_logic.py rename to execution_engine/util/logic.py index d07e784c..d3c7b0c2 100644 --- a/execution_engine/util/cohort_logic.py +++ b/execution_engine/util/logic.py @@ -1,10 +1,18 @@ from abc import ABC, abstractmethod from datetime import time -from typing import Any, Callable, cast +from typing import Any, Callable, Dict, Type, cast -from execution_engine.constants import CohortCategory from execution_engine.omop.criterion.abstract import Criterion from execution_engine.omop.criterion.combination.temporal import TimeIntervalType +from execution_engine.omop.serializable import Serializable + +_EXPR_CLASSES: dict[str, Type["BaseExpr"]] = {} + + +def register_expr_class(cls: Type["BaseExpr"]) -> Type["BaseExpr"]: + """Decorator to register expression classes by their class name.""" + _EXPR_CLASSES[cls.__name__] = cls + return cls class BaseExpr(ABC): @@ -17,7 +25,7 @@ class BaseExpr(ABC): @classmethod def _recreate(cls, args: Any, kwargs: dict) -> "Expr": """ - Recreate an expression from its arguments and category. + Recreate an expression from its arguments. """ return cast(Expr, cls(*args, **kwargs)) @@ -70,11 +78,11 @@ def is_Not(self) -> bool: def __reduce__(self) -> tuple[Callable, tuple]: """ - Reduce the expression to its arguments and category. + Reduce the expression to its arguments. Required for pickling (e.g. when using multiprocessing). - :return: Tuple of the class, arguments, and category. + :return: Tuple of the class, and arguments. """ return self._recreate, (self.args, self.get_instance_variables()) @@ -97,24 +105,80 @@ def get_instance_variables(self, immutable: bool = False) -> dict | tuple: else: return instance_vars + def dict(self) -> dict: + """ + Serialize this expression (and its children) into a dict. + """ + data: dict[str, Any] = {"type": self.__class__.__name__} + # Store the class name, so we know which subclass to instantiate on from_dict + + # Store instance variables (like thresholds, category, etc.) except for args + # (They come from get_instance_variables in your snippet) + instance_vars = self.get_instance_variables(immutable=False) + for key in instance_vars: + if isinstance(instance_vars[key], Serializable): + data[key] = instance_vars[key].dict() + else: + data[key] = instance_vars[key] + + # Also store the expression's children by recursing + if self.args: + serialized_args = [] + for arg in self.args: + if isinstance(arg, (BaseExpr, Serializable)): + serialized_args.append(arg.dict()) # Recursively serialize + else: + # If non-expression objects appear in .args, store them directly + serialized_args.append(arg) + data["args"] = serialized_args + + return data + + @classmethod + def from_dict(cls, data: Dict) -> "BaseExpr": + """ + Rebuild an expression (of the correct subclass) from the given dict. + """ + expr_type = data["type"] + + # Find the actual subclass + if expr_type not in _EXPR_CLASSES: + raise ValueError(f"No registered expression class named '{expr_type}'") + sub_cls = _EXPR_CLASSES[expr_type] + + # Deserialize child expressions + children = [] + for arg_data in data.get("args", []): + if isinstance(arg_data, dict) and "type" in arg_data: + # It's a nested BaseExpr + child_expr = cls.from_dict(arg_data) + children.append(child_expr) + else: + # It's a literal (int, str, etc.) + children.append(arg_data) + + instance_vars = { + key: val for key, val in data.items() if key not in ("type", "args") + } + + return sub_cls._recreate(tuple(children), instance_vars) + class Expr(BaseExpr): """ Class for expressions that require a category. """ - category: CohortCategory + # todo: isn't this now a bit redundant with the BaseExpr class? (because we've removed category) - def __new__(cls, *args: Any, category: CohortCategory) -> "Expr": + def __new__(cls, *args: Any) -> "Expr": """ - Initialize an expression with given arguments and a mandatory category. + Initialize an expression with given arguments. :param args: Arguments for the expression. - :param category: Mandatory category of the expression. """ self = cast(Expr, super().__new__(cls, *args)) self.args = args - self.category = category return self @@ -122,7 +186,7 @@ def __repr__(self) -> str: """ Represent the expression in a readable format. """ - return f"{self.__class__.__name__}({', '.join(map(repr, self.args))}, category='{self.category}')" + return f"{self.__class__.__name__}({', '.join(map(repr, self.args))})" def __eq__(self, other: Any) -> bool: """ @@ -162,24 +226,23 @@ def is_Not(self) -> bool: return isinstance(self, Not) +@register_expr_class class Symbol(BaseExpr): """ Class representing a symbolic variable. """ criterion: Criterion - category: CohortCategory - def __new__(cls, criterion: Criterion, category: CohortCategory) -> "Symbol": + def __new__(cls, criterion: Criterion) -> "Symbol": """ Initialize a symbol. :param criterion: The criterion of the symbol. """ self = cast(Symbol, super().__new__(cls)) - self.args = () self.criterion = criterion - self.category = category + self.args = () return self @@ -227,6 +290,7 @@ def is_Not(self) -> bool: return False +@register_expr_class class BooleanFunction(Expr): """ Base class for boolean functions like OR, AND, and NOT. @@ -285,6 +349,7 @@ def __repr__(self) -> str: return super().__repr__() +@register_expr_class class Or(BooleanFunction): """ Class representing a logical OR operation. @@ -302,6 +367,7 @@ def __new__(cls, *args: Any, **kwargs: Any) -> BaseExpr: return super().__new__(cls, *args, **kwargs) +@register_expr_class class And(BooleanFunction): """ Class representing a logical AND operation. @@ -319,6 +385,7 @@ def __new__(cls, *args: Any, **kwargs: Any) -> BaseExpr: return super().__new__(cls, *args, **kwargs) +@register_expr_class class Not(BooleanFunction): """ Class representing a logical NOT operation. @@ -340,6 +407,7 @@ def __new__(cls, *args: Any, **kwargs: Any) -> "Not": return cast(Not, super().__new__(cls, *args, **kwargs)) +@register_expr_class class Count(BooleanFunction, ABC): """ Class representing a logical COUNT operation. @@ -353,6 +421,7 @@ class Count(BooleanFunction, ABC): count_max: int | None = None +@register_expr_class class MinCount(Count): """ Class representing a logical MIN_COUNT operation. @@ -368,24 +437,25 @@ def __new__(cls, *args: Any, threshold: int | None, **kwargs: Any) -> "MinCount" def __reduce__(self) -> tuple[Callable, tuple]: """ - Reduce the expression to its arguments and category. + Reduce the expression to its arguments Required for pickling (e.g. when using multiprocessing). - :return: Tuple of the class, arguments, and category. + :return: Tuple of the class, and arguments. """ return ( self._recreate, - (self.args, {"category": self.category, "threshold": self.count_min}), + (self.args, {"threshold": self.count_min}), ) def __repr__(self) -> str: """ Represent the expression in a readable format. """ - return f"{self.__class__.__name__}(threshold={self.count_min}; {', '.join(map(repr, self.args))}, category='{self.category}')" + return f"{self.__class__.__name__}(threshold={self.count_min}; {', '.join(map(repr, self.args))})" +@register_expr_class class MaxCount(Count): """ Class representing a logical MAX_COUNT operation. @@ -401,24 +471,25 @@ def __new__(cls, *args: Any, threshold: int | None, **kwargs: Any) -> "MaxCount" def __reduce__(self) -> tuple[Callable, tuple]: """ - Reduce the expression to its arguments and category. + Reduce the expression to its arguments. Required for pickling (e.g. when using multiprocessing). - :return: Tuple of the class, arguments, and category. + :return: Tuple of the class and arguments. """ return ( self._recreate, - (self.args, {"category": self.category, "threshold": self.count_max}), + (self.args, {"threshold": self.count_max}), ) def __repr__(self) -> str: """ Represent the expression in a readable format. """ - return f"{self.__class__.__name__}(threshold={self.count_max}; {', '.join(map(repr, self.args))}, category='{self.category}')" + return f"{self.__class__.__name__}(threshold={self.count_max}; {', '.join(map(repr, self.args))})" +@register_expr_class class ExactCount(Count): """ Class representing a logical EXACT_COUNT operation. @@ -435,24 +506,25 @@ def __new__(cls, *args: Any, threshold: int | None, **kwargs: Any) -> "ExactCoun def __reduce__(self) -> tuple[Callable, tuple]: """ - Reduce the expression to its arguments and category. + Reduce the expression to its arguments. Required for pickling (e.g. when using multiprocessing). - :return: Tuple of the class, arguments, and category. + :return: Tuple of the class, and arguments. """ return ( self._recreate, - (self.args, {"category": self.category, "threshold": self.count_min}), + (self.args, {"threshold": self.count_min}), ) def __repr__(self) -> str: """ Represent the expression in a readable format. """ - return f"{self.__class__.__name__}(threshold={self.count_min}; {', '.join(map(repr, self.args))}, category='{self.category}')" + return f"{self.__class__.__name__}(threshold={self.count_min}; {', '.join(map(repr, self.args))})" +@register_expr_class class CappedCount(BooleanFunction, ABC): """ Base class representing a COUNT operation with an upper cap. @@ -473,6 +545,7 @@ class CappedCount(BooleanFunction, ABC): count_max: int | None = None +@register_expr_class class CappedMinCount(CappedCount): """ Class representing a MIN_COUNT operation with an implicit upper cap. @@ -501,30 +574,32 @@ def __new__( def __reduce__(self) -> tuple[Callable, tuple]: """ - Reduce the expression to its arguments and category. + Reduce the expression to its arguments. Required for pickling (e.g., when using multiprocessing). - :return: Tuple of the class, arguments, and category. + :return: Tuple of the class, and arguments. """ return ( self._recreate, - (self.args, {"category": self.category, "threshold": self.count_min}), + (self.args, {"threshold": self.count_min}), ) def __repr__(self) -> str: """ Represent the expression in a readable format. """ - return f"{self.__class__.__name__}(threshold={self.count_min}; {', '.join(map(repr, self.args))}, category='{self.category}')" + return f"{self.__class__.__name__}(threshold={self.count_min}; {', '.join(map(repr, self.args))})" +@register_expr_class class AllOrNone(BooleanFunction): """ Class representing a logical ALL_OR_NONE operation. """ +@register_expr_class class TemporalCount(BooleanFunction, ABC): """ Class representing a logical COUNT operation. @@ -542,6 +617,7 @@ class TemporalCount(BooleanFunction, ABC): interval_criterion: BaseExpr | None = None +@register_expr_class class TemporalMinCount(TemporalCount): """ Class representing a logical temporal MIN_COUNT operation. @@ -571,18 +647,17 @@ def __new__( def __reduce__(self) -> tuple[Callable, tuple]: """ - Reduce the expression to its arguments and category. + Reduce the expression to its arguments. Required for pickling (e.g. when using multiprocessing). - :return: Tuple of the class, arguments, and category. + :return: Tuple of the class, and arguments. """ return ( self._recreate, ( self.args, { - "category": self.category, "threshold": self.count_min, "start_time": self.start_time, "end_time": self.end_time, @@ -606,9 +681,10 @@ def __repr__(self) -> str: else: interval = "None" - return f"{self.__class__.__name__}(interval={interval}; threshold={self.count_min}; {', '.join(map(repr, self.args))}, category='{self.category}')" + return f"{self.__class__.__name__}(interval={interval}; threshold={self.count_min}; {', '.join(map(repr, self.args))})" +@register_expr_class class TemporalMaxCount(TemporalCount): """ Class representing a logical MAX_COUNT operation. @@ -638,18 +714,17 @@ def __new__( def __reduce__(self) -> tuple[Callable, tuple]: """ - Reduce the expression to its arguments and category. + Reduce the expression to its arguments. Required for pickling (e.g. when using multiprocessing). - :return: Tuple of the class, arguments, and category. + :return: Tuple of the class, and arguments. """ return ( self._recreate, ( self.args, { - "category": self.category, "threshold": self.count_max, "start_time": self.start_time, "end_time": self.end_time, @@ -673,9 +748,10 @@ def __repr__(self) -> str: else: interval = "None" - return f"{self.__class__.__name__}(interval={interval}; threshold={self.count_max}; {', '.join(map(repr, self.args))}, category='{self.category}')" + return f"{self.__class__.__name__}(interval={interval}; threshold={self.count_max}; {', '.join(map(repr, self.args))})" +@register_expr_class class TemporalExactCount(TemporalCount): """ Class representing a logical EXACT_COUNT operation. @@ -706,18 +782,17 @@ def __new__( def __reduce__(self) -> tuple[Callable, tuple]: """ - Reduce the expression to its arguments and category. + Reduce the expression to its arguments. Required for pickling (e.g. when using multiprocessing). - :return: Tuple of the class, arguments, and category. + :return: Tuple of the class, and arguments. """ return ( self._recreate, ( self.args, { - "category": self.category, "threshold": self.count_min, "start_time": self.start_time, "end_time": self.end_time, @@ -741,9 +816,10 @@ def __repr__(self) -> str: else: interval = "None" - return f"{self.__class__.__name__}(interval={interval}; threshold={self.count_min}; {', '.join(map(repr, self.args))}, category='{self.category}')" + return f"{self.__class__.__name__}(interval={interval}; threshold={self.count_min}; {', '.join(map(repr, self.args))})" +@register_expr_class class NonSimplifiableAnd(BooleanFunction): """ A NonSimplifiableAnd object represents a logical AND operation that cannot be simplified. @@ -783,6 +859,7 @@ def __new__(cls, *args: Any, **kwargs: Any) -> "NonSimplifiableAnd": # todo: can we rename to more meaningful name? +@register_expr_class class NoDataPreservingAnd(BooleanFunction): """ A And object represents a logical AND operation. @@ -798,6 +875,7 @@ def __new__(cls, *args: Any, **kwargs: Any) -> "NoDataPreservingAnd": return cast(NoDataPreservingAnd, super().__new__(cls, *args, **kwargs)) +@register_expr_class class NoDataPreservingOr(BooleanFunction): """ A Or object represents a logical OR operation. @@ -813,6 +891,7 @@ def __new__(cls, *args: Any, **kwargs: Any) -> "NoDataPreservingOr": return cast(NoDataPreservingOr, super().__new__(cls, *args, **kwargs)) +@register_expr_class class LeftDependentToggle(BooleanFunction): """ A LeftDependentToggle object represents a logical AND operation if the left operand is positive, @@ -838,10 +917,22 @@ def right(self) -> Expr: return self.args[1] +@register_expr_class class ConditionalFilter(BooleanFunction): """ A ConditionalFilter object returns the right operand if the left operand is POSITIVE, and NEGATIVE otherwise + + + A conditional filter returns `right` iff `left` is POSITIVE, otherwise NEGATIVE. + + | left | right | Result | + |----------|----------|----------| + | NEGATIVE | * | NEGATIVE | + | NO_DATA | * | NEGATIVE | + | POSITIVE | POSITIVE | POSITIVE | + | POSITIVE | NEGATIVE | NEGATIVE | + | POSITIVE | NO_DATA | NO_DATA | """ def __new__( diff --git a/tests/execution_engine/converter/action/test_assessment.py b/tests/execution_engine/converter/action/test_assessment.py index 821dc736..ad2a09b8 100644 --- a/tests/execution_engine/converter/action/test_assessment.py +++ b/tests/execution_engine/converter/action/test_assessment.py @@ -33,7 +33,7 @@ class TestAssessmentAction: def test_assessment_action(self, code, timing, criterion_class): action = AssessmentAction(exclude=False, code=code, timing=timing) - criterion = action.to_criterion() + criterion = action.to_positive_expression() assert isinstance(criterion, criterion_class) assert criterion._concept == code diff --git a/tests/execution_engine/converter/action/test_drug_administration.py b/tests/execution_engine/converter/action/test_drug_administration.py index c58557ba..69bf83b7 100644 --- a/tests/execution_engine/converter/action/test_drug_administration.py +++ b/tests/execution_engine/converter/action/test_drug_administration.py @@ -26,7 +26,7 @@ def test_single_dose(self): dosages=[dosage_def], ) - criterion = action.to_criterion() + criterion = action.to_expression() assert isinstance(criterion, DrugExposure) assert criterion._dose == dosage @@ -64,7 +64,7 @@ def test_multiple_doses(self): ], ) - comb = action.to_criterion() + comb = action.to_expression() criteria = list(comb) assert len(criteria) == 3 diff --git a/tests/execution_engine/converter/test_converter.py b/tests/execution_engine/converter/test_converter.py index e7d1b260..88820367 100644 --- a/tests/execution_engine/converter/test_converter.py +++ b/tests/execution_engine/converter/test_converter.py @@ -16,10 +16,7 @@ parse_value, select_value, ) -from execution_engine.omop.criterion.abstract import Criterion -from execution_engine.omop.criterion.combination.logical import ( - LogicalCriterionCombination, -) +from execution_engine.util import logic from execution_engine.util.value import ValueConcept, ValueNumber @@ -157,7 +154,7 @@ def from_fhir(cls, fhir_definition: Element) -> "CriterionConverter": def valid(cls, fhir_definition: Element) -> bool: return fhir_definition.id == "valid" - def to_positive_criterion(self) -> Criterion | LogicalCriterionCombination: + def to_positive_expression(self) -> logic.Symbol: raise NotImplementedError() def test_criterion_converter_factory_register(self): diff --git a/tests/execution_engine/omop/criterion/test_criterion.py b/tests/execution_engine/omop/criterion/test_criterion.py index fee7a47b..475184cd 100644 --- a/tests/execution_engine/omop/criterion/test_criterion.py +++ b/tests/execution_engine/omop/criterion/test_criterion.py @@ -25,7 +25,7 @@ task, ) from execution_engine.task.process import get_processing_module -from execution_engine.util import cohort_logic +from execution_engine.util import logic from execution_engine.util.db import add_result_insert from execution_engine.util.interval import IntervalType from execution_engine.util.types import TimeRange @@ -213,14 +213,14 @@ def register_population_intervention_pair(cls, id_, name, db_session): ).scalar() if not exists: - new_criterion = celida_tables.PopulationInterventionPair( + pi_pair = celida_tables.PopulationInterventionPair( pi_pair_id=id_, recommendation_id=-1, pi_pair_url="https://example.com", pi_pair_name=name, pi_pair_hash=hash(name), ) - db_session.add(new_criterion) + db_session.add(pi_pair) db_session.commit() @pytest.fixture @@ -303,11 +303,11 @@ def insert_criterion_combination( p_sink_node = graph.sink_node(CohortCategory.POPULATION) pi_sink_node = graph.sink_node(CohortCategory.POPULATION_INTERVENTION) - p_combination_node = cohort_logic.NoDataPreservingAnd( + p_combination_node = logic.NoDataPreservingAnd( p_sink_node, category=CohortCategory.POPULATION ) - pi_combination_node = cohort_logic.NoDataPreservingAnd( + pi_combination_node = logic.NoDataPreservingAnd( pi_sink_node, category=CohortCategory.POPULATION_INTERVENTION ) diff --git a/tests/execution_engine/util/test_cohort_logic.py b/tests/execution_engine/util/test_cohort_logic.py index 1d8b9d54..0c3bbabd 100644 --- a/tests/execution_engine/util/test_cohort_logic.py +++ b/tests/execution_engine/util/test_cohort_logic.py @@ -5,7 +5,7 @@ from execution_engine.constants import CohortCategory from execution_engine.omop.criterion.combination.temporal import TimeIntervalType -from execution_engine.util.cohort_logic import ( +from execution_engine.util.logic import ( AllOrNone, And, BooleanFunction, From 9b47d5dee053aec2d5ddb9a5118984ed882d4f1d Mon Sep 17 00:00:00 2001 From: Gregor Lichtner Date: Mon, 10 Mar 2025 19:10:02 +0100 Subject: [PATCH 06/16] refactor: Criterion as Symbol baseclass --- .../converter/action/body_positioning.py | 8 +-- .../converter/action/drug_administration.py | 28 ++++----- .../converter/action/procedure.py | 2 +- .../characteristic/codeable_concept.py | 10 ++-- .../converter/characteristic/value.py | 10 ++-- .../converter/goal/assessment_scale.py | 8 +-- .../converter/goal/laboratory_value.py | 8 +-- .../converter/goal/ventilator_management.py | 14 ++--- execution_engine/execution_graph/graph.py | 30 +++++----- .../omop/cohort/recommendation.py | 8 +-- execution_engine/omop/criterion/abstract.py | 39 ++++++++++++- execution_engine/task/creator.py | 4 -- execution_engine/task/task.py | 20 +++---- execution_engine/util/logic.py | 58 ++++--------------- scripts/execute.py | 6 +- 15 files changed, 110 insertions(+), 143 deletions(-) diff --git a/execution_engine/converter/action/body_positioning.py b/execution_engine/converter/action/body_positioning.py index f0c5ae46..0fd80ef8 100644 --- a/execution_engine/converter/action/body_positioning.py +++ b/execution_engine/converter/action/body_positioning.py @@ -48,9 +48,7 @@ def from_fhir(cls, action_def: RecommendationPlan.Action) -> Self: def _to_expression(self) -> logic.Symbol: """Converts this characteristic to a Criterion.""" - return logic.Symbol( - criterion=ProcedureOccurrence( - concept=self._code, - timing=self._timing, - ) + return ProcedureOccurrence( + concept=self._code, + timing=self._timing, ) diff --git a/execution_engine/converter/action/drug_administration.py b/execution_engine/converter/action/drug_administration.py index c556c528..8e072b0e 100644 --- a/execution_engine/converter/action/drug_administration.py +++ b/execution_engine/converter/action/drug_administration.py @@ -282,21 +282,17 @@ def _to_expression(self) -> logic.BaseExpr: if not self._dosages: # no dosages, just return the drug exposure - return logic.Symbol( - criterion=DrugExposure( - ingredient_concept=self._ingredient_concept, - dose=None, - route=None, - ) + return DrugExposure( + ingredient_concept=self._ingredient_concept, + dose=None, + route=None, ) for dosage in self._dosages: - drug_action = logic.Symbol( - criterion=DrugExposure( - ingredient_concept=self._ingredient_concept, - dose=dosage["dose"], - route=dosage.get("route", None), - ) + drug_action = DrugExposure( + ingredient_concept=self._ingredient_concept, + dose=dosage["dose"], + route=dosage.get("route", None), ) extensions = dosage.get("extensions", None) @@ -315,11 +311,9 @@ def _to_expression(self) -> logic.BaseExpr: f"Extension type {extension['type']} not supported yet" ) - ext_criterion = logic.Symbol( - criterion=PointInTimeCriterion( - concept=extension["code"], - value=extension["value"], - ) + ext_criterion = PointInTimeCriterion( + concept=extension["code"], + value=extension["value"], ) # A Conditional Filter returns `right` iff left is POSITIVE, otherwise it returns NEGATIVE diff --git a/execution_engine/converter/action/procedure.py b/execution_engine/converter/action/procedure.py index c8176f13..545bf700 100644 --- a/execution_engine/converter/action/procedure.py +++ b/execution_engine/converter/action/procedure.py @@ -87,4 +87,4 @@ def _to_expression(self) -> logic.Symbol: f"Concept domain {self._code.domain_id} is not supported for {self.__class__.__name__}]" ) - return logic.Symbol(criterion=criterion) + return criterion diff --git a/execution_engine/converter/characteristic/codeable_concept.py b/execution_engine/converter/characteristic/codeable_concept.py index 57d6bef4..18c96eaa 100644 --- a/execution_engine/converter/characteristic/codeable_concept.py +++ b/execution_engine/converter/characteristic/codeable_concept.py @@ -80,10 +80,8 @@ def from_fhir( def to_positive_expression(self) -> logic.BaseExpr: """Converts this characteristic to a Criterion.""" - return logic.Symbol( - criterion=self._criterion_class( - concept=self.value, - value=None, - static=self._concept_value_static, - ) + return self._criterion_class( + concept=self.value, + value=None, + static=self._concept_value_static, ) diff --git a/execution_engine/converter/characteristic/value.py b/execution_engine/converter/characteristic/value.py index 4c84cf58..5512eaa5 100644 --- a/execution_engine/converter/characteristic/value.py +++ b/execution_engine/converter/characteristic/value.py @@ -35,10 +35,8 @@ def from_fhir( def to_positive_expression(self) -> logic.Symbol: """Converts this characteristic to a Criterion.""" - return logic.Symbol( - criterion=self._criterion_class( - concept=self.type, - value=self.value, - static=self._concept_value_static, - ) + return self._criterion_class( + concept=self.type, + value=self.value, + static=self._concept_value_static, ) diff --git a/execution_engine/converter/goal/assessment_scale.py b/execution_engine/converter/goal/assessment_scale.py index ab23c898..28e05bff 100644 --- a/execution_engine/converter/goal/assessment_scale.py +++ b/execution_engine/converter/goal/assessment_scale.py @@ -51,9 +51,7 @@ def to_positive_expression(self) -> logic.Symbol: """ Converts the goal to a criterion. """ - return logic.Symbol( - criterion=Measurement( - concept=self._code, - value=self._value, - ) + return Measurement( + concept=self._code, + value=self._value, ) diff --git a/execution_engine/converter/goal/laboratory_value.py b/execution_engine/converter/goal/laboratory_value.py index f2e1bcc3..ec10b8f9 100644 --- a/execution_engine/converter/goal/laboratory_value.py +++ b/execution_engine/converter/goal/laboratory_value.py @@ -50,9 +50,7 @@ def to_positive_expression(self) -> logic.Symbol: """ Converts the goal to a criterion. """ - return logic.Symbol( - Measurement( - concept=self._code, - value=self._value, - ) + return Measurement( + concept=self._code, + value=self._value, ) diff --git a/execution_engine/converter/goal/ventilator_management.py b/execution_engine/converter/goal/ventilator_management.py index e7a24001..da8630d6 100644 --- a/execution_engine/converter/goal/ventilator_management.py +++ b/execution_engine/converter/goal/ventilator_management.py @@ -62,16 +62,12 @@ def to_positive_expression(self) -> logic.Symbol: """ if self._code in CUSTOM_GOALS: cls = CUSTOM_GOALS[self._code] - return logic.Symbol( - cls( - concept=self._code, - value=self._value, - ) - ) - - return logic.Symbol( - Measurement( + return cls( concept=self._code, value=self._value, ) + + return Measurement( + concept=self._code, + value=self._value, ) diff --git a/execution_engine/execution_graph/graph.py b/execution_engine/execution_graph/graph.py index 53893d46..8561b904 100644 --- a/execution_engine/execution_graph/graph.py +++ b/execution_engine/execution_graph/graph.py @@ -76,9 +76,8 @@ def from_expression( expr = copy.deepcopy(expr) graph = cls() - base_node = logic.Symbol( - criterion=base_criterion, - ) + base_node = base_criterion + graph.add_node( base_node, category=CohortCategory.BASE, @@ -203,7 +202,7 @@ def to_cytoscape_dict(self) -> dict: "id": id(node), "label": str(node), "class": ( - node.criterion.__class__.__name__ + node.__class__.__name__ if isinstance(node, logic.Symbol) else node.__class__.__name__ ), @@ -226,23 +225,24 @@ def to_cytoscape_dict(self) -> dict: if isinstance(node, logic.Symbol): - node_data["data"]["criterion_id"] = node.criterion._id + assert isinstance( + node, Criterion + ), f"Expected Criterion, got {type(node)}" + + node_data["data"]["criterion_id"] = node.id def criterion_attr(attr: str) -> str | None: - if ( - hasattr(node.criterion, attr) - and getattr(node.criterion, attr) is not None - ): - return str(getattr(node.criterion, attr)) + if hasattr(node, attr) and getattr(node, attr) is not None: + return str(getattr(node, attr)) return None try: - if node.criterion.concept is not None: + if node.concept is not None: node_data["data"].update( { "concept": ( - node.criterion.concept.model_dump() - if node.criterion.concept is not None + node.concept.model_dump() + if node.concept is not None else None ), "value": criterion_attr("value"), @@ -267,7 +267,7 @@ def criterion_attr(attr: str) -> str | None: if self.nodes[node]["category"] == CohortCategory.BASE: node_data["data"]["base_criterion"] = str( - node.criterion + node ) # Ensure this is serializable nodes.append(node_data) @@ -311,7 +311,7 @@ def plot(self) -> None: }[self.nodes[node]["category"]] if node.is_Atom: - label = node.criterion.description() + label = node.description() else: label = node.__class__.__name__ symbols = { diff --git a/execution_engine/omop/cohort/recommendation.py b/execution_engine/omop/cohort/recommendation.py index a11ac7a5..6509fdad 100644 --- a/execution_engine/omop/cohort/recommendation.py +++ b/execution_engine/omop/cohort/recommendation.py @@ -188,10 +188,10 @@ def flatten(self) -> list[Criterion]: Retrieve all criteria in a flat list """ - def traverse(expr: logic.Expr) -> list[Criterion]: + def traverse(expr: logic.BaseExpr) -> list[Criterion]: if expr.is_Atom: - assert isinstance(expr, logic.Symbol), f"Expected Symbol, got {expr}" - return [expr.criterion] + assert isinstance(expr, Criterion), f"Expected Symbol, got {expr}" + return [expr] gathered = [] for sub_expr in expr.args: @@ -205,7 +205,7 @@ def population_intervention_pairs(self) -> Iterator[PopulationInterventionPairEx Iterate over all PopulationInterventionPairExpr in the expression tree. """ - def traverse(expr: logic.Expr) -> Iterator[PopulationInterventionPairExpr]: + def traverse(expr: logic.BaseExpr) -> Iterator[PopulationInterventionPairExpr]: if isinstance(expr, PopulationInterventionPairExpr): yield expr else: diff --git a/execution_engine/omop/criterion/abstract.py b/execution_engine/omop/criterion/abstract.py index 70c16f0e..94ea2957 100644 --- a/execution_engine/omop/criterion/abstract.py +++ b/execution_engine/omop/criterion/abstract.py @@ -1,6 +1,6 @@ import copy from abc import ABC, abstractmethod -from typing import Any, Dict, Self, Type, TypedDict, cast +from typing import Any, Callable, Dict, Self, Type, TypedDict, cast import sqlalchemy from sqlalchemy import CTE, Alias, ColumnElement, Date, Integer @@ -91,7 +91,7 @@ def create_conditional_interval_column(condition: ColumnElement) -> ColumnElemen class AbstractCriterion( - logic.Symbol, Serializable, ABC, metaclass=SignatureReprABCMeta + Serializable, logic.Symbol, ABC, metaclass=SignatureReprABCMeta ): """ Abstract base class for Criterion and CriterionCombination. @@ -137,6 +137,41 @@ def __str__(self) -> str: """ return self.description() + def get_instance_variables(self, immutable: bool = False) -> dict | tuple: + """ + Get the instance variables of the criterion. + """ + + if immutable: + return tuple( + sorted(self._init_args.items()) # type: ignore[attr-defined] # this is set due to metaclass + ) + return self._init_args # type: ignore[attr-defined] # this is set due to metaclass + + @classmethod + def _recreate( # type: ignore[override] + cls, kwargs: Any, id: int | None + ) -> "AbstractCriterion": + """ + Recreate an expression from its arguments. + """ + expr = cast(AbstractCriterion, cls(**kwargs)) + + if id is not None: + expr.set_id(id) + + return expr + + def __reduce__(self) -> tuple[Callable, tuple]: + """ + Reduce the expression to its arguments. + + Required for pickling (e.g. when using multiprocessing). + + :return: Tuple of the class, and arguments. + """ + return self._recreate, (self.get_instance_variables(), self._id) + class Criterion(AbstractCriterion): """A criterion in a recommendation.""" diff --git a/execution_engine/task/creator.py b/execution_engine/task/creator.py index 772fe314..466912e8 100644 --- a/execution_engine/task/creator.py +++ b/execution_engine/task/creator.py @@ -1,5 +1,3 @@ -from typing import cast - import networkx as nx import execution_engine.util.logic as logic @@ -23,14 +21,12 @@ def create_tasks_and_dependencies(graph: nx.DiGraph) -> list[Task]: """ def node_to_task(expr: logic.Expr, attr: dict) -> Task: - criterion = cast(logic.Symbol, expr).criterion if expr.is_Atom else None store_result = attr.get("store_result", False) bind_params = attr.get("bind_params", {}) bind_params["category"] = attr["category"] task = Task( expr=expr, - criterion=criterion, bind_params=bind_params, store_result=store_result, ) diff --git a/execution_engine/task/task.py b/execution_engine/task/task.py index 24d12980..7f10dd1f 100644 --- a/execution_engine/task/task.py +++ b/execution_engine/task/task.py @@ -55,13 +55,11 @@ class Task: def __init__( self, - expr: logic.Expr, - criterion: Criterion | None, + expr: logic.BaseExpr, bind_params: dict | None, store_result: bool = False, ) -> None: self.expr = expr - self.criterion = criterion self.dependencies: list[Task] = [] self.status = TaskStatus.PENDING self.bind_params = bind_params if bind_params is not None else {} @@ -126,13 +124,12 @@ def run( if len(self.dependencies) == 0 or self.expr.is_Atom: # atomic expressions (i.e. criterion) - assert ( - self.criterion is not None - ), "criterion shall not be None for atomic expression" - logging.debug(f"Running criterion - '{self.name()}'") + + assert isinstance(self.expr, Criterion), "Dependency is not a Criterion" + result = self.handle_criterion( - self.criterion, bind_params, base_data, observation_window + self.expr, bind_params, base_data, observation_window ) logging.debug(f"Storing results - '{self.name()}'") @@ -566,7 +563,7 @@ def store_result_in_db( return pi_pair_id = bind_params.get("pi_pair_id", None) - criterion_id = self.criterion.id if self.expr.is_Atom else None # type: ignore # when expr.is_Atom, criterion is not None + criterion_id = self.expr.id if self.expr.is_Atom else None # type: ignore # when expr.is_Atom, criterion is not None if self.expr.is_Atom: assert pi_pair_id is None, "pi_pair_id shall be None for criterion" @@ -654,7 +651,4 @@ def __repr__(self) -> str: """ Returns a string representation of the Task object. """ - if self.expr.is_Atom: - return f"Task(criterion={self.expr}, category={self.category})" - else: - return f"Task({self.expr}), category={self.category})" + return f"Task({self.expr}), category={self.category})" diff --git a/execution_engine/util/logic.py b/execution_engine/util/logic.py index d3c7b0c2..d84ed760 100644 --- a/execution_engine/util/logic.py +++ b/execution_engine/util/logic.py @@ -2,7 +2,6 @@ from datetime import time from typing import Any, Callable, Dict, Type, cast -from execution_engine.omop.criterion.abstract import Criterion from execution_engine.omop.criterion.combination.temporal import TimeIntervalType from execution_engine.omop.serializable import Serializable @@ -29,21 +28,6 @@ def _recreate(cls, args: Any, kwargs: dict) -> "Expr": """ return cast(Expr, cls(*args, **kwargs)) - def __new__(cls, *args: Any, **kwargs: Any) -> "BaseExpr": - """ - Initialize a new instance of the class. - """ - new_self = super().__new__(cls) - - # we must not allow the __init__ function because of possible infinite recursion when using the __new__ function - # (see https://pdarragh.github.io/blog/2017/05/22/oddities-in-pythons-new-method/) - if "__init__" in cls.__dict__: - raise AttributeError( - f"__init__ is not allowed in subclass {cls.__name__} of BaseExpr" - ) - - return new_self - @abstractmethod def __eq__(self, other: Any) -> bool: """ @@ -177,7 +161,15 @@ def __new__(cls, *args: Any) -> "Expr": :param args: Arguments for the expression. """ - self = cast(Expr, super().__new__(cls, *args)) + self = cast(Expr, super().__new__(cls)) + + # we must not allow the __init__ function because of possible infinite recursion when using the __new__ function + # (see https://pdarragh.github.io/blog/2017/05/22/oddities-in-pythons-new-method/) + if "__init__" in cls.__dict__: + raise AttributeError( + f"__init__ is not allowed in subclass {cls.__name__} of BaseExpr" + ) + self.args = args return self @@ -232,45 +224,15 @@ class Symbol(BaseExpr): Class representing a symbolic variable. """ - criterion: Criterion - - def __new__(cls, criterion: Criterion) -> "Symbol": + def __new__(cls, *args: Any, **kwargs: Any) -> "Symbol": """ Initialize a symbol. - - :param criterion: The criterion of the symbol. """ self = cast(Symbol, super().__new__(cls)) - self.criterion = criterion self.args = () return self - def __eq__(self, other: Any) -> bool: - """ - Check if this symbol is equal to another symbol. - - :param other: The other symbol. - :return: True if equal, False otherwise. - """ - return isinstance(other, Symbol) and self.criterion == other.criterion - - def __hash__(self) -> int: - """ - Get the hash of this symbol. - - :return: Hash of the symbol. - """ - return hash(self.criterion) - - def __repr__(self) -> str: - """ - Represent the symbol. - - :return: Name of the symbol. - """ - return self.criterion.description() - @property def is_Atom(self) -> bool: """ diff --git a/scripts/execute.py b/scripts/execute.py index 724fbb80..5be3b871 100644 --- a/scripts/execute.py +++ b/scripts/execute.py @@ -92,9 +92,9 @@ recommendation_package_version = "v1.5.2" urls = [ - "covid19-inpatient-therapy/recommendation/no-therapeutic-anticoagulation", - "sepsis/recommendation/ventilation-plan-ards-tidal-volume", - "covid19-inpatient-therapy/recommendation/ventilation-plan-ards-tidal-volume", + # "sepsis/recommendation/ventilation-plan-ards-tidal-volume", + # "covid19-inpatient-therapy/recommendation/no-therapeutic-anticoagulation", + # "covid19-inpatient-therapy/recommendation/ventilation-plan-ards-tidal-volume", "covid19-inpatient-therapy/recommendation/covid19-ventilation-plan-peep", "covid19-inpatient-therapy/recommendation/prophylactic-anticoagulation", "covid19-inpatient-therapy/recommendation/therapeutic-anticoagulation", From 71c79783cf2f7bf04c586d0223b4a01d3b81d8e5 Mon Sep 17 00:00:00 2001 From: Gregor Lichtner Date: Thu, 13 Mar 2025 21:06:35 +0100 Subject: [PATCH 07/16] refactor: serializable handling --- apps/graph/static/script.js | 6 +- apps/rest_api/app/routers/recommendation.py | 2 +- apps/viz-backend/app/main.py | 2 + .../converter/action/procedure.py | 4 +- .../converter/recommendation_factory.py | 5 +- .../converter/time_from_event/abstract.py | 8 +- execution_engine/execution_engine.py | 10 +- execution_engine/execution_graph/graph.py | 75 +- execution_engine/omop/cohort/graph_builder.py | 129 ++++ .../cohort/population_intervention_pair.py | 40 +- .../omop/cohort/recommendation.py | 123 +-- execution_engine/omop/concepts.py | 24 +- execution_engine/omop/criterion/abstract.py | 149 ++-- .../omop/criterion/combination/temporal.py | 21 - execution_engine/omop/criterion/concept.py | 50 +- .../omop/criterion/custom/tidal_volume.py | 4 +- .../omop/criterion/drug_exposure.py | 32 - execution_engine/omop/criterion/factory.py | 67 -- execution_engine/omop/criterion/meta.py | 150 ---- execution_engine/omop/criterion/noop.py | 41 + .../omop/criterion/point_in_time.py | 22 +- .../omop/criterion/procedure_occurrence.py | 45 -- .../omop/criterion/visit_detail.py | 21 +- .../omop/criterion/visit_occurrence.py | 15 - execution_engine/omop/serializable.py | 96 --- execution_engine/omop/vocabulary.py | 2 +- execution_engine/task/creator.py | 2 +- execution_engine/task/task.py | 102 ++- execution_engine/util/__init__.py | 13 + execution_engine/util/enum.py | 34 +- execution_engine/util/logic.py | 710 ++++++++---------- execution_engine/util/serializable.py | 534 +++++++++++++ execution_engine/util/temporal_logic_util.py | 156 ++++ execution_engine/util/types.py | 4 + execution_engine/util/value/time.py | 5 + execution_engine/util/value/value.py | 7 + .../execution_engine}/__init__.py | 0 tests/execution_engine/omop/__init__.py | 0 .../omop/criterion/__init__.py | 0 .../{test_cohort_logic.py => test_logic.py} | 2 +- 40 files changed, 1459 insertions(+), 1253 deletions(-) create mode 100644 execution_engine/omop/cohort/graph_builder.py delete mode 100644 execution_engine/omop/criterion/combination/temporal.py delete mode 100644 execution_engine/omop/criterion/factory.py delete mode 100644 execution_engine/omop/criterion/meta.py create mode 100644 execution_engine/omop/criterion/noop.py delete mode 100644 execution_engine/omop/serializable.py create mode 100644 execution_engine/util/serializable.py create mode 100644 execution_engine/util/temporal_logic_util.py rename {execution_engine/omop/criterion/combination => tests/execution_engine}/__init__.py (100%) create mode 100644 tests/execution_engine/omop/__init__.py create mode 100644 tests/execution_engine/omop/criterion/__init__.py rename tests/execution_engine/util/{test_cohort_logic.py => test_logic.py} (99%) diff --git a/apps/graph/static/script.js b/apps/graph/static/script.js index 7ce35453..a6995c98 100644 --- a/apps/graph/static/script.js +++ b/apps/graph/static/script.js @@ -129,15 +129,15 @@ async function loadGraph(recommendationId) { return nodeColors[ele.data('category')] || '#666'; // Assign color based on 'type', with a default }, 'shape': function(ele) { - return nodeShapes[ele.data('type')] || 'star'; // Assign color based on 'type', with a default + return nodeShapes[ele.data("is_atom") ? "Symbol" : ele.data('type')] || 'star'; // Assign color based on 'type', with a default }, 'text-valign': 'center', 'color': '#000000', 'width': function(ele) { - return ele.data('type') === 'Symbol' ? '120px': '40px'; + return ele.data('is_atom') ? '120px': '40px'; }, 'height': function(ele) { - return ele.data('type') === 'Symbol' ? '80px': '40px'; + return ele.data('is_atom') ? '80px': '40px'; }, 'font-size': '10px', 'text-wrap': 'wrap', diff --git a/apps/rest_api/app/routers/recommendation.py b/apps/rest_api/app/routers/recommendation.py index 25c7bb07..4aa48051 100644 --- a/apps/rest_api/app/routers/recommendation.py +++ b/apps/rest_api/app/routers/recommendation.py @@ -46,7 +46,7 @@ async def recommendation_criteria( data = [] - for c in recommendation.flatten(): + for c in recommendation.atoms(): data.append( { "description": c.description(), diff --git a/apps/viz-backend/app/main.py b/apps/viz-backend/app/main.py index 6b84ea4b..6c74641d 100644 --- a/apps/viz-backend/app/main.py +++ b/apps/viz-backend/app/main.py @@ -89,6 +89,8 @@ def get_execution_graph(recommendation_id: int, db: Session = Depends(get_db)) - if not result: raise HTTPException(status_code=404, detail="Recommendation not found") + print(result) + # Decode the bytes to a string and parse it as JSON execution_graph = json.loads(result.recommendation_execution_graph.decode("utf-8")) diff --git a/execution_engine/converter/action/procedure.py b/execution_engine/converter/action/procedure.py index 545bf700..bf392d2f 100644 --- a/execution_engine/converter/action/procedure.py +++ b/execution_engine/converter/action/procedure.py @@ -71,7 +71,7 @@ def _to_expression(self) -> logic.Symbol: # as Observation and Measurement normally expect a value. criterion = Measurement( concept=self._code, - override_value_required=False, + value_required=False, timing=self._timing, ) case "Observation": @@ -79,7 +79,7 @@ def _to_expression(self) -> logic.Symbol: # as Observation and Measurement normally expect a value. criterion = Observation( concept=self._code, - override_value_required=False, + value_required=False, timing=self._timing, ) case _: diff --git a/execution_engine/converter/recommendation_factory.py b/execution_engine/converter/recommendation_factory.py index e80a9ce1..b42b3c14 100644 --- a/execution_engine/converter/recommendation_factory.py +++ b/execution_engine/converter/recommendation_factory.py @@ -63,7 +63,7 @@ def parse_recommendation_from_url( pi_pairs: list[PopulationInterventionPairExpr] = [] base_criterion = PatientsActiveDuringPeriod() - + base_criterion.dict() for rec_plan in rec.plans(): # parse population and create criteria @@ -72,6 +72,9 @@ def parse_recommendation_from_url( # parse intervention and create criteria actions = parser.parse_actions(rec_plan.actions, rec_plan) + # population_expr is assigned a NoDataPreservingAnd to ensure creation of negative intervals + # todo: not sure we really need this - we can just always create negative intervals when store_results=True + # in the graph pi_pair = PopulationInterventionPairExpr( population_expr=population_criteria, intervention_expr=actions, diff --git a/execution_engine/converter/time_from_event/abstract.py b/execution_engine/converter/time_from_event/abstract.py index 5163b48f..fe0bfa15 100644 --- a/execution_engine/converter/time_from_event/abstract.py +++ b/execution_engine/converter/time_from_event/abstract.py @@ -55,9 +55,9 @@ def valid(cls, fhir: Element) -> bool: raise NotImplementedError("must be implemented by class") @abstractmethod - def to_temporal_combination(self, combo: logic.BaseExpr) -> logic.TemporalCount: + def to_temporal_combination(self, expr: logic.BaseExpr) -> logic.Expr: """ - Wraps Criterion/CriterionCombinaion with a TemporalIndicatorCombination + Wraps Criterion/CriterionCombination with a TemporalIndicatorCombination """ raise NotImplementedError("must be implemented by class") @@ -122,8 +122,8 @@ def valid(cls, fhir: Element) -> bool: return cls._event_vocabulary.is_system(cc.system) and cc.code == cls._event_code @abstractmethod - def to_temporal_combination(self, combo: logic.BaseExpr) -> logic.TemporalCount: + def to_temporal_combination(self, expr: logic.BaseExpr) -> logic.Expr: """ - Wraps Criterion/CriterionCombinaion with a TemporalIndicatorCombination + Wraps expression with a TemporalIndicatorCombination """ raise NotImplementedError("must be implemented by class") diff --git a/execution_engine/execution_engine.py b/execution_engine/execution_engine.py index 919c6dac..1432e8c8 100644 --- a/execution_engine/execution_engine.py +++ b/execution_engine/execution_engine.py @@ -17,8 +17,8 @@ from execution_engine.omop.cohort import PopulationInterventionPairExpr from execution_engine.omop.criterion.abstract import Criterion from execution_engine.omop.db.celida import tables as result_db -from execution_engine.omop.serializable import Serializable from execution_engine.task import runner +from execution_engine.util.serializable import Serializable class ExecutionEngine: @@ -208,7 +208,7 @@ def load_recommendation_from_database( pi_pair, rec_db.recommendation_id ) - for criterion in recommendation.flatten(): + for criterion in recommendation.atoms(): self.register_criterion(criterion) # All objects in the deserialized object graph must have an id. @@ -219,7 +219,7 @@ def load_recommendation_from_database( for pi_pair in recommendation.population_intervention_pairs(): assert pi_pair.id is not None - for criterion in recommendation.flatten(): + for criterion in recommendation.atoms(): assert criterion.id is not None return recommendation @@ -294,7 +294,7 @@ def register_recommendation(self, recommendation: cohort.Recommendation) -> None pi_pair, recommendation_id=recommendation.id ) - for criterion in recommendation.flatten(): + for criterion in recommendation.atoms(): self.register_criterion(criterion) assert recommendation.id is not None @@ -304,7 +304,7 @@ def register_recommendation(self, recommendation: cohort.Recommendation) -> None for pi_pair in recommendation.population_intervention_pairs(): assert pi_pair.id is not None - for criterion in recommendation.flatten(): + for criterion in recommendation.atoms(): assert criterion.id is not None # Update the recommendation in the database with the final diff --git a/execution_engine/execution_graph/graph.py b/execution_engine/execution_graph/graph.py index 8561b904..31dd668d 100644 --- a/execution_engine/execution_graph/graph.py +++ b/execution_engine/execution_graph/graph.py @@ -1,5 +1,4 @@ -import copy -from typing import Any +from typing import Any, cast import networkx as nx @@ -40,27 +39,6 @@ def is_sink(self, expr: logic.Expr) -> bool: """ return self.out_degree(expr) == 0 - @classmethod - def filter_symbols(cls, node: logic.Expr, filter_: logic.Expr) -> logic.Expr: - """ - Filter (=AND-combine) all symbols by the applied filter function - - Used to filter all intervention criteria (symbols) by the population output in order to exclude - all intervention events outside the population intervals, which may otherwise interfere with corrected - determination of temporal combination, i.e. the presence of an intervention event during some time window. - """ - - if isinstance(node, logic.Symbol): - return logic.LeftDependentToggle(left=filter_, right=node) - - if hasattr(node, "args") and isinstance(node.args, tuple): - converted_args = [cls.filter_symbols(a, filter_) for a in node.args] - - if any(a is not b for a, b in zip(node.args, converted_args)): - node.args = tuple(converted_args) - - return node - @classmethod def from_expression( cls, expr: logic.Expr, base_criterion: Criterion, category: CohortCategory @@ -71,9 +49,7 @@ def from_expression( from execution_engine.omop.cohort import PopulationInterventionPairExpr - # we might make changes to the expression (e.g. filtering), so we must preserve - # the original expression from the caller - expr = copy.deepcopy(expr) + expr_hash = hash(expr) graph = cls() base_node = base_criterion @@ -89,42 +65,33 @@ def traverse( parent: logic.Expr | None = None, category: CohortCategory = category, ) -> None: + + graph.add_node(expr, category=category, store_result=False) + + if parent is not None: + assert expr in graph.nodes + assert parent in graph.nodes + graph.add_edge(expr, parent) + if isinstance(expr, PopulationInterventionPairExpr): # special case for PopulationInterventionPairExpr: # we need explicitly set the category of the population and intervention nodes p, i = expr.left, expr.right - # filter all intervention criteria by the output of the population - this is performed to filter out - # intervention events that outside of the population intervals (i.e. the time windows during which - # patients are part of the population) as otherwise events outside of the population time may be picked up - # by Temporal criteria that determine the presence of some event or condition during a specific time window. - - # the following command changes expr, i.e. we must not add expr before this command to the graph - - i = cls.filter_symbols(i, filter_=p) - traverse(i, parent=expr, category=CohortCategory.INTERVENTION) traverse(p, parent=expr, category=CohortCategory.POPULATION) - graph.add_node(expr, category=category, store_result=False) - - if parent is not None: - graph.add_edge(expr, parent) - - subgraph = graph.subgraph(nx.ancestors(graph, expr) | {expr}) - - subgraph.set_sink_nodes_store(bind_params=dict(pi_pair_id=expr._id)) - - # children are already traversed - return - - graph.add_node(expr, category=category, store_result=False) + # create a subgraph for the pair in order to determine the sink nodes (i.e. the nodes that have no + # outgoing edges) for POPULATION and POPULATION_INTERVENTION and mark them for storing their result + # in the database + subgraph = cast( + ExecutionGraph, graph.subgraph(nx.ancestors(graph, expr) | {expr}) + ) + subgraph.set_sink_nodes_store(bind_params=dict(pi_pair_id=expr.id)) - if parent is not None: - graph.add_edge(expr, parent) - - if expr.is_Atom: + elif expr.is_Atom: + assert expr in graph.nodes graph.nodes[expr]["store_result"] = True graph.add_edge(base_node, expr) else: @@ -133,6 +100,9 @@ def traverse( traverse(expr, category=category) + if hash(expr) != expr_hash: + raise ValueError("Expression has been modified during traversal") + return graph def add_node( @@ -219,6 +189,7 @@ def to_cytoscape_dict(self) -> dict: self.nodes[node]["store_result"] ), # Convert to string if necessary "is_sink": self.is_sink(node), + "is_atom": node.is_Atom, "bind_params": self.nodes[node]["bind_params"], } } diff --git a/execution_engine/omop/cohort/graph_builder.py b/execution_engine/omop/cohort/graph_builder.py new file mode 100644 index 00000000..bf4d92ae --- /dev/null +++ b/execution_engine/omop/cohort/graph_builder.py @@ -0,0 +1,129 @@ +import copy + +import execution_engine.util.logic as logic +from execution_engine.constants import CohortCategory +from execution_engine.execution_graph import ExecutionGraph +from execution_engine.omop.cohort.population_intervention_pair import ( + PopulationInterventionPairExpr, +) +from execution_engine.omop.criterion.abstract import Criterion + + +class RecommendationGraphBuilder: + """ + A builder class for constructing ExecutionGraph objects based on + population/intervention expressions. It provides utility methods to filter + intervention criteria by population constraints, and then converts the + filtered expression into an ExecutionGraph ready for execution and storage. + """ + + @classmethod + def filter_symbols(cls, node: logic.Expr, filter_: logic.Expr) -> logic.Expr: + """ + Filter (=AND-combine) all symbols by the applied filter function + + Used to filter all intervention criteria (symbols) by the population output in order to exclude + all intervention events outside the population intervals, which may otherwise interfere with corrected + determination of temporal combination, i.e. the presence of an intervention event during some time window. + + :param node: The expression node to be filtered. + :type node: logic.Expr + :param filter_: The filter expression to AND-combine with symbols in the node. + :type filter_: logic.Expr + :return: A new expression in which all symbols are constrained by the filter expression. + :rtype: logic.Expr + """ + + if isinstance(node, logic.Symbol): + return logic.LeftDependentToggle(left=filter_, right=node) + elif isinstance(node, logic.Expr): + converted_args = [cls.filter_symbols(a, filter_) for a in node.args] + + if any(a is not b for a, b in zip(node.args, converted_args)): + node.update_args(*converted_args) + + return node + + @classmethod + def filter_intervention_criteria_by_population(cls, expr: logic.Expr) -> logic.Expr: + """ + Filter all intervention criteria in a given expression by the population part of the expression. + + :param expr: The expression that may contain population and intervention parts. + :type expr: logic.Expr + :return: A new expression where all intervention symbols are constrained by the population intervals. + :rtype: logic.Expr + """ + + from execution_engine.omop.cohort import PopulationInterventionPairExpr + + # we might make changes to the expression (e.g. filtering), so we must preserve + # the original expression from the caller + expr = copy.deepcopy(expr) + + def traverse( + expr: logic.Expr, + ) -> None: + if isinstance(expr, PopulationInterventionPairExpr): + p, i = expr.left, expr.right + + # filter all intervention criteria by the output of the population - this is performed to filter out + # intervention events that outside of the population intervals (i.e. the time windows during which + # patients are part of the population) as otherwise events outside of the population time may be picked up + # by Temporal criteria that determine the presence of some event or condition during a specific time window. + i = cls.filter_symbols(i, filter_=p) + + expr.update_args(p, i) + + traverse(i) + traverse(p) + + elif not expr.is_Atom: + for child in expr.args: + traverse(child) + + traverse(expr) + + return expr + + @classmethod + def build(cls, expr: logic.Expr, base_criterion: Criterion) -> ExecutionGraph: + """ + Build an ExecutionGraph for a population/intervention expression. + + If the expression is a PopulationInterventionPairExpr, it is wrapped in a + NonSimplifiableAnd to ensure a top-level result entry is generated in the database. + Then the expression is filtered and converted into an ExecutionGraph with the + appropriate sink nodes and edges. + + :param expr: The population/intervention expression to build the graph from. + :type expr: logic.Expr + :param base_criterion: The base criterion used to label the execution graph. + :type base_criterion: Criterion + :return: The constructed ExecutionGraph for the given expression. + :rtype: ExecutionGraph + """ + if isinstance(expr, PopulationInterventionPairExpr): + expr = logic.NonSimplifiableAnd(expr) + + # Make sure the expr is filtered + expr_filtered = cls.filter_intervention_criteria_by_population(expr) + + graph = ExecutionGraph.from_expression( + expr_filtered, + base_criterion=base_criterion, + category=CohortCategory.POPULATION_INTERVENTION, + ) + + p_sink_nodes = graph.sink_nodes(CohortCategory.POPULATION) + graph.set_sink_nodes_store( + bind_params={}, desired_category=CohortCategory.POPULATION_INTERVENTION + ) + + p_combination_node = logic.NoDataPreservingOr(*p_sink_nodes) + graph.add_node( + p_combination_node, store_result=True, category=CohortCategory.POPULATION + ) + graph.add_edges_from((src, p_combination_node) for src in p_sink_nodes) + + return graph diff --git a/execution_engine/omop/cohort/population_intervention_pair.py b/execution_engine/omop/cohort/population_intervention_pair.py index bdcbb535..ad2095bb 100644 --- a/execution_engine/omop/cohort/population_intervention_pair.py +++ b/execution_engine/omop/cohort/population_intervention_pair.py @@ -2,10 +2,9 @@ import execution_engine.util.logic as logic from execution_engine.omop.criterion.abstract import Criterion -from execution_engine.omop.serializable import Serializable -class PopulationInterventionPairExpr(logic.LeftDependentToggle, Serializable): +class PopulationInterventionPairExpr(logic.LeftDependentToggle): """ A logical expression that ties together a population expression and an intervention expression, plus any extra info like name, url, base_criterion, etc. @@ -71,29 +70,20 @@ def base_criterion(self) -> Criterion: """ return self._base_criterion - def __repr__(self) -> str: + def dict(self, include_id: bool = False) -> dict: """ - Return a string representation of the object. + Get a dictionary representation of the object. """ - return ( - f"{self.__class__.__name__}(" - f"name={self.name!r}, " - f"url={self.url!r}, " - f"base_criterion={self.base_criterion!r}, " - f"left={self.left!r}, right={self.right!r}, " + data = super().dict(include_id=include_id) + del data["data"]["left"] + del data["data"]["right"] + data["data"].update( + { + "population_expr": self.left.dict(include_id=include_id), + "intervention_expr": self.right.dict(include_id=include_id), + "name": self.name, + "url": self.url, + "base_criterion": self.base_criterion.dict(include_id=include_id), + } ) - - def get_instance_variables(self, immutable: bool = False) -> dict | tuple: - """ - Get the instance variables of the object. - """ - # Include the base class instance variables plus the ones we added: - base_vars = cast(dict, super().get_instance_variables(immutable=False)) - base_vars["name"] = self.name - base_vars["url"] = self.url - base_vars["base_criterion"] = self.base_criterion - - # Convert to tuple if needed - if immutable: - return tuple(sorted(base_vars.items())) - return base_vars + return data diff --git a/execution_engine/omop/cohort/recommendation.py b/execution_engine/omop/cohort/recommendation.py index 6509fdad..ee757143 100644 --- a/execution_engine/omop/cohort/recommendation.py +++ b/execution_engine/omop/cohort/recommendation.py @@ -1,7 +1,6 @@ import re -from typing import Any, Dict, Iterator, Self, cast +from typing import Iterator -import networkx as nx from sqlalchemy import ( Column, Date, @@ -17,16 +16,16 @@ import execution_engine.util.logic as logic from execution_engine.constants import CohortCategory from execution_engine.execution_graph import ExecutionGraph +from execution_engine.omop.cohort.graph_builder import RecommendationGraphBuilder from execution_engine.omop.cohort.population_intervention_pair import ( PopulationInterventionPairExpr, ) from execution_engine.omop.criterion.abstract import Criterion -from execution_engine.omop.criterion.factory import criterion_factory from execution_engine.omop.db.celida.tables import ResultInterval -from execution_engine.omop.serializable import Serializable +from execution_engine.util.serializable import SerializableDataClass -class Recommendation(Serializable): +class Recommendation(SerializableDataClass): """ A recommendation OMOP as a collection of separate population/intervention pairs. @@ -61,15 +60,9 @@ def __repr__(self) -> str: """ Get the string representation of the recommendation. """ - pi_repr = "\n".join( - [(" " + line) for line in repr(self._expr).split("\n")] - ).strip() - pi_repr = ( - pi_repr[0] + "\n " + pi_repr[1:-2] + pi_repr[-2] + "\n " + pi_repr[-1] - ) return ( f"{self.__class__.__name__}(\n" - f" pi_pairs={pi_repr},\n" + f" expr={repr(self._expr)},\n" f" base_criterion={repr(self._base_criterion)},\n" f" name={repr(self._name)},\n" f" title={repr(self._title)},\n" @@ -137,68 +130,14 @@ def execution_graph(self) -> ExecutionGraph: execution maps of the individual population/intervention pairs of the recommendation. """ - # p_sink_nodes = [] - # pi_sink_nodes = [] - - common_graph = ExecutionGraph.from_expression( - self._expr, - base_criterion=self._base_criterion, - category=CohortCategory.POPULATION_INTERVENTION, - ) - - # for pi_pair in self.population_intervention_pairs(): - # subgraph: ExecutionGraph = cast(ExecutionGraph, common_graph.subgraph(nx.ancestors(common_graph, pi_pair) | {pi_pair})) - # p_sink_nodes.append(subgraph.sink_node(CohortCategory.POPULATION)) - # pi_sink_nodes.append(subgraph.sink_node(CohortCategory.POPULATION_INTERVENTION)) - - p_sink_nodes = common_graph.sink_nodes(CohortCategory.POPULATION) - - p_combination_node = logic.NoDataPreservingOr( - *common_graph.sink_nodes(CohortCategory.POPULATION) - ) - # pi_combination_node = logic.NoDataPreservingAnd( - # *pi_sink_nodes, - # ) - - common_graph.add_node( - p_combination_node, store_result=True, category=CohortCategory.POPULATION - ) - - # common_graph.add_node( - # pi_combination_node, - # store_result=True, - # category=CohortCategory.POPULATION_INTERVENTION, - # ) + return RecommendationGraphBuilder.build(self._expr, self._base_criterion) - common_graph.add_edges_from((src, p_combination_node) for src in p_sink_nodes) - # common_graph.add_edges_from((src, pi_combination_node) for src in pi_sink_nodes) - - import json - - with open("/home/glichtner/cyto.json", "w") as f: - json.dump({"elements": common_graph.to_cytoscape_dict()}, f, indent=4) - - if not nx.is_directed_acyclic_graph(common_graph): - raise ValueError("The recommendation execution graph is not a DAG.") - - return common_graph - - def flatten(self) -> list[Criterion]: + def atoms(self) -> Iterator[Criterion]: """ Retrieve all criteria in a flat list """ - - def traverse(expr: logic.BaseExpr) -> list[Criterion]: - if expr.is_Atom: - assert isinstance(expr, Criterion), f"Expected Symbol, got {expr}" - return [expr] - - gathered = [] - for sub_expr in expr.args: - gathered.extend(traverse(sub_expr)) - return gathered - - return [self._base_criterion] + traverse(self._expr) + yield self._base_criterion + yield from self._expr.atoms() def population_intervention_pairs(self) -> Iterator[PopulationInterventionPairExpr]: """ @@ -276,47 +215,5 @@ def reset_state(self) -> None: for pi_pair in self.population_intervention_pairs(): pi_pair._id = None - for criterion in self.flatten(): + for criterion in self.atoms(): criterion._id = None - - def dict(self) -> dict: - """ - Get the combination as a dictionary. - """ - base_criterion = self._base_criterion - return { - "expr": self._expr.dict(), - "base_criterion": { - "class_name": base_criterion.__class__.__name__, - "data": base_criterion.dict(), - }, - "recommendation_name": self._name, - "recommendation_title": self._title, - "recommendation_url": self._url, - "recommendation_version": self._version, - "recommendation_package_version": self._package_version, - "recommendation_description": self._description, - } - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> Self: - """ - Create a combination from a dictionary. - """ - base_criterion = criterion_factory(**data["base_criterion"]) - assert isinstance( - base_criterion, Criterion - ), "Base criterion must be a Criterion" - - return cls( - expr=cast( - logic.BooleanFunction, logic.BooleanFunction.from_dict(data["expr"]) - ), - base_criterion=base_criterion, - name=data["recommendation_name"], - title=data["recommendation_title"], - url=data["recommendation_url"], - version=data["recommendation_version"], - description=data["recommendation_description"], - package_version=data["recommendation_package_version"], - ) diff --git a/execution_engine/omop/concepts.py b/execution_engine/omop/concepts.py index 57c9ec37..96d135bd 100644 --- a/execution_engine/omop/concepts.py +++ b/execution_engine/omop/concepts.py @@ -1,7 +1,10 @@ import pandas as pd from pydantic import BaseModel +from execution_engine.util import serializable + +@serializable.register_class class Concept(BaseModel, frozen=True): # type: ignore """Represents an OMOP Standard Vocabulary concept.""" @@ -35,23 +38,14 @@ def is_custom(self) -> bool: return self.concept_id < 0 -class CustomConcept(Concept): +@serializable.register_class +class CustomConcept(Concept, frozen=True): """Represents a custom concept.""" - def __init__( - self, name: str, concept_code: str, domain_id: str, vocabulary_id: str - ) -> None: - """Creates a custom concept.""" - super().__init__( - concept_id=-1, - concept_name=name, - concept_code=concept_code, - domain_id=domain_id, - vocabulary_id=vocabulary_id, - concept_class_id="Custom", - standard_concept=None, - invalid_reason=None, - ) + concept_id: int = -1 + concept_class_id: str = "Custom" + standard_concept: str | None = None + invalid_reason: str | None = None @property def id(self) -> int: diff --git a/execution_engine/omop/criterion/abstract.py b/execution_engine/omop/criterion/abstract.py index 94ea2957..5b9a42d1 100644 --- a/execution_engine/omop/criterion/abstract.py +++ b/execution_engine/omop/criterion/abstract.py @@ -1,6 +1,5 @@ -import copy -from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, Self, Type, TypedDict, cast +from abc import abstractmethod +from typing import Type, TypedDict, cast import sqlalchemy from sqlalchemy import CTE, Alias, ColumnElement, Date, Integer @@ -11,7 +10,6 @@ from execution_engine.constants import CohortCategory from execution_engine.omop.concepts import Concept -from execution_engine.omop.criterion.meta import SignatureReprABCMeta from execution_engine.omop.db.base import DateTimeWithTimeZone from execution_engine.omop.db.celida.tables import IntervalTypeEnum, ResultInterval from execution_engine.omop.db.omop.tables import ( @@ -25,13 +23,13 @@ VisitDetail, VisitOccurrence, ) -from execution_engine.omop.serializable import Serializable from execution_engine.util import logic from execution_engine.util.interval import IntervalType +from execution_engine.util.serializable import SerializableDataClassABC from execution_engine.util.sql import SelectInto, select_into from execution_engine.util.types import PersonIntervals, TimeRange -__all__ = ["AbstractCriterion", "Criterion"] +__all__ = ["Criterion"] Domain = TypedDict( "Domain", @@ -90,90 +88,7 @@ def create_conditional_interval_column(condition: ColumnElement) -> ColumnElemen ) -class AbstractCriterion( - Serializable, logic.Symbol, ABC, metaclass=SignatureReprABCMeta -): - """ - Abstract base class for Criterion and CriterionCombination. - """ - - _base: bool = False - - def is_base(self) -> bool: - """ - Check if this criterion is the base criterion. - """ - return self._base - - def set_base(self, value: bool = True) -> None: - """ - Set the base criterion. - """ - self._base = value - - @property - def type(self) -> str: - """ - Get the type of the criterion. - """ - return self.__class__.__name__ - - def copy(self) -> "AbstractCriterion": - """ - Copy the criterion. - """ - return copy.copy(self) - - @abstractmethod - def description(self) -> str: - """ - Return a description of the criterion. - """ - raise NotImplementedError() - - def __str__(self) -> str: - """ - Get the name of the criterion. - """ - return self.description() - - def get_instance_variables(self, immutable: bool = False) -> dict | tuple: - """ - Get the instance variables of the criterion. - """ - - if immutable: - return tuple( - sorted(self._init_args.items()) # type: ignore[attr-defined] # this is set due to metaclass - ) - return self._init_args # type: ignore[attr-defined] # this is set due to metaclass - - @classmethod - def _recreate( # type: ignore[override] - cls, kwargs: Any, id: int | None - ) -> "AbstractCriterion": - """ - Recreate an expression from its arguments. - """ - expr = cast(AbstractCriterion, cls(**kwargs)) - - if id is not None: - expr.set_id(id) - - return expr - - def __reduce__(self) -> tuple[Callable, tuple]: - """ - Reduce the expression to its arguments. - - Required for pickling (e.g. when using multiprocessing). - - :return: Tuple of the class, and arguments. - """ - return self._recreate, (self.get_instance_variables(), self._id) - - -class Criterion(AbstractCriterion): +class Criterion(SerializableDataClassABC, logic.Symbol): """A criterion in a recommendation.""" _OMOP_TABLE: Type[Base] @@ -202,6 +117,11 @@ class Criterion(AbstractCriterion): during the recommendation period). """ + _base: bool = False + """ + Specifies whether this criterion is the base criterion (i.e. the criterion that selects the initial cohort). + """ + DOMAINS: dict[str, Domain] = { "condition": { "table": ConditionOccurrence, @@ -247,6 +167,34 @@ class Criterion(AbstractCriterion): Flag to indicate whether the filter_datetime function has been called. """ + def __init__(self) -> None: + super().__init__() + + def is_base(self) -> bool: + """ + Check if this criterion is the base criterion. + """ + return self._base + + def set_base(self, value: bool = True) -> None: + """ + Set the base criterion. + """ + self._base = value + + @abstractmethod + def description(self) -> str: + """ + Return a description of the criterion. + """ + raise NotImplementedError() + + def __str__(self) -> str: + """ + Get the name of the criterion. + """ + return self.description() + def _set_omop_variables_from_domain(self, domain_id: str) -> None: """ Set the OMOP table and column prefix based on the domain ID. @@ -264,6 +212,19 @@ def _set_omop_variables_from_domain(self, domain_id: str) -> None: self._static = cast(bool, domain["static"]) self._table = cast(Base, domain["table"]).__table__.alias(self.table_alias) + # @property + # def type(self) -> str: + # """ + # Get the type of the criterion. + # """ + # return self.__class__.__name__ + # + # def copy(self) -> "AbstractCriterion": + # """ + # Copy the criterion. + # """ + # return copy.copy(self) + @property def table_alias(self) -> str: """ @@ -652,14 +613,6 @@ def sql_insert_into_result_table(query: Select) -> SelectInto: return query - @classmethod - @abstractmethod - def from_dict(cls, data: Dict[str, Any]) -> Self: - """ - Create a criterion from a JSON object. - """ - raise NotImplementedError() - def cte_interval_starts( self, query: Select, diff --git a/execution_engine/omop/criterion/combination/temporal.py b/execution_engine/omop/criterion/combination/temporal.py deleted file mode 100644 index 9988ba1a..00000000 --- a/execution_engine/omop/criterion/combination/temporal.py +++ /dev/null @@ -1,21 +0,0 @@ -from enum import StrEnum - - -class TimeIntervalType(StrEnum): - """ - Types of time intervals to aggregate criteria over. - """ - - MORNING_SHIFT = "morning_shift" - AFTERNOON_SHIFT = "afternoon_shift" - NIGHT_SHIFT = "night_shift" - NIGHT_SHIFT_BEFORE_MIDNIGHT = "night_shift_before_midnight" - NIGHT_SHIFT_AFTER_MIDNIGHT = "night_shift_after_midnight" - DAY = "day" - ANY_TIME = "any_time" - - def __repr__(self) -> str: - """ - Get the string representation of the time interval type. - """ - return f"{self.__class__.__name__}.{self.name}" diff --git a/execution_engine/omop/criterion/concept.py b/execution_engine/omop/criterion/concept.py index c8ec3faf..d5d7b39e 100644 --- a/execution_engine/omop/criterion/concept.py +++ b/execution_engine/omop/criterion/concept.py @@ -1,14 +1,10 @@ -from typing import Any, Dict, cast - from sqlalchemy.sql import Select from execution_engine.constants import OMOPConcepts from execution_engine.omop.concepts import Concept from execution_engine.omop.criterion.abstract import Criterion -from execution_engine.omop.criterion.meta import SignatureReprMeta from execution_engine.util.types import Timing from execution_engine.util.value import Value -from execution_engine.util.value.factory import value_factory __all__ = ["ConceptCriterion"] @@ -25,7 +21,7 @@ # TODO: Only use weight etc from the current encounter/visit! -class ConceptCriterion(Criterion, metaclass=SignatureReprMeta): +class ConceptCriterion(Criterion): """ Abstract class for a criterion based on an OMOP concept and optional value. @@ -45,7 +41,7 @@ def __init__( value: Value | None = None, static: bool | None = None, timing: Timing | None = None, - override_value_required: bool | None = None, + value_required: bool | None = None, ): super().__init__() @@ -67,10 +63,8 @@ def __init__( # if static is None, then the criterion is static if the concept is in the STATIC_CLINICAL_CONCEPTS list self._static = concept.concept_id in STATIC_CLINICAL_CONCEPTS - if override_value_required is not None and isinstance( - override_value_required, bool - ): - self._value_required = override_value_required + if value_required is not None and isinstance(value_required, bool): + self._value_required = value_required @property def concept(self) -> Concept: @@ -132,39 +126,3 @@ def description(self) -> str: desc += "]" return desc - - def dict(self) -> dict[str, Any]: - """ - Get a JSON representation of the criterion. - """ - return { - "concept": self._concept.model_dump(), - "value": ( - self._value.model_dump(include_meta=True) - if self._value is not None - else None - ), - "static": self._static, - "timing": ( - self._timing.model_dump(include_meta=True) - if self._timing is not None - else None - ), - } - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "ConceptCriterion": - """ - Create a criterion from a JSON representation. - """ - - return cls( - concept=Concept(**data["concept"]), - value=( - cast(Value, value_factory(**data["value"])) - if data["value"] is not None - else None - ), - static=data["static"], - timing=Timing(**data["timing"]) if data["timing"] is not None else None, - ) diff --git a/execution_engine/omop/criterion/custom/tidal_volume.py b/execution_engine/omop/criterion/custom/tidal_volume.py index d0ef3801..39480698 100644 --- a/execution_engine/omop/criterion/custom/tidal_volume.py +++ b/execution_engine/omop/criterion/custom/tidal_volume.py @@ -54,7 +54,7 @@ def __init__( value: Value | None = None, static: bool | None = None, timing: Timing | None = None, - override_value_required: bool | None = None, + value_required: bool | None = None, forward_fill: bool = True, ): super().__init__( @@ -62,7 +62,7 @@ def __init__( value=value, static=static, timing=timing, - override_value_required=override_value_required, + value_required=value_required, ) self._table = self._cte() diff --git a/execution_engine/omop/criterion/drug_exposure.py b/execution_engine/omop/criterion/drug_exposure.py index fd08b032..90c1668c 100644 --- a/execution_engine/omop/criterion/drug_exposure.py +++ b/execution_engine/omop/criterion/drug_exposure.py @@ -1,5 +1,4 @@ import logging -from typing import Any, Dict from sqlalchemy import Column, and_, case, func, select, true from sqlalchemy.sql import Select @@ -17,7 +16,6 @@ from execution_engine.util.interval import IntervalType from execution_engine.util.sql import SelectInto from execution_engine.util.types import Dosage -from execution_engine.util.value.factory import value_factory __all__ = ["DrugExposure"] @@ -349,33 +347,3 @@ def description(self) -> str: parts.append(f"route={route.concept_name}") return f"{self.__class__.__name__}[" + ", ".join(parts) + "]" - - def dict(self) -> dict[str, Any]: - """ - Return a dictionary representation of the criterion. - """ - return { - "ingredient_concept": self._ingredient_concept.model_dump(), - "dose": ( - self._dose.model_dump(include_meta=True) - if self._dose is not None - else None - ), - "route": self._route.model_dump() if self._route is not None else None, - } - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "DrugExposure": - """ - Create a drug exposure criterion from a dictionary representation. - """ - - dose = value_factory(**data["dose"]) if data["dose"] is not None else None - - assert dose is None or isinstance(dose, Dosage), "Dose must be a Dosage or None" - - return cls( - ingredient_concept=Concept(**data["ingredient_concept"]), - dose=dose, - route=Concept(**data["route"]) if data["route"] is not None else None, - ) diff --git a/execution_engine/omop/criterion/factory.py b/execution_engine/omop/criterion/factory.py deleted file mode 100644 index f6f6bbfd..00000000 --- a/execution_engine/omop/criterion/factory.py +++ /dev/null @@ -1,67 +0,0 @@ -from typing import Type - -from execution_engine.omop.criterion.condition_occurrence import ConditionOccurrence -from execution_engine.omop.criterion.custom import TidalVolumePerIdealBodyWeight -from execution_engine.omop.criterion.drug_exposure import DrugExposure -from execution_engine.omop.criterion.measurement import Measurement -from execution_engine.omop.criterion.observation import Observation -from execution_engine.omop.criterion.point_in_time import PointInTimeCriterion -from execution_engine.omop.criterion.procedure_occurrence import ProcedureOccurrence -from execution_engine.omop.criterion.visit_detail import VisitDetail -from execution_engine.omop.criterion.visit_occurrence import ( - ActivePatients, - PatientsActiveDuringPeriod, - VisitOccurrence, -) - -__all__ = ["criterion_factory", "register_criterion_class"] - -from execution_engine.util import logic - -class_map: dict[str, Type[logic.BaseExpr]] = { - # "ConceptCriterion": ConceptCriterion, # is an Abstract class - "LogicalCriterionCombination": logic.BooleanFunction, - "TemporalCriterionCombination": logic.BooleanFunction, - "NonCommutativeLogicalCriterionCombination": logic.BooleanFunction, - "ConditionOccurrence": ConditionOccurrence, - "DrugExposure": DrugExposure, - "Measurement": Measurement, - "Observation": Observation, - "ProcedureOccurrence": ProcedureOccurrence, - "VisitOccurrence": VisitOccurrence, - "ActivePatients": ActivePatients, - "PatientsActiveDuringPeriod": PatientsActiveDuringPeriod, - "TidalVolumePerIdealBodyWeight": TidalVolumePerIdealBodyWeight, - "VisitDetail": VisitDetail, - "PointInTimeCriterion": PointInTimeCriterion, - "PersonalWindowTemporalIndicatorCombination": logic.BooleanFunction, -} - - -def register_criterion_class( - class_name: str, - criterion_class: Type[logic.BaseExpr], -) -> None: - """ - Register a criterion class. - - :param class_name: The name of the criterion class. - :param criterion_class: The criterion class. - """ - class_map[class_name] = criterion_class - - -def criterion_factory(class_name: str, data: dict) -> logic.BaseExpr: - """ - Create a criterion from a dictionary representation. - - :param class_name: The name of the criterion class. - :param data: The dictionary representation of the criterion. - :return: The criterion. - :raises ValueError: If the class name is not recognized. - """ - - if class_name not in class_map: - raise ValueError(f"Unknown criterion class {class_name}") - - return class_map[class_name].from_dict(data) diff --git a/execution_engine/omop/criterion/meta.py b/execution_engine/omop/criterion/meta.py deleted file mode 100644 index 9012e45f..00000000 --- a/execution_engine/omop/criterion/meta.py +++ /dev/null @@ -1,150 +0,0 @@ -import abc -import inspect -from typing import Any, Type, TypeVar - -T = TypeVar("T", bound="SignatureReprMeta") - - -class SignatureReprMeta(type): - """ - A metaclass that automatically captures constructor arguments and generates a __repr__ method. - - This metaclass wraps the `__init__` method of a class to store the arguments passed at instantiation, - allowing `__repr__` to dynamically generate an informative string representation, omitting default values. - - Features: - - Captures constructor arguments at instantiation time (`self._init_args`). - - Automatically generates a `__repr__` if the class does not define one. - - Ensures that `__repr__` only displays arguments that differ from the default values. - - Prevents redundant re-capturing when a parent class’s `__init__` is invoked via `super()`. - - Retains the original `__init__` function signature for accurate introspection. - - Usage: - ```python - class MyClass(metaclass=SignatureReprMeta): - def __init__(self, x, y=10, z=None): - self.x = x - self.y = y - self.z = z - - obj = MyClass(5, z="test") - print(obj) # Output: MyClass(x=5, z='test') (y is omitted since it uses the default 10) - ``` - - """ - - def __new__( - mcs: Type[T], name: str, bases: tuple[type, ...], namespace: dict[str, Any] - ) -> T: - """ - Wrap the __init__ method and attach a default __repr__ if not defined. - """ - original_init = namespace.get("__init__") - - if not original_init: - # Try to find an __init__ in the bases - for base in reversed(bases): - if base.__init__ is not object.__init__: - original_init = base.__init__ - break - - # We'll define a new __init__ only if there's an actual one to wrap - if original_init: - - def __init__(self: object, *args: Any, **kwargs: Any) -> None: - if type(self) is cls: - assert original_init is not None, "No valid __init__ found!" - sig = inspect.signature(original_init) - bound = sig.bind(self, *args, **kwargs) - bound.apply_defaults() - all_args = dict(bound.arguments) - all_args.pop("self", None) - self._init_args = all_args # type: ignore[attr-defined] - - original_init(self, *args, **kwargs) - - # The rest is basically the same - # But we must create the class first so we can set new_init.__signature__ = ... - cls = super().__new__(mcs, name, bases, namespace) - - # If we actually did define new_init, attach it - if original_init: - # Replace/attach the new __init__ to cls - setattr(cls, "__init__", __init__) - - # Manually override the function's signature - init_sig = inspect.signature(original_init) - __init__.__signature__ = init_sig # type: ignore[attr-defined] - - original_repr = namespace.get("__repr__") - - if not original_repr: - # Try to find an __repr__ in the bases - for base in reversed(bases): - if base.__repr__ is not object.__repr__: - original_repr = base.__repr__ - break - - # If no user-defined __repr__, attach a default - if not original_repr or getattr( - original_repr, "__signature_repr_generated__", False - ): - - def __repr__(self: object) -> str: - # Case 1: No real __init__ found at all => just return ClassName() - if original_init is None: - return f"{name}()" - - # Case 2: If the class or its parents do define an __init__, we check - # for _init_args. If the class isn't wrapping init, your class would - # have to set _init_args itself (or you'll just get a normal object repr). - if not hasattr(self, "_init_args"): - return super(type(self), self).__repr__() - - # Build param=value only if they differ from default - sig = inspect.signature(original_init) - parts = [] - for param_name, param in sig.parameters.items(): - if param_name == "self": - continue - default = param.default - if ( - param_name in self._init_args - and self._init_args[param_name] != default - ): - parts.append( - f"{param_name}={repr(self._init_args[param_name])}" - ) - return f"{name}({', '.join(parts)})" - - # Tag it so children know it's auto-generated - __repr__.__signature_repr_generated__ = True # type: ignore[attr-defined] - - setattr(cls, "__repr__", __repr__) - - return cls - - -class SignatureReprABCMeta(SignatureReprMeta, abc.ABCMeta): - """ - A metaclass combining `SignatureReprMeta` and `ABCMeta`. - - This metaclass extends `SignatureReprMeta`, allowing abstract base classes (`ABC`) to inherit - automatic argument capturing and dynamic `__repr__` generation. - - Usage: - ```python - class AbstractExample(metaclass=SignatureReprABCMeta): - @abc.abstractmethod - def some_method(self): - pass - - class ConcreteExample(AbstractExample): - def __init__(self, value, flag=True): - self.value = value - self.flag = flag - - obj = ConcreteExample(42) - print(obj) # Output: ConcreteExample(value=42) (flag is omitted since it uses the default True) - ``` - """ diff --git a/execution_engine/omop/criterion/noop.py b/execution_engine/omop/criterion/noop.py new file mode 100644 index 00000000..2b01ed8c --- /dev/null +++ b/execution_engine/omop/criterion/noop.py @@ -0,0 +1,41 @@ +from sqlalchemy import Select, select + +from execution_engine.omop.criterion.abstract import ( + Criterion, + column_interval_type, + observation_end_datetime, + observation_start_datetime, +) +from execution_engine.util.interval import IntervalType + + +class NoopCriterion(Criterion): + """ + Select patients who are post-surgical in the timeframe between the day of the surgery and 6 days after the surgery. + """ + + _static = True + + def _create_query(self) -> Select: + """ + Get the SQL Select query for data required by this criterion. + """ + subquery = self.base_query().subquery() + + query = select( + subquery.c.person_id, + column_interval_type(IntervalType.POSITIVE), + observation_start_datetime.label("interval_start"), + observation_end_datetime.label("interval_end"), + ) + + query = self._filter_base_persons(query, c_person_id=subquery.c.person_id) + query = self._filter_datetime(query) + + return query + + def description(self) -> str: + """ + Get a description of the criterion. + """ + return self.__class__.__name__ diff --git a/execution_engine/omop/criterion/point_in_time.py b/execution_engine/omop/criterion/point_in_time.py index 2fdb6a0d..d79342a0 100644 --- a/execution_engine/omop/criterion/point_in_time.py +++ b/execution_engine/omop/criterion/point_in_time.py @@ -1,5 +1,3 @@ -from typing import Any, Dict, cast - from sqlalchemy import CTE, ColumnElement, Select, select from execution_engine.omop.concepts import Concept @@ -27,7 +25,7 @@ def __init__( value: Value | None = None, static: bool | None = None, timing: Timing | None = None, - override_value_required: bool | None = None, + value_required: bool | None = None, forward_fill: bool = True, ): super().__init__( @@ -35,7 +33,7 @@ def __init__( value=value, static=static, timing=timing, - override_value_required=override_value_required, + value_required=value_required, ) self._forward_fill = forward_fill @@ -103,22 +101,6 @@ def _create_query(self) -> Select: return query - def dict(self) -> dict[str, Any]: - """ - Get a JSON representation of the criterion. - """ - from_super = super().dict() - return from_super | {"forward_fill": self._forward_fill} - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "PointInTimeCriterion": - """ - Create a criterion from a JSON representation. - """ - object = cast("PointInTimeCriterion", super().from_dict(data)) - object._forward_fill = data.get("forward_fill", True) # Backward compat - return object - def process_data( self, data: PersonIntervals, diff --git a/execution_engine/omop/criterion/procedure_occurrence.py b/execution_engine/omop/criterion/procedure_occurrence.py index 0084afe5..c2f853f9 100644 --- a/execution_engine/omop/criterion/procedure_occurrence.py +++ b/execution_engine/omop/criterion/procedure_occurrence.py @@ -1,5 +1,3 @@ -from typing import Any, Dict, cast - from sqlalchemy import ColumnElement, case, func, select from sqlalchemy.sql import Select @@ -13,7 +11,6 @@ from execution_engine.util.interval import IntervalType from execution_engine.util.types import Timing from execution_engine.util.value import ValueNumber -from execution_engine.util.value.factory import value_factory from execution_engine.util.value.time import ValueCount __all__ = ["ProcedureOccurrence"] @@ -160,45 +157,3 @@ def description(self) -> str: parts.append(f"dose={str(self._timing)}") return f"{self.__class__.__name__}[" + ", ".join(parts) + "]" - - def dict(self) -> dict[str, Any]: - """ - Return a dictionary representation of the criterion. - """ - assert self._concept is not None, "Concept must be set" - - return { - "concept": self._concept.model_dump(), - "value": ( - self._value.model_dump(include_meta=True) - if self._value is not None - else None - ), - "timing": ( - self._timing.model_dump(include_meta=True) - if self._timing is not None - else None - ), - } - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "ProcedureOccurrence": - """ - Create a procedure occurrence criterion from a dictionary representation. - """ - - value = value_factory(**data["value"]) if data["value"] is not None else None - timing = value_factory(**data["timing"]) if data["timing"] is not None else None - - assert ( - isinstance(value, ValueNumber) or value is None - ), "value must be a ValueNumber" - assert ( - isinstance(timing, ValueNumber | Timing) or timing is None - ), "timing must be a ValueNumber" - - return cls( - concept=Concept(**data["concept"]), - value=value, - timing=cast(Timing, timing), - ) diff --git a/execution_engine/omop/criterion/visit_detail.py b/execution_engine/omop/criterion/visit_detail.py index de5f74d2..66bacdde 100644 --- a/execution_engine/omop/criterion/visit_detail.py +++ b/execution_engine/omop/criterion/visit_detail.py @@ -1,4 +1,6 @@ -from typing import Any +from execution_engine.omop.concepts import Concept +from execution_engine.util.types import Timing +from execution_engine.util.value import Value __all__ = ["VisitDetail"] @@ -13,8 +15,21 @@ class VisitDetail(ContinuousCriterion): visit details may be transfers between units of a hospital or a change of bed. We d """ - def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, **kwargs) + def __init__( + self, + concept: Concept, + value: Value | None = None, + static: bool | None = None, + timing: Timing | None = None, + value_required: bool | None = None, + ): + super().__init__( + concept=concept, + value=value, + static=static, + timing=timing, + value_required=value_required, + ) # visit concepts are mapped automatically to the visit_occurrence table, so map visit_detail explicitly here self._set_omop_variables_from_domain("visit_detail") diff --git a/execution_engine/omop/criterion/visit_occurrence.py b/execution_engine/omop/criterion/visit_occurrence.py index 34b830ca..f6bab235 100644 --- a/execution_engine/omop/criterion/visit_occurrence.py +++ b/execution_engine/omop/criterion/visit_occurrence.py @@ -1,5 +1,3 @@ -from typing import Any, Dict - from sqlalchemy import select from sqlalchemy.sql import Select @@ -83,19 +81,6 @@ def description(self) -> str: """ return f"{self.__class__.__name__}[]" - def dict(self) -> dict[str, Any]: - """ - Get a JSON representation of the criterion. - """ - return {} - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "ActivePatients": - """ - Create a criterion from a JSON representation. - """ - return cls() - class PatientsActiveDuringPeriod(ActivePatients): """ diff --git a/execution_engine/omop/serializable.py b/execution_engine/omop/serializable.py deleted file mode 100644 index 4376c5d7..00000000 --- a/execution_engine/omop/serializable.py +++ /dev/null @@ -1,96 +0,0 @@ -import json -from abc import ABC, abstractmethod -from typing import Any, Dict, Self - - -class Serializable(ABC): - """ - Base class for serializable objects. - """ - - _id: int | None = None - """ - The id is used in the database tables. - """ - - def set_id(self, value: int) -> None: - """ - Assigns the database ID to the object. This can only be done once. - - This ID corresponds to the primary key in the database and is set - when the object is persisted. - - :param value: The database ID assigned to the object. - :raises ValueError: If the ID has already been set. - """ - if self._id is not None: - raise ValueError("Database ID has already been set!") - self._id = value - - @property - def id(self) -> int: - """ - Retrieves the database ID of the object. - - This ID is only available after the object has been stored in the database. - - :return: The database ID, or None if the object has not been stored yet. - """ - if self._id is None: - raise ValueError("Database ID has not been set yet!") - return self._id - - def is_persisted(self) -> bool: - """ - Returns True if the object has been stored in the database. - """ - return self._id is not None - - @abstractmethod - def dict(self) -> dict: - """ - Get a dictionary representation of the object. - """ - raise NotImplementedError() - - @classmethod - @abstractmethod - def from_dict(cls, data: Dict[str, Any]) -> Self: - """ - Create an object from a dictionary. - """ - raise NotImplementedError() - - def json(self) -> bytes: - """ - Get a JSON representation of the object. - - The json excludes the id, as this is auto-inserted by the database - and not known during the creation of the object. - """ - - s_json = self.dict() - - return json.dumps(s_json, sort_keys=True).encode() - - @classmethod - def from_json(cls, data: str | bytes) -> Self: - """ - Create a combination from a JSON string. - """ - return cls.from_dict(json.loads(data)) - - def __eq__(self, other: Any) -> bool: - """ - Check if two objects are equal. - """ - if not isinstance(other, self.__class__): - return False - - return self.dict() == other.dict() - - def __hash__(self) -> int: - """ - Get the hash of the object. - """ - return hash(self.__class__.__name__.encode() + self.json()) diff --git a/execution_engine/omop/vocabulary.py b/execution_engine/omop/vocabulary.py index d55156f9..5d0c17c8 100644 --- a/execution_engine/omop/vocabulary.py +++ b/execution_engine/omop/vocabulary.py @@ -201,7 +201,7 @@ class CODEXCELIDA(AbstractVocabulary): vocab_id = "CODEX_CELIDA" map = { "tvpibw": CustomConcept( - name="Tidal volume / ideal body weight (ARDSnet)", + concept_name="Tidal volume / ideal body weight (ARDSnet)", concept_code="tvpibw", domain_id="Measurement", vocabulary_id=vocab_id, diff --git a/execution_engine/task/creator.py b/execution_engine/task/creator.py index 466912e8..cd445827 100644 --- a/execution_engine/task/creator.py +++ b/execution_engine/task/creator.py @@ -22,7 +22,7 @@ def create_tasks_and_dependencies(graph: nx.DiGraph) -> list[Task]: def node_to_task(expr: logic.Expr, attr: dict) -> Task: store_result = attr.get("store_result", False) - bind_params = attr.get("bind_params", {}) + bind_params = attr.get("bind_params", {}).copy() bind_params["category"] = attr["category"] task = Task( diff --git a/execution_engine/task/task.py b/execution_engine/task/task.py index 7f10dd1f..a8bc153c 100644 --- a/execution_engine/task/task.py +++ b/execution_engine/task/task.py @@ -9,11 +9,11 @@ import execution_engine.util.logic as logic from execution_engine.constants import CohortCategory from execution_engine.omop.criterion.abstract import Criterion -from execution_engine.omop.criterion.combination.temporal import TimeIntervalType from execution_engine.omop.db.celida.tables import ResultInterval from execution_engine.omop.sqlclient import OMOPSQLClient from execution_engine.settings import get_config from execution_engine.task.process import Interval, get_processing_module +from execution_engine.util.enum import TimeIntervalType from execution_engine.util.interval import IntervalType from execution_engine.util.types import PersonIntervals, TimeRange @@ -93,6 +93,32 @@ def find_base_task(task: Task) -> Task: return find_base_task(self) + def select_predecessor_result( + self, expr: logic.BaseExpr, data: list[PersonIntervals] + ) -> PersonIntervals: + """ + Select the result results of the predecessor task from the given expression. + + This is required in expressions where order is important, e.g. in BinaryNonCommutativeOperator. + As the nx.DiGraph (and by inheritance, ExecutionGraph) does not store the order of the predecessors, + we need to find the predecessor task by its expression and select the result from the data. + + :param expr: The expression of the predecessor task. + :param data: The input data. + :return: The result of the predecessor task. + """ + if len(self.dependencies) == 0: + raise ValueError("Task has no dependencies.") + + idx = next((i for i, t in enumerate(self.dependencies) if t.expr == expr), None) + + if idx is None: + raise ValueError( + f"Task with expression '{str(expr)}' not found in dependencies." + ) + + return data[idx] + def run( self, data: list[PersonIntervals], @@ -159,20 +185,16 @@ def run( ), ): result = self.handle_binary_logical_operator(data) - elif isinstance( - self.expr, (logic.LeftDependentToggle, logic.ConditionalFilter) - ): + elif isinstance(self.expr, (logic.BinaryNonCommutativeOperator)): result = self.handle_left_dependent_toggle( - left=data[0], - right=data[1], + left=self.select_predecessor_result(self.expr.left, data), + right=self.select_predecessor_result(self.expr.right, data), base_data=base_data, observation_window=observation_window, ) - elif isinstance(self.expr, logic.NoDataPreservingAnd): - result = self.handle_no_data_preserving_operator( - data, base_data, observation_window - ) - elif isinstance(self.expr, logic.NoDataPreservingOr): + elif isinstance( + self.expr, (logic.NoDataPreservingAnd, logic.NoDataPreservingOr) + ): result = self.handle_no_data_preserving_operator( data, base_data, observation_window ) @@ -182,6 +204,13 @@ def run( raise ValueError(f"Unsupported expression type: {type(self.expr)}") if self.store_result: + if ( + not isinstance(self.expr, logic.NoDataPreservingAnd) + and not self.expr.is_Atom + ): + result = self.insert_negative_intervals( + result, base_data, observation_window + ) logging.debug(f"Storing results - '{self.name()}'") self.store_result_in_db(result, base_data, bind_params) @@ -352,20 +381,15 @@ def handle_no_data_preserving_operator( result = process.intersect_intervals(data) elif isinstance(self.expr, logic.NoDataPreservingOr): result = process.union_intervals(data) + else: + raise ValueError(f"Unsupported expression type: {type(self.expr)}") # todo: the only difference between this function and handle_binary_logical_operator is the following lines # - can we merge? - result_negative = process.complementary_intervals( - result, - reference=base_data, - observation_window=observation_window, - interval_type=IntervalType.NEGATIVE, + return self.insert_negative_intervals( + data=result, base_data=base_data, observation_window=observation_window ) - result = process.concat_intervals([result, result_negative]) - - return result - def handle_left_dependent_toggle( self, left: PersonIntervals, @@ -494,7 +518,11 @@ def get_start_end_from_interval_type( assert isinstance(self.expr, logic.TemporalCount), "Invalid expression type" if self.expr.interval_criterion is not None: + # last element is the indicator windows + assert ( + len(data) >= 2 + ), "TemporalCount with indicator criterion requires at least two inputs" data, indicator_personal_windows = data[:-1], data[-1] result = process.find_overlapping_personal_windows( @@ -535,6 +563,35 @@ def get_start_end_from_interval_type( return result + def insert_negative_intervals( + self, + data: PersonIntervals, + base_data: PersonIntervals, + observation_window: TimeRange, + ) -> PersonIntervals: + """ + Inserts negative intervals into the result. + + Usually, negative intervals are implicit. This functions fills all gaps between other intervals with negative + intervals. + + :param data: The input data. + :param base_data: The result of the base criterion. + :param observation_window: The observation window. + :return: A DataFrame with the merged intervals. + """ + + data_negative = process.complementary_intervals( + data, + reference=base_data, + observation_window=observation_window, + interval_type=IntervalType.NEGATIVE, + ) + + result = process.concat_intervals([data, data_negative]) + + return result + def store_result_in_db( self, result: PersonIntervals, @@ -623,7 +680,10 @@ def name(self) -> str: Uniqueness is guaranteed by prepending the base64-encoded hash of the Task object. """ - return f"[{self.id()}] {str(self)}" + if self.expr.is_Atom: + return f"[{self.id()}] {str(self)}" + else: + return f"[{self.id()}] {self.expr.__class__.__name__}()" def id(self) -> str: """ diff --git a/execution_engine/util/__init__.py b/execution_engine/util/__init__.py index 67cf1f92..92a8cbae 100644 --- a/execution_engine/util/__init__.py +++ b/execution_engine/util/__init__.py @@ -1,7 +1,20 @@ +import datetime from abc import ABCMeta from typing import Any +def datetime_converter(obj: Any) -> Any: + """ + Convert datetime objects to ISO format strings. + + Used in json.dumps() to serialize datetime objects. + """ + if isinstance(obj, (datetime.datetime, datetime.date, datetime.time)): + return obj.isoformat() + + return obj + + class AbstractPrivateMethods(ABCMeta): """ A metaclass that prevents overriding of methods decorated with @typing.final. diff --git a/execution_engine/util/enum.py b/execution_engine/util/enum.py index a9653fa5..574ed87c 100644 --- a/execution_engine/util/enum.py +++ b/execution_engine/util/enum.py @@ -23,9 +23,15 @@ class TimeUnit(StrEnum): MONTH = "mo" YEAR = "a" + def __repr__(self) -> str: + """ + Get the string representation of the category. + """ + return f"{self.__class__.__name__}.{self.name}" + def __str__(self) -> str: """ - Returns the string representation of the TimeUnit. + Get the string representation of the category. """ return self.name @@ -58,3 +64,29 @@ def to_sql_interval_length_seconds(self) -> ColumnElement: return func.cast(func.extract("EPOCH", self.to_sql_interval()), NUMERIC).label( "duration_seconds" ) + + +class TimeIntervalType(StrEnum): + """ + Types of time intervals to aggregate criteria over. + """ + + MORNING_SHIFT = "morning_shift" + AFTERNOON_SHIFT = "afternoon_shift" + NIGHT_SHIFT = "night_shift" + NIGHT_SHIFT_BEFORE_MIDNIGHT = "night_shift_before_midnight" + NIGHT_SHIFT_AFTER_MIDNIGHT = "night_shift_after_midnight" + DAY = "day" + ANY_TIME = "any_time" + + def __repr__(self) -> str: + """ + Get the string representation of the category. + """ + return f"{self.__class__.__name__}.{self.name}" + + def __str__(self) -> str: + """ + Get the string representation of the category. + """ + return self.name diff --git a/execution_engine/util/logic.py b/execution_engine/util/logic.py index d84ed760..c140ea1e 100644 --- a/execution_engine/util/logic.py +++ b/execution_engine/util/logic.py @@ -1,47 +1,29 @@ -from abc import ABC, abstractmethod +from abc import abstractmethod from datetime import time -from typing import Any, Callable, Dict, Type, cast +from typing import Any, Dict, Iterator, Self, cast -from execution_engine.omop.criterion.combination.temporal import TimeIntervalType -from execution_engine.omop.serializable import Serializable +from execution_engine.util.enum import TimeIntervalType +from execution_engine.util.serializable import Serializable, SerializableABC -_EXPR_CLASSES: dict[str, Type["BaseExpr"]] = {} +def arg_to_dict(arg: Any, include_id: bool) -> dict: + """ + Convert an argument to a dictionary representation. -def register_expr_class(cls: Type["BaseExpr"]) -> Type["BaseExpr"]: - """Decorator to register expression classes by their class name.""" - _EXPR_CLASSES[cls.__name__] = cls - return cls + :param arg: The argument to convert. + :param include_id: Whether to include the ID in the dictionary. + :return: Dictionary representation of the argument. + """ + return arg.dict(include_id=include_id) if isinstance(arg, BaseExpr) else arg -class BaseExpr(ABC): +class BaseExpr(Serializable): """ Base class for expressions and symbols, defining common properties. """ args: tuple - @classmethod - def _recreate(cls, args: Any, kwargs: dict) -> "Expr": - """ - Recreate an expression from its arguments. - """ - return cast(Expr, cls(*args, **kwargs)) - - @abstractmethod - def __eq__(self, other: Any) -> bool: - """ - Check if this expression is equal to another expression. - """ - raise NotImplementedError("__eq__ must be implemented by subclasses") - - @abstractmethod - def __hash__(self) -> int: - """ - Get the hash of this expression. - """ - raise NotImplementedError("__hash__ must be implemented by subclasses") - @property def is_Atom(self) -> bool: """ @@ -60,15 +42,13 @@ def is_Not(self) -> bool: """ raise NotImplementedError("is_Not must be implemented by subclasses") - def __reduce__(self) -> tuple[Callable, tuple]: - """ - Reduce the expression to its arguments. - Required for pickling (e.g. when using multiprocessing). +class Expr(BaseExpr): + """ + Class for expressions that are not Symbols + """ - :return: Tuple of the class, and arguments. - """ - return self._recreate, (self.args, self.get_instance_variables()) + # todo: isn't this now a bit redundant with the BaseExpr class? (because we've removed category) def get_instance_variables(self, immutable: bool = False) -> dict | tuple: """ @@ -89,71 +69,26 @@ def get_instance_variables(self, immutable: bool = False) -> dict | tuple: else: return instance_vars - def dict(self) -> dict: - """ - Serialize this expression (and its children) into a dict. - """ - data: dict[str, Any] = {"type": self.__class__.__name__} - # Store the class name, so we know which subclass to instantiate on from_dict - - # Store instance variables (like thresholds, category, etc.) except for args - # (They come from get_instance_variables in your snippet) - instance_vars = self.get_instance_variables(immutable=False) - for key in instance_vars: - if isinstance(instance_vars[key], Serializable): - data[key] = instance_vars[key].dict() - else: - data[key] = instance_vars[key] - - # Also store the expression's children by recursing - if self.args: - serialized_args = [] - for arg in self.args: - if isinstance(arg, (BaseExpr, Serializable)): - serialized_args.append(arg.dict()) # Recursively serialize - else: - # If non-expression objects appear in .args, store them directly - serialized_args.append(arg) - data["args"] = serialized_args - - return data - - @classmethod - def from_dict(cls, data: Dict) -> "BaseExpr": - """ - Rebuild an expression (of the correct subclass) from the given dict. + def __setattr__(self, name: str, value: Any) -> None: """ - expr_type = data["type"] - - # Find the actual subclass - if expr_type not in _EXPR_CLASSES: - raise ValueError(f"No registered expression class named '{expr_type}'") - sub_cls = _EXPR_CLASSES[expr_type] - - # Deserialize child expressions - children = [] - for arg_data in data.get("args", []): - if isinstance(arg_data, dict) and "type" in arg_data: - # It's a nested BaseExpr - child_expr = cls.from_dict(arg_data) - children.append(child_expr) - else: - # It's a literal (int, str, etc.) - children.append(arg_data) - - instance_vars = { - key: val for key, val in data.items() if key not in ("type", "args") - } - - return sub_cls._recreate(tuple(children), instance_vars) + Set an attribute on the object. + This is overridden to prevent setting attributes on the object. + """ + if name in self.__dict__ and name not in ["args", "_init_args"]: + raise AttributeError( + f"Cannot update attributes on {self.__class__.__name__}" + ) + super().__setattr__(name, value) -class Expr(BaseExpr): - """ - Class for expressions that require a category. - """ + @abstractmethod + def update_args(self, *args: Any) -> None: + """ + Update the arguments of the expression. - # todo: isn't this now a bit redundant with the BaseExpr class? (because we've removed category) + :param args: The new arguments. + """ + raise NotImplementedError("update_args must be implemented by subclasses") def __new__(cls, *args: Any) -> "Expr": """ @@ -174,20 +109,11 @@ def __new__(cls, *args: Any) -> "Expr": return self - def __repr__(self) -> str: + def __str__(self) -> str: """ Represent the expression in a readable format. """ - return f"{self.__class__.__name__}({', '.join(map(repr, self.args))})" - - def __eq__(self, other: Any) -> bool: - """ - Check if this expression is equal to another expression. - - :param other: The other expression. - :return: True if equal, False otherwise. - """ - return isinstance(other, self.__class__) and hash(self) == hash(other) + return f"{self.__class__.__name__}({', '.join(map(str, self.args))})" def __hash__(self) -> int: """ @@ -217,8 +143,38 @@ def is_Not(self) -> bool: """ return isinstance(self, Not) + def atoms(self) -> Iterator["Symbol"]: + """ + Get all symbols in the expression. + """ + + def traverse(expr: BaseExpr) -> Iterator[Symbol]: + if expr.is_Atom: + assert isinstance(expr, Symbol), f"Expected Symbol, got {expr}" + yield expr + + for sub_expr in expr.args: + yield from traverse(sub_expr) + + yield from traverse(self) + + def dict(self, include_id: bool = False) -> dict: + """ + Get a dictionary representation of the object. + """ + data: Dict[str, Any] = { + "type": self.__class__.__name__, + "data": { + "args": [arg_to_dict(arg, include_id=include_id) for arg in self.args], + }, + } + + if include_id and self._id is not None: + data["data"]["_id"] = self._id + + return data + -@register_expr_class class Symbol(BaseExpr): """ Class representing a symbolic variable. @@ -251,8 +207,26 @@ def is_Not(self) -> bool: """ return False + def get_instance_variables(self, immutable: bool = False) -> Dict[str, Any] | tuple: + """ + Return all instance variables of the object. + + If immutable is True, return as an immutable tuple of key-value pairs. + If immutable is False, return as a mutable dictionary. + """ + instance_vars = { + key: value + for key, value in vars(self).items() + if not key.startswith("_") # Exclude private or special attributes + and key != "args" + } + + if immutable: + return tuple(sorted(instance_vars.items())) + else: + return instance_vars + -@register_expr_class class BooleanFunction(Expr): """ Base class for boolean functions like OR, AND, and NOT. @@ -260,29 +234,6 @@ class BooleanFunction(Expr): _repr_join_str: str | None = None - def __eq__(self, other: Any) -> bool: - """ - Check if this operator is equal to another operator. - - :param other: The other operator. - :return: True if equal, False otherwise. - """ - return ( - isinstance(other, self.__class__) - and self.args == other.args - and self.get_instance_variables(immutable=True) - == other.get_instance_variables(immutable=True) - ) - - # Needs to be defined again (although it is the same as in Expr) because we define __eq__ here - def __hash__(self) -> int: - """ - Get the hash of this operator. - - :return: Hash of the operator. - """ - return super().__hash__() - @property def is_Atom(self) -> bool: """ @@ -301,18 +252,51 @@ def is_Not(self) -> bool: """ return isinstance(self, Not) - def __repr__(self) -> str: + +class UnaryOperator(BooleanFunction): + """ + Base class for unary operators. + """ + + def __new__(cls, *args: Any, **kwargs: Any) -> "UnaryOperator": """ - Represent the BooleanFunction in a readable format. + Create a new UnaryOperator object. """ - if self._repr_join_str is not None: - return "(" + f" {self._repr_join_str} ".join(map(repr, self.args)) + ")" - else: - return super().__repr__() + if len(args) > 1: + raise ValueError(f"{cls.__name__} can only have one argument") + + return cast(UnaryOperator, super().__new__(cls, *args, **kwargs)) + + def update_args(self, *args: Any) -> None: + """ + Update the arguments of the expression. + + :param args: The new arguments. + """ + self.args = args -@register_expr_class -class Or(BooleanFunction): +class CommutativeOperator(BooleanFunction, SerializableABC): + """ + Base class for commutative operators. + """ + + def update_args(self, *args: Any) -> None: + """ + Update the arguments of the expression. + + :param args: The new arguments. + """ + self.args = args + + def __new__(cls, *args: Any, **kwargs: Any) -> "CommutativeOperator": + """ + Create a new CommutativeOperator object. + """ + return cast(CommutativeOperator, super().__new__(cls, *args, **kwargs)) + + +class Or(CommutativeOperator): """ Class representing a logical OR operation. """ @@ -329,8 +313,7 @@ def __new__(cls, *args: Any, **kwargs: Any) -> BaseExpr: return super().__new__(cls, *args, **kwargs) -@register_expr_class -class And(BooleanFunction): +class And(CommutativeOperator): """ Class representing a logical AND operation. """ @@ -347,13 +330,12 @@ def __new__(cls, *args: Any, **kwargs: Any) -> BaseExpr: return super().__new__(cls, *args, **kwargs) -@register_expr_class -class Not(BooleanFunction): +class Not(UnaryOperator): """ Class representing a logical NOT operation. """ - def __repr__(self) -> str: + def __str__(self) -> str: """ Represent the NOT operation as a string. """ @@ -369,8 +351,7 @@ def __new__(cls, *args: Any, **kwargs: Any) -> "Not": return cast(Not, super().__new__(cls, *args, **kwargs)) -@register_expr_class -class Count(BooleanFunction, ABC): +class Count(CommutativeOperator, SerializableABC): """ Class representing a logical COUNT operation. @@ -383,7 +364,6 @@ class Count(BooleanFunction, ABC): count_max: int | None = None -@register_expr_class class MinCount(Count): """ Class representing a logical MIN_COUNT operation. @@ -397,27 +377,23 @@ def __new__(cls, *args: Any, threshold: int | None, **kwargs: Any) -> "MinCount" self.count_min = threshold return self - def __reduce__(self) -> tuple[Callable, tuple]: + def dict(self, include_id: bool = False) -> dict: + """ + Get a dictionary representation of the object. """ - Reduce the expression to its arguments - Required for pickling (e.g. when using multiprocessing). + data = super().dict(include_id=include_id) + data["data"].update({"threshold": self.count_min}) - :return: Tuple of the class, and arguments. - """ - return ( - self._recreate, - (self.args, {"threshold": self.count_min}), - ) + return data - def __repr__(self) -> str: + def __str__(self) -> str: """ Represent the expression in a readable format. """ return f"{self.__class__.__name__}(threshold={self.count_min}; {', '.join(map(repr, self.args))})" -@register_expr_class class MaxCount(Count): """ Class representing a logical MAX_COUNT operation. @@ -431,27 +407,21 @@ def __new__(cls, *args: Any, threshold: int | None, **kwargs: Any) -> "MaxCount" self.count_max = threshold return self - def __reduce__(self) -> tuple[Callable, tuple]: + def dict(self, include_id: bool = False) -> dict: """ - Reduce the expression to its arguments. - - Required for pickling (e.g. when using multiprocessing). - - :return: Tuple of the class and arguments. + Get a dictionary representation of the object. """ - return ( - self._recreate, - (self.args, {"threshold": self.count_max}), - ) + data = super().dict(include_id=include_id) + data["data"].update({"threshold": self.count_max}) + return data - def __repr__(self) -> str: + def __str__(self) -> str: """ Represent the expression in a readable format. """ return f"{self.__class__.__name__}(threshold={self.count_max}; {', '.join(map(repr, self.args))})" -@register_expr_class class ExactCount(Count): """ Class representing a logical EXACT_COUNT operation. @@ -466,28 +436,22 @@ def __new__(cls, *args: Any, threshold: int | None, **kwargs: Any) -> "ExactCoun self.count_max = threshold return self - def __reduce__(self) -> tuple[Callable, tuple]: + def dict(self, include_id: bool = False) -> dict: """ - Reduce the expression to its arguments. - - Required for pickling (e.g. when using multiprocessing). - - :return: Tuple of the class, and arguments. + Get a dictionary representation of the object. """ - return ( - self._recreate, - (self.args, {"threshold": self.count_min}), - ) + data = super().dict(include_id=include_id) + data["data"].update({"threshold": self.count_min}) + return data - def __repr__(self) -> str: + def __str__(self) -> str: """ Represent the expression in a readable format. """ return f"{self.__class__.__name__}(threshold={self.count_min}; {', '.join(map(repr, self.args))})" -@register_expr_class -class CappedCount(BooleanFunction, ABC): +class CappedCount(CommutativeOperator, SerializableABC): """ Base class representing a COUNT operation with an upper cap. @@ -507,7 +471,6 @@ class CappedCount(BooleanFunction, ABC): count_max: int | None = None -@register_expr_class class CappedMinCount(CappedCount): """ Class representing a MIN_COUNT operation with an implicit upper cap. @@ -534,35 +497,28 @@ def __new__( self.count_min = threshold return self - def __reduce__(self) -> tuple[Callable, tuple]: + def dict(self, include_id: bool = False) -> dict: """ - Reduce the expression to its arguments. - - Required for pickling (e.g., when using multiprocessing). - - :return: Tuple of the class, and arguments. + Get a dictionary representation of the object. """ - return ( - self._recreate, - (self.args, {"threshold": self.count_min}), - ) + data = super().dict(include_id=include_id) + data["data"].update({"threshold": self.count_min}) + return data - def __repr__(self) -> str: + def __str__(self) -> str: """ Represent the expression in a readable format. """ return f"{self.__class__.__name__}(threshold={self.count_min}; {', '.join(map(repr, self.args))})" -@register_expr_class -class AllOrNone(BooleanFunction): +class AllOrNone(CommutativeOperator): """ Class representing a logical ALL_OR_NONE operation. """ -@register_expr_class -class TemporalCount(BooleanFunction, ABC): +class TemporalCount(CommutativeOperator, SerializableABC): """ Class representing a logical COUNT operation. @@ -578,58 +534,108 @@ class TemporalCount(BooleanFunction, ABC): interval_type: TimeIntervalType | None = None interval_criterion: BaseExpr | None = None - -@register_expr_class -class TemporalMinCount(TemporalCount): - """ - Class representing a logical temporal MIN_COUNT operation. - """ - def __new__( cls, *args: Any, threshold: int | None, - start_time: time | None, - end_time: time | None, - interval_type: TimeIntervalType | None, - interval_criterion: BaseExpr | None, + start_time: time | None = None, + end_time: time | None = None, + interval_type: TimeIntervalType | None = None, + interval_criterion: BaseExpr | None = None, **kwargs: Any, - ) -> "TemporalMinCount": + ) -> Self: """ - Create a new MinCount object. + Create a new TemporalCount object. """ - self = cast(TemporalMinCount, super().__new__(cls, *args, **kwargs)) + TemporalCount._validate_time_inputs( + start_time, end_time, interval_type, interval_criterion + ) + + if interval_criterion: + # we need to add the interval_criterion to the list of arguments of this criterion in order to have + # it properly processed + args += (interval_criterion,) + + self = cast(Self, super().__new__(cls, *args, **kwargs)) + self.count_min = threshold - self.start_time = start_time - self.end_time = end_time + self.start_time = ( + time.fromisoformat(start_time) # type: ignore[arg-type] + if isinstance(start_time, str) + else start_time + ) + self.end_time = ( + time.fromisoformat(end_time) # type: ignore[arg-type] + if isinstance(end_time, str) + else end_time + ) self.interval_type = interval_type self.interval_criterion = interval_criterion return self - def __reduce__(self) -> tuple[Callable, tuple]: - """ - Reduce the expression to its arguments. - - Required for pickling (e.g. when using multiprocessing). + @classmethod + def _validate_time_inputs( + self, + start_time: time | None, + end_time: time | None, + interval_type: TimeIntervalType | None, + interval_criterion: BaseExpr | None, + ) -> None: + + if interval_type: + if start_time is not None or end_time is not None: + raise ValueError( + "start_time/end_time cannot be used together with interval_type" + ) + if interval_criterion is not None: + raise ValueError( + "interval_criterion cannot be used together with interval_type" + ) + # Validate the interval_type if needed + self.interval_type = interval_type + self.start_time = None + self.end_time = None + + elif start_time or end_time: + # Must have start_time and end_time + if start_time is None or end_time is None: + raise ValueError( + "Either interval_type or interval_criterion or both start_time & end_time must be provided" + ) + if interval_criterion is not None: + raise ValueError( + "interval_criterion cannot be used together with start_time/end_time" + ) + if start_time >= end_time: + raise ValueError("start_time must be less than end_time") + + elif interval_criterion and not isinstance(interval_criterion, (BaseExpr)): + raise ValueError( + f"Invalid criterion - expected Criterion or CriterionCombination, got {type(interval_criterion)}" + ) - :return: Tuple of the class, and arguments. - """ - return ( - self._recreate, - ( - self.args, - { - "threshold": self.count_min, - "start_time": self.start_time, - "end_time": self.end_time, - "interval_type": self.interval_type, - "interval_criterion": self.interval_criterion, - }, - ), + def dict(self, include_id: bool = False) -> dict: + """ + Get a dictionary representation of the object. + """ + data = super().dict(include_id=include_id) + data["data"].update( + { + "threshold": self.count_min, + "start_time": self.start_time.isoformat() if self.start_time else None, + "end_time": self.end_time.isoformat() if self.end_time else None, + "interval_type": self.interval_type, + "interval_criterion": ( + self.interval_criterion.dict(include_id=include_id) + if self.interval_criterion + else None + ), + } ) + return data - def __repr__(self) -> str: + def __str__(self) -> str: """ Represent the expression in a readable format. """ @@ -646,143 +652,25 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}(interval={interval}; threshold={self.count_min}; {', '.join(map(repr, self.args))})" -@register_expr_class -class TemporalMaxCount(TemporalCount): +class TemporalMinCount(TemporalCount): """ - Class representing a logical MAX_COUNT operation. + Class representing a logical temporal MIN_COUNT operation. """ - def __new__( - cls, - *args: Any, - threshold: int | None, - start_time: time | None, - end_time: time | None, - interval_type: TimeIntervalType | None, - interval_criterion: BaseExpr | None, - **kwargs: Any, - ) -> "TemporalMaxCount": - """ - Create a new MaxCount object. - """ - self = cast(TemporalMaxCount, super().__new__(cls, *args, **kwargs)) - self.count_max = threshold - self.start_time = start_time - self.end_time = end_time - self.interval_type = interval_type - self.interval_criterion = interval_criterion - - return self - - def __reduce__(self) -> tuple[Callable, tuple]: - """ - Reduce the expression to its arguments. - - Required for pickling (e.g. when using multiprocessing). - :return: Tuple of the class, and arguments. - """ - return ( - self._recreate, - ( - self.args, - { - "threshold": self.count_max, - "start_time": self.start_time, - "end_time": self.end_time, - "interval_type": self.interval_type, - "interval_criterion": self.interval_criterion, - }, - ), - ) - - def __repr__(self) -> str: - """ - Represent the expression in a readable format. - """ - - if self.start_time is not None and self.end_time is not None: - interval = f"{self.start_time} - {self.end_time}" - elif self.interval_type is not None: - interval = self.interval_type.name - elif self.interval_criterion is not None: - interval = repr(self.interval_criterion) - else: - interval = "None" - - return f"{self.__class__.__name__}(interval={interval}; threshold={self.count_max}; {', '.join(map(repr, self.args))})" +class TemporalMaxCount(TemporalCount): + """ + Class representing a logical MAX_COUNT operation. + """ -@register_expr_class class TemporalExactCount(TemporalCount): """ Class representing a logical EXACT_COUNT operation. """ - def __new__( - cls, - *args: Any, - threshold: int | None, - start_time: time | None, - end_time: time | None, - interval_type: TimeIntervalType | None, - interval_criterion: BaseExpr | None, - **kwargs: Any, - ) -> "TemporalExactCount": - """ - Create a new ExactCount object. - """ - self = cast(TemporalExactCount, super().__new__(cls, *args, **kwargs)) - self.count_min = threshold - self.count_max = threshold - self.start_time = start_time - self.end_time = end_time - self.interval_type = interval_type - self.interval_criterion = interval_criterion - return self - - def __reduce__(self) -> tuple[Callable, tuple]: - """ - Reduce the expression to its arguments. - - Required for pickling (e.g. when using multiprocessing). - - :return: Tuple of the class, and arguments. - """ - return ( - self._recreate, - ( - self.args, - { - "threshold": self.count_min, - "start_time": self.start_time, - "end_time": self.end_time, - "interval_type": self.interval_type, - "interval_criterion": self.interval_criterion, - }, - ), - ) - - def __repr__(self) -> str: - """ - Represent the expression in a readable format. - """ - - if self.start_time is not None and self.end_time is not None: - interval = f"{self.start_time} - {self.end_time}" - elif self.interval_type is not None: - interval = self.interval_type.name - elif self.interval_criterion is not None: - interval = repr(self.interval_criterion) - else: - interval = "None" - - return f"{self.__class__.__name__}(interval={interval}; threshold={self.count_min}; {', '.join(map(repr, self.args))})" - - -@register_expr_class -class NonSimplifiableAnd(BooleanFunction): +class NonSimplifiableAnd(CommutativeOperator): """ A NonSimplifiableAnd object represents a logical AND operation that cannot be simplified. @@ -796,23 +684,6 @@ class NonSimplifiableAnd(BooleanFunction): that. """ - def __eq__(self, other: Any) -> bool: - """ - Check if this operator is equal to another operator. - - This always yields false to prevent combination of two NonSimplifiableAnd operators when merging a graph. - """ - return False - - def __hash__(self) -> int: - """ - Get the hash of this operator. - - Required because __eq__ yields always False -- and we need distinct hashes for distinct objects, as this - operator should not be merged. - """ - return id(self) - def __new__(cls, *args: Any, **kwargs: Any) -> "NonSimplifiableAnd": """ Create a new NonSimplifiableAnd object. @@ -821,8 +692,7 @@ def __new__(cls, *args: Any, **kwargs: Any) -> "NonSimplifiableAnd": # todo: can we rename to more meaningful name? -@register_expr_class -class NoDataPreservingAnd(BooleanFunction): +class NoDataPreservingAnd(CommutativeOperator): """ A And object represents a logical AND operation. @@ -837,8 +707,7 @@ def __new__(cls, *args: Any, **kwargs: Any) -> "NoDataPreservingAnd": return cast(NoDataPreservingAnd, super().__new__(cls, *args, **kwargs)) -@register_expr_class -class NoDataPreservingOr(BooleanFunction): +class NoDataPreservingOr(CommutativeOperator): """ A Or object represents a logical OR operation. @@ -853,20 +722,35 @@ def __new__(cls, *args: Any, **kwargs: Any) -> "NoDataPreservingOr": return cast(NoDataPreservingOr, super().__new__(cls, *args, **kwargs)) -@register_expr_class -class LeftDependentToggle(BooleanFunction): +class BinaryNonCommutativeOperator(BooleanFunction, SerializableABC): """ - A LeftDependentToggle object represents a logical AND operation if the left operand is positive, - otherwise it returns NOT_APPLICABLE. + Base class for binary non-commutative operators. + + This class should not be instantiated directly but used as a base for specific + binary non-commutative operators like LeftDependentToggle. """ + def update_args(self, *args: Any) -> None: + """ + Update the arguments of the expression. + + :param args: The new arguments. + """ + if len(args) != 2: + raise ValueError( + f"{self.__class__.__name__} requires exactly two arguments" + ) + self.args = args + def __new__( cls, left: BaseExpr, right: BaseExpr, **kwargs: Any - ) -> "LeftDependentToggle": + ) -> "BinaryNonCommutativeOperator": """ - Initialize a LeftDependentToggle object. + Create a new BinaryNonCommutativeOperator object. """ - return cast(LeftDependentToggle, super().__new__(cls, left, right, **kwargs)) + return cast( + BinaryNonCommutativeOperator, super().__new__(cls, left, right, **kwargs) + ) @property def left(self) -> Expr: @@ -878,9 +762,29 @@ def right(self) -> Expr: """Returns the right operand""" return self.args[1] + def dict(self, include_id: bool = False) -> dict: + """ + Get a dictionary representation of the object. + """ + data = super().dict(include_id=include_id) + del data["data"]["args"] + data["data"].update( + { + "left": arg_to_dict(self.left, include_id=include_id), + "right": arg_to_dict(self.right, include_id=include_id), + } + ) + return data + -@register_expr_class -class ConditionalFilter(BooleanFunction): +class LeftDependentToggle(BinaryNonCommutativeOperator): + """ + A LeftDependentToggle object represents a logical AND operation if the left operand is positive, + otherwise it returns NOT_APPLICABLE. + """ + + +class ConditionalFilter(BinaryNonCommutativeOperator): """ A ConditionalFilter object returns the right operand if the left operand is POSITIVE, and NEGATIVE otherwise @@ -896,21 +800,3 @@ class ConditionalFilter(BooleanFunction): | POSITIVE | NEGATIVE | NEGATIVE | | POSITIVE | NO_DATA | NO_DATA | """ - - def __new__( - cls, left: BaseExpr, right: BaseExpr, **kwargs: Any - ) -> "ConditionalFilter": - """ - Initialize a ConditionalFilter object. - """ - return cast(ConditionalFilter, super().__new__(cls, left, right, **kwargs)) - - @property - def left(self) -> Expr: - """Returns the left operand""" - return self.args[0] - - @property - def right(self) -> Expr: - """Returns the right operand""" - return self.args[1] diff --git a/execution_engine/util/serializable.py b/execution_engine/util/serializable.py new file mode 100644 index 00000000..239c8030 --- /dev/null +++ b/execution_engine/util/serializable.py @@ -0,0 +1,534 @@ +import abc +import inspect +import json +from typing import Any, Callable, Dict, Self, Tuple, final + +from pydantic import BaseModel + +from execution_engine.util import datetime_converter + +__class_registry: dict[str, type] = {} +""" +Registry of classes that can be serialized. +""" + + +def register_class(cls: type) -> type: + """ + Register a class for serialization. + + This function may be used as a decorator to register a class in the serialization registry. + + :param cls: The class to register. + :return: The same class (for decorator chaining). + :raises ValueError: If the class name is already registered. + """ + if cls.__name__ in __class_registry: + raise ValueError(f"Class {cls.__name__} is already registered.") + + __class_registry[cls.__name__] = cls + return cls + + +def resolve_class(name: str) -> type: + """ + Resolve a registered class by name. + + :param name: The name of the class to retrieve. + :return: The corresponding class object. + :raises ValueError: If the class is not found in the registry. + """ + cls = __class_registry.get(name) + + if cls is None: + raise ValueError(f"Class {name} is not registered.") + + return cls + + +def is_class_registered(name: str) -> bool: + """ + Check if a class is registered under the given name. + + :param name: The name of the class to check. + :return: True if the class is registered, False otherwise. + """ + return name in __class_registry + + +class RegisteredPostInitMeta(type): + """ + Metaclass that automatically registers a class for serialization and calls + a custom __post_init__ method (if defined) once after regular object initialization. + """ + + def __call__(cls, *args: Any, do_post_init: bool = True, **kwargs: Any) -> Self: + """ + Create and return a new instance of the class, then call __post_init__ if defined. + + This overrides the default object construction process to perform any + custom post-initialization logic defined in __post_init__. If __post_init__ + exists, it is called exactly once, after the instance is created. + + :param args: Positional arguments used during object creation. + :param kwargs: Keyword arguments used during object creation. + :return: The newly created instance of the class. + """ + instance = super().__call__(*args, **kwargs) + + if ( + do_post_init + and hasattr(instance, "__post_init__") + and not getattr(instance, "_post_initialized", False) + ): + instance.__post_init__() + instance._post_initialized = True + + return instance + + def __new__(mcs, name: str, bases: tuple[type, ...], attrs: dict[str, Any]) -> Self: + """ + Create and return a new class, registering it in the serialization registry. + + This method intercepts the class creation process itself. It registers the + newly defined class to __class_registry, allowing instances to be properly + deserialized later. + + :param name: The name of the newly created class. + :param bases: A tuple of base classes. + :param attrs: A dictionary of attributes/methods for the new class. + :return: The newly created class object. + :raises ValueError: If the class is already registered. + """ + new_class = super().__new__(mcs, name, bases, attrs) + register_class(new_class) + return new_class + + +class Serializable(metaclass=RegisteredPostInitMeta): + """ + Base class for making objects serializable. + + Stores construction arguments, manages an optional database ID, and provides serialization + and deserialization to and from dictionaries or JSON. + + Note that Serializable classes are immutable. This means that once an object is created, + its attributes cannot be changed. This is enforced by overriding the __setattr__ method in + the __post_init__ method. + The rationale behind this is to provide a fixed hash value for the object, which is + calculated only once during the object's lifetime. This is important for caching and + serialization purposes. + """ + + _id: int | None = None + """ + The id is used in the database tables. + """ + + _hash: int + """ + The hash of the object. This is calculated based on the class name and the JSON representation + of the object. It is used to ensure that the object is immutable. + """ + + def __post_init__(self) -> None: + """ + Create a new instance of the object. + """ + self.rehash() + + def immutable_setattr(self: Self, key: str, value: Any) -> None: + raise AttributeError( + f"Cannot set attribute {key} on immutable object {self.__class__.__name__}" + ) + + self.__setattr__ = immutable_setattr # type: ignore[assignment] + + def set_id(self, value: int, overwrite: bool = False) -> None: + """ + Assigns the database ID to the object. This can only be done once. + + This ID corresponds to the primary key in the database and is set + when the object is persisted. + + :param value: The database ID assigned to the object. + :param overwrite: If True, allows overwriting an existing ID. + :raises ValueError: If the ID has already been set. + """ + if self._id is not None and not overwrite: + raise ValueError("Database ID has already been set!") + self._id = value + + @property + def id(self) -> int: + """ + Retrieves the database ID of the object. + + This ID is only available after the object has been stored in the database. + + :return: The database ID, or None if the object has not been stored yet. + """ + if self._id is None: + raise ValueError("Database ID has not been set yet!") + return self._id + + def is_persisted(self) -> bool: + """ + Returns True if the object has been stored in the database. + """ + return self._id is not None + + def dict(self, include_id: bool = False) -> dict: + """ + Get a dictionary representation of the object. + """ + data = {} + + def serialize(val: Any) -> Any: + if isinstance(val, Serializable): + return val.dict(include_id=include_id) + elif isinstance(val, BaseModel): + return { + "type": val.__class__.__name__, + "data": val.model_dump(), + } + elif isinstance(val, (list, tuple)): + return type(val)(serialize(item) for item in val) + return val + + for key, val in self.get_instance_variables().items(): # type: ignore[union-attr] + data[key] = serialize(val) + + if include_id and self._id is not None: + data["_id"] = self._id + + return {"type": self.__class__.__name__, "data": data} + + def get_instance_variables(self, immutable: bool = False) -> Dict[str, Any] | tuple: + """ + Get the instance variables of the object. + + This is only required if the subclass doesn't provide an own implementation of dict(). + """ + raise NotImplementedError( + "Method get_instance_variables must be implemented in subclasses." + ) + + @final + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> Self: + """ + Create an object from a dictionary produced by dict(). + + Expected format: + { + "type": , + "data": { + "_id": ... , (optional) + "args": [...], (optional, positional arguments) + ... any additional keys ... + } + } + """ + class_name = data["type"] + content = data["data"] + + # Look up the target class from our registry + sub_cls = resolve_class(class_name) + + # Extract _id if present + db_id = content.pop("_id", None) + + pos_args = content.pop("args", []) + var_kwargs = content + + def deserialize_item(item: Any) -> Any: + if isinstance(item, dict) and "type" in item: + return cls.from_dict(item) + return item + + pos_args = [deserialize_item(arg) for arg in pos_args] + var_kwargs = {k: deserialize_item(v) for k, v in var_kwargs.items()} + + obj = sub_cls(*pos_args, **var_kwargs) + + if db_id is not None: + obj.set_id(db_id) + + return obj + + def json(self) -> bytes: + """ + Get a JSON representation of the object. + + The json excludes the id, as this is auto-inserted by the database + and not known during the creation of the object. + """ + s_json = self.dict() + + return json.dumps(s_json, default=datetime_converter, sort_keys=True).encode() + + @classmethod + def from_json(cls, data: str | bytes) -> Self: + """ + Create a combination from a JSON string. + """ + return cls.from_dict(json.loads(data)) + + def __eq__(self, other: Any) -> bool: + """ + Check if two objects are equal. + """ + if not isinstance(other, self.__class__): + return False + + return self._hash == other._hash + + def rehash(self) -> None: + """ + Recalculate the hash of the object. + """ + self._hash = hash(self.__class__.__name__.encode() + self.json()) + + def __hash__(self) -> int: + """ + Get the hash of the object. + """ + return self._hash + + def __reduce__(self) -> Tuple[Callable, tuple]: + """ + Support pickling of the object. + """ + return self.__class__.from_dict, (self.dict(include_id=True),) + + @final + def __repr__(self) -> str: + """ + Get a string representation of the object. + """ + rep_dict = self.dict(include_id=False) + data = rep_dict["data"].copy() + + return self.build_repr(rep_dict["type"], data, indent=0) + + def build_repr(self, cls: str, data: Dict, indent: int) -> str: + """ + Build a string representation of the object. + """ + indent_unit = " " + + def pformat_value(val: Any, indent: int) -> str: + current_indent = indent_unit * indent + next_indent = indent_unit * (indent + 1) + # Check for a serialized object (dict with "type" and "data") + if isinstance(val, dict) and "type" in val and "data" in val: + + if is_class_registered(val["type"]) and issubclass( + resolve_class(val["type"]), BaseModel + ): + flat_data = ", ".join( + f"{k}={repr(v)}" for k, v in val["data"].items() + ) + return f"{val['type']}({flat_data})" + + return self.build_repr(val["type"], val["data"], indent + 1) + elif isinstance(val, list): + if not val: + return "[]" + items = [pformat_value(item, indent + 1) for item in val] + return ( + "[\n" + + ",\n".join(next_indent + item for item in items) + + "\n" + + current_indent + + "]" + ) + elif isinstance(val, tuple): + if not val: + return "()" + # Special case for single-element tuples + if len(val) == 1: + item = pformat_value(val[0], indent + 1) + return "(\n" + next_indent + item + ",\n" + current_indent + ")" + else: + items = [pformat_value(item, indent + 1) for item in val] + return ( + "(\n" + + ",\n".join(next_indent + item for item in items) + + "\n" + + current_indent + + ")" + ) + elif isinstance(val, dict): + if not val: + return "{}" + items = [] + for k, v in val.items(): + formatted_v = pformat_value(v, indent + 1) + items.append(f"{next_indent}{repr(k)}: {formatted_v}") + return "{\n" + ",\n".join(items) + "\n" + current_indent + "}" + else: + return repr(val) + + # Extract positional arguments (if any) + pos_args = data.pop("args", []) + pos_args_str = [pformat_value(arg, indent=indent + 1) for arg in pos_args] + + # Format remaining keyword arguments + kw_args_str = [ + f"{key}={pformat_value(value, indent=indent+1)}" + for key, value in data.items() + ] + + # Combine both sets of arguments + all_args = pos_args_str + kw_args_str + outer_indent = indent_unit * indent + inner_indent = indent_unit * (indent + 1) + + if all_args: + args_joined = ",\n".join(inner_indent + arg for arg in all_args) + return f"{cls}(\n{args_joined}\n{outer_indent})" + else: + return f"{cls}()" + + def __str__(self) -> str: + """ + Get a string representation of the object. + """ + data = {} + + for key, val in self.get_instance_variables().items(): # type: ignore[union-attr] + data[key] = str(val) + + return f"{self.__class__.__name__}({data})" + + +def get_constructor_vars(cls: type) -> set[str]: + """ + Get the variables that are passed to the constructor of a class. + + :param cls: The class. + :return: A set of variable names. + """ + self_var = "self" + original_init = cls.__init__ + if cls.__init__ == object.__init__ and cls.__new__ != object.__new__: + original_init = cls.__new__ + self_var = "cls" + + sig = inspect.signature(original_init) + + return {arg for arg in sig.parameters if arg != self_var} + + +class SerializableDataClassMeta(RegisteredPostInitMeta): + """ + Base class for making objects serializable. Stores construction + arguments, manages an optional database ID, and provides serialization + and deserialization to and from dictionaries or JSON. + """ + + def __call__(cls, *args: Any, **kwargs: Any) -> Self: + """ + Creates and returns a new instance of the class, ensuring that arguments + are assigned to protected attributes. Uses the function signature to bind + positional and keyword arguments and raises a ValueError if protected + attributes are missing. + + :param args: Positional arguments for object creation. + :param kwargs: Keyword arguments for object creation. + :return: Newly created instance of the class. + """ + instance = super().__call__(*args, do_post_init=False, **kwargs) + + self_var = "self" + original_init = cls.__init__ + if cls.__init__ == object.__init__ and cls.__new__ != object.__new__: + original_init = cls.__new__ + self_var = "cls" + + sig = inspect.signature(original_init) + if self_var == "cls": + # For __new__, the first argument should be the class (instance's class) + bound = sig.bind(cls, *args, **kwargs) + else: + # For __init__, the first argument is the instance + bound = sig.bind(instance, *args, **kwargs) + + bound.arguments.pop(self_var, None) # Remove 'self' or 'cls' + + protected_instance_vars = set( + arg + for arg in vars(instance) + if arg.startswith("_") and not arg.startswith("__") + ) + + # Check if the set of arguments passed to __init__ is a subset of the protected instance variables + if not set(f"_{arg}" for arg in bound.arguments) <= protected_instance_vars: + raise AttributeError( + "All arguments passed to __init__ must be assigned to the instance as protected attributes" + " in a SerializableDataClass." + ) + + if hasattr(instance, "__post_init__") and not getattr( + instance, "_post_initialized", False + ): + instance.__post_init__() + instance._post_initialized = True # type: ignore[attr-defined] + + return instance + + +class SerializableABCMeta(RegisteredPostInitMeta, abc.ABCMeta): + """ + Metaclass that combines SerializableDataClassMeta logic with ABCMeta + to allow abstract methods in serializable classes. + """ + + +class SerializableABC(metaclass=SerializableABCMeta): + """ + Abstract base class for serializable objects. This class allows + the use of abstract methods in serializable classes. + """ + + +class SerializableDataClassABCMeta(SerializableDataClassMeta, abc.ABCMeta): + """ + Metaclass that combines SerializableDataClassMeta logic with ABCMeta + to allow abstract methods in serializable data classes. + """ + + +class SerializableDataClass(Serializable, metaclass=SerializableDataClassMeta): + """ + Serializable data class that ensures arguments passed to __init__ + are protected attributes. Automatically registers subclasses and + supports dict() and JSON exports. + """ + + def get_instance_variables(self, immutable: bool = False) -> dict | tuple: + """ + Get a dictionary representation of the criterion. + """ + + if immutable: + raise NotImplementedError( + "get_instance_variables() must be implemented in subclasses." + ) + + return { + var: getattr(self, f"_{var}") + for var in get_constructor_vars(self.__class__) + } + + +class SerializableDataClassABC( + SerializableDataClass, metaclass=SerializableDataClassABCMeta +): + """ + Abstract variant of a serializable data class. Requires subclasses + to implement abstract methods, while keeping the serialization + features and protected attribute checks. + """ diff --git a/execution_engine/util/temporal_logic_util.py b/execution_engine/util/temporal_logic_util.py new file mode 100644 index 00000000..42c06e51 --- /dev/null +++ b/execution_engine/util/temporal_logic_util.py @@ -0,0 +1,156 @@ +from datetime import time + +from execution_engine.util import logic +from execution_engine.util.enum import TimeIntervalType + + +def Presence( + criterion: logic.BaseExpr, + *, + interval_type: TimeIntervalType | None = None, + start_time: time | None = None, + end_time: time | None = None, + interval_criterion: logic.BaseExpr | None = None, +) -> logic.TemporalMinCount: + """ + Create a presence combination of criteria. + """ + return logic.TemporalMinCount( + criterion, + threshold=1, + interval_type=interval_type, + start_time=start_time, + end_time=end_time, + interval_criterion=interval_criterion, + ) + + +def MinCount( + criterion: logic.BaseExpr, + *, + threshold: int, + interval_type: TimeIntervalType | None = None, + start_time: time | None = None, + end_time: time | None = None, + interval_criterion: logic.BaseExpr | None = None, +) -> logic.TemporalMinCount: + """ + Create a minimum count combination of criteria. + """ + return logic.TemporalMinCount( + criterion, + threshold=threshold, + interval_type=interval_type, + start_time=start_time, + end_time=end_time, + interval_criterion=interval_criterion, + ) + + +def MaxCount( + criterion: logic.BaseExpr, + *, + threshold: int, + interval_type: TimeIntervalType | None = None, + start_time: time | None = None, + end_time: time | None = None, + interval_criterion: logic.BaseExpr | None = None, +) -> logic.TemporalMaxCount: + """ + Create a maximum count combination of criteria. + """ + return logic.TemporalMaxCount( + criterion, + threshold=threshold, + interval_type=interval_type, + start_time=start_time, + end_time=end_time, + interval_criterion=interval_criterion, + ) + + +def ExactCount( + criterion: logic.BaseExpr, + *, + threshold: int, + interval_type: TimeIntervalType | None = None, + start_time: time | None = None, + end_time: time | None = None, + interval_criterion: logic.BaseExpr | None = None, +) -> logic.TemporalExactCount: + """ + Create an exact count combination of criteria. + """ + return logic.TemporalExactCount( + criterion, + threshold=threshold, + interval_type=interval_type, + start_time=start_time, + end_time=end_time, + interval_criterion=interval_criterion, + ) + + +def MorningShift( + criterion: logic.BaseExpr, +) -> logic.TemporalMinCount: + """ + Create a morning shift combination of criteria. + """ + return Presence(criterion, interval_type=TimeIntervalType.MORNING_SHIFT) + + +def AfternoonShift( + criterion: logic.BaseExpr, +) -> logic.TemporalMinCount: + """ + Create an afternoon shift combination of criteria. + """ + return Presence(criterion, interval_type=TimeIntervalType.AFTERNOON_SHIFT) + + +def NightShift( + criterion: logic.BaseExpr, +) -> logic.TemporalMinCount: + """ + Create a night shift combination of criteria. + """ + return Presence(criterion, interval_type=TimeIntervalType.NIGHT_SHIFT) + + +def NightShiftBeforeMidnight( + criterion: logic.BaseExpr, +) -> logic.TemporalMinCount: + """ + Create a night shift before midnight combination of criteria. + """ + return Presence( + criterion, interval_type=TimeIntervalType.NIGHT_SHIFT_BEFORE_MIDNIGHT + ) + + +def NightShiftAfterMidnight( + criterion: logic.BaseExpr, +) -> logic.TemporalMinCount: + """ + Create a night shift after midnight combination of criteria. + """ + return Presence( + criterion, interval_type=TimeIntervalType.NIGHT_SHIFT_AFTER_MIDNIGHT + ) + + +def Day( + criterion: logic.BaseExpr, +) -> logic.TemporalMinCount: + """ + Create a day combination of criteria. + """ + return Presence(criterion, interval_type=TimeIntervalType.DAY) + + +def AnyTime(criterion: logic.BaseExpr) -> logic.TemporalMinCount: + """ + Any time overlap + """ + return Presence(criterion, interval_type=TimeIntervalType.ANY_TIME) diff --git a/execution_engine/util/types.py b/execution_engine/util/types.py index cf099700..7ff9982a 100644 --- a/execution_engine/util/types.py +++ b/execution_engine/util/types.py @@ -5,6 +5,7 @@ import pytz from pydantic import BaseModel, ConfigDict, field_validator, model_validator +from execution_engine.util import serializable from execution_engine.util.enum import TimeUnit from execution_engine.util.interval import ( DateTimeInterval, @@ -17,6 +18,7 @@ PersonIntervals = dict[int, Any] +@serializable.register_class class TimeRange(BaseModel): """ A time range. @@ -91,6 +93,7 @@ def model_dump(self, *args: Any, **kwargs: Any) -> dict[str, datetime]: } +@serializable.register_class class Timing(BaseModel): """ The timing of a criterion. @@ -212,6 +215,7 @@ def model_dump( return data +@serializable.register_class class Dosage(Timing): """ A dosage consisting of a dose in addition to the Timing fields. diff --git a/execution_engine/util/value/time.py b/execution_engine/util/value/time.py index da2de370..8d905187 100644 --- a/execution_engine/util/value/time.py +++ b/execution_engine/util/value/time.py @@ -5,6 +5,7 @@ from sqlalchemy import TableClause from sqlalchemy.sql.functions import concat, func +from execution_engine.util import serializable from execution_engine.util.enum import TimeUnit from execution_engine.util.value import ucum_to_postgres from execution_engine.util.value.value import ( @@ -16,6 +17,7 @@ ) +@serializable.register_class class ValuePeriod(ValueNumeric[NonNegativeInt, TimeUnit]): """ A non-negative integer value with a unit of type TimeUnit, where value_min and value_max are not allowed. @@ -52,6 +54,7 @@ def to_sql_interval(self) -> ColumnElement: ) +@serializable.register_class class ValueCount(ValueNumeric[NonNegativeInt, None]): """ A non-negative integer value without a unit. @@ -71,6 +74,7 @@ def supports_units(self) -> bool: return False +@serializable.register_class class ValueDuration(ValueNumeric[float, TimeUnit]): """ A float value with a unit of type TimeUnit. @@ -104,6 +108,7 @@ def to_sql( return super().to_sql(table, c / c_duration_seconds, with_unit) +@serializable.register_class class ValueFrequency(Value): """ A non-negative integer value with no unit, with a Period. diff --git a/execution_engine/util/value/value.py b/execution_engine/util/value/value.py index 72c892c7..7b795716 100644 --- a/execution_engine/util/value/value.py +++ b/execution_engine/util/value/value.py @@ -21,6 +21,8 @@ "check_int", ] +from execution_engine.util import serializable + ValueT = TypeVar("ValueT") UnitT = TypeVar("UnitT") ValueNumericClassT = TypeVar("ValueNumericClassT", bound="ValueNumeric") @@ -90,6 +92,7 @@ def get_precision(value: float | int) -> int: ) # the number of digits after the decimal point is the precision +@serializable.register_class class Value(BaseModel, ABC): """A value in a criterion.""" @@ -145,6 +148,7 @@ def supports_units(self) -> bool: return hasattr(self, "unit") +@serializable.register_class class ValueNumeric(Value, Generic[ValueT, UnitT]): """ A value of type number. @@ -294,6 +298,7 @@ def eps(number: ValueT) -> float: return and_(*clauses) +@serializable.register_class class ValueNumber(ValueNumeric[float, Concept]): """ A float value with a unit of type Concept. @@ -307,6 +312,7 @@ def _get_unit_clause(self, table: TableClause | None) -> ColumnClause: return c_unit == self.unit.concept_id +@serializable.register_class class ValueScalar(ValueNumeric[float, None]): """ A numeric value without a unit. @@ -323,6 +329,7 @@ def supports_units(self) -> bool: return False +@serializable.register_class class ValueConcept(Value): """ A value of type concept. diff --git a/execution_engine/omop/criterion/combination/__init__.py b/tests/execution_engine/__init__.py similarity index 100% rename from execution_engine/omop/criterion/combination/__init__.py rename to tests/execution_engine/__init__.py diff --git a/tests/execution_engine/omop/__init__.py b/tests/execution_engine/omop/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/execution_engine/omop/criterion/__init__.py b/tests/execution_engine/omop/criterion/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/execution_engine/util/test_cohort_logic.py b/tests/execution_engine/util/test_logic.py similarity index 99% rename from tests/execution_engine/util/test_cohort_logic.py rename to tests/execution_engine/util/test_logic.py index 0c3bbabd..86d34a07 100644 --- a/tests/execution_engine/util/test_cohort_logic.py +++ b/tests/execution_engine/util/test_logic.py @@ -4,7 +4,7 @@ import pytest from execution_engine.constants import CohortCategory -from execution_engine.omop.criterion.combination.temporal import TimeIntervalType +from execution_engine.util.enum import TimeIntervalType from execution_engine.util.logic import ( AllOrNone, And, From a76a9846528e8479064c4716b71f1807f13a75e5 Mon Sep 17 00:00:00 2001 From: Gregor Lichtner Date: Fri, 14 Mar 2025 14:21:46 +0100 Subject: [PATCH 08/16] tests: adopted to exclude criterion combination --- tests/_fixtures/omop_fixture.py | 4 +- .../action/test_drug_administration.py | 4 +- .../omop/cohort/test_cohort_recommendation.py | 57 ++- .../test_population_intervention_pair.py | 17 +- .../omop/criterion/combination/__init__.py | 56 --- .../combination/test_logical_combination.py | 242 +++-------- .../combination/test_temporal_combination.py | 405 +++++++----------- .../criterion/custom/test_tidal_volume.py | 2 +- .../omop/criterion/test_criterion.py | 80 ++-- .../omop/criterion/test_drug_exposure.py | 6 +- tests/execution_engine/omop/test_concepts.py | 6 +- tests/execution_engine/util/test_logic.py | 121 +++--- tests/execution_engine/util/test_types.py | 5 +- tests/mocks/criterion.py | 21 - 14 files changed, 370 insertions(+), 656 deletions(-) diff --git a/tests/_fixtures/omop_fixture.py b/tests/_fixtures/omop_fixture.py index 174dece3..c682bc55 100644 --- a/tests/_fixtures/omop_fixture.py +++ b/tests/_fixtures/omop_fixture.py @@ -177,7 +177,7 @@ def celida_recommendation( recommendation_id=recommendation_id, pi_pair_url="https://example.com", pi_pair_name="my_pair", - pi_pair_hash=hash("my_pair"), + pi_pair_hash=str(hash("my_pair")), ) db_session.add(pi_pair) db_session.commit() @@ -185,7 +185,7 @@ def celida_recommendation( criterion = Criterion( criterion_id=criterion_id, criterion_description="my_description", - criterion_hash=hash("my_criterion"), + criterion_hash=str(hash("my_criterion")), ) db_session.add(criterion) db_session.commit() diff --git a/tests/execution_engine/converter/action/test_drug_administration.py b/tests/execution_engine/converter/action/test_drug_administration.py index 69bf83b7..2b101a85 100644 --- a/tests/execution_engine/converter/action/test_drug_administration.py +++ b/tests/execution_engine/converter/action/test_drug_administration.py @@ -64,8 +64,8 @@ def test_multiple_doses(self): ], ) - comb = action.to_expression() - criteria = list(comb) + expr = action.to_expression() + criteria = list(expr.args) assert len(criteria) == 3 assert all(isinstance(c, DrugExposure) for c in criteria) diff --git a/tests/execution_engine/omop/cohort/test_cohort_recommendation.py b/tests/execution_engine/omop/cohort/test_cohort_recommendation.py index 750e285e..37f8ab0a 100644 --- a/tests/execution_engine/omop/cohort/test_cohort_recommendation.py +++ b/tests/execution_engine/omop/cohort/test_cohort_recommendation.py @@ -1,9 +1,10 @@ import pytest -from execution_engine.omop.cohort import PopulationInterventionPair +from execution_engine.omop.cohort import PopulationInterventionPairExpr from execution_engine.omop.cohort.recommendation import Recommendation from execution_engine.omop.concepts import Concept from execution_engine.omop.criterion.visit_occurrence import ActivePatients +from execution_engine.util import logic from tests.mocks.criterion import MockCriterion @@ -11,12 +12,9 @@ class TestRecommendation: def test_serialization(self): # Register the mock criterion class - from execution_engine.omop.criterion import factory - - factory.register_criterion_class("MockCriterion", MockCriterion) original = Recommendation( - pi_pairs=[], + expr=logic.BooleanFunction(), base_criterion=MockCriterion("c"), name="foo", title="bar", @@ -47,7 +45,7 @@ def test_serialization_with_active_patients(self, concept): # specifically because there used to be a problem with the # serialization of that combination. original = Recommendation( - pi_pairs=[], + expr=logic.BooleanFunction(), base_criterion=ActivePatients(), name="foo", title="bar", @@ -60,11 +58,12 @@ def test_serialization_with_active_patients(self, concept): deserialized = Recommendation.from_json(json) assert original == deserialized - def test_database_serialization(self, concept): + def test_database_serialization( + self, + concept, + db_session, # use_db_session fixture for a proper database clean up after this test + ): from execution_engine.execution_engine import ExecutionEngine - from execution_engine.omop.criterion.factory import register_criterion_class - - register_criterion_class("MockCriterion", MockCriterion) e = ExecutionEngine(verbose=False) @@ -74,18 +73,18 @@ def test_database_serialization(self, concept): population_criterion = MockCriterion("p") intervention_criterion = MockCriterion("i)") - pi_pair = PopulationInterventionPair( + pi_pair = PopulationInterventionPairExpr( + population_expr=population_criterion, + intervention_expr=intervention_criterion, name="foo", url="foo", base_criterion=base_criterion, - population=population_criterion, - intervention=intervention_criterion, ) pi_pairs = [pi_pair] recommendation = Recommendation( - pi_pairs=pi_pairs, + expr=pi_pair, base_criterion=ActivePatients(), name="foo", title="bar", @@ -105,33 +104,33 @@ def test_database_serialization(self, concept): with pytest.raises(ValueError, match=r"Database ID has not been set yet!"): assert pi_pair.id is None - for criterion in pi_pair.flatten(): - with pytest.raises( - ValueError, match=r"Database ID has not been set yet!" - ): - assert criterion.id is None + for criterion in recommendation.atoms(): + with pytest.raises(ValueError, match=r"Database ID has not been set yet!"): + assert criterion.id is None e.register_recommendation(recommendation) assert recommendation.id is not None assert recommendation.base_criterion.id is not None + for pi_pair in recommendation.population_intervention_pairs(): assert pi_pair.id is not None - for criterion in pi_pair.flatten(): - assert criterion.id is not None + + for criterion in recommendation.atoms(): + assert criterion.id is not None rec_loaded = e.load_recommendation_from_database(url) assert recommendation == rec_loaded assert rec_loaded.id == recommendation.id - assert len(rec_loaded._pi_pairs) == len(pi_pairs) + assert len(list(rec_loaded.population_intervention_pairs())) == 1 - for pi_pair_loaded, pi_pair in zip(rec_loaded._pi_pairs, pi_pairs): + for pi_pair_loaded, pi_pair in zip( + list(rec_loaded.population_intervention_pairs()), pi_pairs + ): assert pi_pair_loaded.id == pi_pair.id - assert len(pi_pair_loaded.flatten()) == len(pi_pair.flatten()) - - for criterion_loaded, criterion in zip( - pi_pair_loaded.flatten(), pi_pair.flatten() - ): - assert criterion.id == criterion_loaded.id + for criterion_loaded, criterion in zip( + rec_loaded.atoms(), recommendation.atoms() + ): + assert criterion.id == criterion_loaded.id diff --git a/tests/execution_engine/omop/cohort/test_population_intervention_pair.py b/tests/execution_engine/omop/cohort/test_population_intervention_pair.py index 12910d41..2993cbfc 100644 --- a/tests/execution_engine/omop/cohort/test_population_intervention_pair.py +++ b/tests/execution_engine/omop/cohort/test_population_intervention_pair.py @@ -1,6 +1,7 @@ from execution_engine.omop.cohort.population_intervention_pair import ( - PopulationInterventionPair, + PopulationInterventionPairExpr, ) +from execution_engine.util import logic from tests.mocks.criterion import MockCriterion @@ -8,14 +9,14 @@ class TestPopulationInterventionPair: def test_serialization(self): # Register the mock criterion class - from execution_engine.omop.criterion import factory - - factory.register_criterion_class("MockCriterion", MockCriterion) - - original = PopulationInterventionPair( - name="foo", url="bar", base_criterion=MockCriterion("c") + original = PopulationInterventionPairExpr( + population_expr=logic.NonSimplifiableAnd(MockCriterion("population")), + intervention_expr=logic.NonSimplifiableAnd(MockCriterion("intervention")), + name="foo", + url="bar", + base_criterion=MockCriterion("base"), ) json = original.json() - deserialized = PopulationInterventionPair.from_json(json) + deserialized = PopulationInterventionPairExpr.from_json(json) assert original == deserialized diff --git a/tests/execution_engine/omop/criterion/combination/__init__.py b/tests/execution_engine/omop/criterion/combination/__init__.py index 5829cf14..e69de29b 100644 --- a/tests/execution_engine/omop/criterion/combination/__init__.py +++ b/tests/execution_engine/omop/criterion/combination/__init__.py @@ -1,56 +0,0 @@ -from typing import Any, Dict, Self - -from sqlalchemy import Select, select - -from execution_engine.omop.criterion.abstract import ( - Criterion, - column_interval_type, - observation_end_datetime, - observation_start_datetime, -) -from execution_engine.util.interval import IntervalType - - -class NoopCriterion(Criterion): - """ - Select patients who are post-surgical in the timeframe between the day of the surgery and 6 days after the surgery. - """ - - _static = True - - def _create_query(self) -> Select: - """ - Get the SQL Select query for data required by this criterion. - """ - subquery = self.base_query().subquery() - - query = select( - subquery.c.person_id, - column_interval_type(IntervalType.POSITIVE), - observation_start_datetime.label("interval_start"), - observation_end_datetime.label("interval_end"), - ) - - query = self._filter_base_persons(query, c_person_id=subquery.c.person_id) - query = self._filter_datetime(query) - - return query - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> Self: - """ - Create an object from a dictionary. - """ - return cls() - - def description(self) -> str: - """ - Get a description of the criterion. - """ - return self.__class__.__name__ - - def dict(self) -> dict: - """ - Get a dictionary representation of the object. - """ - return {} diff --git a/tests/execution_engine/omop/criterion/combination/test_logical_combination.py b/tests/execution_engine/omop/criterion/combination/test_logical_combination.py index 62f9faee..6c83d4b2 100644 --- a/tests/execution_engine/omop/criterion/combination/test_logical_combination.py +++ b/tests/execution_engine/omop/criterion/combination/test_logical_combination.py @@ -6,15 +6,13 @@ import sympy from execution_engine.constants import CohortCategory -from execution_engine.omop.criterion.combination.logical import ( - LogicalCriterionCombination, - NonCommutativeLogicalCriterionCombination, -) from execution_engine.omop.criterion.condition_occurrence import ConditionOccurrence from execution_engine.omop.criterion.drug_exposure import DrugExposure from execution_engine.omop.criterion.measurement import Measurement +from execution_engine.omop.criterion.noop import NoopCriterion from execution_engine.omop.criterion.procedure_occurrence import ProcedureOccurrence from execution_engine.task.process import get_processing_module +from execution_engine.util import logic from execution_engine.util.types import Dosage, TimeRange from execution_engine.util.value import ValueNumber from tests._fixtures.concept import ( @@ -28,7 +26,6 @@ concept_unit_ml, ) from tests._testdata import concepts -from tests.execution_engine.omop.criterion.combination import NoopCriterion from tests.execution_engine.omop.criterion.test_criterion import TestCriterion, date_set from tests.functions import ( create_condition, @@ -51,196 +48,83 @@ def intervals_to_df(result, by=None): return df -class TestCriterionCombination: +class TestExpr: """ - Test class for testing criterion combinations (without database). + Test class for testing Expr """ @pytest.fixture def mock_criteria(self): return [MockCriterion(f"c{i}") for i in range(1, 6)] - def test_criterion_combination_init(self, mock_criteria): - operator = LogicalCriterionCombination.Operator( - LogicalCriterionCombination.Operator.AND - ) - combination = LogicalCriterionCombination( - operator=operator, - ) - - assert combination.operator == operator - assert len(combination) == 0 - - def test_criterion_combination_add(self, mock_criteria): - operator = LogicalCriterionCombination.Operator( - LogicalCriterionCombination.Operator.AND - ) - combination = LogicalCriterionCombination( - operator=operator, - ) - - for criterion in mock_criteria: - combination.add(criterion) - - assert len(combination) == len(mock_criteria) - - for idx, criterion in enumerate(combination): - assert criterion == mock_criteria[idx] + def test_expr_dict(self, mock_criteria): - def test_criterion_combination_dict(self, mock_criteria): - operator = LogicalCriterionCombination.Operator( - LogicalCriterionCombination.Operator.AND - ) - combination = LogicalCriterionCombination( - operator=operator, - ) + expr = logic.And(*mock_criteria) - for criterion in mock_criteria: - combination.add(criterion) + combination_dict = expr.dict() - combination_dict = combination.dict() assert combination_dict == { - "operator": "AND", - "threshold": None, - "criteria": [ - {"class_name": "MockCriterion", "data": criterion.dict()} - for criterion in mock_criteria - ], - "root": False, + "type": "And", + "data": {"args": [criterion.dict() for criterion in mock_criteria]}, } - def test_criterion_combination_from_dict(self, mock_criteria): - operator = LogicalCriterionCombination.Operator( - LogicalCriterionCombination.Operator.AND - ) - combination_data = { - "operator": "AND", - "threshold": None, - "criteria": [ - {"class_name": "MockCriterion", "data": criterion.dict()} - for criterion in mock_criteria - ], - "root": False, + def test_expr_from_dict(self, mock_criteria): + expr_data = { + "type": "And", + "data": {"args": [criterion.dict() for criterion in mock_criteria]}, } - # Register the mock criterion class - from execution_engine.omop.criterion import factory - - factory.register_criterion_class("MockCriterion", MockCriterion) + expr = logic.Expr.from_dict(expr_data) - combination = LogicalCriterionCombination.from_dict(combination_data) + assert len(expr.args) == len(mock_criteria) - assert combination.operator == operator - assert len(combination) == len(mock_criteria) - - for idx, criterion in enumerate(combination): + for idx, criterion in enumerate(expr.args): assert str(criterion) == str(mock_criteria[idx]) @pytest.mark.parametrize( - "operator", + "expr_class", [ - LogicalCriterionCombination.Operator( - LogicalCriterionCombination.Operator.AND, None - ), - LogicalCriterionCombination.Operator( - LogicalCriterionCombination.Operator.AT_LEAST, 10 - ), + logic.And, + lambda *args: logic.MinCount(*args, threshold=10), ], ) - def test_criterion_combination_serialization(self, operator, mock_criteria): + def test_criterion_combination_serialization(self, expr_class, mock_criteria): # Register the mock criterion class - from execution_engine.omop.criterion import factory - factory.register_criterion_class("MockCriterion", MockCriterion) + expr = expr_class(*mock_criteria) - combination = LogicalCriterionCombination( - operator=operator, - ) - for criterion in mock_criteria: - combination.add(criterion) + json = expr.json() + deserialized = logic.Expr.from_json(json) - json = combination.json() - deserialized = LogicalCriterionCombination.from_json(json) - assert combination == deserialized + assert expr == deserialized def test_noncommutative_logical_criterion_combination_serialization( self, mock_criteria ): - # Register the mock criterion class - from execution_engine.omop.criterion import factory - - factory.register_criterion_class("MockCriterion", MockCriterion) - - operator = NonCommutativeLogicalCriterionCombination.Operator( - NonCommutativeLogicalCriterionCombination.Operator.CONDITIONAL_FILTER - ) - combination = NonCommutativeLogicalCriterionCombination( - operator=operator, + expr = logic.ConditionalFilter( left=mock_criteria[0], right=mock_criteria[1], ) - json = combination.json() - deserialized = NonCommutativeLogicalCriterionCombination.from_json(json) - assert combination == deserialized + json = expr.json() + deserialized = logic.Expr.from_json(json) - @pytest.mark.parametrize("operator", ["AT_LEAST", "AT_MOST", "EXACTLY"]) - def test_operator_with_threshold(self, operator): - with pytest.raises( - AssertionError, match=f"Threshold must be set for operator {operator}" - ): - LogicalCriterionCombination.Operator(operator) + assert expr == deserialized - def test_operator(self): - with pytest.raises(AssertionError, match=""): - LogicalCriterionCombination.Operator("invalid") - - @pytest.mark.parametrize( - "operator, threshold", - [("AND", None), ("OR", None), ("AT_LEAST", 1), ("AT_MOST", 1), ("EXACTLY", 1)], + @pytest.mark.skip( + reason="repr does not have a fixed argument order, therefore test fails randomly" ) - def test_operator_str(self, operator, threshold): - op = LogicalCriterionCombination.Operator(operator, threshold) - - if operator in ["AT_LEAST", "AT_MOST", "EXACTLY"]: - assert ( - repr(op) - == f'LogicalCriterionCombination.Operator(operator="{operator}", threshold={threshold})' - ) - assert str(op) == f"{operator}(threshold={threshold})" - else: - assert ( - repr(op) - == f'LogicalCriterionCombination.Operator(operator="{operator}")' - ) - assert str(op) == f"{operator}" - - def test_repr(self): - operator = LogicalCriterionCombination.Operator( - LogicalCriterionCombination.Operator.AND - ) - combination = LogicalCriterionCombination( - operator=operator, - ) - - assert repr(combination) == ("LogicalCriterionCombination.And(\n" ")") - - def test_add_all(self): - operator = LogicalCriterionCombination.Operator( - LogicalCriterionCombination.Operator.AND - ) - combination = LogicalCriterionCombination( - operator=operator, - ) - - assert len(combination) == 0 + def test_repr(self, mock_criteria): + expr = logic.And(*mock_criteria) - combination.add_all([MockCriterion("c1"), MockCriterion("c2")]) + assert repr(expr) == ("LogicalCriterionCombination.And(\n" ")") - assert len(combination) == 2 + def test_expr_contains_criteria(self, mock_criteria): + expr = logic.And(*mock_criteria) + assert len(expr.args) == len(mock_criteria) - assert str(combination[0]) == str(MockCriterion("c1")) - assert str(combination[1]) == str(MockCriterion("c2")) + for i in range(len(mock_criteria)): + assert expr.args[i] == mock_criteria[i] class TestCriterionCombinationDatabase(TestCriterion): @@ -306,41 +190,43 @@ def run_criteria_test( threshold = None if c.func == sympy.And: - operator = LogicalCriterionCombination.Operator.AND + cls = logic.And elif c.func == sympy.Or: - operator = LogicalCriterionCombination.Operator.OR + cls = logic.Or elif isinstance(c.func, sympy.core.function.UndefinedFunction): if c.func.name in ["MinCount", "MaxCount", "ExactCount"]: assert args[0].is_number, "First argument must be a number (threshold)" - threshold = args[0] + threshold = int(args[0]) args = args[1:] if c.func.name == "MinCount": - operator = LogicalCriterionCombination.Operator.AT_LEAST + cls = lambda *args: logic.MinCount(*args, threshold=threshold) elif c.func.name == "MaxCount": - operator = LogicalCriterionCombination.Operator.AT_MOST + cls = lambda *args: logic.MaxCount(*args, threshold=threshold) elif c.func.name == "ExactCount": - operator = LogicalCriterionCombination.Operator.EXACTLY + cls = lambda *args: logic.ExactCount(*args, threshold=threshold) elif c.func.name == "AllOrNone": - operator = LogicalCriterionCombination.Operator.ALL_OR_NONE + cls = lambda *args: logic.AllOrNone(*args) elif c.func.name == "ConditionalFilter": - operator = None + cls = lambda *args: logic.ConditionalFilter(*args) else: raise ValueError(f"Unknown operator {c.func}") else: raise ValueError(f"Unknown operator {c.func}") - c1, c2, c3 = [ - c.copy() for c in criteria - ] # TODO(jmoringe): copy should no longer be necessary + # c1, c2, c3 = [ + # c for c in criteria + # ] # TODO(jmoringe): copy should no longer be necessary + + c1, c2, c3 = criteria for arg in args: if arg.is_Not: if arg.args[0].name == "c1": - c1 = LogicalCriterionCombination.Not(c1) + c1 = logic.Not(c1) elif arg.args[0].name == "c2": - c2 = LogicalCriterionCombination.Not(c2) + c2 = logic.Not(c2) elif arg.args[0].name == "c3": - c3 = LogicalCriterionCombination.Not(c3) + c3 = logic.Not(c3) else: raise ValueError(f"Unknown criterion {arg.args[0].name}") @@ -349,33 +235,25 @@ def run_criteria_test( if hasattr(c.func, "name") and c.func.name == "ConditionalFilter": assert len(c.args) == 2 - comb = NonCommutativeLogicalCriterionCombination.ConditionalFilter( + comb = logic.ConditionalFilter( left=symbols[str(c.args[0])], right=symbols[str(c.args[1])], ) else: - comb = LogicalCriterionCombination( - operator=LogicalCriterionCombination.Operator( - operator, threshold=threshold - ), + comb = cls( + *[symbols[str(symbol)] for symbol in c.atoms() if not symbol.is_number] ) - for symbol in c.atoms(): - if symbol.is_number: - continue - else: - comb.add(symbols[str(symbol)]) - if exclude: - comb = LogicalCriterionCombination.Not(comb) + comb = logic.Not(comb) noop_criterion = NoopCriterion() noop_criterion.set_id(1005) - noop_intervention = LogicalCriterionCombination.And(noop_criterion) + noop_intervention = logic.NonSimplifiableAnd(noop_criterion) self.register_criterion(noop_criterion, db_session) - self.insert_criterion_combination( + self.insert_expression( db_session, population=comb, intervention=noop_intervention, diff --git a/tests/execution_engine/omop/criterion/combination/test_temporal_combination.py b/tests/execution_engine/omop/criterion/combination/test_temporal_combination.py index 7501c6fa..b5c7a4f9 100644 --- a/tests/execution_engine/omop/criterion/combination/test_temporal_combination.py +++ b/tests/execution_engine/omop/criterion/combination/test_temporal_combination.py @@ -1,5 +1,4 @@ import datetime -from typing import Any, Dict, Self import pandas as pd import pendulum @@ -8,20 +7,15 @@ from execution_engine.constants import CohortCategory from execution_engine.omop.criterion.abstract import Criterion, column_interval_type -from execution_engine.omop.criterion.combination.logical import ( - LogicalCriterionCombination, -) -from execution_engine.omop.criterion.combination.temporal import ( - FixedWindowTemporalIndicatorCombination, - PersonalWindowTemporalIndicatorCombination, - TimeIntervalType, -) from execution_engine.omop.criterion.condition_occurrence import ConditionOccurrence from execution_engine.omop.criterion.drug_exposure import DrugExposure from execution_engine.omop.criterion.measurement import Measurement +from execution_engine.omop.criterion.noop import NoopCriterion from execution_engine.omop.criterion.procedure_occurrence import ProcedureOccurrence from execution_engine.omop.db.omop import tables from execution_engine.task.process import get_processing_module +from execution_engine.util import logic, temporal_logic_util +from execution_engine.util.enum import TimeIntervalType from execution_engine.util.interval import IntervalType from execution_engine.util.types import Dosage, TimeRange from execution_engine.util.value import ValueNumber @@ -36,7 +30,6 @@ concept_unit_mg, ) from tests._testdata import concepts -from tests.execution_engine.omop.criterion.combination import NoopCriterion from tests.execution_engine.omop.criterion.test_criterion import TestCriterion from tests.functions import ( create_condition, @@ -70,206 +63,128 @@ class TestFixedWindowTemporalIndicatorCombination: def mock_criteria(self): return [MockCriterion(f"c{i}") for i in range(1, 6)] - def test_criterion_combination_init(self, mock_criteria): - operator = FixedWindowTemporalIndicatorCombination.Operator( - FixedWindowTemporalIndicatorCombination.Operator.AT_LEAST, threshold=1 - ) - combination = FixedWindowTemporalIndicatorCombination( - operator=operator, - start_time=datetime.time(8, 0), - end_time=datetime.time(16, 0), - ) - - assert combination.operator == operator - assert len(combination) == 0 - - def test_criterion_combination_add(self, mock_criteria): - operator = FixedWindowTemporalIndicatorCombination.Operator( - FixedWindowTemporalIndicatorCombination.Operator.AT_LEAST, threshold=1 - ) - combination = FixedWindowTemporalIndicatorCombination( - operator=operator, - start_time=datetime.time(8, 0), - end_time=datetime.time(16, 0), - ) - - for criterion in mock_criteria: - combination.add(criterion) - - assert len(combination) == len(mock_criteria) - - for idx, criterion in enumerate(combination): - assert criterion == mock_criteria[idx] - def test_criterion_combination_dict(self, mock_criteria): - operator = FixedWindowTemporalIndicatorCombination.Operator( - FixedWindowTemporalIndicatorCombination.Operator.AT_LEAST, threshold=1 - ) - combination = FixedWindowTemporalIndicatorCombination( - operator=operator, + + expr = logic.TemporalMinCount( + *mock_criteria, + threshold=1, start_time=datetime.time(8, 0), end_time=datetime.time(16, 0), ) - for criterion in mock_criteria: - combination.add(criterion) - - combination_dict = combination.dict() - assert combination_dict == { - "operator": "AT_LEAST", - "threshold": 1, - "start_time": "08:00:00", - "end_time": "16:00:00", - "interval_type": None, - "criteria": [ - {"class_name": "MockCriterion", "data": criterion.dict()} - for criterion in mock_criteria - ], - "root": False, + expr_dict = expr.dict() + assert expr_dict == { + "type": "TemporalMinCount", + "data": { + "threshold": 1, + "start_time": "08:00:00", + "end_time": "16:00:00", + "interval_type": None, + "interval_criterion": None, + "args": [criterion.dict() for criterion in mock_criteria], + }, } def test_criterion_combination_from_dict(self, mock_criteria): - # Register the mock criterion class - from execution_engine.omop.criterion import factory - - factory.register_criterion_class("MockCriterion", MockCriterion) - operator = FixedWindowTemporalIndicatorCombination.Operator( - FixedWindowTemporalIndicatorCombination.Operator.AT_LEAST, threshold=1 - ) - - combination_data = { - "id": None, - "operator": "AT_LEAST", - "threshold": 1, - "start_time": "08:00:00", - "end_time": "16:00:00", - "interval_type": None, - "criteria": [ - {"class_name": "MockCriterion", "data": criterion.dict()} - for criterion in mock_criteria - ], + expr_dict = { + "type": "TemporalMinCount", + "data": { + "threshold": 1, + "start_time": "08:00:00", + "end_time": "16:00:00", + "interval_type": None, + "interval_criterion": None, + "args": [criterion.dict() for criterion in mock_criteria], + }, } - combination = FixedWindowTemporalIndicatorCombination.from_dict( - combination_data - ) + expr = logic.Expr.from_dict(expr_dict) - assert combination.operator == operator - assert len(combination) == len(mock_criteria) - assert combination.start_time == datetime.time(8, 0) - assert combination.end_time == datetime.time(16, 0) - assert combination.interval_type is None + assert len(expr.args) == len(mock_criteria) + assert expr.start_time == datetime.time(8, 0) + assert expr.end_time == datetime.time(16, 0) + assert expr.interval_type is None + assert expr.interval_criterion is None - for idx, criterion in enumerate(combination): + for idx, criterion in enumerate(expr.args): assert str(criterion) == str(mock_criteria[idx]) - combination_data = { - "operator": "AT_LEAST", - "threshold": 1, - "start_time": None, - "end_time": None, - "interval_type": TimeIntervalType.MORNING_SHIFT, - "criteria": [ - {"class_name": "MockCriterion", "data": criterion.dict()} - for criterion in mock_criteria - ], + expr_dict = { + "type": "TemporalMinCount", + "data": { + "threshold": 1, + "start_time": None, + "end_time": None, + "interval_type": TimeIntervalType.MORNING_SHIFT, + "interval_criterion": None, + "args": [criterion.dict() for criterion in mock_criteria], + }, } - combination = FixedWindowTemporalIndicatorCombination.from_dict( - combination_data - ) + expr = logic.Expr.from_dict(expr_dict) - assert combination.operator == operator - assert len(combination) == len(mock_criteria) - assert combination.start_time is None - assert combination.end_time is None - assert combination.interval_type == TimeIntervalType.MORNING_SHIFT + assert len(expr.args) == len(mock_criteria) + assert expr.start_time is None + assert expr.end_time is None + assert expr.interval_type == TimeIntervalType.MORNING_SHIFT + assert expr.interval_criterion is None - for idx, criterion in enumerate(combination): + for idx, criterion in enumerate(expr.args): assert str(criterion) == str(mock_criteria[idx]) - @pytest.mark.parametrize("operator", ["AT_LEAST", "AT_MOST", "EXACTLY"]) - def test_operator_with_threshold(self, operator): - with pytest.raises( - AssertionError, match=f"Threshold must be set for operator {operator}" - ): - FixedWindowTemporalIndicatorCombination.Operator(operator) - - def test_operator(self): - with pytest.raises(AssertionError, match=""): - FixedWindowTemporalIndicatorCombination.Operator("invalid") - - @pytest.mark.parametrize( - "operator, threshold", - [("AT_LEAST", 1), ("AT_MOST", 1), ("EXACTLY", 1)], - ) - def test_operator_str(self, operator, threshold): - op = FixedWindowTemporalIndicatorCombination.Operator(operator, threshold) - - if operator in ["AT_LEAST", "AT_MOST", "EXACTLY"]: - assert ( - repr(op) - == f'TemporalIndicatorCombination.Operator(operator="{operator}", threshold={threshold})' - ) - assert str(op) == f"{operator}(threshold={threshold})" - else: - assert ( - repr(op) - == f'TemporalIndicatorCombination.Operator(operator="{operator}")' - ) - assert str(op) == f"{operator}" - - def test_repr(self): - operator = FixedWindowTemporalIndicatorCombination.Operator( - FixedWindowTemporalIndicatorCombination.Operator.AT_LEAST, threshold=1 - ) - combination = FixedWindowTemporalIndicatorCombination( - operator=operator, - interval_type=TimeIntervalType.MORNING_SHIFT, - ) + # @pytest.mark.skip( + # reason="the repr does not return arguments in a consistent manner" + # ) + def test_repr(self, mock_criteria): + expr = temporal_logic_util.MorningShift(mock_criteria[0]) assert ( - repr(combination) == "FixedWindowTemporalIndicatorCombination(\n" - " interval_type=TimeIntervalType.MORNING_SHIFT,\n" + repr(expr) == "TemporalMinCount(\n" + " MockCriterion(\n" + " name='c1'\n" + " ),\n" + " threshold=1,\n" " start_time=None,\n" " end_time=None,\n" - ' operator=TemporalIndicatorCombination.Operator(operator="AT_LEAST", threshold=1),\n' + " interval_type=TimeIntervalType.MORNING_SHIFT,\n" + " interval_criterion=None\n" ")" ) - combination = FixedWindowTemporalIndicatorCombination( - operator=operator, + expr = logic.TemporalMinCount( + mock_criteria[0], start_time=datetime.time(8, 0), end_time=datetime.time(16, 0), + threshold=1, ) assert ( - repr(combination) == "FixedWindowTemporalIndicatorCombination(\n" + repr(expr) == "TemporalMinCount(\n" + " MockCriterion(\n" + " name='c1'\n" + " ),\n" + " threshold=1,\n" + " start_time='08:00:00',\n" + " end_time='16:00:00',\n" " interval_type=None,\n" - " start_time=datetime.time(8, 0),\n" - " end_time=datetime.time(16, 0),\n" - ' operator=TemporalIndicatorCombination.Operator(operator="AT_LEAST", threshold=1),\n' + " interval_criterion=None\n" ")" ) - def test_add_all(self): - operator = FixedWindowTemporalIndicatorCombination.Operator( - FixedWindowTemporalIndicatorCombination.Operator.AT_MOST, threshold=1 - ) - combination = FixedWindowTemporalIndicatorCombination( - operator=operator, - interval_type=TimeIntervalType.MORNING_SHIFT, - ) - - assert len(combination) == 0 + def test_expr_contains_criteria(self, mock_criteria): + with pytest.raises( + TypeError, + match=r"MinCount\(\) takes 1 positional argument but 5 were given", + ): + expr = temporal_logic_util.MinCount(*mock_criteria) - combination.add_all([MockCriterion("c1"), MockCriterion("c2")]) + expr = logic.TemporalCount(*mock_criteria, threshold=1) - assert len(combination) == 2 + assert len(expr.args) == len(mock_criteria) - assert str(combination[0]) == str(MockCriterion("c1")) - assert str(combination[1]) == str(MockCriterion("c2")) + for i in range(len(mock_criteria)): + assert expr.args[i] == mock_criteria[i] c1 = DrugExposure( @@ -334,8 +249,7 @@ def criteria(self, db_session): delir_screening, ] for i, c in enumerate(criteria): - if not c.is_persisted(): - c.set_id(i + 1) + c.set_id(i + 1, overwrite=True) self.register_criterion(c, db_session) return criteria @@ -351,11 +265,11 @@ def run_criteria_test( ): noop_criterion = NoopCriterion() - noop_criterion.set_id(1005) - noop_intervention = LogicalCriterionCombination.And(noop_criterion) + noop_criterion.set_id(1005, overwrite=True) + noop_intervention = logic.And(noop_criterion) self.register_criterion(noop_criterion, db_session) - self.insert_criterion_combination( + self.insert_expression( db_session, population=combination, intervention=noop_intervention, @@ -440,7 +354,7 @@ def patient_events(self, db_session, person_visit): # Full Day #################### ( - FixedWindowTemporalIndicatorCombination.Day( + temporal_logic_util.Day( c1, ), { @@ -455,7 +369,7 @@ def patient_events(self, db_session, person_visit): }, ), ( - FixedWindowTemporalIndicatorCombination.Day( + temporal_logic_util.Day( c2, ), { @@ -470,7 +384,7 @@ def patient_events(self, db_session, person_visit): }, ), ( - FixedWindowTemporalIndicatorCombination.Day( + temporal_logic_util.Day( c3, ), { @@ -488,7 +402,7 @@ def patient_events(self, db_session, person_visit): # Explicit Times #################### ( - FixedWindowTemporalIndicatorCombination.Presence( + temporal_logic_util.Presence( c1, start_time=datetime.time(8, 30), end_time=datetime.time(16, 59), @@ -509,7 +423,7 @@ def patient_events(self, db_session, person_visit): }, ), ( - FixedWindowTemporalIndicatorCombination.Presence( + temporal_logic_util.Presence( c1, start_time=datetime.time(8, 30), end_time=datetime.time(18, 59), @@ -530,7 +444,7 @@ def patient_events(self, db_session, person_visit): }, ), ( - FixedWindowTemporalIndicatorCombination.Presence( + temporal_logic_util.Presence( c2, start_time=datetime.time(8, 30), end_time=datetime.time(16, 59), @@ -547,7 +461,7 @@ def patient_events(self, db_session, person_visit): }, ), ( - FixedWindowTemporalIndicatorCombination.Presence( + temporal_logic_util.Presence( c3, start_time=datetime.time(8, 30), end_time=datetime.time(16, 59), @@ -555,7 +469,7 @@ def patient_events(self, db_session, person_visit): {1: set(), 2: set(), 3: set()}, ), ( - FixedWindowTemporalIndicatorCombination.Presence( + temporal_logic_util.Presence( c3, start_time=datetime.time(17, 30), end_time=datetime.time(22, 00), @@ -575,7 +489,7 @@ def patient_events(self, db_session, person_visit): # Morning Shifts #################### ( - FixedWindowTemporalIndicatorCombination.MorningShift( + temporal_logic_util.MorningShift( c1, ), { @@ -594,7 +508,7 @@ def patient_events(self, db_session, person_visit): }, ), ( - FixedWindowTemporalIndicatorCombination.MorningShift( + temporal_logic_util.MorningShift( c2, ), { @@ -609,7 +523,7 @@ def patient_events(self, db_session, person_visit): }, ), ( - FixedWindowTemporalIndicatorCombination.MorningShift( + temporal_logic_util.MorningShift( c3, ), {1: set(), 2: set(), 3: set()}, @@ -618,7 +532,7 @@ def patient_events(self, db_session, person_visit): # Afternoon Shifts #################### ( - FixedWindowTemporalIndicatorCombination.AfternoonShift( + temporal_logic_util.AfternoonShift( c1, ), { @@ -637,7 +551,7 @@ def patient_events(self, db_session, person_visit): }, ), ( - FixedWindowTemporalIndicatorCombination.AfternoonShift( + temporal_logic_util.AfternoonShift( c2, ), { @@ -656,7 +570,7 @@ def patient_events(self, db_session, person_visit): }, ), ( - FixedWindowTemporalIndicatorCombination.AfternoonShift( + temporal_logic_util.AfternoonShift( c3, ), { @@ -674,7 +588,7 @@ def patient_events(self, db_session, person_visit): # Night Shifts #################### ( - FixedWindowTemporalIndicatorCombination.NightShift( + temporal_logic_util.NightShift( c1, ), { @@ -693,7 +607,7 @@ def patient_events(self, db_session, person_visit): }, ), ( - FixedWindowTemporalIndicatorCombination.NightShift( + temporal_logic_util.NightShift( c2, ), { @@ -708,7 +622,7 @@ def patient_events(self, db_session, person_visit): }, ), ( - FixedWindowTemporalIndicatorCombination.NightShift( + temporal_logic_util.NightShift( c3, ), {1: set(), 2: set(), 3: set()}, @@ -717,7 +631,7 @@ def patient_events(self, db_session, person_visit): # Partial Night Shifts (before midnight) ####################### ( - FixedWindowTemporalIndicatorCombination.NightShiftBeforeMidnight( + temporal_logic_util.NightShiftBeforeMidnight( c1, ), { @@ -736,7 +650,7 @@ def patient_events(self, db_session, person_visit): }, ), ( - FixedWindowTemporalIndicatorCombination.NightShiftBeforeMidnight( + temporal_logic_util.NightShiftBeforeMidnight( c2, ), { @@ -751,7 +665,7 @@ def patient_events(self, db_session, person_visit): }, ), ( - FixedWindowTemporalIndicatorCombination.NightShiftBeforeMidnight( + temporal_logic_util.NightShiftBeforeMidnight( c3, ), {1: set(), 2: set(), 3: set()}, @@ -760,7 +674,7 @@ def patient_events(self, db_session, person_visit): # Partial Night Shifts (after midnight) ####################### ( - FixedWindowTemporalIndicatorCombination.NightShiftAfterMidnight( + temporal_logic_util.NightShiftAfterMidnight( c1, ), { @@ -775,7 +689,7 @@ def patient_events(self, db_session, person_visit): }, ), ( - FixedWindowTemporalIndicatorCombination.NightShiftAfterMidnight( + temporal_logic_util.NightShiftAfterMidnight( c2, ), { @@ -790,7 +704,7 @@ def patient_events(self, db_session, person_visit): }, ), ( - FixedWindowTemporalIndicatorCombination.NightShiftAfterMidnight( + temporal_logic_util.NightShiftAfterMidnight( c3, ), {1: set(), 2: set(), 3: set()}, @@ -873,7 +787,7 @@ def patient_events(self, db_session, visit_occurrence): # Full Day #################### ( - FixedWindowTemporalIndicatorCombination.Day( + temporal_logic_util.Day( c1, ), { @@ -888,7 +802,7 @@ def patient_events(self, db_session, visit_occurrence): }, ), ( - FixedWindowTemporalIndicatorCombination.Day( + temporal_logic_util.Day( c2, ), { @@ -903,7 +817,7 @@ def patient_events(self, db_session, visit_occurrence): }, ), ( - FixedWindowTemporalIndicatorCombination.Day( + temporal_logic_util.Day( c3, ), { @@ -921,7 +835,7 @@ def patient_events(self, db_session, visit_occurrence): # Explicit Times #################### ( - FixedWindowTemporalIndicatorCombination.Presence( + temporal_logic_util.Presence( c1, start_time=datetime.time(8, 30), end_time=datetime.time(16, 59), @@ -970,7 +884,7 @@ def patient_events(self, db_session, visit_occurrence): }, ), ( - FixedWindowTemporalIndicatorCombination.Presence( + temporal_logic_util.Presence( c2, start_time=datetime.time(8, 30), end_time=datetime.time(16, 59), @@ -995,7 +909,7 @@ def patient_events(self, db_session, visit_occurrence): }, ), ( - FixedWindowTemporalIndicatorCombination.Presence( + temporal_logic_util.Presence( c3, start_time=datetime.time(17, 30), end_time=datetime.time(22, 00), @@ -1027,7 +941,7 @@ def patient_events(self, db_session, visit_occurrence): # # Morning Shifts # #################### ( - FixedWindowTemporalIndicatorCombination.MorningShift( + temporal_logic_util.MorningShift( c1, ), { @@ -1074,7 +988,7 @@ def patient_events(self, db_session, visit_occurrence): }, ), ( - FixedWindowTemporalIndicatorCombination.MorningShift( + temporal_logic_util.MorningShift( c2, ), { @@ -1097,7 +1011,7 @@ def patient_events(self, db_session, visit_occurrence): }, ), ( - FixedWindowTemporalIndicatorCombination.MorningShift( + temporal_logic_util.MorningShift( c3, ), { @@ -1127,7 +1041,7 @@ def patient_events(self, db_session, visit_occurrence): # # Afternoon Shifts # #################### ( - FixedWindowTemporalIndicatorCombination.AfternoonShift( + temporal_logic_util.AfternoonShift( c1, ), { @@ -1174,7 +1088,7 @@ def patient_events(self, db_session, visit_occurrence): }, ), ( - FixedWindowTemporalIndicatorCombination.AfternoonShift( + temporal_logic_util.AfternoonShift( c2, ), { @@ -1201,7 +1115,7 @@ def patient_events(self, db_session, visit_occurrence): }, ), ( - FixedWindowTemporalIndicatorCombination.AfternoonShift( + temporal_logic_util.AfternoonShift( c3, ), { @@ -1231,7 +1145,7 @@ def patient_events(self, db_session, visit_occurrence): # # Night Shifts # #################### ( - FixedWindowTemporalIndicatorCombination.NightShift( + temporal_logic_util.NightShift( c1, ), { @@ -1282,7 +1196,7 @@ def patient_events(self, db_session, visit_occurrence): }, ), ( - FixedWindowTemporalIndicatorCombination.NightShift( + temporal_logic_util.NightShift( c2, ), { @@ -1305,7 +1219,7 @@ def patient_events(self, db_session, visit_occurrence): }, ), ( - FixedWindowTemporalIndicatorCombination.NightShift( + temporal_logic_util.NightShift( c3, ), { @@ -1407,7 +1321,7 @@ def patient_events(self, db_session, visit_occurrence): "combination,expected", [ ( - FixedWindowTemporalIndicatorCombination.MorningShift( + temporal_logic_util.MorningShift( bodyweight_measurement_without_forward_fill, ), { @@ -1420,13 +1334,13 @@ def patient_events(self, db_session, visit_occurrence): }, ), ( - FixedWindowTemporalIndicatorCombination.AfternoonShift( + temporal_logic_util.AfternoonShift( bodyweight_measurement_without_forward_fill, ), {1: set()}, ), ( - FixedWindowTemporalIndicatorCombination.MorningShift( + temporal_logic_util.MorningShift( bodyweight_measurement_with_forward_fill, ), { @@ -1447,7 +1361,7 @@ def patient_events(self, db_session, visit_occurrence): }, ), ( - FixedWindowTemporalIndicatorCombination.AfternoonShift( + temporal_logic_util.AfternoonShift( bodyweight_measurement_with_forward_fill, ), { @@ -1518,25 +1432,12 @@ def __init__(self) -> None: super().__init__() self._table = tables.ProcedureOccurrence.__table__.alias("po") - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> Self: - """ - Create an object from a dictionary. - """ - return cls() - def description(self) -> str: """ Get a description of the criterion. """ return self.__class__.__name__ - def dict(self) -> dict: - """ - Get a dictionary representation of the object. - """ - return {} - def _create_query(self) -> sa.Select: """ Get the SQL Select query for data required by this criterion. @@ -1567,17 +1468,13 @@ def _create_query(self) -> sa.Select: class TestPersonalWindowTemporalIndicatorCombination(TestCriterionCombinationDatabase): """ Test class for testing criterion combinations on the database with individual windows, - i.e. windows whose length is dependend on some patient-specific event (here: surgery) + i.e. windows whose length is dependent on some patient-specific event (here: surgery) """ @pytest.fixture def criteria(self, db_session): - # c4.set_id(4) # surgical procedure - - if not c_preop.is_persisted(): - c_preop.set_id(4) - if not bodyweight_measurement_without_forward_fill.is_persisted(): - bodyweight_measurement_without_forward_fill.set_id(5) + c_preop.set_id(4, overwrite=True) + bodyweight_measurement_without_forward_fill.set_id(5, overwrite=True) self.register_criterion(c_preop, db_session) self.register_criterion(bodyweight_measurement_without_forward_fill, db_session) @@ -1622,7 +1519,7 @@ def patient_events(self, db_session, person_visit): # Explicit Times #################### ( - PersonalWindowTemporalIndicatorCombination.Presence( + temporal_logic_util.Presence( bodyweight_measurement_without_forward_fill, interval_criterion=c_preop, ), @@ -1740,17 +1637,17 @@ def patient_events(self, db_session, visit_occurrence): "population,intervention,expected", [ ( - LogicalCriterionCombination.And(c2), - LogicalCriterionCombination.CappedAtLeast( + logic.And(c2), + logic.CappedMinCount( *[ - FixedWindowTemporalIndicatorCombination.Day( + temporal_logic_util.Day( criterion=shift_class(criterion=delir_screening), ) for shift_class in [ - FixedWindowTemporalIndicatorCombination.NightShiftAfterMidnight, - FixedWindowTemporalIndicatorCombination.MorningShift, - FixedWindowTemporalIndicatorCombination.AfternoonShift, - FixedWindowTemporalIndicatorCombination.NightShiftBeforeMidnight, + temporal_logic_util.NightShiftAfterMidnight, + temporal_logic_util.MorningShift, + temporal_logic_util.AfternoonShift, + temporal_logic_util.NightShiftBeforeMidnight, ] ], threshold=2, @@ -1803,7 +1700,7 @@ def test_at_least_combination_on_database( db_session.add_all(vos) db_session.commit() - self.insert_criterion_combination( + self.insert_expression( db_session, population, intervention, base_criterion, observation_window ) @@ -1828,17 +1725,17 @@ def test_at_least_combination_on_database( "population,intervention,expected", [ ( - LogicalCriterionCombination.And(c2), - LogicalCriterionCombination.CappedAtLeast( + logic.And(c2), + logic.CappedMinCount( *[ - FixedWindowTemporalIndicatorCombination.Day( + temporal_logic_util.Day( criterion=shift_class(criterion=delir_screening), ) for shift_class in [ - FixedWindowTemporalIndicatorCombination.NightShiftAfterMidnight, - FixedWindowTemporalIndicatorCombination.MorningShift, - FixedWindowTemporalIndicatorCombination.AfternoonShift, - FixedWindowTemporalIndicatorCombination.NightShiftBeforeMidnight, + temporal_logic_util.NightShiftAfterMidnight, + temporal_logic_util.MorningShift, + temporal_logic_util.AfternoonShift, + temporal_logic_util.NightShiftBeforeMidnight, ] ], threshold=2, @@ -1882,7 +1779,7 @@ def test_at_least_combination_on_database_no_measurements( db_session.add_all([c1]) db_session.commit() - self.insert_criterion_combination( + self.insert_expression( db_session, population, intervention, base_criterion, observation_window ) diff --git a/tests/execution_engine/omop/criterion/custom/test_tidal_volume.py b/tests/execution_engine/omop/criterion/custom/test_tidal_volume.py index f349c331..1996fc60 100644 --- a/tests/execution_engine/omop/criterion/custom/test_tidal_volume.py +++ b/tests/execution_engine/omop/criterion/custom/test_tidal_volume.py @@ -36,7 +36,7 @@ class TestTidalVolumePerIdealBodyWeight(TestCriterion): @pytest.fixture def concept(self): return CustomConcept( - name="Tidal volume / ideal body weight (ARDSnet)", + concept_name="Tidal volume / ideal body weight (ARDSnet)", concept_code="tvpibw", domain_id="Measurement", vocabulary_id="CODEX-CELIDA", diff --git a/tests/execution_engine/omop/criterion/test_criterion.py b/tests/execution_engine/omop/criterion/test_criterion.py index 475184cd..ec095ef9 100644 --- a/tests/execution_engine/omop/criterion/test_criterion.py +++ b/tests/execution_engine/omop/criterion/test_criterion.py @@ -5,14 +5,15 @@ import pandas as pd import pendulum import pytest -from sqlalchemy import Column, Date, Integer, MetaData, Table, select +from sqlalchemy import Column, Date, Integer, MetaData, Table, select, update import execution_engine.omop.db.celida.tables as celida_tables from execution_engine.constants import CohortCategory from execution_engine.execution_graph import ExecutionGraph +from execution_engine.omop.cohort import PopulationInterventionPairExpr +from execution_engine.omop.cohort.graph_builder import RecommendationGraphBuilder from execution_engine.omop.concepts import Concept from execution_engine.omop.criterion.abstract import Criterion -from execution_engine.omop.criterion.combination.combination import CriterionCombination from execution_engine.omop.criterion.visit_occurrence import PatientsActiveDuringPeriod from execution_engine.omop.db.celida.views import ( full_day_coverage, @@ -25,7 +26,7 @@ task, ) from execution_engine.task.process import get_processing_module -from execution_engine.util import logic +from execution_engine.util import datetime_converter, logic from execution_engine.util.db import add_result_insert from execution_engine.util.interval import IntervalType from execution_engine.util.types import TimeRange @@ -62,6 +63,27 @@ def date_set(tc: Iterable): return set(pendulum.parse(t).date() for t in tc) +def store_execution_graph(graph: ExecutionGraph, db_session, recommendation_id: int): + import json + + from execution_engine.omop.db.celida import tables as result_db + + rec_graph: bytes = json.dumps( + graph.to_cytoscape_dict(), sort_keys=True, default=datetime_converter + ).encode() + + update_query = ( + update(result_db.Recommendation) + .where(result_db.Recommendation.recommendation_id == recommendation_id) + .values( + recommendation_execution_graph=rec_graph, + ) + ) + + db_session.execute(update_query) + db_session.commit() + + class TestCriterion: result_table = celida_tables.ResultInterval.__table__ run_id = 1234 @@ -187,7 +209,7 @@ def register_criterion(cls, criterion: Criterion, db_session): exists = db_session.query( db_session.query(celida_tables.Criterion) - .filter_by(criterion_id=criterion.id) + .filter_by(criterion_hash=str(hash(criterion))) .exists() ).scalar() @@ -195,7 +217,7 @@ def register_criterion(cls, criterion: Criterion, db_session): new_criterion = celida_tables.Criterion( criterion_id=criterion.id, criterion_description=criterion.description(), - criterion_hash=hash(criterion), + criterion_hash=str(hash(criterion)), ) db_session.add(new_criterion) db_session.commit() @@ -251,7 +273,11 @@ def create_occurrence( ) def insert_criterion(self, db_session, criterion, observation_window: TimeRange): - criterion.set_id(self.criterion_id) + + criterion.set_id( + self.criterion_id + 1 + ) # +1 to avoid collision with the criterion saved in + # omop_fixture.py::celida_recommendation() self.register_criterion(criterion, db_session) query = criterion.create_query() @@ -287,44 +313,30 @@ def insert_criterion(self, db_session, criterion, observation_window: TimeRange) db_session.commit() - def insert_criterion_combination( + def insert_expression( self, db_session, - population: CriterionCombination, - intervention: CriterionCombination, + population: logic.Expr, + intervention: logic.Expr, base_criterion: Criterion, observation_window: TimeRange, ): - graph = ExecutionGraph.from_criterion_combination( - population, intervention, base_criterion - ) - - # we need to add a NoDataPreservingAnd node to insert NEGATIVE intervals - p_sink_node = graph.sink_node(CohortCategory.POPULATION) - pi_sink_node = graph.sink_node(CohortCategory.POPULATION_INTERVENTION) - - p_combination_node = logic.NoDataPreservingAnd( - p_sink_node, category=CohortCategory.POPULATION + # population_expr is assigned a NoDataPreservingAnd to ensure creation of negative intervals + pi_pair = PopulationInterventionPairExpr( + population_expr=logic.NonSimplifiableAnd(population), + intervention_expr=intervention, + base_criterion=base_criterion, + name="Test", + url="https://example.com", ) + pi_pair.set_id(self.pi_pair_id) - pi_combination_node = logic.NoDataPreservingAnd( - pi_sink_node, category=CohortCategory.POPULATION_INTERVENTION - ) + graph = RecommendationGraphBuilder.build(pi_pair, base_criterion) - graph.add_node( - p_combination_node, store_result=True, category=CohortCategory.POPULATION - ) - graph.add_node( - pi_combination_node, - store_result=True, - category=CohortCategory.POPULATION_INTERVENTION, + store_execution_graph( + graph=graph, db_session=db_session, recommendation_id=self.recommendation_id ) - graph.add_edge(p_sink_node, p_combination_node) - graph.add_edge(pi_sink_node, pi_combination_node) - - graph.set_sink_nodes_store(bind_params={"pi_pair_id": self.pi_pair_id}) - params = observation_window.model_dump() | {"run_id": self.run_id} task_runner = runner.SequentialTaskRunner(graph) diff --git a/tests/execution_engine/omop/criterion/test_drug_exposure.py b/tests/execution_engine/omop/criterion/test_drug_exposure.py index b405a999..cf87d007 100644 --- a/tests/execution_engine/omop/criterion/test_drug_exposure.py +++ b/tests/execution_engine/omop/criterion/test_drug_exposure.py @@ -4,10 +4,8 @@ from execution_engine.constants import CohortCategory from execution_engine.omop.concepts import Concept -from execution_engine.omop.criterion.combination.logical import ( - LogicalCriterionCombination, -) from execution_engine.omop.criterion.drug_exposure import DrugExposure +from execution_engine.util import logic from execution_engine.util.enum import TimeUnit from execution_engine.util.types import Dosage from execution_engine.util.value import ValueNumber @@ -43,7 +41,7 @@ def _run_drug_exposure( route=route, ) if exclude: - criterion = LogicalCriterionCombination.Not(criterion) + criterion = logic.Not(criterion) self.insert_criterion(db_session, criterion, observation_window) diff --git a/tests/execution_engine/omop/test_concepts.py b/tests/execution_engine/omop/test_concepts.py index 6a933224..283ab505 100644 --- a/tests/execution_engine/omop/test_concepts.py +++ b/tests/execution_engine/omop/test_concepts.py @@ -81,7 +81,7 @@ def test_is_custom(self): class TestCustomConcept: def test_init(self): custom_concept = CustomConcept( - name="Test Custom Concept", + concept_name="Test Custom Concept", concept_code="CC123", domain_id="Test Domain", vocabulary_id="Test Vocabulary", @@ -97,7 +97,7 @@ def test_init(self): def test_id_property(self): custom_concept = CustomConcept( - name="Test Custom Concept", + concept_name="Test Custom Concept", concept_code="CC123", domain_id="Test Domain", vocabulary_id="Test Vocabulary", @@ -107,7 +107,7 @@ def test_id_property(self): def test_str(self): custom_concept = CustomConcept( - name="Test Custom Concept", + concept_name="Test Custom Concept", concept_code="CC123", domain_id="Test Domain", vocabulary_id="Test Vocabulary", diff --git a/tests/execution_engine/util/test_logic.py b/tests/execution_engine/util/test_logic.py index 86d34a07..f823ea13 100644 --- a/tests/execution_engine/util/test_logic.py +++ b/tests/execution_engine/util/test_logic.py @@ -3,7 +3,6 @@ import pytest -from execution_engine.constants import CohortCategory from execution_engine.util.enum import TimeIntervalType from execution_engine.util.logic import ( AllOrNone, @@ -28,34 +27,29 @@ dummy_criterion = MockCriterion( name="dummy_criterion", - exclude=False, ) x, y, z = ( - Symbol(MockCriterion("x", False), category=CohortCategory.POPULATION), - Symbol(MockCriterion("y", False), category=CohortCategory.POPULATION), - Symbol(MockCriterion("z", False), category=CohortCategory.POPULATION), + MockCriterion("x"), + MockCriterion("y"), + MockCriterion("z"), ) # Tests for Expr class TestExpr: - def test_creation_with_category(self): - expr = Expr(category=CohortCategory.POPULATION) - assert expr.category == CohortCategory.POPULATION - def test_is_Atom_false(self): - expr = Expr(category=CohortCategory.POPULATION) + expr = Expr() assert not expr.is_Atom # Tests for Symbol class TestSymbol: def test_symbol_creation(self): - symbol = Symbol(dummy_criterion, category=CohortCategory.POPULATION) - assert str(symbol) == "MockCriterion[dummy_criterion]" - assert symbol.criterion == dummy_criterion + symbol = dummy_criterion + assert repr(symbol) == "MockCriterion(\n name='dummy_criterion'\n)" + assert symbol == dummy_criterion def test_is_Atom_true(self): symbol = x @@ -68,11 +62,11 @@ def test_is_Not_false(self): class TestBooleanFunction: def test_or_creation(self): - or_expr = Or(x, y, category=CohortCategory.POPULATION) + or_expr = Or(x, y) assert isinstance(or_expr, Or) def test_or_is_Atom_false(self): - or_expr = Or(x, y, category=CohortCategory.POPULATION) + or_expr = Or(x, y) assert not or_expr.is_Atom @@ -81,24 +75,35 @@ def test_and_creation_with_multiple_args(self): and_expr = And( x, y, - Not(z, category=CohortCategory.POPULATION), - category=CohortCategory.POPULATION, + Not(z), ) assert isinstance(and_expr, And) assert ( - str(and_expr) == "(MockCriterion[x] & MockCriterion[y] & ~MockCriterion[z])" + repr(and_expr) == "And(\n" + " MockCriterion(\n" + " name='x'\n" + " ),\n" + " MockCriterion(\n" + " name='y'\n" + " ),\n" + " Not(\n" + " MockCriterion(\n" + " name='z'\n" + " )\n" + " )\n" + ")" ) assert and_expr.args[0] == x assert and_expr.args[1] == y - assert and_expr.args[2] == Not(z, category=CohortCategory.POPULATION) + assert and_expr.args[2] == Not(z) assert not and_expr.is_Not assert not and_expr.is_Atom def test_and_creation_with_single_arg(self): single_arg = x - and_expr = And(single_arg, category=CohortCategory.POPULATION) + and_expr = And(single_arg) assert and_expr is single_arg @@ -107,30 +112,41 @@ def test_or_creation_with_multiple_args(self): or_expr = Or( x, y, - Not(z, category=CohortCategory.POPULATION), - category=CohortCategory.POPULATION, + Not(z), ) assert isinstance(or_expr, Or) assert ( - str(or_expr) == "(MockCriterion[x] | MockCriterion[y] | ~MockCriterion[z])" + repr(or_expr) == "Or(\n" + " MockCriterion(\n" + " name='x'\n" + " ),\n" + " MockCriterion(\n" + " name='y'\n" + " ),\n" + " Not(\n" + " MockCriterion(\n" + " name='z'\n" + " )\n" + " )\n" + ")" ) assert or_expr.args[0] == x assert or_expr.args[1] == y - assert or_expr.args[2] == Not(z, category=CohortCategory.POPULATION) + assert or_expr.args[2] == Not(z) assert not or_expr.is_Not assert not or_expr.is_Atom def test_or_creation_with_single_arg(self): single_arg = x - or_expr = Or(single_arg, category=CohortCategory.POPULATION) + or_expr = Or(single_arg) assert or_expr is single_arg class TestNot: def test_not_creation(self): - not_expr = Not(x, category=CohortCategory.POPULATION) + not_expr = Not(x) assert isinstance(not_expr, Not) assert str(not_expr) == "~MockCriterion[x]" assert not_expr.args[0] == x @@ -140,33 +156,31 @@ def test_not_creation(self): def test_not_creation_with_multiple_args(self): with pytest.raises(ValueError): - Not(x, y, category=CohortCategory.POPULATION) + Not(x, y) class TestNonSimplifiableAnd: def test_non_simplifiable_and_creation(self): - non_simp_and = NonSimplifiableAnd(x, y, category=CohortCategory.POPULATION) + non_simp_and = NonSimplifiableAnd(x, y) assert isinstance(non_simp_and, NonSimplifiableAnd) assert not non_simp_and.is_Not assert not non_simp_and.is_Atom def test_non_simplifiable_and_single_arg(self): single_arg = x - non_simp_and = NonSimplifiableAnd( - single_arg, category=CohortCategory.POPULATION - ) + non_simp_and = NonSimplifiableAnd(single_arg) assert isinstance(non_simp_and, NonSimplifiableAnd) assert non_simp_and.args[0] is single_arg def test_non_simplifiable_and_equality(self): - non_simp_and1 = NonSimplifiableAnd(x, category=CohortCategory.POPULATION) - non_simp_and2 = NonSimplifiableAnd(x, category=CohortCategory.POPULATION) - assert non_simp_and1 != non_simp_and2 + non_simp_and1 = NonSimplifiableAnd(x) + non_simp_and2 = NonSimplifiableAnd(x) + assert non_simp_and1 == non_simp_and2 class TestNoDataPreservingAnd: def test_no_data_preserving_and_creation(self): - no_data_and = NoDataPreservingAnd(x, y, category=CohortCategory.POPULATION) + no_data_and = NoDataPreservingAnd(x, y) assert isinstance(no_data_and, NoDataPreservingAnd) assert no_data_and.args[0] == x assert no_data_and.args[1] == y @@ -177,7 +191,7 @@ def test_no_data_preserving_and_creation(self): class TestNoDataPreservingOr: def test_no_data_preserving_or_creation(self): - no_data_or = NoDataPreservingOr(x, y, category=CohortCategory.POPULATION) + no_data_or = NoDataPreservingOr(x, y) assert isinstance(no_data_or, NoDataPreservingOr) assert no_data_or.args[0] == x assert no_data_or.args[1] == y @@ -190,9 +204,7 @@ class TestLeftDependentToggle: def test_left_dependent_toggle_creation(self): left_expr = x right_expr = y - left_toggle = LeftDependentToggle( - left=left_expr, right=right_expr, category=CohortCategory.POPULATION - ) + left_toggle = LeftDependentToggle(left=left_expr, right=right_expr) assert isinstance(left_toggle, LeftDependentToggle) assert left_toggle.left == left_expr == left_toggle.args[0] assert left_toggle.right == right_expr == left_toggle.args[1] @@ -213,20 +225,20 @@ def worker(queue: Queue, symbol: Symbol): class TestSymbolMultiprocessing: @pytest.fixture( params=[ - Expr(1, 2, 3, category=CohortCategory.POPULATION), - Symbol(dummy_criterion, category=CohortCategory.POPULATION), - BooleanFunction(1, 2, 3, category=CohortCategory.POPULATION), - Or(1, 2, 3, category=CohortCategory.POPULATION), - And(1, 2, 3, category=CohortCategory.POPULATION), - Not(1, category=CohortCategory.POPULATION), - MinCount(1, 2, 3, threshold=2, category=CohortCategory.POPULATION), - MaxCount(1, 2, 3, threshold=2, category=CohortCategory.POPULATION), - ExactCount(1, 2, 3, threshold=2, category=CohortCategory.POPULATION), - AllOrNone(1, 2, 3, category=CohortCategory.POPULATION), - NonSimplifiableAnd(1, 2, 3, category=CohortCategory.POPULATION), - NoDataPreservingAnd(1, 2, 3, category=CohortCategory.POPULATION), - NoDataPreservingOr(1, 2, 3, category=CohortCategory.POPULATION), - LeftDependentToggle(left=1, right=2, category=CohortCategory.POPULATION), + Expr(1, 2, 3), + Symbol(dummy_criterion), + BooleanFunction(1, 2, 3), + Or(1, 2, 3), + And(1, 2, 3), + Not(1), + MinCount(1, 2, 3, threshold=2), + MaxCount(1, 2, 3, threshold=2), + ExactCount(1, 2, 3, threshold=2), + AllOrNone(1, 2, 3), + NonSimplifiableAnd(1, 2, 3), + NoDataPreservingAnd(1, 2, 3), + NoDataPreservingOr(1, 2, 3), + LeftDependentToggle(left=1, right=2), TemporalMinCount( 1, 2, @@ -236,7 +248,6 @@ class TestSymbolMultiprocessing: end_time=None, interval_type=TimeIntervalType.DAY, interval_criterion=None, - category=CohortCategory.POPULATION, ), TemporalMaxCount( 1, @@ -247,7 +258,6 @@ class TestSymbolMultiprocessing: end_time=None, interval_type=TimeIntervalType.MORNING_SHIFT, interval_criterion=None, - category=CohortCategory.POPULATION, ), TemporalExactCount( 1, @@ -258,7 +268,6 @@ class TestSymbolMultiprocessing: end_time=None, interval_type=TimeIntervalType.NIGHT_SHIFT, interval_criterion=None, - category=CohortCategory.POPULATION, ), ], ids=lambda expr: expr.__class__.__name__, diff --git a/tests/execution_engine/util/test_types.py b/tests/execution_engine/util/test_types.py index a15e6686..350fd66a 100644 --- a/tests/execution_engine/util/test_types.py +++ b/tests/execution_engine/util/test_types.py @@ -67,9 +67,6 @@ def test_dosage_creation(self): assert dosage.count.value == 5 assert dosage.frequency.value == 10 - @pytest.mark.skip( - reason="Timing in Dosage is another type that does not use the use use_enum_values flag" - ) def test_enum_values(self): dosage = Dosage( dose=ValueNumber(value=10, unit=concept_unit_mg), @@ -81,7 +78,7 @@ def test_enum_values(self): assert ( dosage.interval.unit == "h" ) # Check if the enum value is used instead of the enum name - assert not isinstance( + assert isinstance( dosage.interval.unit, TimeUnit ) # Ensure it's not returning the enum member diff --git a/tests/mocks/criterion.py b/tests/mocks/criterion.py index 3559a36c..f2ff7c25 100644 --- a/tests/mocks/criterion.py +++ b/tests/mocks/criterion.py @@ -1,5 +1,3 @@ -from typing import Any, Dict, Self - from sqlalchemy import Select from execution_engine.omop.concepts import Concept @@ -13,12 +11,9 @@ def _create_query(self) -> Select: def __init__( self, name: str, - exclude: bool = False, ): self._id = None self._name = name - self._exclude = exclude - assert not exclude def unique_name(self) -> str: return self._name @@ -35,22 +30,6 @@ def _sql_select_data(self, query: Select) -> Select: def description(self) -> str: return f"MockCriterion[{self._name}]" - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> Self: - return cls( - name=data["name"], - exclude=data["exclude"], - ) - - def dict(self) -> dict[str, Any]: - return { - "name": self._name, - "exclude": self._exclude, - } - - def __repr__(self) -> str: - return self.description() - @property def concept(self) -> Concept: # type: ignore pass From 1afbf52a199d375809987c70452119e578958084 Mon Sep 17 00:00:00 2001 From: Gregor Lichtner Date: Mon, 17 Mar 2025 17:08:05 +0100 Subject: [PATCH 09/16] refactor: remove NoDataPreserving operators --- apps/viz-backend/app/main.py | 2 - .../converter/characteristic/abstract.py | 4 +- .../converter/recommendation_factory.py | 2 +- execution_engine/execution_engine.py | 81 ++++++++----------- execution_engine/execution_graph/graph.py | 2 - execution_engine/omop/cohort/graph_builder.py | 2 +- execution_engine/omop/criterion/abstract.py | 19 ++++- execution_engine/omop/criterion/concept.py | 4 +- execution_engine/task/process/__init__.py | 5 +- execution_engine/task/task.py | 67 ++------------- .../util/interval/typed_interval.py | 1 + execution_engine/util/logic.py | 38 ++++----- execution_engine/util/value/value.py | 3 + scripts/execute.py | 6 +- .../omop/criterion/test_criterion.py | 2 +- tests/execution_engine/util/test_logic.py | 26 ------ .../test_recommendation_base_v2.py | 2 + 17 files changed, 90 insertions(+), 176 deletions(-) diff --git a/apps/viz-backend/app/main.py b/apps/viz-backend/app/main.py index 6c74641d..6b84ea4b 100644 --- a/apps/viz-backend/app/main.py +++ b/apps/viz-backend/app/main.py @@ -89,8 +89,6 @@ def get_execution_graph(recommendation_id: int, db: Session = Depends(get_db)) - if not result: raise HTTPException(status_code=404, detail="Recommendation not found") - print(result) - # Decode the bytes to a string and parse it as JSON execution_graph = json.loads(result.recommendation_execution_graph.decode("utf-8")) diff --git a/execution_engine/converter/characteristic/abstract.py b/execution_engine/converter/characteristic/abstract.py index 176d9302..dd4c5594 100644 --- a/execution_engine/converter/characteristic/abstract.py +++ b/execution_engine/converter/characteristic/abstract.py @@ -61,9 +61,9 @@ def type(self) -> Concept: return self._type @type.setter - def type(self, type: Concept) -> None: + def type(self, type_: Concept) -> None: """Sets the type of this characteristic.""" - self._type = type + self._type = type_ @property def value(self) -> Any: diff --git a/execution_engine/converter/recommendation_factory.py b/execution_engine/converter/recommendation_factory.py index b42b3c14..dbcb7a5f 100644 --- a/execution_engine/converter/recommendation_factory.py +++ b/execution_engine/converter/recommendation_factory.py @@ -72,7 +72,7 @@ def parse_recommendation_from_url( # parse intervention and create criteria actions = parser.parse_actions(rec_plan.actions, rec_plan) - # population_expr is assigned a NoDataPreservingAnd to ensure creation of negative intervals + # population_expr is assigned a NonSimplifiableAnd to ensure creation of negative intervals # todo: not sure we really need this - we can just always create negative intervals when store_results=True # in the graph pi_pair = PopulationInterventionPairExpr( diff --git a/execution_engine/execution_engine.py b/execution_engine/execution_engine.py index 1432e8c8..9bef1e7f 100644 --- a/execution_engine/execution_engine.py +++ b/execution_engine/execution_engine.py @@ -198,29 +198,7 @@ def load_recommendation_from_database( # we'll run registration again to set all ids recommendation.set_id(rec_db.recommendation_id) - assert ( - recommendation.base_criterion is not None - ), "Base Criterion must be set" - self.register_criterion(recommendation.base_criterion) - - for pi_pair in recommendation.population_intervention_pairs(): - self.register_population_intervention_pair( - pi_pair, rec_db.recommendation_id - ) - - for criterion in recommendation.atoms(): - self.register_criterion(criterion) - - # All objects in the deserialized object graph must have an id. - assert recommendation.id is not None - assert recommendation.base_criterion is not None - assert recommendation.base_criterion.id is not None - - for pi_pair in recommendation.population_intervention_pairs(): - assert pi_pair.id is not None - - for criterion in recommendation.atoms(): - assert criterion.id is not None + self.register_children(recommendation) return recommendation @@ -280,10 +258,39 @@ def register_recommendation(self, recommendation: cohort.Recommendation) -> None result = con.execute(query) recommendation.set_id(result.fetchone().recommendation_id) - # Register all child objects. After that, the recommendation - # and all child objects have valid ids (either restored or - # fresh). + self.register_children(recommendation) + + # Update the recommendation in the database with the final + # JSON representation and execution graph (now that + # recommendation id, criteria ids and pi pair is are known) + # TODO(jmoringe): only when necessary + with self._db.begin() as con: + rec_graph: bytes = json.dumps( + recommendation.execution_graph().to_cytoscape_dict(), sort_keys=True + ).encode() + + rec_json: bytes = recommendation.json() + + update_query = ( + update(recommendation_table) + .where(recommendation_table.recommendation_id == recommendation.id) + .values( + recommendation_json=rec_json, + recommendation_execution_graph=rec_graph, + ) + ) + + con.execute(update_query) + + def register_children(self, recommendation: cohort.Recommendation) -> None: + """ + Registers all child objects of the recommendation in the result database. + :param recommendation: The Recommendation object. + :type recommendation: cohort.Recommendation + :raises ValueError: If the base criterion is not set. + :raises AssertionError: If the recommendation or any of its child objects do not have a valid ID. + """ if recommendation.base_criterion is None: raise ValueError("Base Criterion must be set when storing recommendation") @@ -307,28 +314,6 @@ def register_recommendation(self, recommendation: cohort.Recommendation) -> None for criterion in recommendation.atoms(): assert criterion.id is not None - # Update the recommendation in the database with the final - # JSON representation and execution graph (now that - # recommendation id, criteria ids and pi pair is are known) - # TODO(jmoringe): only when necessary - with self._db.begin() as con: - rec_graph: bytes = json.dumps( - recommendation.execution_graph().to_cytoscape_dict(), sort_keys=True - ).encode() - - rec_json: bytes = recommendation.json() - - update_query = ( - update(recommendation_table) - .where(recommendation_table.recommendation_id == recommendation.id) - .values( - recommendation_json=rec_json, - recommendation_execution_graph=rec_graph, - ) - ) - - con.execute(update_query) - def register_population_intervention_pair( self, pi_pair: PopulationInterventionPairExpr, recommendation_id: int ) -> None: diff --git a/execution_engine/execution_graph/graph.py b/execution_engine/execution_graph/graph.py index 31dd668d..8c5d5dc3 100644 --- a/execution_engine/execution_graph/graph.py +++ b/execution_engine/execution_graph/graph.py @@ -292,8 +292,6 @@ def plot(self) -> None: "LeftDependentToggle": "=>", "NonSimplifiableOr": "!|", "NonSimplifiableAnd": "!&", - "NoDataPreservingAnd": "NDP-&", - "NoDataPreservingOr": "NPD-|", "MinCount": "Min", "MaxCount": "Max", "ExactCount": "Exact", diff --git a/execution_engine/omop/cohort/graph_builder.py b/execution_engine/omop/cohort/graph_builder.py index bf4d92ae..10926d46 100644 --- a/execution_engine/omop/cohort/graph_builder.py +++ b/execution_engine/omop/cohort/graph_builder.py @@ -120,7 +120,7 @@ def build(cls, expr: logic.Expr, base_criterion: Criterion) -> ExecutionGraph: bind_params={}, desired_category=CohortCategory.POPULATION_INTERVENTION ) - p_combination_node = logic.NoDataPreservingOr(*p_sink_nodes) + p_combination_node = logic.NonSimplifiableOr(*p_sink_nodes) graph.add_node( p_combination_node, store_result=True, category=CohortCategory.POPULATION ) diff --git a/execution_engine/omop/criterion/abstract.py b/execution_engine/omop/criterion/abstract.py index 5b9a42d1..8ceaece8 100644 --- a/execution_engine/omop/criterion/abstract.py +++ b/execution_engine/omop/criterion/abstract.py @@ -29,7 +29,15 @@ from execution_engine.util.sql import SelectInto, select_into from execution_engine.util.types import PersonIntervals, TimeRange -__all__ = ["Criterion"] +__all__ = [ + "Criterion", + "column_interval_type", + "create_conditional_interval_column", + "SQL_ONE_SECOND", + "observation_start_datetime", + "observation_end_datetime", + "run_id", +] Domain = TypedDict( "Domain", @@ -283,9 +291,12 @@ def create_query(self) -> Select: ), "Query must select 4 columns: person_id, interval_start, interval_end, interval_type" # assert that the output columns are person_id, interval_start, interval_end, type - assert set([c.name for c in query.selected_columns]) == set( - ["person_id", "interval_start", "interval_end", "interval_type"] - ), "Query must select 4 columns: person_id, interval_start, interval_end, interval_type" + assert set([c.name for c in query.selected_columns]) == { + "person_id", + "interval_start", + "interval_end", + "interval_type", + }, "Query must select 4 columns: person_id, interval_start, interval_end, interval_type" return query diff --git a/execution_engine/omop/criterion/concept.py b/execution_engine/omop/criterion/concept.py index d5d7b39e..ac33ad2c 100644 --- a/execution_engine/omop/criterion/concept.py +++ b/execution_engine/omop/criterion/concept.py @@ -1,3 +1,5 @@ +from abc import ABC + from sqlalchemy.sql import Select from execution_engine.constants import OMOPConcepts @@ -21,7 +23,7 @@ # TODO: Only use weight etc from the current encounter/visit! -class ConceptCriterion(Criterion): +class ConceptCriterion(Criterion, ABC): """ Abstract class for a criterion based on an OMOP concept and optional value. diff --git a/execution_engine/task/process/__init__.py b/execution_engine/task/process/__init__.py index 72b4589c..b9e4444e 100644 --- a/execution_engine/task/process/__init__.py +++ b/execution_engine/task/process/__init__.py @@ -15,6 +15,7 @@ def get_processing_module( - rectangle (faster, using rectangles intersection/union) :param name: name of the processing module + :param version: version of the processing module """ if name == "rectangle": @@ -38,4 +39,6 @@ def get_processing_module( Interval = namedtuple("Interval", ["lower", "upper", "type"]) IntervalWithCount = namedtuple("IntervalWithCount", ["lower", "upper", "type", "count"]) -IntervalWithTypeCounts = namedtuple("IntervalWithTypeCounts", ["lower", "upper", "counts"]) +IntervalWithTypeCounts = namedtuple( + "IntervalWithTypeCounts", ["lower", "upper", "counts"] +) diff --git a/execution_engine/task/task.py b/execution_engine/task/task.py index a8bc153c..107e48ef 100644 --- a/execution_engine/task/task.py +++ b/execution_engine/task/task.py @@ -173,41 +173,25 @@ def run( result = self.handle_unary_logical_operator( data, base_data, observation_window ) + elif isinstance(self.expr, logic.TemporalCount): + result = self.handle_temporal_operator(data, observation_window) elif isinstance( self.expr, - ( - logic.And, - logic.Or, - logic.NonSimplifiableAnd, - logic.Count, - logic.CappedCount, - logic.AllOrNone, - ), + (logic.CommutativeOperator), ): result = self.handle_binary_logical_operator(data) - elif isinstance(self.expr, (logic.BinaryNonCommutativeOperator)): + elif isinstance(self.expr, logic.BinaryNonCommutativeOperator): result = self.handle_left_dependent_toggle( left=self.select_predecessor_result(self.expr.left, data), right=self.select_predecessor_result(self.expr.right, data), base_data=base_data, observation_window=observation_window, ) - elif isinstance( - self.expr, (logic.NoDataPreservingAnd, logic.NoDataPreservingOr) - ): - result = self.handle_no_data_preserving_operator( - data, base_data, observation_window - ) - elif isinstance(self.expr, logic.TemporalCount): - result = self.handle_temporal_operator(data, observation_window) else: raise ValueError(f"Unsupported expression type: {type(self.expr)}") if self.store_result: - if ( - not isinstance(self.expr, logic.NoDataPreservingAnd) - and not self.expr.is_Atom - ): + if not self.expr.is_Atom: result = self.insert_negative_intervals( result, base_data, observation_window ) @@ -296,7 +280,7 @@ def handle_binary_logical_operator( self, data: list[PersonIntervals] ) -> PersonIntervals: """ - Handles a binary logical operator (And or Or) by merging or intersecting the intervals of the + Handles a binary logical operator by using the appropriate processing function. :param data: The input data. :return: A DataFrame with the merged or intersected intervals. @@ -311,7 +295,7 @@ def handle_binary_logical_operator( if isinstance(self.expr, (logic.And, logic.NonSimplifiableAnd)): result = process.intersect_intervals(data) - elif isinstance(self.expr, logic.Or): + elif isinstance(self.expr, (logic.Or, logic.NonSimplifiableOr)): result = process.union_intervals(data) elif isinstance(self.expr, logic.Count): result = process.count_intervals(data) @@ -353,43 +337,6 @@ def handle_binary_logical_operator( return result - def handle_no_data_preserving_operator( - self, - data: list[PersonIntervals], - base_data: PersonIntervals, - observation_window: TimeRange, - ) -> PersonIntervals: - """ - Handles a NoDataPreservingAnd/Or operator. - - These are used to combine POPULATION, INTERVENTION and POPULATION/INTERVENTION results from different - population/intervention pairs into a single result (i.e. the full recommendation's POPULATION etc.). - - The POSITIVE intervals are intersected (And) or merged (Or), the NO_DATA intervals are intersected and the - remaining intervals are set to NEGATIVE. - - :param data: The input data. - :param base_data: The result of the base criterion. - :param observation_window: The observation window. - :return: A DataFrame with the merged intervals. - """ - assert isinstance( - self.expr, (logic.NoDataPreservingAnd, logic.NoDataPreservingOr) - ), "Dependency is not a NoDataPreservingAnd / NoDataPreservingOr expression." - - if isinstance(self.expr, logic.NoDataPreservingAnd): - result = process.intersect_intervals(data) - elif isinstance(self.expr, logic.NoDataPreservingOr): - result = process.union_intervals(data) - else: - raise ValueError(f"Unsupported expression type: {type(self.expr)}") - - # todo: the only difference between this function and handle_binary_logical_operator is the following lines - # - can we merge? - return self.insert_negative_intervals( - data=result, base_data=base_data, observation_window=observation_window - ) - def handle_left_dependent_toggle( self, left: PersonIntervals, diff --git a/execution_engine/util/interval/typed_interval.py b/execution_engine/util/interval/typed_interval.py index 3c2e4eac..aa35696f 100644 --- a/execution_engine/util/interval/typed_interval.py +++ b/execution_engine/util/interval/typed_interval.py @@ -1418,6 +1418,7 @@ def interval_datetime( :param lower: The lower bound. :param upper: The upper bound. + :param type_: The type of the interval. :return: The new datetime interval. """ return DateTimeInterval.from_atomic(Bound.CLOSED, lower, upper, Bound.CLOSED, type_) # type: ignore # mypy expects "IntervalT", not sure why diff --git a/execution_engine/util/logic.py b/execution_engine/util/logic.py index c140ea1e..f28ff219 100644 --- a/execution_engine/util/logic.py +++ b/execution_engine/util/logic.py @@ -90,7 +90,7 @@ def update_args(self, *args: Any) -> None: """ raise NotImplementedError("update_args must be implemented by subclasses") - def __new__(cls, *args: Any) -> "Expr": + def __new__(cls, *args: Any, **kwargs: Any) -> "Expr": """ Initialize an expression with given arguments. @@ -610,7 +610,7 @@ def _validate_time_inputs( if start_time >= end_time: raise ValueError("start_time must be less than end_time") - elif interval_criterion and not isinstance(interval_criterion, (BaseExpr)): + elif interval_criterion and not isinstance(interval_criterion, BaseExpr): raise ValueError( f"Invalid criterion - expected Criterion or CriterionCombination, got {type(interval_criterion)}" ) @@ -691,35 +691,25 @@ def __new__(cls, *args: Any, **kwargs: Any) -> "NonSimplifiableAnd": return cast(NonSimplifiableAnd, super().__new__(cls, *args, **kwargs)) -# todo: can we rename to more meaningful name? -class NoDataPreservingAnd(CommutativeOperator): +class NonSimplifiableOr(CommutativeOperator): """ - A And object represents a logical AND operation. + A NonSimplifiableOr object represents a logical Or operation that cannot be simplified. - See Task.handle_no_data_preserving_operator for the rules. Currently, the only difference between this operator - and the And operator is that during handling of this operator, the negative intervals are added explicitly. - """ - - def __new__(cls, *args: Any, **kwargs: Any) -> "NoDataPreservingAnd": - """ - Create a new NoDataPreservingAnd object. - """ - return cast(NoDataPreservingAnd, super().__new__(cls, *args, **kwargs)) - - -class NoDataPreservingOr(CommutativeOperator): - """ - A Or object represents a logical OR operation. + Simplified here means that when this operator is used on a single argument, still this operator is returned + instead of the argument itself, as is the case with the sympy.Or operator. - See Task.handle_no_data_preserving_operator for the rules. Currently, the only difference between this operator - and the And operator is that during handling of this operator, the negative intervals are added explicitly. + The reason for this operator is that if there is a single population or intervention criterion, the And/Or + operators would simplify to the criterion itself. In that case, the _whole_ population or intervention expression of + the respective population/intervention pair, which should be written to the database with criterion_id = None, would + be lost (i.e. not written, because there is no graph execution node that would perform this. This operator prevents + that. """ - def __new__(cls, *args: Any, **kwargs: Any) -> "NoDataPreservingOr": + def __new__(cls, *args: Any, **kwargs: Any) -> "NonSimplifiableOr": """ - Create a new NoDataPreservingOr object. + Create a new NonSimplifiableOr object. """ - return cast(NoDataPreservingOr, super().__new__(cls, *args, **kwargs)) + return cast(NonSimplifiableOr, super().__new__(cls, *args, **kwargs)) class BinaryNonCommutativeOperator(BooleanFunction, SerializableABC): diff --git a/execution_engine/util/value/value.py b/execution_engine/util/value/value.py index 7b795716..fcda3dc6 100644 --- a/execution_engine/util/value/value.py +++ b/execution_engine/util/value/value.py @@ -18,7 +18,10 @@ "ValueNumeric", "ValueNumber", "ValueConcept", + "ValueScalar", "check_int", + "check_unit_none", + "check_value_min_max_none", ] from execution_engine.util import serializable diff --git a/scripts/execute.py b/scripts/execute.py index 5be3b871..c17ae152 100644 --- a/scripts/execute.py +++ b/scripts/execute.py @@ -92,13 +92,13 @@ recommendation_package_version = "v1.5.2" urls = [ - # "sepsis/recommendation/ventilation-plan-ards-tidal-volume", - # "covid19-inpatient-therapy/recommendation/no-therapeutic-anticoagulation", - # "covid19-inpatient-therapy/recommendation/ventilation-plan-ards-tidal-volume", + "covid19-inpatient-therapy/recommendation/no-therapeutic-anticoagulation", + "covid19-inpatient-therapy/recommendation/ventilation-plan-ards-tidal-volume", "covid19-inpatient-therapy/recommendation/covid19-ventilation-plan-peep", "covid19-inpatient-therapy/recommendation/prophylactic-anticoagulation", "covid19-inpatient-therapy/recommendation/therapeutic-anticoagulation", "covid19-inpatient-therapy/recommendation/covid19-abdominal-positioning-ards", + "sepsis/recommendation/ventilation-plan-ards-tidal-volume", ] start_datetime = pendulum.parse("2020-01-01 00:00:00+01:00") diff --git a/tests/execution_engine/omop/criterion/test_criterion.py b/tests/execution_engine/omop/criterion/test_criterion.py index ec095ef9..092e2e9c 100644 --- a/tests/execution_engine/omop/criterion/test_criterion.py +++ b/tests/execution_engine/omop/criterion/test_criterion.py @@ -321,7 +321,7 @@ def insert_expression( base_criterion: Criterion, observation_window: TimeRange, ): - # population_expr is assigned a NoDataPreservingAnd to ensure creation of negative intervals + # population_expr is assigned a NonSimplifiableAnd to ensure creation of negative intervals pi_pair = PopulationInterventionPairExpr( population_expr=logic.NonSimplifiableAnd(population), intervention_expr=intervention, diff --git a/tests/execution_engine/util/test_logic.py b/tests/execution_engine/util/test_logic.py index f823ea13..feef0886 100644 --- a/tests/execution_engine/util/test_logic.py +++ b/tests/execution_engine/util/test_logic.py @@ -13,8 +13,6 @@ LeftDependentToggle, MaxCount, MinCount, - NoDataPreservingAnd, - NoDataPreservingOr, NonSimplifiableAnd, Not, Or, @@ -178,28 +176,6 @@ def test_non_simplifiable_and_equality(self): assert non_simp_and1 == non_simp_and2 -class TestNoDataPreservingAnd: - def test_no_data_preserving_and_creation(self): - no_data_and = NoDataPreservingAnd(x, y) - assert isinstance(no_data_and, NoDataPreservingAnd) - assert no_data_and.args[0] == x - assert no_data_and.args[1] == y - - assert not no_data_and.is_Not - assert not no_data_and.is_Atom - - -class TestNoDataPreservingOr: - def test_no_data_preserving_or_creation(self): - no_data_or = NoDataPreservingOr(x, y) - assert isinstance(no_data_or, NoDataPreservingOr) - assert no_data_or.args[0] == x - assert no_data_or.args[1] == y - - assert not no_data_or.is_Not - assert not no_data_or.is_Atom - - class TestLeftDependentToggle: def test_left_dependent_toggle_creation(self): left_expr = x @@ -236,8 +212,6 @@ class TestSymbolMultiprocessing: ExactCount(1, 2, 3, threshold=2), AllOrNone(1, 2, 3), NonSimplifiableAnd(1, 2, 3), - NoDataPreservingAnd(1, 2, 3), - NoDataPreservingOr(1, 2, 3), LeftDependentToggle(left=1, right=2), TemporalMinCount( 1, diff --git a/tests/recommendation/test_recommendation_base_v2.py b/tests/recommendation/test_recommendation_base_v2.py index 5bc1a584..c87c92a7 100644 --- a/tests/recommendation/test_recommendation_base_v2.py +++ b/tests/recommendation/test_recommendation_base_v2.py @@ -553,6 +553,8 @@ def setup_testdata(self, db_session, run_slow_tests): for item in generate_combinations(c, self.invalid_combinations) ] + # combinations = [combinations[0]] + self.insert_criteria_into_database(db_session, combinations) df_criterion_entries = self.generate_criterion_entries(combinations) From a4e49f1bcdeb4e2a8588f8dbc0b10435f747df59 Mon Sep 17 00:00:00 2001 From: Gregor Lichtner Date: Mon, 17 Mar 2025 20:57:48 +0100 Subject: [PATCH 10/16] feat: allow nested RecommendationPlans --- .../converter/characteristic/value.py | 7 +- execution_engine/converter/parser/base.py | 12 ++ .../converter/recommendation_factory.py | 78 ++++++---- execution_engine/fhir/recommendation.py | 139 ++++++++++++++++-- execution_engine/fhir/terminology.py | 21 ++- .../omop/cohort/recommendation.py | 4 +- .../fhir/test_recommendation.py | 21 +-- 7 files changed, 219 insertions(+), 63 deletions(-) diff --git a/execution_engine/converter/characteristic/value.py b/execution_engine/converter/characteristic/value.py index 5512eaa5..51fb91a0 100644 --- a/execution_engine/converter/characteristic/value.py +++ b/execution_engine/converter/characteristic/value.py @@ -22,7 +22,12 @@ def from_fhir( """Creates a new Characteristic instance from a FHIR EvidenceVariable.characteristic.""" assert cls.valid(characteristic), "Invalid characteristic definition" - type_omop_concept = parse_code(characteristic.definitionByTypeAndValue.type) + try: + type_omop_concept = parse_code(characteristic.definitionByTypeAndValue.type) + except ValueError: + type_omop_concept = parse_code( + characteristic.definitionByTypeAndValue.type, standard=False + ) value = parse_value( value_parent=characteristic.definitionByTypeAndValue, value_prefix="value" ) diff --git a/execution_engine/converter/parser/base.py b/execution_engine/converter/parser/base.py index 8f75aad1..0c55f1c0 100644 --- a/execution_engine/converter/parser/base.py +++ b/execution_engine/converter/parser/base.py @@ -1,6 +1,8 @@ from abc import ABC, abstractmethod +from typing import Callable, Type from fhir.resources.evidencevariable import EvidenceVariable +from fhir.resources.plandefinition import PlanDefinition, PlanDefinitionAction from execution_engine import fhir from execution_engine.converter.criterion import CriterionConverterFactory @@ -41,3 +43,13 @@ def parse_actions( corresponding action selection behavior. """ raise NotImplementedError() + + @abstractmethod + def parse_action_combination_method( + self, action_parent: PlanDefinition | PlanDefinitionAction + ) -> Type[logic.BooleanFunction] | Callable: + """ + Parses the action combination method of a Recommendation (PlanDefinition) and returns the corresponding + combination method in form of a logical expression. + """ + raise NotImplementedError() diff --git a/execution_engine/converter/recommendation_factory.py b/execution_engine/converter/recommendation_factory.py index dbcb7a5f..64f57978 100644 --- a/execution_engine/converter/recommendation_factory.py +++ b/execution_engine/converter/recommendation_factory.py @@ -1,7 +1,12 @@ from execution_engine import fhir from execution_engine.builder import ExecutionEngineBuilder +from execution_engine.converter.parser.base import FhirRecommendationParserInterface from execution_engine.converter.parser.factory import FhirRecommendationParserFactory from execution_engine.fhir.client import FHIRClient +from execution_engine.fhir.recommendation import ( + RecommendationPlan, + RecommendationPlanCollection, +) from execution_engine.omop import cohort from execution_engine.omop.cohort.population_intervention_pair import ( PopulationInterventionPairExpr, @@ -60,34 +65,12 @@ def parse_recommendation_from_url( fhir_connector=fhir_client, ) - pi_pairs: list[PopulationInterventionPairExpr] = [] - - base_criterion = PatientsActiveDuringPeriod() - base_criterion.dict() - for rec_plan in rec.plans(): - - # parse population and create criteria - population_criteria = parser.parse_characteristics(rec_plan.population) - - # parse intervention and create criteria - actions = parser.parse_actions(rec_plan.actions, rec_plan) - - # population_expr is assigned a NonSimplifiableAnd to ensure creation of negative intervals - # todo: not sure we really need this - we can just always create negative intervals when store_results=True - # in the graph - pi_pair = PopulationInterventionPairExpr( - population_expr=population_criteria, - intervention_expr=actions, - name=rec_plan.name, - url=rec_plan.url, - base_criterion=base_criterion, - ) - - pi_pairs.append(pi_pair) + # Recursively build a single expression from the nested plans/collections: + expr = self._parse_collection(rec.plans(), parser) recommendation = cohort.Recommendation( - expr=logic.Or(*pi_pairs), - base_criterion=base_criterion, + expr=expr, + base_criterion=PatientsActiveDuringPeriod(), url=rec.url, name=rec.name, title=rec.title, @@ -97,3 +80,46 @@ def parse_recommendation_from_url( ) return recommendation + + def _parse_collection( + self, + plan_or_collection: RecommendationPlanCollection | RecommendationPlan, + parser: FhirRecommendationParserInterface, + ) -> logic.Expr: + """ + Recursively parse a single RecommendationPlan or a nested RecommendationPlanCollection + into a logic.Expr. If it's a collection, gather each sub-item's expression and combine + them using parser.parse_action_combination_method(...). + """ + if isinstance(plan_or_collection, RecommendationPlan): + # Parse a leaf plan: build a population + interventions expression (e.g. PopulationInterventionPairExpr) + population_criteria = parser.parse_characteristics( + plan_or_collection.population + ) + intervention_criteria = parser.parse_actions( + plan_or_collection.actions, plan_or_collection + ) + return PopulationInterventionPairExpr( + population_expr=population_criteria, + intervention_expr=intervention_criteria, + name=plan_or_collection.name, + url=plan_or_collection.url, + base_criterion=PatientsActiveDuringPeriod(), + ) + + elif isinstance(plan_or_collection, RecommendationPlanCollection): + # Recursively parse all sub-items + sub_exprs = [ + self._parse_collection(sub_item, parser) + for sub_item in plan_or_collection.plans + ] + + combination_op = parser.parse_action_combination_method( + plan_or_collection.fhir + ) + + # Combine all sub-expressions with the appropriate operator + return combination_op(*sub_exprs) + + else: + raise TypeError("Unknown plan_or_collection type.") diff --git a/execution_engine/fhir/recommendation.py b/execution_engine/fhir/recommendation.py index 98f6d650..a45e9da7 100644 --- a/execution_engine/fhir/recommendation.py +++ b/execution_engine/fhir/recommendation.py @@ -1,4 +1,5 @@ import logging +from typing import Tuple, Union, cast import fhir from fhir.resources.activitydefinition import ActivityDefinition @@ -23,7 +24,9 @@ def __init__( self.fhir = fhir_connector self._recommendation: PlanDefinition | None = None - self._recommendation_plans: list[RecommendationPlan] = [] + self._recommendation_plans: RecommendationPlanCollection = ( + RecommendationPlanCollection(fhir=None) + ) self.load(url) @@ -96,7 +99,7 @@ def description(self) -> str: return self._recommendation.description - def plans(self) -> list["RecommendationPlan"]: + def plans(self) -> "RecommendationPlanCollection": """ Return the recommendation plans of this recommendation (i.e. the individual, non-overlapping parts or steps of the recommendation). @@ -107,19 +110,86 @@ def load(self, url: str) -> None: """ Load the recommendation from the FHIR server. """ - self._recommendation = self.fetch_recommendation(url) + self._recommendation, self._recommendation_plans = self.fetch_recommendation( + url + ) + logging.info("Recommendation loaded.") + + def _build_population_intervention_pair_collection( + self, plan_def: PlanDefinition + ) -> "RecommendationPlanCollection": + """ + Build and return a RecommendationPlanCollection for the given PlanDefinition. + + If the PlanDefinition is of type "eca-rule", we create a single RecommendationPlan + and add it as the sole entry in the collection. If it is of type "workflow-definition", + we recursively process any sub-actions referencing other PlanDefinitions and add + their collections/plans as entries. + + :param plan_def: The loaded PlanDefinition resource. + :return: A RecommendationPlanCollection containing one or more RecommendationPlans (or nested collections). + """ + + cc = get_coding(plan_def.type, CS_PLAN_DEFINITION_TYPE) + + collection = RecommendationPlanCollection(fhir=plan_def) - for rec_action in self._recommendation.action: - rec_plan = RecommendationPlan( - rec_action.definitionCanonical, + if cc.code == "eca-rule": + plan = RecommendationPlan( + plan_def.url, package_version=self._package_version, fhir_connector=self.fhir, ) - self._recommendation_plans.append(rec_plan) + collection.add_plan(plan) + + elif cc.code == "workflow-definition": + # Recursively build sub-items for each referenced PlanDefinition + if plan_def.action: + for sub_action in plan_def.action: + if sub_action.definitionCanonical: + sub_plan_def = cast( + PlanDefinition, + self.fhir.fetch_resource( + "PlanDefinition", + sub_action.definitionCanonical, + self._package_version, + ), + ) + sub_item = self._build_population_intervention_pair_collection( + sub_plan_def + ) + collection.add_plan(sub_item) + else: + raise ValueError(f"Unknown PlanDefinition type: {cc.code}") + + return collection + + def fetch_recommendation( + self, canonical_url: str + ) -> Tuple[PlanDefinition, "RecommendationPlanCollection"]: + """ + Fetches the PlanDefinition, then checks if it's an 'eca-rule' or 'workflow-definition'. + If it's 'eca-rule', this is effectively a single RecommendationPlan; + if it's 'workflow-definition', it's a RecommendationPlanCollection. + + (Only changes shown below: how you might branch to a new or refactored _build_plan_or_collection.) + """ + plan_def = cast( + PlanDefinition, + self.fhir.fetch_resource( + "PlanDefinition", canonical_url, self._package_version + ), + ) + cc = get_coding(plan_def.type, CS_PLAN_DEFINITION_TYPE) - logging.info("Recommendation loaded.") + plan_collection = self._build_population_intervention_pair_collection(plan_def) + + if cc.code not in ("eca-rule", "workflow-definition"): + raise ValueError(f"Unknown recommendation type: {cc.code}") - def fetch_recommendation(self, canonical_url: str) -> PlanDefinition: + return plan_def, plan_collection + + def fetch_recommendation_DELETEME(self, canonical_url: str) -> PlanDefinition: """ Fetch the recommendation specified by the canonical URL from the FHIR server. @@ -286,8 +356,8 @@ def fetch_recommendation_plan(self, canonical_url: str) -> PlanDefinition: This method checks if the PlanDefinition resource referenced by the canonical URL is a recommendation (i.e. PlanDefinition.type = #workflow-definition). If - it is an recommendation-plan instead (i.e. PlanDefinition.type = #eca-rule), - it will fetch the PlanDefinition that is referenced by the extension[partOf]. + it is a recommendation-plan instead (i.e. PlanDefinition.type = #eca-rule), + it will recursively fetch the PlanDefinitions. :param canonical_url: Canonical URL of the recommendation :return: FHIR PlanDefinition @@ -364,3 +434,50 @@ def is_combination_definition( characteristics defined using EvidenceVariableCharacteristic.definitionByCombination. """ return characteristic.definitionByCombination is not None + + +class RecommendationPlanCollection: + """ + Represents a collection of RecommendationPlan objects or nested RecommendationPlanCollection objects. + + This class is used to aggregate multiple recommendation plans into a single structure, + enabling the recursive composition of recommendation workflows. + """ + + def __init__(self, fhir: PlanDefinition | None) -> None: + """ + Create a collection of RecommendationPlans or nested RecommendationPlanCollection objects. + + :param combination_method: A string indicating how these plans should be combined + (e.g. "AND", "OR", "SEQUENCE"). + """ + self._fhir = fhir + self._plans: list[RecommendationPlan | RecommendationPlanCollection] = [] + + @property + def fhir(self) -> PlanDefinition | None: + """ + Retrieve the FHIR PlanDefinition resource associated with this collection. + + :return: The FHIR PlanDefinition if set; otherwise, None. + """ + return self._fhir + + @property + def plans(self) -> list["RecommendationPlan | RecommendationPlanCollection"]: + """ + Retrieve the list of recommendation plans or nested recommendation plan collections. + + :return: A list containing RecommendationPlan objects or nested RecommendationPlanCollection objects. + """ + return self._plans + + def add_plan( + self, plan: Union[RecommendationPlan, "RecommendationPlanCollection"] + ) -> None: + """ + Add a recommendation plan or a nested recommendation plan collection to this collection. + + :param plan: The RecommendationPlan or RecommendationPlanCollection to be added. + """ + self._plans.append(plan) diff --git a/execution_engine/fhir/terminology.py b/execution_engine/fhir/terminology.py index 1f3d86a9..180d5394 100644 --- a/execution_engine/fhir/terminology.py +++ b/execution_engine/fhir/terminology.py @@ -1,3 +1,6 @@ +import logging +from json import JSONDecodeError + import requests # fixme use caching? @@ -121,7 +124,9 @@ def code_in_valueset( :param system: The coding system the code belongs to :return: True if the code is valid within the ValueSet, False otherwise """ - # Prepare the request data + + logging.debug(f"Validating code {code} from system {system} in ValueSet") + data = { "resourceType": "Parameters", "parameter": [ @@ -145,15 +150,23 @@ def code_in_valueset( timeout=30, ) - json_response = response.json() - if response.status_code == 200: - # Look for a parameter named 'result' and check if its value is True + + json_response = response.json() + for param in json_response.get("parameter", []): if param.get("name") == "result": return param.get("valueBoolean", False) return False else: + + try: + json_response = response.json() + except JSONDecodeError: + raise FHIRTerminologyServerException( + f"Error validating code: HTTP Status {response.status_code}" + ) + issues = [ issue["severity"] + ": " + issue["details"]["text"] for issue in json_response["issue"] diff --git a/execution_engine/omop/cohort/recommendation.py b/execution_engine/omop/cohort/recommendation.py index ee757143..695805b6 100644 --- a/execution_engine/omop/cohort/recommendation.py +++ b/execution_engine/omop/cohort/recommendation.py @@ -38,7 +38,7 @@ class Recommendation(SerializableDataClass): def __init__( self, - expr: logic.BooleanFunction, + expr: logic.Expr, base_criterion: Criterion, name: str, title: str, @@ -47,7 +47,7 @@ def __init__( description: str, package_version: str | None = None, ) -> None: - self._expr: logic.BooleanFunction = expr + self._expr: logic.Expr = expr self._base_criterion: Criterion = base_criterion self._name: str = name self._title: str = title diff --git a/tests/execution_engine/fhir/test_recommendation.py b/tests/execution_engine/fhir/test_recommendation.py index 0aa496ca..b69a0e8f 100644 --- a/tests/execution_engine/fhir/test_recommendation.py +++ b/tests/execution_engine/fhir/test_recommendation.py @@ -65,7 +65,7 @@ def mock_fetch_recommendation(self, test_class): {"status": "draft", "action": []} ) with patch.object( - test_class, "fetch_recommendation", return_value=plan_definition + test_class, "fetch_recommendation", return_value=(plan_definition, None) ) as _fixture: yield _fixture @@ -125,24 +125,7 @@ def test_recommendation_load_with_unknown_type(self, mock_fetch_resource_unknown # Test with pytest.raises( - ValueError, match=r"Unknown recommendation type: unknown-type" - ): - _ = Recommendation( - canonical_url, - package_version="latest", - fhir_connector=FHIRClient("http://fhir.example.com"), - ) - - def test_recommendation_fetch_with_no_partof_extension( - self, mock_fetch_resource_no_partOf - ): - # Setup - canonical_url = "http://test.com/PlanDefinition/123" - - # Test - with pytest.raises( - ValueError, - match=r"No partOf extension found in PlanDefinition, can't fetch recommendation.", + ValueError, match=r"Unknown PlanDefinition type: unknown-type" ): _ = Recommendation( canonical_url, From 1587c966bc31c6bccaf1f436775b7706f9b596c9 Mon Sep 17 00:00:00 2001 From: Gregor Lichtner Date: Mon, 17 Mar 2025 22:56:25 +0100 Subject: [PATCH 11/16] fix: task serialization --- execution_engine/task/creator.py | 7 +++++++ execution_engine/task/runner.py | 7 +++++++ execution_engine/task/task.py | 2 +- execution_engine/util/logic.py | 13 +++++++++++++ execution_engine/util/serializable.py | 2 +- 5 files changed, 29 insertions(+), 2 deletions(-) diff --git a/execution_engine/task/creator.py b/execution_engine/task/creator.py index cd445827..87b23809 100644 --- a/execution_engine/task/creator.py +++ b/execution_engine/task/creator.py @@ -43,6 +43,13 @@ def node_to_task(expr: logic.Expr, attr: dict) -> Task: flattened_tasks = list(tasks.values()) + # we will make sure all tasks are depickled correctly + for i, node in enumerate(tasks): + if logic.Expr.from_dict(node.dict(include_id=True)) != node: + raise RuntimeError( + "Expected depickled node to be the same as initial node." + ) + assert ( len(set(flattened_tasks)) == len(flattened_tasks) diff --git a/execution_engine/task/runner.py b/execution_engine/task/runner.py index 29510951..af0ea25d 100644 --- a/execution_engine/task/runner.py +++ b/execution_engine/task/runner.py @@ -310,6 +310,8 @@ def task_executor_worker() -> None: self.start_workers(task_executor_worker) + task_names = {task.name() for task in self.tasks} + try: while len(self.completed_tasks) < len(self.tasks): if self.stop_event.is_set(): @@ -335,6 +337,11 @@ def task_executor_worker() -> None: logging.info( f"Completed {len(self.completed_tasks)} of {len(self.tasks)} tasks" ) + if not all(task in task_names for task in self.completed_tasks): + raise TaskError( + "Completed tasks differ from actual tasks " + "- problem with pickling/unpickling in multiprocessing?" + ) except Exception as e: logging.error(f"An error occurred: {e}") diff --git a/execution_engine/task/task.py b/execution_engine/task/task.py index 107e48ef..daf28fa1 100644 --- a/execution_engine/task/task.py +++ b/execution_engine/task/task.py @@ -636,7 +636,7 @@ def id(self) -> str: """ Returns the id of the Task object. """ - hash_value = hash((str(self.expr), json.dumps(self.bind_params))) + hash_value = hash((self.expr, json.dumps(self.bind_params))) # Determine the number of bytes needed. Python's hash returns a value based on the platform's pointer size. # It's 8 bytes for 64-bit systems and 4 bytes for 32-bit systems. diff --git a/execution_engine/util/logic.py b/execution_engine/util/logic.py index f28ff219..e30da2a6 100644 --- a/execution_engine/util/logic.py +++ b/execution_engine/util/logic.py @@ -620,6 +620,19 @@ def dict(self, include_id: bool = False) -> dict: Get a dictionary representation of the object. """ data = super().dict(include_id=include_id) + + if self.interval_criterion: + args, pop = self.args[:-1], self.args[-1] + + if pop != self.interval_criterion: + raise ValueError( + f"Expected last argument to be the interval_criterion, got {str(pop)}" + ) + + data["data"]["args"] = [ + arg_to_dict(arg, include_id=include_id) for arg in args + ] + data["data"].update( { "threshold": self.count_min, diff --git a/execution_engine/util/serializable.py b/execution_engine/util/serializable.py index 239c8030..9089aeff 100644 --- a/execution_engine/util/serializable.py +++ b/execution_engine/util/serializable.py @@ -282,7 +282,7 @@ def __eq__(self, other: Any) -> bool: if not isinstance(other, self.__class__): return False - return self._hash == other._hash + return hash(self) == hash(other) def rehash(self) -> None: """ From f0b29f6d4bdc7cbb4b0e32d2058882632f73353a Mon Sep 17 00:00:00 2001 From: Gregor Lichtner Date: Mon, 17 Mar 2025 23:50:35 +0100 Subject: [PATCH 12/16] fix: Count operator should return N/A intervals --- execution_engine/task/process/rectangle.py | 46 +++++++++++-------- .../test_recommendation_base_v2.py | 2 - 2 files changed, 28 insertions(+), 20 deletions(-) diff --git a/execution_engine/task/process/rectangle.py b/execution_engine/task/process/rectangle.py index b3546180..91bef1e8 100644 --- a/execution_engine/task/process/rectangle.py +++ b/execution_engine/task/process/rectangle.py @@ -479,6 +479,7 @@ def filter_count_intervals( min_count: int | None, max_count: int | None, keep_no_data: bool = True, + keep_not_applicable: bool = True, ) -> PersonIntervals: """ Filters the intervals per dict key in the list by count. @@ -487,14 +488,18 @@ def filter_count_intervals( :param min_count: The minimum count of the intervals. :param max_count: The maximum count of the intervals. :param keep_no_data: Whether to keep NO_DATA intervals (irrespective of the count). + :param keep_not_applicable: Whether to keep NOT_APPLICABLE intervals (irrespective of the count). :return: A dict with the unioned intervals. """ result: PersonIntervals = {} interval_filter = [] + if keep_no_data: interval_filter.append(IntervalType.NO_DATA) + if keep_not_applicable: + interval_filter.append(IntervalType.NOT_APPLICABLE) if min_count is None and max_count is None: raise ValueError("min_count and max_count cannot both be None") @@ -668,6 +673,7 @@ def create_time_intervals( # Prepare to collect intervals intervals = [] previous_end = None + def add_interval(interval_start, interval_end, interval_type): nonlocal previous_end effective_start = max(interval_start, start_datetime) @@ -681,11 +687,13 @@ def add_interval(interval_start, interval_end, interval_type): # touching intervals. if previous_end is not None: assert previous_end < effective_start - intervals.append(Interval( - lower=effective_start.timestamp(), - upper=effective_end.timestamp(), - type=interval_type, - )) + intervals.append( + Interval( + lower=effective_start.timestamp(), + upper=effective_end.timestamp(), + type=interval_type, + ) + ) previous_end = effective_end # Current date to process @@ -714,25 +722,22 @@ def add_interval(interval_start, interval_end, interval_type): # overlaps the main datetime range, otherwise fill the day # with an interval of type "not applicable". # TODO: what about intervals "before" the main datetime range? - if end_interval < start_datetime: # completely before datetime range + if end_interval < start_datetime: # completely before datetime range day_start = timezone.localize( - datetime.datetime.combine( - current_date, datetime.time(0, 0, 0) - )) + datetime.datetime.combine(current_date, datetime.time(0, 0, 0)) + ) day_end = timezone.localize( - datetime.datetime.combine( - current_date, datetime.time(23, 59, 59) - )) + datetime.datetime.combine(current_date, datetime.time(23, 59, 59)) + ) if (previous_end is not None) and day_start <= previous_end: start = previous_end + datetime.timedelta(seconds=1) else: start = day_start add_interval(start, day_end, IntervalType.NOT_APPLICABLE) - elif end_datetime < start_interval: # completely after datetime range + elif end_datetime < start_interval: # completely after datetime range day_start = timezone.localize( - datetime.datetime.combine( - current_date, datetime.time(0, 0, 0) - )) + datetime.datetime.combine(current_date, datetime.time(0, 0, 0)) + ) if (previous_end is not None) and day_start <= previous_end: start = previous_end + datetime.timedelta(seconds=1) else: @@ -799,10 +804,15 @@ def find_overlapping_personal_windows( return result + def find_rectangles_with_count(data: list[PersonIntervals]) -> PersonIntervals: if len(data) == 0: return {} else: keys = data[0].keys() - return {key: _impl.find_rectangles_with_count([ intervals[key] for intervals in data ]) - for key in keys} + return { + key: _impl.find_rectangles_with_count( + [intervals[key] for intervals in data] + ) + for key in keys + } diff --git a/tests/recommendation/test_recommendation_base_v2.py b/tests/recommendation/test_recommendation_base_v2.py index c87c92a7..5bc1a584 100644 --- a/tests/recommendation/test_recommendation_base_v2.py +++ b/tests/recommendation/test_recommendation_base_v2.py @@ -553,8 +553,6 @@ def setup_testdata(self, db_session, run_slow_tests): for item in generate_combinations(c, self.invalid_combinations) ] - # combinations = [combinations[0]] - self.insert_criteria_into_database(db_session, combinations) df_criterion_entries = self.generate_criterion_entries(combinations) From eeea0213ac0cca9c60a5130f6176372f19d84c30 Mon Sep 17 00:00:00 2001 From: Gregor Lichtner Date: Fri, 21 Mar 2025 09:34:04 +0100 Subject: [PATCH 13/16] feat: implement relativeTime extension --- .pre-commit-config.yaml | 12 +- converter/parser/util.py | 84 +++++++ execution_engine/builder.py | 56 ++++- execution_engine/constants.py | 1 + execution_engine/converter/action/abstract.py | 11 +- .../converter/action/procedure.py | 16 +- execution_engine/converter/parser/base.py | 37 +++- execution_engine/converter/parser/factory.py | 5 + .../converter/parser/fhir_parser_v1.py | 208 ++++++++++++++++-- .../converter/relative_time/__init__.py | 0 .../converter/relative_time/abstract.py | 90 ++++++++ execution_engine/converter/temporal.py | 2 +- .../converter/temporal_indicator.py | 46 ++++ .../converter/time_from_event/abstract.py | 90 ++++---- execution_engine/fhir/util.py | 71 +++++- .../omop/criterion/device_exposure.py | 7 + .../converter/test_converter.py | 7 +- 17 files changed, 641 insertions(+), 102 deletions(-) create mode 100644 converter/parser/util.py create mode 100644 execution_engine/converter/relative_time/__init__.py create mode 100644 execution_engine/converter/relative_time/abstract.py create mode 100644 execution_engine/converter/temporal_indicator.py create mode 100644 execution_engine/omop/criterion/device_exposure.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b4596d3c..0e450f6a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,7 +3,7 @@ default_language_version: exclude: '^docs/' repos: - repo: https://github.com/pycqa/isort - rev: 5.13.2 + rev: 6.0.1 hooks: - id: isort name: isort (python) @@ -17,13 +17,13 @@ repos: - id: check-json - id: pretty-format-json args: ['--autofix', '--no-sort-keys'] -- repo: https://github.com/ambv/black - rev: 24.10.0 +- repo: https://github.com/psf/black + rev: 25.1.0 hooks: - id: black language_version: python3.11 - repo: https://github.com/pre-commit/mirrors-mypy - rev: 'v1.13.0' + rev: 'v1.15.0' hooks: - id: mypy name: mypy @@ -36,12 +36,12 @@ repos: exclude: tests/ args: [--select, "D101,D102,D103,D105,D106"] - repo: https://github.com/PyCQA/bandit - rev: '1.8.0' + rev: '1.8.3' hooks: - id: bandit args: [--skip, "B101,B303,B110,B311"] - repo: https://github.com/PyCQA/flake8 - rev: '7.1.1' + rev: '7.1.2' hooks: - id: flake8 - repo: https://github.com/myint/autoflake diff --git a/converter/parser/util.py b/converter/parser/util.py new file mode 100644 index 00000000..0dcac482 --- /dev/null +++ b/converter/parser/util.py @@ -0,0 +1,84 @@ +import logging +from typing import Callable, cast + +from execution_engine.omop.criterion.abstract import Criterion +from execution_engine.omop.criterion.procedure_occurrence import ProcedureOccurrence +from execution_engine.util import logic +from execution_engine.util.temporal_logic_util import Presence + + +def _wrap_criteria_with_factory( + expr: logic.BaseExpr, + factory: Callable[[logic.BaseExpr], logic.TemporalCount], +) -> logic.Expr: + """ + Recursively wraps all Criterion instances within a combination using the specified factory. + + :param expr: A single Criterion or an expression to be processed. + :param factory: A callable that takes a Criterion or expression and returns a TemporalCount. + :return: A new TemporalCount where all Criterion instances have been wrapped using the factory. + :raises ValueError: If an unexpected element type is encountered. + """ + + from digipod.concepts import OMOP_SURGICAL_PROCEDURE + + new_expr: logic.Expr + + if isinstance(expr, Criterion): + new_expr = factory(expr) + elif isinstance(expr, logic.Expr): + + # Create a new combination of the same type with the same operator + args = [] + + # Loop through all elements + for element in expr.args: + if isinstance(element, logic.Expr): + # Recursively wrap nested combinations + args.append(_wrap_criteria_with_factory(element, factory)) + elif isinstance(element, Criterion): + # Wrap individual criteria with the factory + + if ( + isinstance(element, ProcedureOccurrence) + and element.concept.concept_id == OMOP_SURGICAL_PROCEDURE + and element.concept.vocabulary_id == "SNOMED" + ): + logging.warning( + "Removing Surgical Procedure Criterion in TimeFromEvent-SurgicalOperationDate" + ) + continue + + args.append(factory(element)) + + else: + raise ValueError(f"Unexpected element type: {type(element)}") + + new_expr = expr.__class__(*args) + else: + raise ValueError(f"Unexpected element type: {type(expr)}") + + return new_expr + + +def wrap_criteria_with_temporal_indicator( + expr: logic.BaseExpr, + interval_criterion: logic.BaseExpr, +) -> logic.TemporalMinCount: + """ + Wraps all Criterion instances in a combination with a TemporalCount (with interval_criterion). + + :param expr: A single Criterion or an expression to be wrapped. + :param interval_criterion: A Criterion or expression that defines the temporal interval. + :return: A new expression where all Criterion instances are wrapped with a TemporalCount (with interval_criterion). + """ + temporal_combo_factory = lambda criterion: Presence( + criterion=criterion, interval_criterion=interval_criterion + ) + + new_combo = cast( + logic.TemporalMinCount, + _wrap_criteria_with_factory(expr, temporal_combo_factory), + ) + + return new_combo diff --git a/execution_engine/builder.py b/execution_engine/builder.py index 8251f35c..7999ced7 100644 --- a/execution_engine/builder.py +++ b/execution_engine/builder.py @@ -28,7 +28,8 @@ from execution_engine.converter.goal.ventilator_management import ( VentilatorManagementGoal, ) -from execution_engine.converter.time_from_event.abstract import TemporalIndicator +from execution_engine.converter.relative_time.abstract import RelativeTime +from execution_engine.converter.time_from_event.abstract import TimeFromEvent if TYPE_CHECKING: from execution_engine.execution_engine import ExecutionEngine @@ -42,7 +43,8 @@ class CriterionConverterType(TypedDict): characteristic: list[type[CriterionConverter]] action: list[type[CriterionConverter]] goal: list[type[CriterionConverter]] - time_from_event: list[type[TemporalIndicator]] + time_from_event: list[type[TimeFromEvent]] + relative_time: list[type[RelativeTime]] _default_converters: CriterionConverterType = { @@ -63,6 +65,7 @@ class CriterionConverterType(TypedDict): ], "goal": [LaboratoryValueGoal, VentilatorManagementGoal, AssessmentScaleGoal], "time_from_event": [], + "relative_time": [], } @@ -76,6 +79,7 @@ def default_execution_engine_builder() -> "ExecutionEngineBuilder": builder.set_action_converters(_default_converters["action"]) builder.set_goal_converters(_default_converters["goal"]) builder.set_time_from_event_converters(_default_converters["time_from_event"]) + builder.set_relative_time_converters(_default_converters["relative_time"]) return builder @@ -92,7 +96,8 @@ def __init__(self) -> None: self.characteristic_converters: list[type[CriterionConverter]] = [] self.action_converters: list[type[CriterionConverter]] = [] self.goal_converters: list[type[CriterionConverter]] = [] - self.time_from_event_converters: list[type[TemporalIndicator]] = [] + self.time_from_event_converters: list[type[TimeFromEvent]] = [] + self.relative_time_converters: list[type[RelativeTime]] = [] def set_characteristic_converters( self, converters: list[type[CriterionConverter]] @@ -128,7 +133,7 @@ def set_goal_converters( return self def set_time_from_event_converters( - self, converters: list[type[TemporalIndicator]] + self, converters: list[type[TimeFromEvent]] ) -> "ExecutionEngineBuilder": """ Sets (overwrites) the time from event converters for this builder. @@ -140,6 +145,19 @@ def set_time_from_event_converters( return self + def set_relative_time_converters( + self, converters: list[type[RelativeTime]] + ) -> "ExecutionEngineBuilder": + """ + Sets (overwrites) the time from event converters for this builder. + """ + self.relative_time_converters.clear() + + for converter_type in converters: + self.append_relative_time_converter(converter_type) + + return self + def append_characteristic_converter( self, converter_type: type[CriterionConverter] ) -> "ExecutionEngineBuilder": @@ -207,27 +225,49 @@ def prepend_goal_converter( return self def append_time_from_event_converter( - self, converter_type: type[TemporalIndicator] + self, converter_type: type[TimeFromEvent] ) -> "ExecutionEngineBuilder": """ Appends a single time_from_event converter at the end of the list. """ - if not issubclass(converter_type, TemporalIndicator): + if not issubclass(converter_type, TimeFromEvent): raise ValueError(f"Invalid TimeFromEvent converter type: {converter_type}") self.time_from_event_converters.append(converter_type) return self def prepend_time_from_event_converter( - self, converter_type: type[TemporalIndicator] + self, converter_type: type[TimeFromEvent] ) -> "ExecutionEngineBuilder": """ Inserts a single time_from_event converter at the front of the list. """ - if not issubclass(converter_type, TemporalIndicator): + if not issubclass(converter_type, TimeFromEvent): raise ValueError(f"Invalid TimeFromEvent converter type: {converter_type}") self.time_from_event_converters.insert(0, converter_type) return self + def append_relative_time_converter( + self, converter_type: type[RelativeTime] + ) -> "ExecutionEngineBuilder": + """ + Appends a single relative_time converter at the end of the list. + """ + if not issubclass(converter_type, RelativeTime): + raise ValueError(f"Invalid TimeFromEvent converter type: {converter_type}") + self.relative_time_converters.append(converter_type) + return self + + def prepend_relative_time_converter( + self, converter_type: type[RelativeTime] + ) -> "ExecutionEngineBuilder": + """ + Inserts a single relative_time converter at the front of the list. + """ + if not issubclass(converter_type, RelativeTime): + raise ValueError(f"Invalid TimeFromEvent converter type: {converter_type}") + self.relative_time_converters.insert(0, converter_type) + return self + def build(self, verbose: bool = False) -> "ExecutionEngine": """ Builds an ExecutionEngine with the specified converters. diff --git a/execution_engine/constants.py b/execution_engine/constants.py index 2deec5e9..eed6654c 100644 --- a/execution_engine/constants.py +++ b/execution_engine/constants.py @@ -26,6 +26,7 @@ EXT_DOSAGE_CONDITION = "https://www.netzwerk-universitaetsmedizin.de/fhir/cpg-on-ebm-on-fhir/StructureDefinition/ext-dosage-condition" EXT_ACTION_COMBINATION_METHOD = "https://www.netzwerk-universitaetsmedizin.de/fhir/cpg-on-ebm-on-fhir/StructureDefinition/ext-action-combination-method" EXT_CPG_PARTOF = "http://hl7.org/fhir/uv/cpg/StructureDefinition/cpg-partOf" +EXT_RELATIVE_TIME = "https://www.netzwerk-universitaetsmedizin.de/fhir/cpg-on-ebm-on-fhir/StructureDefinition/relative-time" CS_ACTION_COMBINATION_METHOD = "https://www.netzwerk-universitaetsmedizin.de/fhir/cpg-on-ebm-on-fhir/CodeSystem/cs-action-combination-method" diff --git a/execution_engine/converter/action/abstract.py b/execution_engine/converter/action/abstract.py index eaf4ae19..cd91c000 100644 --- a/execution_engine/converter/action/abstract.py +++ b/execution_engine/converter/action/abstract.py @@ -3,10 +3,11 @@ from fhir.resources.timing import Timing as FHIRTiming +from execution_engine import constants from execution_engine.converter.criterion import CriterionConverter, parse_code from execution_engine.converter.goal.abstract import Goal from execution_engine.fhir.recommendation import RecommendationPlan -from execution_engine.fhir.util import get_coding +from execution_engine.fhir.util import get_coding, get_extensions from execution_engine.omop.criterion.abstract import Criterion from execution_engine.omop.vocabulary import AbstractVocabulary from execution_engine.util import AbstractPrivateMethods, logic @@ -136,6 +137,14 @@ def process_timing(cls, timing: FHIRTiming) -> Timing: if rep.offset is not None: raise NotImplementedError("offset has not been implemented") + relative_time = get_extensions(timing, constants.EXT_RELATIVE_TIME) + + if relative_time: + raise NotImplementedError( + "RelativeTime processing within AbstractAction not implemented - " + "should be performed in the parser" + ) + return Timing( count=count, duration=duration, frequency=frequency, interval=interval ) diff --git a/execution_engine/converter/action/procedure.py b/execution_engine/converter/action/procedure.py index bf392d2f..129fad22 100644 --- a/execution_engine/converter/action/procedure.py +++ b/execution_engine/converter/action/procedure.py @@ -3,6 +3,8 @@ from execution_engine.fhir.recommendation import RecommendationPlan from execution_engine.omop.concepts import Concept from execution_engine.omop.criterion.abstract import Criterion +from execution_engine.omop.criterion.condition_occurrence import ConditionOccurrence +from execution_engine.omop.criterion.device_exposure import DeviceExposure from execution_engine.omop.criterion.measurement import Measurement from execution_engine.omop.criterion.observation import Observation from execution_engine.omop.criterion.procedure_occurrence import ProcedureOccurrence @@ -15,7 +17,7 @@ class ProcedureAction(AbstractAction): """ An ProcedureAction is an action that describes a procedure to be performed - This action tests whether the procedure has been performed by determining whether it is + This action tests whether the procedure has been performed by determining whether it is present in the respective OMOP CDM table. """ @@ -82,6 +84,18 @@ def _to_expression(self) -> logic.Symbol: value_required=False, timing=self._timing, ) + case "Device": + criterion = DeviceExposure( + concept=self._code, + value_required=False, + timing=self._timing, + ) + case "Condition": + criterion = ConditionOccurrence( + concept=self._code, + value_required=False, + timing=self._timing, + ) case _: raise ValueError( f"Concept domain {self._code.domain_id} is not supported for {self.__class__.__name__}]" diff --git a/execution_engine/converter/parser/base.py b/execution_engine/converter/parser/base.py index 0c55f1c0..06c9c302 100644 --- a/execution_engine/converter/parser/base.py +++ b/execution_engine/converter/parser/base.py @@ -1,7 +1,12 @@ from abc import ABC, abstractmethod from typing import Callable, Type -from fhir.resources.evidencevariable import EvidenceVariable +from fhir.resources.evidencevariable import ( + EvidenceVariable, + EvidenceVariableCharacteristic, + EvidenceVariableCharacteristicTimeFromEvent, +) +from fhir.resources.extension import Extension from fhir.resources.plandefinition import PlanDefinition, PlanDefinitionAction from execution_engine import fhir @@ -19,11 +24,13 @@ def __init__( action_converters: CriterionConverterFactory, goal_converters: CriterionConverterFactory, time_from_event_converters: TemporalIndicatorConverterFactory, + relative_time_converters: TemporalIndicatorConverterFactory, ): self.characteristics_converters = characteristic_converters self.action_converters = action_converters self.goal_converters = goal_converters self.time_from_event_converters = time_from_event_converters + self.relative_time_converters = relative_time_converters @abstractmethod def parse_characteristics(self, ev: EvidenceVariable) -> logic.BooleanFunction: @@ -53,3 +60,31 @@ def parse_action_combination_method( combination method in form of a logical expression. """ raise NotImplementedError() + + def parse_time_from_event( + self, + tfes: list[EvidenceVariableCharacteristicTimeFromEvent], + ) -> list[logic.BaseExpr]: + """ + Parses `timeFromEvent` elements and converts them into interval-based logical criteria. + """ + raise NotImplementedError() + + @abstractmethod + def parse_relative_time( + self, + relative_time: list[Extension], + ) -> list[logic.BaseExpr]: + """ + Parses `extension[relativeTime]` elements and converts them into interval-based logical criteria. + """ + raise NotImplementedError() + + def parse_timing( + self, characteristic: EvidenceVariableCharacteristic, expr: logic.BaseExpr + ) -> logic.BaseExpr: + """ + Applies temporal constraints to a given criterion expression based on `timeFromEvent` and + the relativeTime extension elements. + """ + raise NotImplementedError() diff --git a/execution_engine/converter/parser/factory.py b/execution_engine/converter/parser/factory.py index 7a0196d4..8e0ed3be 100644 --- a/execution_engine/converter/parser/factory.py +++ b/execution_engine/converter/parser/factory.py @@ -17,6 +17,7 @@ def __init__(self, builder: ExecutionEngineBuilder | None = None): self.action_converters = CriterionConverterFactory() self.goal_converters = CriterionConverterFactory() self.time_from_event_converters = TemporalIndicatorConverterFactory() + self.relative_time_converters = TemporalIndicatorConverterFactory() if builder is None: builder = default_execution_engine_builder() @@ -33,6 +34,9 @@ def __init__(self, builder: ExecutionEngineBuilder | None = None): for temporal_converter in builder.time_from_event_converters: self.time_from_event_converters.register(temporal_converter) + for relative_time_converter in builder.relative_time_converters: + self.relative_time_converters.register(relative_time_converter) + def get_parser(self, parser_version: int) -> FhirRecommendationParserInterface: """ Return the correct FhirParser based on the version string. @@ -52,4 +56,5 @@ def get_parser(self, parser_version: int) -> FhirRecommendationParserInterface: action_converters=self.action_converters, goal_converters=self.goal_converters, time_from_event_converters=self.time_from_event_converters, + relative_time_converters=self.relative_time_converters, ) diff --git a/execution_engine/converter/parser/fhir_parser_v1.py b/execution_engine/converter/parser/fhir_parser_v1.py index 43a145f8..ab6ab62a 100644 --- a/execution_engine/converter/parser/fhir_parser_v1.py +++ b/execution_engine/converter/parser/fhir_parser_v1.py @@ -5,13 +5,17 @@ EvidenceVariableCharacteristic, EvidenceVariableCharacteristicTimeFromEvent, ) +from fhir.resources.extension import Extension from fhir.resources.plandefinition import PlanDefinition, PlanDefinitionAction from execution_engine import fhir +from execution_engine.constants import EXT_RELATIVE_TIME from execution_engine.converter.action.abstract import AbstractAction from execution_engine.converter.characteristic.abstract import AbstractCharacteristic from execution_engine.converter.goal.abstract import Goal from execution_engine.converter.parser.base import FhirRecommendationParserInterface +from execution_engine.converter.parser.util import wrap_criteria_with_temporal_indicator +from execution_engine.fhir.util import get_extensions, pop_extensions from execution_engine.util import logic as logic @@ -59,34 +63,183 @@ class FhirRecommendationParserV1(FhirRecommendationParserInterface): def parse_time_from_event( self, tfes: list[EvidenceVariableCharacteristicTimeFromEvent], - combo: logic.BooleanFunction, - ) -> logic.BooleanFunction: + ) -> list[logic.BaseExpr]: """ - Parses the timeFromEvent elements and updates the logic.BooleanFunction. + Parses `timeFromEvent` elements and converts them into interval-based logical criteria. - Root element timeFromEvent specifies the time within which the population is valid. - Non-root elements act as filters for the criteria they are attached to, meaning only criteria within the timeFromEvent timeframe will be observed to determine if the patient is part of the population. + The `timeFromEvent` elements specify time intervals that constrain when the associated + criteria must occur. These intervals serve as temporal filters, ensuring that only criteria + occurring within the defined time window contribute to determining population membership. + + This function processes each `timeFromEvent` element by converting it into an interval-based + logical criterion. These criteria are later used to enforce temporal constraints on individual + criteria within a characteristic. + + Args: + tfes (list[EvidenceVariableCharacteristicTimeFromEvent]): + A list of `timeFromEvent` elements defining temporal constraints. + + Returns: + list[logic.BaseExpr]: + A list of interval-based logical expressions representing the extracted temporal constraints. + + Raises: + ValueError: + If any `timeFromEvent` element does not yield a valid `logic.BaseExpr`. + + Notes: + - Each `timeFromEvent` element is processed using a registered converter that transforms + it into an interval-based logical expression. + - These interval-based criteria are meant to be combined later (typically with AND) and + applied to individual criteria rather than an entire combination. + """ + interval_criteria = [] + + for tfe in tfes: + converter = self.time_from_event_converters.get(tfe) + + interval_criterion = converter.to_interval_criterion() + + if not isinstance(interval_criterion, logic.BaseExpr): + raise ValueError( + f"Expected instance of BaseExpr, got {type(interval_criterion)}" + ) + + interval_criteria.append(interval_criterion) + + return interval_criteria + + def parse_relative_time( + self, + relative_time: list[Extension], + ) -> list[logic.BaseExpr]: + """ + Parses `extension[relativeTime]` elements and converts them into interval-based logical criteria. + + The `extension[relativeTime]` elements specify time intervals that constrain when the associated + criteria must occur. These intervals serve as temporal filters, ensuring that only criteria + occurring within the defined time window contribute to determining population membership. + + This function processes each `extension[relativeTime]` element by converting it into an interval-based + logical criterion. These criteria are later used to enforce temporal constraints on individual + criteria within a characteristic. Args: - tfes (list[EvidenceVariableCharacteristicTimeFromEvent]): List of timeFromEvent elements. - combo (logic.BooleanFunction): The criterion combination to update. + relative_time (list[Extension]): + A list of `extension[relativeTime]` elements defining temporal constraints. Returns: - TemporalIndicatorCombination: Updated criterion combination. + list[logic.BaseExpr]: + A list of interval-based logical expressions representing the extracted temporal constraints. + + Raises: + ValueError: + If any `extension[relativeTime]` element does not yield a valid `logic.BaseExpr`. + + Notes: + - Each `extension[relativeTime]` element is processed using a registered converter that transforms + it into an interval-based logical expression. + - These interval-based criteria are meant to be combined later (typically with AND) and + applied to individual criteria rather than an entire combination. """ - if len(tfes) != 1: - raise ValueError(f"Expected exactly 1 timeFromEvent, got {len(tfes)}") + interval_criteria = [] + + for ext in relative_time: + + converter = self.relative_time_converters.get(ext) - tfe = tfes[0] + interval_criterion = converter.to_interval_criterion() - converter = self.time_from_event_converters.get(tfe) + if not isinstance(interval_criterion, logic.BaseExpr): + raise ValueError( + f"Expected instance of BaseExpr, got {type(interval_criterion)}" + ) + + interval_criteria.append(interval_criterion) - new_combo = converter.to_temporal_combination(combo) + return interval_criteria + + def parse_timing( + self, characteristic: EvidenceVariableCharacteristic, expr: logic.BaseExpr + ) -> logic.BaseExpr: + """ + Applies temporal constraints to a given criterion expression based on `timeFromEvent` and + the relativeTime extension elements. - if not isinstance(new_combo, logic.BooleanFunction): - raise ValueError(f"Expected BooleanFunction, got {type(new_combo)}") + This function aggregates all applicable temporal constraints associated with a characteristic, + including `timeFromEvent` elements and relative timing extensions. These constraints define + the time intervals within which the given criterion must be evaluated. - return new_combo + The extracted interval criteria are AND-combined and used to wrap **each individual criterion** + rather than the entire criterion combination. This prevents unintended constraints that could + arise if multiple criteria were required to occur simultaneously. + + Args: + characteristic (EvidenceVariableCharacteristic): + The characteristic whose timing constraints should be applied. + expr (logic.BaseExpr): + The logical expression representing the criterion to be updated. + + Returns: + logic.BaseExpr: + The criterion expression with the applied temporal constraints. + + Notes: + - If `timeFromEvent` elements are present, they are processed using `parse_time_from_event`. + - If a relative timing extension is present, it is processed separately. + - All extracted interval criteria are AND-combined and used to wrap each individual + criterion within the logical expression to ensure correct temporal evaluation. + """ + interval_criteria: list[logic.BaseExpr] = [] + + if characteristic.timeFromEvent is not None: + interval_criteria.extend( + self.parse_time_from_event(characteristic.timeFromEvent) + ) + + relative_time = get_extensions(characteristic, EXT_RELATIVE_TIME) + + if relative_time: + interval_criteria.extend(self.parse_relative_time(relative_time)) + + if interval_criteria: + interval_criterion_combo = logic.And(*interval_criteria) + + expr = wrap_criteria_with_temporal_indicator(expr, interval_criterion_combo) + + return expr + + def process_action_relative_time( + self, action_def: fhir.RecommendationPlan.Action + ) -> list[logic.BaseExpr]: + """ + Processes the relativeTime extension in the given Action and removes it from + the ActivityDefinition's timing field so that subsequent calls (e.g. process_timing) + do not re-process it. + + Args: + action_def (fhir.RecommendationPlan.Action): + The FHIR action definition that potentially contains relativeTime extensions + in its ActivityDefinition. + + Returns: + list[logic.BaseExpr]: A list of expressions derived from the relativeTime extensions, + or an empty list if none were found. + """ + if ( + action_def.activity_definition_fhir is None + or action_def.activity_definition_fhir.timingTiming is None + ): + return [] + + relative_time = pop_extensions( + action_def.activity_definition_fhir.timingTiming, EXT_RELATIVE_TIME + ) + + if not relative_time: + return [] + + return self.parse_relative_time(relative_time) def parse_characteristics(self, ev: EvidenceVariable) -> logic.BooleanFunction: """ @@ -118,7 +271,7 @@ def build_criterion( Union[Symbol, BooleanFunction]: The built criterion or criterion combination. """ - combo: logic.BooleanFunction + combo: logic.BaseExpr # If this characteristic is itself a combination if characteristic.definitionByCombination is not None: @@ -139,10 +292,7 @@ def build_criterion( combo, ) - if characteristic.timeFromEvent is not None: - combo = self.parse_time_from_event( - characteristic.timeFromEvent, combo - ) + combo = self.parse_timing(characteristic, combo) return combo @@ -196,12 +346,17 @@ def action_to_combination( for action_def in actions_def: # check if is combination of actions if action_def.nested_actions: - # todo: make sure code is used correctly and we don't have a definition? action_combination = action_to_combination( action_def.nested_actions, action_def ) actions.append(action_combination) else: + + # first process a possible relativeTime extension in timingTiming + # if there is one, a corresponding logic.Presence (=logic.TemporalMinCount(*args, threshold=1)) + # expression is constructed and later wrapped around the actual criterion + interval_criteria = self.process_action_relative_time(action_def) + action_conv = self.action_converters.get(action_def) action_conv = cast( AbstractAction, action_conv @@ -212,7 +367,14 @@ def action_to_combination( goal = cast(Goal, goal) action_conv.goals.append(goal) - actions.append(action_conv.to_expression()) + expr = action_conv.to_expression() + + if interval_criteria: + expr = wrap_criteria_with_temporal_indicator( + expr, logic.And(*interval_criteria) + ) + + actions.append(expr) action_combination_expr = self.parse_action_combination_method( parent.fhir() diff --git a/execution_engine/converter/relative_time/__init__.py b/execution_engine/converter/relative_time/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/execution_engine/converter/relative_time/abstract.py b/execution_engine/converter/relative_time/abstract.py new file mode 100644 index 00000000..0a6c96d4 --- /dev/null +++ b/execution_engine/converter/relative_time/abstract.py @@ -0,0 +1,90 @@ +from abc import abstractmethod +from typing import cast + +from fhir.resources.extension import Extension + +from execution_engine.converter.criterion import parse_value +from execution_engine.converter.temporal_indicator import TemporalIndicator +from execution_engine.fhir.util import get_coding, get_extension +from execution_engine.omop.vocabulary import AbstractVocabulary +from execution_engine.util import logic +from execution_engine.util.value import ValueNumeric + + +class RelativeTime(TemporalIndicator): + """ + extension[relativeTime] in the context of CPG-on-EBM-on-FHIR. + """ + + _event_vocabulary: type[AbstractVocabulary] + _event_code: str + _value: ValueNumeric | None + + def __init__( + self, + value: ValueNumeric | None, + ) -> None: + """ + Initialize the drug administration action. + """ + super().__init__() + self._value = value + + @classmethod + def valid(cls, fhir: Extension) -> bool: + """Checks if the given FHIR definition is a valid TemporalIndicator in the context of CPG-on-EBM-on-FHIR.""" + + assert isinstance( + fhir, Extension + ), f"Expected Extension type, got {fhir.__class__.__name__}" + + context = get_extension(fhir, "contextCode") + + if not context: + raise ValueError("Required relativeTime:contextCode not found") + + cc = get_coding(context.valueCodeableConcept) + + return cls._event_vocabulary.is_system(cc.system) and cc.code == cls._event_code + + @classmethod + def from_fhir(cls, fhir: Extension) -> "TemporalIndicator": + """ + Creates a new TemporalIndicator from a FHIR PlanDefinition. + """ + assert isinstance( + fhir, Extension + ), f"Expected Extension type, got {fhir.__class__.__name__}" + + value = None + + offset = get_extension(fhir, "offset") + + if offset: + value = cast(ValueNumeric, parse_value(offset.valueDuration)) + + return cls(value) + + @abstractmethod + def to_interval_criterion(self) -> logic.BaseExpr: + """ + Returns the criterion that returns the intervals during the enclosed criterion/combination is evaluated. + + This criterion comes from a "timeFromEvent" field in EvidenceVariable.characteristic and therefore + specifies some time window (a.k.a. interval) during which the actual characteristic is supposed to happen. + For example, the characteristic could be some kind of measurement to be performed, and the timeFromEvent + could be "post surgical". + + The interval criterion returned by the TimeFromEvent class is later AND-combined (if there are more than one + timeFromEvent requirements defined - they are always supposed to be simultaneously fulfilled, i.e. AND-combined) + - and a logic.Presence TemporalIndicator is instantiated, with the AND-combination of interval criteria, and + each single criterion contained in the characteristic to which this timeFromEvent belongs is wrapped with that + logic.Presence( *args, interval_criterion=interval_criterion). + + + Note that it is not the (potential) combination of single criteria in characteristic (if characteristic + contains a definitionByCombination element) that is wrapped with logic.Presence, but the single criteria, + because e.g. AND-combining single measurements to be performed would likely result in no positive intervals, + because measurements are not performed simultaneously. + """ + raise NotImplementedError("must be implemented by class") diff --git a/execution_engine/converter/temporal.py b/execution_engine/converter/temporal.py index f06741a5..3a654cd0 100644 --- a/execution_engine/converter/temporal.py +++ b/execution_engine/converter/temporal.py @@ -3,7 +3,7 @@ from fhir.resources.element import Element -from execution_engine.converter.time_from_event.abstract import TemporalIndicator +from execution_engine.converter.temporal_indicator import TemporalIndicator class TemporalIndicatorConverterFactory: diff --git a/execution_engine/converter/temporal_indicator.py b/execution_engine/converter/temporal_indicator.py new file mode 100644 index 00000000..9b747786 --- /dev/null +++ b/execution_engine/converter/temporal_indicator.py @@ -0,0 +1,46 @@ +from abc import ABC, abstractmethod + +from fhir.resources.element import Element + +from execution_engine.util import logic + + +class TemporalIndicator(ABC): + """ + EvidenceVariable.characteristic.timeFromEvent in the context of CPG-on-EBM-on-FHIR. + """ + + @classmethod + @abstractmethod + def valid(cls, fhir: Element) -> bool: + """Checks if the given FHIR definition is a valid TemporalIndicator in the context of CPG-on-EBM-on-FHIR.""" + raise NotImplementedError("must be implemented by class") + + @classmethod + @abstractmethod + def from_fhir(cls, fhir: Element) -> "TemporalIndicator": + """ + Creates a new TemporalIndicator from a FHIR PlanDefinition. + """ + raise NotImplementedError("must be implemented by class") + + @abstractmethod + def to_interval_criterion(self) -> logic.BaseExpr: + """ + Returns the criterion that returns the intervals during the enclosed criterion/combination is evaluated. + + This criterion (in FHIR) specifies some time window (a.k.a. interval) during which the actual criterion or + combination of criteria are supposed to happen. For example, the criterion could be some kind of measurement + to be performed, and the temporal indicator could be "post surgical". + + The interval criterion returned by this class is later AND-combined (if there are more than one + temporal requirements defined - they are always supposed to be simultaneously fulfilled, i.e. AND-combined) + - and a logic.Presence TemporalIndicator is instantiated, with the AND-combination of interval criteria, and + each single criterion contained in the characteristic to which this timeFromEvent belongs is wrapped with that + logic.Presence( *args, interval_criterion=interval_criterion). + + Note that it is not the (potential) combination of single criteria that is wrapped with logic.Presence, but the + single criteria, because e.g. AND-combining single measurements to be performed would likely result in no + positive intervals, because measurements are not performed simultaneously. + """ + raise NotImplementedError("must be implemented by class") diff --git a/execution_engine/converter/time_from_event/abstract.py b/execution_engine/converter/time_from_event/abstract.py index fe0bfa15..fcba5392 100644 --- a/execution_engine/converter/time_from_event/abstract.py +++ b/execution_engine/converter/time_from_event/abstract.py @@ -1,10 +1,10 @@ -from abc import ABC, abstractmethod +from abc import abstractmethod from typing import Callable, cast -from fhir.resources.element import Element from fhir.resources.evidencevariable import EvidenceVariableCharacteristicTimeFromEvent from execution_engine.converter.criterion import parse_value +from execution_engine.converter.temporal_indicator import TemporalIndicator from execution_engine.fhir.util import get_coding from execution_engine.omop.criterion.abstract import Criterion from execution_engine.omop.vocabulary import AbstractVocabulary @@ -35,33 +35,6 @@ def _wrap_criteria_with_factory( return combo.__class__(children) -class TemporalIndicator(ABC): - """ - EvidenceVariable.characteristic.timeFromEvent in the context of CPG-on-EBM-on-FHIR. - """ - - @classmethod - @abstractmethod - def from_fhir(cls, fhir: Element) -> "TemporalIndicator": - """ - Creates a new TemporalIndicator from a FHIR PlanDefinition. - """ - raise NotImplementedError("must be implemented by class") - - @classmethod - @abstractmethod - def valid(cls, fhir: Element) -> bool: - """Checks if the given FHIR definition is a valid TemporalIndicator in the context of CPG-on-EBM-on-FHIR.""" - raise NotImplementedError("must be implemented by class") - - @abstractmethod - def to_temporal_combination(self, expr: logic.BaseExpr) -> logic.Expr: - """ - Wraps Criterion/CriterionCombination with a TemporalIndicatorCombination - """ - raise NotImplementedError("must be implemented by class") - - class TimeFromEvent(TemporalIndicator): """ EvidenceVariable.characteristic.timeFromEvent in the context of CPG-on-EBM-on-FHIR. @@ -82,7 +55,21 @@ def __init__( self._value = value @classmethod - def from_fhir(cls, fhir: Element) -> "TemporalIndicator": + def valid(cls, fhir: EvidenceVariableCharacteristicTimeFromEvent) -> bool: + """Checks if the given FHIR definition is a valid TemporalIndicator in the context of CPG-on-EBM-on-FHIR.""" + + assert isinstance( + fhir, EvidenceVariableCharacteristicTimeFromEvent + ), f"Expected timeFromEvent type, got {fhir.__class__.__name__}" + + cc = get_coding(fhir.eventCodeableConcept) + + return cls._event_vocabulary.is_system(cc.system) and cc.code == cls._event_code + + @classmethod + def from_fhir( + cls, fhir: EvidenceVariableCharacteristicTimeFromEvent + ) -> "TemporalIndicator": """ Creates a new TemporalIndicator from a FHIR PlanDefinition. """ @@ -90,40 +77,41 @@ def from_fhir(cls, fhir: Element) -> "TemporalIndicator": fhir, EvidenceVariableCharacteristicTimeFromEvent ), f"Expected timeFromEvent type, got {fhir.__class__.__name__}" - tfe: EvidenceVariableCharacteristicTimeFromEvent = fhir - value = None - if tfe.range is not None and tfe.quantity is not None: + if fhir.range is not None and fhir.quantity is not None: raise ValueError( "Must specify either Range or Quantity in characteristic.timeFromEvent, not both." ) - if tfe.range is not None: - value = cast(ValueNumeric, parse_value(tfe.range)) + if fhir.range is not None: + value = cast(ValueNumeric, parse_value(fhir.range)) - if tfe.quantity is not None: - value = cast(ValueNumeric, parse_value(tfe.quantity)) + if fhir.quantity is not None: + value = cast(ValueNumeric, parse_value(fhir.quantity)) return cls(value) - @classmethod - def valid(cls, fhir: Element) -> bool: - """Checks if the given FHIR definition is a valid TemporalIndicator in the context of CPG-on-EBM-on-FHIR.""" - - assert isinstance( - fhir, EvidenceVariableCharacteristicTimeFromEvent - ), f"Expected timeFromEvent type, got {fhir.__class__.__name__}" + @abstractmethod + def to_interval_criterion(self) -> logic.BaseExpr: + """ + Returns the criterion that returns the intervals during the enclosed criterion/combination is evaluated. - tfe: EvidenceVariableCharacteristicTimeFromEvent = fhir + This criterion comes from a "timeFromEvent" field in EvidenceVariable.characteristic and therefore + specifies some time window (a.k.a. interval) during which the actual characteristic is supposed to happen. + For example, the characteristic could be some kind of measurement to be performed, and the timeFromEvent + could be "post surgical". - cc = get_coding(tfe.eventCodeableConcept) + The interval criterion returned by the TimeFromEvent class is later AND-combined (if there are more than one + timeFromEvent requirements defined - they are always supposed to be simultaneously fulfilled, i.e. AND-combined) + - and a logic.Presence TemporalIndicator is instantiated, with the AND-combination of interval criteria, and + each single criterion contained in the characteristic to which this timeFromEvent belongs is wrapped with that + logic.Presence( *args, interval_criterion=interval_criterion). - return cls._event_vocabulary.is_system(cc.system) and cc.code == cls._event_code - @abstractmethod - def to_temporal_combination(self, expr: logic.BaseExpr) -> logic.Expr: - """ - Wraps expression with a TemporalIndicatorCombination + Note that it is not the (potential) combination of single criteria in characteristic (if characteristic + contains a definitionByCombination element) that is wrapped with logic.Presence, but the single criteria, + because e.g. AND-combining single measurements to be performed would likely result in no positive intervals, + because measurements are not performed simultaneously. """ raise NotImplementedError("must be implemented by class") diff --git a/execution_engine/fhir/util.py b/execution_engine/fhir/util.py index d1a32981..0c9699cb 100644 --- a/execution_engine/fhir/util.py +++ b/execution_engine/fhir/util.py @@ -1,6 +1,7 @@ from fhir.resources.codeableconcept import CodeableConcept from fhir.resources.coding import Coding from fhir.resources.element import Element +from fhir.resources.extension import Extension def get_coding(cc: CodeableConcept, system_uri: str | None = None) -> Coding: @@ -20,15 +21,71 @@ def get_coding(cc: CodeableConcept, system_uri: str | None = None) -> Coding: return coding[0] -def get_extension(base: Element, extension_url: str) -> Element | None: +def get_extensions(base: Element, extension_url: str) -> list[Extension]: """ - Get the extension with the given URL from the given element. + Retrieves all extensions with the given URL from the specified element. + + This function returns a list of extensions that match the given URL, + allowing for multiple occurrences. + + Args: + base (Element): The element containing extensions. + extension_url (str): The URL of the extensions to retrieve. + + Returns: + list[Extension]: A list of matching extensions (empty if none are found). """ if base.extension is None: - return None + return [] + + return [ext for ext in base.extension if ext.url == extension_url] + + +def pop_extensions(base: Element, extension_url: str) -> list[Extension]: + """ + Retrieves and removes all extensions with the given URL from the specified element, + returning them as a list in the order they appeared. + + Args: + base (Element): The element containing extensions. + extension_url (str): The URL of the extensions to retrieve and remove. + + Returns: + list[Extension]: A list of matching extensions (empty if none are found). + """ + if not base.extension: + return [] + + matches = [ext for ext in base.extension if ext.url == extension_url] + keepers = [ext for ext in base.extension if ext.url != extension_url] + + base.extension = keepers - for ext in base.extension: - if ext.url == extension_url: - return ext + return matches + + +def get_extension(base: Element, extension_url: str) -> Extension | None: + """ + Retrieves a single extension with the given URL from the specified element. + + If multiple extensions with the same URL exist, an error is raised to ensure + that only unique extensions are retrieved. + + Args: + base (Element): The element containing extensions. + extension_url (str): The URL of the extension to retrieve. + + Returns: + Extension | None: The matching extension if found, otherwise None. + + Raises: + ValueError: If multiple extensions with the same URL exist. + """ + matching_extensions = get_extensions(base, extension_url) + + if len(matching_extensions) > 1: + raise ValueError( + f"Multiple extensions found with URL '{extension_url}', but only one was expected." + ) - return None + return matching_extensions[0] if matching_extensions else None diff --git a/execution_engine/omop/criterion/device_exposure.py b/execution_engine/omop/criterion/device_exposure.py new file mode 100644 index 00000000..7bc55222 --- /dev/null +++ b/execution_engine/omop/criterion/device_exposure.py @@ -0,0 +1,7 @@ +__all__ = ["DeviceExposure"] + +from execution_engine.omop.criterion.continuous import ContinuousCriterion + + +class DeviceExposure(ContinuousCriterion): + """A device_exposure criterion in a recommendation.""" diff --git a/tests/execution_engine/converter/test_converter.py b/tests/execution_engine/converter/test_converter.py index 88820367..fc2ed131 100644 --- a/tests/execution_engine/converter/test_converter.py +++ b/tests/execution_engine/converter/test_converter.py @@ -146,14 +146,15 @@ class TestCriterionConverter: # Continuing the test classes for CriterionConverter and CriterionConverterFactory class MockCriterionConverter(CriterionConverter): - @classmethod - def from_fhir(cls, fhir_definition: Element) -> "CriterionConverter": - return cls(exclude=False) @classmethod def valid(cls, fhir_definition: Element) -> bool: return fhir_definition.id == "valid" + @classmethod + def from_fhir(cls, fhir_definition: Element) -> "CriterionConverter": + return cls(exclude=False) + def to_positive_expression(self) -> logic.Symbol: raise NotImplementedError() From 6d3bdc6dbc3a12ca927fec95abb5dc8acce7ba15 Mon Sep 17 00:00:00 2001 From: Gregor Lichtner Date: Fri, 21 Mar 2025 10:01:50 +0100 Subject: [PATCH 14/16] feat: implement relativeTime extension --- .../converter}/parser/util.py | 11 +++++++++-- execution_engine/omop/cohort/graph_builder.py | 14 +++++++++++++- execution_engine/omop/vocabulary.py | 1 + execution_engine/util/logic.py | 6 ++++++ .../combination/test_temporal_combination.py | 3 +-- 5 files changed, 30 insertions(+), 5 deletions(-) rename {converter => execution_engine/converter}/parser/util.py (89%) diff --git a/converter/parser/util.py b/execution_engine/converter/parser/util.py similarity index 89% rename from converter/parser/util.py rename to execution_engine/converter/parser/util.py index 0dcac482..c9fa4113 100644 --- a/converter/parser/util.py +++ b/execution_engine/converter/parser/util.py @@ -3,6 +3,7 @@ from execution_engine.omop.criterion.abstract import Criterion from execution_engine.omop.criterion.procedure_occurrence import ProcedureOccurrence +from execution_engine.omop.vocabulary import OMOP_SURGICAL_PROCEDURE from execution_engine.util import logic from execution_engine.util.temporal_logic_util import Presence @@ -20,8 +21,6 @@ def _wrap_criteria_with_factory( :raises ValueError: If an unexpected element type is encountered. """ - from digipod.concepts import OMOP_SURGICAL_PROCEDURE - new_expr: logic.Expr if isinstance(expr, Criterion): @@ -31,8 +30,16 @@ def _wrap_criteria_with_factory( # Create a new combination of the same type with the same operator args = [] + interval_criterion = ( + expr.interval_criterion if hasattr(expr, "interval_criterion") else None + ) + # Loop through all elements for element in expr.args: + + if element == interval_criterion: + # interval_criterion must not be wrapped! + args.append(element) if isinstance(element, logic.Expr): # Recursively wrap nested combinations args.append(_wrap_criteria_with_factory(element, factory)) diff --git a/execution_engine/omop/cohort/graph_builder.py b/execution_engine/omop/cohort/graph_builder.py index 10926d46..5f305734 100644 --- a/execution_engine/omop/cohort/graph_builder.py +++ b/execution_engine/omop/cohort/graph_builder.py @@ -37,7 +37,16 @@ def filter_symbols(cls, node: logic.Expr, filter_: logic.Expr) -> logic.Expr: if isinstance(node, logic.Symbol): return logic.LeftDependentToggle(left=filter_, right=node) elif isinstance(node, logic.Expr): - converted_args = [cls.filter_symbols(a, filter_) for a in node.args] + if hasattr(node, "interval_criterion") and node.interval_criterion: + # we must not wrap the interval_criterion + interval_criterion = node.interval_criterion + converted_args = [ + cls.filter_symbols(a, filter_) + for a in node.args + if not a == interval_criterion + ] + [interval_criterion] + else: + converted_args = [cls.filter_symbols(a, filter_) for a in node.args] if any(a is not b for a, b in zip(node.args, converted_args)): node.update_args(*converted_args) @@ -64,6 +73,9 @@ def filter_intervention_criteria_by_population(cls, expr: logic.Expr) -> logic.E def traverse( expr: logic.Expr, ) -> None: + if expr is None: + pass + if isinstance(expr, PopulationInterventionPairExpr): p, i = expr.left, expr.right diff --git a/execution_engine/omop/vocabulary.py b/execution_engine/omop/vocabulary.py index 5d0c17c8..baa7d1b3 100644 --- a/execution_engine/omop/vocabulary.py +++ b/execution_engine/omop/vocabulary.py @@ -7,6 +7,7 @@ OMOP_INTENSIVE_CARE = 32037 OMOP_INPATIENT_VISIT = 9201 OMOP_OUTPATIENT_VISIT = 9202 +OMOP_SURGICAL_PROCEDURE = 4301351 # OMOP surgical procedure class VocabularyNotFoundError(Exception): diff --git a/execution_engine/util/logic.py b/execution_engine/util/logic.py index e30da2a6..ee3dcc9f 100644 --- a/execution_engine/util/logic.py +++ b/execution_engine/util/logic.py @@ -622,6 +622,12 @@ def dict(self, include_id: bool = False) -> dict: data = super().dict(include_id=include_id) if self.interval_criterion: + + if len(self.args) <= 1: + raise ValueError( + "More than one argument required if interval_criterion is set" + ) + args, pop = self.args[:-1], self.args[-1] if pop != self.interval_criterion: diff --git a/tests/execution_engine/omop/criterion/combination/test_temporal_combination.py b/tests/execution_engine/omop/criterion/combination/test_temporal_combination.py index b5c7a4f9..4bd4d485 100644 --- a/tests/execution_engine/omop/criterion/combination/test_temporal_combination.py +++ b/tests/execution_engine/omop/criterion/combination/test_temporal_combination.py @@ -13,6 +13,7 @@ from execution_engine.omop.criterion.noop import NoopCriterion from execution_engine.omop.criterion.procedure_occurrence import ProcedureOccurrence from execution_engine.omop.db.omop import tables +from execution_engine.omop.vocabulary import OMOP_SURGICAL_PROCEDURE from execution_engine.task.process import get_processing_module from execution_engine.util import logic, temporal_logic_util from execution_engine.util.enum import TimeIntervalType @@ -43,8 +44,6 @@ process = get_processing_module() -OMOP_SURGICAL_PROCEDURE = 4301351 # OMOP surgical procedure - def intervals_to_df(result, by=None): df = intervals_to_df_orig(result, by, process.normalize_interval) From b10c40542a1d90e458b2fea96f641d6f859d4b08 Mon Sep 17 00:00:00 2001 From: Gregor Lichtner Date: Fri, 21 Mar 2025 12:00:41 +0100 Subject: [PATCH 15/16] perf: reduce calls to dict, hash --- execution_engine/omop/cohort/graph_builder.py | 11 +- .../cohort/population_intervention_pair.py | 23 +- execution_engine/util/logic.py | 293 ++++++++++++++---- execution_engine/util/serializable.py | 12 +- 4 files changed, 270 insertions(+), 69 deletions(-) diff --git a/execution_engine/omop/cohort/graph_builder.py b/execution_engine/omop/cohort/graph_builder.py index 5f305734..0d639a06 100644 --- a/execution_engine/omop/cohort/graph_builder.py +++ b/execution_engine/omop/cohort/graph_builder.py @@ -1,5 +1,7 @@ import copy +import networkx as nx + import execution_engine.util.logic as logic from execution_engine.constants import CohortCategory from execution_engine.execution_graph import ExecutionGraph @@ -73,9 +75,6 @@ def filter_intervention_criteria_by_population(cls, expr: logic.Expr) -> logic.E def traverse( expr: logic.Expr, ) -> None: - if expr is None: - pass - if isinstance(expr, PopulationInterventionPairExpr): p, i = expr.left, expr.right @@ -96,6 +95,9 @@ def traverse( traverse(expr) + # we need to rehash as the structure has been changed due to insertion of additional nodes + expr.rehash(recursive=True) + return expr @classmethod @@ -138,4 +140,7 @@ def build(cls, expr: logic.Expr, base_criterion: Criterion) -> ExecutionGraph: ) graph.add_edges_from((src, p_combination_node) for src in p_sink_nodes) + if not nx.is_directed_acyclic_graph(graph): + raise ValueError("Graph is not acyclic") + return graph diff --git a/execution_engine/omop/cohort/population_intervention_pair.py b/execution_engine/omop/cohort/population_intervention_pair.py index ad2095bb..457ec6f5 100644 --- a/execution_engine/omop/cohort/population_intervention_pair.py +++ b/execution_engine/omop/cohort/population_intervention_pair.py @@ -1,4 +1,4 @@ -from typing import cast +from typing import Callable, cast import execution_engine.util.logic as logic from execution_engine.omop.criterion.abstract import Criterion @@ -70,6 +70,27 @@ def base_criterion(self) -> Criterion: """ return self._base_criterion + def __reduce__(self) -> tuple[Callable, tuple]: + """ + Reduce the expression to its arguments and category. + + Required for pickling (e.g. when using multiprocessing). + + :return: Tuple of the class, arguments, and category. + """ + return ( + self._recreate, + ( + self.args, + { + "name": self.name, + "url": self.url, + "base_criterion": self.base_criterion, + } + | {"_id": self._id}, + ), + ) + def dict(self, include_id: bool = False) -> dict: """ Get a dictionary representation of the object. diff --git a/execution_engine/util/logic.py b/execution_engine/util/logic.py index ee3dcc9f..71f9e72d 100644 --- a/execution_engine/util/logic.py +++ b/execution_engine/util/logic.py @@ -1,6 +1,5 @@ -from abc import abstractmethod from datetime import time -from typing import Any, Dict, Iterator, Self, cast +from typing import Any, Callable, Dict, Iterator, Self, cast from execution_engine.util.enum import TimeIntervalType from execution_engine.util.serializable import Serializable, SerializableABC @@ -48,7 +47,47 @@ class Expr(BaseExpr): Class for expressions that are not Symbols """ - # todo: isn't this now a bit redundant with the BaseExpr class? (because we've removed category) + @classmethod + def _recreate(cls, args: Any, kwargs: dict) -> "Expr": + """ + Recreate an expression from its arguments and category. + """ + _id = kwargs.pop("_id") + + self = cast(Expr, cls(*args, **kwargs)) + self.set_id(_id) + return self + + def _reduce_helper(self, ivars_map: dict | None = None) -> tuple[Callable, tuple]: + """ + Return a picklable tuple that calls self._recreate and passes in + (self.args, combined_ivars). + + :param ivars_map: A dictionary for renaming keys in the instance variables. + For each old_key -> new_key in ivars_map, if old_key exists + in data, it will be removed and stored under new_key instead. + """ + data = dict(self.get_instance_variables()) + data["_id"] = self._id + + # Apply key renaming if ivars_map is provided + if ivars_map: + for old_key, new_key in ivars_map.items(): + if old_key in data: + data[new_key] = data.pop(old_key) + + return (self._recreate, (self.args, data)) + + def __reduce__(self) -> tuple[Callable, tuple]: + """ + Reduce the expression to its arguments and category. + + Required for pickling (e.g. when using multiprocessing). + + :return: Tuple of the class, arguments, and category. + """ + # return self._recreate, (self.args, self.get_instance_variables() | {"_id": self._id}) + return self._reduce_helper() def get_instance_variables(self, immutable: bool = False) -> dict | tuple: """ @@ -75,20 +114,20 @@ def __setattr__(self, name: str, value: Any) -> None: This is overridden to prevent setting attributes on the object. """ - if name in self.__dict__ and name not in ["args", "_init_args"]: + if name in self.__dict__ and name not in ["args", "_hash"]: raise AttributeError( f"Cannot update attributes on {self.__class__.__name__}" ) super().__setattr__(name, value) - @abstractmethod def update_args(self, *args: Any) -> None: """ Update the arguments of the expression. :param args: The new arguments. """ - raise NotImplementedError("update_args must be implemented by subclasses") + self.args = args + self.rehash() def __new__(cls, *args: Any, **kwargs: Any) -> "Expr": """ @@ -115,13 +154,19 @@ def __str__(self) -> str: """ return f"{self.__class__.__name__}({', '.join(map(str, self.args))})" - def __hash__(self) -> int: + def rehash(self, recursive: bool = False) -> None: """ - Get the hash of this expression. - - :return: Hash of the expression. + Recalculate the hash of the object. """ - return hash( + + if recursive: + for arg in self.args: + if isinstance(arg, Expr): + arg.rehash(recursive=True) + else: + arg.rehash() + + self._hash = hash( (self.__class__, self.args, self.get_instance_variables(immutable=True)) ) @@ -267,28 +312,12 @@ def __new__(cls, *args: Any, **kwargs: Any) -> "UnaryOperator": return cast(UnaryOperator, super().__new__(cls, *args, **kwargs)) - def update_args(self, *args: Any) -> None: - """ - Update the arguments of the expression. - - :param args: The new arguments. - """ - self.args = args - class CommutativeOperator(BooleanFunction, SerializableABC): """ Base class for commutative operators. """ - def update_args(self, *args: Any) -> None: - """ - Update the arguments of the expression. - - :param args: The new arguments. - """ - self.args = args - def __new__(cls, *args: Any, **kwargs: Any) -> "CommutativeOperator": """ Create a new CommutativeOperator object. @@ -351,17 +380,45 @@ def __new__(cls, *args: Any, **kwargs: Any) -> "Not": return cast(Not, super().__new__(cls, *args, **kwargs)) -class Count(CommutativeOperator, SerializableABC): +class CountOperator(CommutativeOperator, SerializableABC): """ - Class representing a logical COUNT operation. + Base class for count operators - Adds a "threshold" parameter of type int. - - This class should not be instantiated directly, but rather through one of its subclasses. + This is the BaseClass for Count, TemporalCount and CappedCount - while these three classes + may not define any additional code, we need them to be able to use isinstance on the different subclasses and + distinguish between subclasses from Count, TemporalCount or CappedCount """ - count_min: int | None = None - count_max: int | None = None + count_min: int | None + count_max: int | None + + def __new__( + cls, *args: Any, min_count: int | None, max_count: int | None, **kwargs: Any + ) -> "CountOperator": + """ + Create a new MinCount object. + """ + self = cast(MinCount, super().__new__(cls, *args, **kwargs)) + self.count_min = min_count + self.count_max = max_count + + return self + + def __reduce__(self) -> tuple[Callable, tuple]: + """ + Reduce the expression to its arguments and category. + + Required for pickling (e.g. when using multiprocessing). + + :return: Tuple of the class, arguments, and category. + """ + return self._reduce_helper({"count_min": "threshold", "count_max": "threshold"}) + + +class Count(CountOperator): + """ + Class representing a logical COUNT operation. + """ class MinCount(Count): @@ -373,8 +430,10 @@ def __new__(cls, *args: Any, threshold: int | None, **kwargs: Any) -> "MinCount" """ Create a new MinCount object. """ - self = cast(MinCount, super().__new__(cls, *args, **kwargs)) - self.count_min = threshold + self = cast( + MinCount, + super().__new__(cls, *args, min_count=threshold, max_count=None, **kwargs), + ) return self def dict(self, include_id: bool = False) -> dict: @@ -391,7 +450,7 @@ def __str__(self) -> str: """ Represent the expression in a readable format. """ - return f"{self.__class__.__name__}(threshold={self.count_min}; {', '.join(map(repr, self.args))})" + return f"{self.__class__.__name__}(threshold={self.count_min})" class MaxCount(Count): @@ -403,8 +462,10 @@ def __new__(cls, *args: Any, threshold: int | None, **kwargs: Any) -> "MaxCount" """ Create a new MaxCount object. """ - self = cast(MaxCount, super().__new__(cls, *args, **kwargs)) - self.count_max = threshold + self = cast( + MaxCount, + super().__new__(cls, *args, min_count=None, max_count=threshold, **kwargs), + ) return self def dict(self, include_id: bool = False) -> dict: @@ -419,7 +480,7 @@ def __str__(self) -> str: """ Represent the expression in a readable format. """ - return f"{self.__class__.__name__}(threshold={self.count_max}; {', '.join(map(repr, self.args))})" + return f"{self.__class__.__name__}(threshold={self.count_max})" class ExactCount(Count): @@ -431,9 +492,12 @@ def __new__(cls, *args: Any, threshold: int | None, **kwargs: Any) -> "ExactCoun """ Create a new ExactCount object. """ - self = cast(ExactCount, super().__new__(cls, *args, **kwargs)) - self.count_min = threshold - self.count_max = threshold + self = cast( + ExactCount, + super().__new__( + cls, *args, min_count=threshold, max_count=threshold, **kwargs + ), + ) return self def dict(self, include_id: bool = False) -> dict: @@ -448,10 +512,10 @@ def __str__(self) -> str: """ Represent the expression in a readable format. """ - return f"{self.__class__.__name__}(threshold={self.count_min}; {', '.join(map(repr, self.args))})" + return f"{self.__class__.__name__}(threshold={self.count_min})" -class CappedCount(CommutativeOperator, SerializableABC): +class CappedCount(CountOperator, SerializableABC): """ Base class representing a COUNT operation with an upper cap. @@ -467,9 +531,6 @@ class CappedCount(CommutativeOperator, SerializableABC): capped count operations like CappedMinCount. """ - count_min: int | None = None - count_max: int | None = None - class CappedMinCount(CappedCount): """ @@ -493,9 +554,10 @@ def __new__( """ Create a new CappedMinCount object. """ - self = cast(CappedMinCount, super().__new__(cls, *args, **kwargs)) - self.count_min = threshold - return self + return cast( + CappedMinCount, + super().__new__(cls, *args, min_count=threshold, max_count=None, **kwargs), + ) def dict(self, include_id: bool = False) -> dict: """ @@ -509,7 +571,7 @@ def __str__(self) -> str: """ Represent the expression in a readable format. """ - return f"{self.__class__.__name__}(threshold={self.count_min}; {', '.join(map(repr, self.args))})" + return f"{self.__class__.__name__}(threshold={self.count_min})" class AllOrNone(CommutativeOperator): @@ -518,7 +580,7 @@ class AllOrNone(CommutativeOperator): """ -class TemporalCount(CommutativeOperator, SerializableABC): +class TemporalCount(CountOperator, SerializableABC): """ Class representing a logical COUNT operation. @@ -537,7 +599,8 @@ class TemporalCount(CommutativeOperator, SerializableABC): def __new__( cls, *args: Any, - threshold: int | None, + min_count: int | None, + max_count: int | None, start_time: time | None = None, end_time: time | None = None, interval_type: TimeIntervalType | None = None, @@ -547,6 +610,7 @@ def __new__( """ Create a new TemporalCount object. """ + TemporalCount._validate_time_inputs( start_time, end_time, interval_type, interval_criterion ) @@ -556,9 +620,13 @@ def __new__( # it properly processed args += (interval_criterion,) - self = cast(Self, super().__new__(cls, *args, **kwargs)) + self = cast( + Self, + super().__new__( + cls, *args, min_count=min_count, max_count=max_count, **kwargs + ), + ) - self.count_min = threshold self.start_time = ( time.fromisoformat(start_time) # type: ignore[arg-type] if isinstance(start_time, str) @@ -641,7 +709,6 @@ def dict(self, include_id: bool = False) -> dict: data["data"].update( { - "threshold": self.count_min, "start_time": self.start_time.isoformat() if self.start_time else None, "end_time": self.end_time.isoformat() if self.end_time else None, "interval_type": self.interval_type, @@ -668,7 +735,7 @@ def __str__(self) -> str: else: interval = "None" - return f"{self.__class__.__name__}(interval={interval}; threshold={self.count_min}; {', '.join(map(repr, self.args))})" + return f"{self.__class__.__name__}(interval={interval}; threshold={self.count_min})" class TemporalMinCount(TemporalCount): @@ -676,18 +743,126 @@ class TemporalMinCount(TemporalCount): Class representing a logical temporal MIN_COUNT operation. """ + def __new__( + cls, + *args: Any, + threshold: int | None, + start_time: time | None, + end_time: time | None, + interval_type: TimeIntervalType | None, + interval_criterion: BaseExpr | None, + **kwargs: Any, + ) -> "TemporalMinCount": + """ + Create a new MinCount object. + """ + self = cast( + TemporalMinCount, + super().__new__( + cls, + *args, + min_count=threshold, + max_count=None, + start_time=start_time, + end_time=end_time, + interval_type=interval_type, + interval_criterion=interval_criterion, + ), + ) + return self + + def dict(self, include_id: bool = False) -> dict: + """ + Get a dictionary representation of the object. + """ + data = super().dict(include_id=include_id) + data.update({"threshold": self.count_min}) + return data + class TemporalMaxCount(TemporalCount): """ Class representing a logical MAX_COUNT operation. """ + def __new__( + cls, + *args: Any, + threshold: int | None, + start_time: time | None, + end_time: time | None, + interval_type: TimeIntervalType | None, + interval_criterion: BaseExpr | None, + **kwargs: Any, + ) -> "TemporalMaxCount": + """ + Create a new MaxCount object. + """ + self = cast( + TemporalMaxCount, + super().__new__( + cls, + *args, + min_count=None, + max_count=threshold, + start_time=start_time, + end_time=end_time, + interval_type=interval_type, + interval_criterion=interval_criterion, + ), + ) + return self + + def dict(self, include_id: bool = False) -> dict: + """ + Get a dictionary representation of the object. + """ + data = super().dict(include_id=include_id) + data.update({"threshold": self.count_max}) + return data + class TemporalExactCount(TemporalCount): """ Class representing a logical EXACT_COUNT operation. """ + def __new__( + cls, + *args: Any, + threshold: int | None, + start_time: time | None, + end_time: time | None, + interval_type: TimeIntervalType | None, + interval_criterion: BaseExpr | None, + **kwargs: Any, + ) -> "TemporalExactCount": + """ + Create a new ExactCount object. + """ + self = cast( + TemporalExactCount, + super().__new__( + cls, + *args, + min_count=threshold, + max_count=threshold, + start_time=start_time, + end_time=end_time, + interval_type=interval_type, + interval_criterion=interval_criterion, + ), + ) + return self + + def dict(self, include_id: bool = False) -> dict: + """ + Get a dictionary representation of the object. + """ + data = super().dict(include_id=include_id) + data.update({"threshold": self.count_min}) + return data + class NonSimplifiableAnd(CommutativeOperator): """ @@ -749,7 +924,7 @@ def update_args(self, *args: Any) -> None: raise ValueError( f"{self.__class__.__name__} requires exactly two arguments" ) - self.args = args + super().update_args(*args) def __new__( cls, left: BaseExpr, right: BaseExpr, **kwargs: Any diff --git a/execution_engine/util/serializable.py b/execution_engine/util/serializable.py index 9089aeff..c655ef7d 100644 --- a/execution_engine/util/serializable.py +++ b/execution_engine/util/serializable.py @@ -1,7 +1,7 @@ import abc import inspect import json -from typing import Any, Callable, Dict, Self, Tuple, final +from typing import Any, Dict, Self, final from pydantic import BaseModel @@ -296,11 +296,11 @@ def __hash__(self) -> int: """ return self._hash - def __reduce__(self) -> Tuple[Callable, tuple]: - """ - Support pickling of the object. - """ - return self.__class__.from_dict, (self.dict(include_id=True),) + # def __reduce__(self) -> Tuple[Callable, tuple]: + # """ + # Support pickling of the object. + # """ + # return self.__class__.from_dict, (self.dict(include_id=True),) @final def __repr__(self) -> str: From b05950d78ea8e690517c39197cc2276987b485e7 Mon Sep 17 00:00:00 2001 From: Gregor Lichtner Date: Fri, 21 Mar 2025 13:53:38 +0100 Subject: [PATCH 16/16] fix: deserialization, base criterion in graph --- execution_engine/execution_graph/graph.py | 6 +- execution_engine/omop/cohort/graph_builder.py | 7 +- execution_engine/task/creator.py | 98 +++++++++++++++++-- execution_engine/util/logic.py | 96 +++++++++++++----- execution_engine/util/serializable.py | 23 ++++- .../combination/test_temporal_combination.py | 10 +- 6 files changed, 197 insertions(+), 43 deletions(-) diff --git a/execution_engine/execution_graph/graph.py b/execution_engine/execution_graph/graph.py index 8c5d5dc3..792a034c 100644 --- a/execution_engine/execution_graph/graph.py +++ b/execution_engine/execution_graph/graph.py @@ -90,8 +90,12 @@ def traverse( ) subgraph.set_sink_nodes_store(bind_params=dict(pi_pair_id=expr.id)) + elif expr == base_node: + # don't need to do anything - only non-base criteria are connected to the base criterion, + # otherwise we get a cyclic graph + pass elif expr.is_Atom: - assert expr in graph.nodes + assert expr in graph.nodes, "Node not found in graph" graph.nodes[expr]["store_result"] = True graph.add_edge(base_node, expr) else: diff --git a/execution_engine/omop/cohort/graph_builder.py b/execution_engine/omop/cohort/graph_builder.py index 0d639a06..9329bd26 100644 --- a/execution_engine/omop/cohort/graph_builder.py +++ b/execution_engine/omop/cohort/graph_builder.py @@ -36,6 +36,8 @@ def filter_symbols(cls, node: logic.Expr, filter_: logic.Expr) -> logic.Expr: :rtype: logic.Expr """ + node = copy.copy(node) + if isinstance(node, logic.Symbol): return logic.LeftDependentToggle(left=filter_, right=node) elif isinstance(node, logic.Expr): @@ -140,7 +142,10 @@ def build(cls, expr: logic.Expr, base_criterion: Criterion) -> ExecutionGraph: ) graph.add_edges_from((src, p_combination_node) for src in p_sink_nodes) + if graph.in_degree(base_criterion) != 0: + raise AssertionError("Base criterion must not have incoming edges") + if not nx.is_directed_acyclic_graph(graph): - raise ValueError("Graph is not acyclic") + raise AssertionError("Graph is not acyclic") return graph diff --git a/execution_engine/task/creator.py b/execution_engine/task/creator.py index 87b23809..72a1f223 100644 --- a/execution_engine/task/creator.py +++ b/execution_engine/task/creator.py @@ -1,9 +1,96 @@ +import pickle # nosec + import networkx as nx +from typing_extensions import Any import execution_engine.util.logic as logic from execution_engine.task.task import Task +def assert_pickle_roundtrip(obj: logic.BaseExpr) -> None: + """ + Serializes 'obj' via pickle (the same method multiprocessing would use), + then deserializes it, and finally compares the original object to the result. + + :param obj: The object to serialize/deserialize. + :raises AssertionError: If the object does not match its clone. + :return: The deserialized clone (for further inspection if needed). + """ + pickled = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL) # nosec + clone = pickle.loads(pickled) # nosec + + if isinstance(obj, logic.CountOperator): + if obj.dict()["data"]["threshold"] is None: + raise AssertionError("Threshold must be set") + + if obj == clone: + return # They are considered equal, so nothing to do. + + # If they're unequal, compare their dict() representations to find differences. + d1 = obj.dict() + d2 = clone.dict() + + if d1 == d2: + # If they're unequal but dicts are the same, there's some internal difference + # not visible via .dict(). Just alert that we can't show details. + raise AssertionError( + f"Objects differ in __eq__, but their dict() representations are identical.\n" + f"obj: {obj}\n" + f"clone: {clone}" + ) + + # Otherwise, gather all leaf-level differences in d1 vs d2. + diffs = _compare_dicts_leaf_level(d1, d2) + diff_msg = "\n".join(diffs) + + raise AssertionError( + f"Object does not match its clone after round-trip!\n\n" + f"Differences at leaf level in .dict() representations:\n{diff_msg}" + ) + + +def _compare_dicts_leaf_level(d1: Any, d2: Any, path: str = "") -> list[str]: + """ + Recursively compare two dict/list/tuple/scalar structures and return + a list of strings describing differences at the leaf level. + + :param d1, d2: Potentially nested structures (dict, list, tuple, scalar). + :param path: Path string to locate the current point in the structure. + :return: List of difference descriptions. + """ + differences = [] + + # If both are dicts, recurse into matching keys + if isinstance(d1, dict) and isinstance(d2, dict): + all_keys = set(d1.keys()) | set(d2.keys()) + for key in sorted(all_keys): + sub_path = f"{path}.{key}" if path else str(key) + if key not in d1: + differences.append(f"[MISSING IN ORIGINAL] {sub_path} => {d2[key]!r}") + elif key not in d2: + differences.append(f"[MISSING IN CLONE] {sub_path} => {d1[key]!r}") + else: + differences.extend( + _compare_dicts_leaf_level(d1[key], d2[key], sub_path) + ) + + # If both are lists/tuples, compare element by element + elif isinstance(d1, (list, tuple)) and isinstance(d2, (list, tuple)): + if len(d1) != len(d2): + differences.append(f"[LEN MISMATCH] {path} => {len(d1)} vs {len(d2)}") + else: + for i, (item1, item2) in enumerate(zip(d1, d2)): + sub_path = f"{path}[{i}]" + differences.extend(_compare_dicts_leaf_level(item1, item2, sub_path)) + + # Otherwise, treat them as leaf values and compare directly + else: + if d1 != d2: + differences.append(f"[VALUE MISMATCH] {path} => {d1!r} vs {d2!r}") + + return differences + + class TaskCreator: """ A TaskCreator object creates a Task tree for an expression and its dependencies. @@ -43,12 +130,11 @@ def node_to_task(expr: logic.Expr, attr: dict) -> Task: flattened_tasks = list(tasks.values()) - # we will make sure all tasks are depickled correctly - for i, node in enumerate(tasks): - if logic.Expr.from_dict(node.dict(include_id=True)) != node: - raise RuntimeError( - "Expected depickled node to be the same as initial node." - ) + # we will make sure all tasks are depickled correctly [commented out for performance reasons] + # from tqdm import tqdm + # + # for node in tqdm(tasks): + # assert_pickle_roundtrip(node) assert ( len(set(flattened_tasks)) diff --git a/execution_engine/util/logic.py b/execution_engine/util/logic.py index 71f9e72d..0d11f587 100644 --- a/execution_engine/util/logic.py +++ b/execution_engine/util/logic.py @@ -58,7 +58,9 @@ def _recreate(cls, args: Any, kwargs: dict) -> "Expr": self.set_id(_id) return self - def _reduce_helper(self, ivars_map: dict | None = None) -> tuple[Callable, tuple]: + def _reduce_helper( + self, ivars_map: dict | None = None, args: tuple | None = None + ) -> tuple[Callable, tuple]: """ Return a picklable tuple that calls self._recreate and passes in (self.args, combined_ivars). @@ -66,17 +68,21 @@ def _reduce_helper(self, ivars_map: dict | None = None) -> tuple[Callable, tuple :param ivars_map: A dictionary for renaming keys in the instance variables. For each old_key -> new_key in ivars_map, if old_key exists in data, it will be removed and stored under new_key instead. + :param args: Optionally, the args parameter can be specified. If it is None, self.args is used. """ - data = dict(self.get_instance_variables()) - data["_id"] = self._id + data = self.get_instance_variables() + data["_id"] = self._id # type: ignore[index] # Apply key renaming if ivars_map is provided if ivars_map: for old_key, new_key in ivars_map.items(): if old_key in data: - data[new_key] = data.pop(old_key) + data[new_key] = data.pop(old_key) # type: ignore[index,union-attr] - return (self._recreate, (self.args, data)) + if args is None: + args = self.args + + return (self._recreate, (args, data)) def __reduce__(self) -> tuple[Callable, tuple]: """ @@ -404,15 +410,30 @@ def __new__( return self + def _replace_map(self) -> dict: + replace = {} + + if self.count_min and self.count_max: + assert self.count_min == self.count_max + replace["count_min"] = "threshold" + elif self.count_min: + replace["count_min"] = "threshold" + elif self.count_max: + replace["count_max"] = "threshold" + else: + raise AttributeError("At least one of count_min or count_max must be set") + + return replace + def __reduce__(self) -> tuple[Callable, tuple]: """ Reduce the expression to its arguments and category. Required for pickling (e.g. when using multiprocessing). - :return: Tuple of the class, arguments, and category. + :return: Tuple of the class, argument. """ - return self._reduce_helper({"count_min": "threshold", "count_max": "threshold"}) + return self._reduce_helper(self._replace_map()) class Count(CountOperator): @@ -683,6 +704,31 @@ def _validate_time_inputs( f"Invalid criterion - expected Criterion or CriterionCombination, got {type(interval_criterion)}" ) + def __reduce__(self) -> tuple[Callable, tuple]: + """ + Reduce the expression to its arguments and category. + + Required for pickling (e.g. when using multiprocessing). + + :return: Tuple of the class, argument. + """ + if self.interval_criterion: + + if len(self.args) <= 1: + raise ValueError( + "More than one argument required if interval_criterion is set" + ) + + args, pop = self.args[:-1], self.args[-1] + + if pop != self.interval_criterion: + raise ValueError( + f"Expected last argument to be the interval_criterion, got {str(pop)}" + ) + return self._reduce_helper(self._replace_map(), args=args) + + return super().__reduce__() + def dict(self, include_id: bool = False) -> dict: """ Get a dictionary representation of the object. @@ -746,11 +792,11 @@ class TemporalMinCount(TemporalCount): def __new__( cls, *args: Any, - threshold: int | None, - start_time: time | None, - end_time: time | None, - interval_type: TimeIntervalType | None, - interval_criterion: BaseExpr | None, + threshold: int, + start_time: time | None = None, + end_time: time | None = None, + interval_type: TimeIntervalType | None = None, + interval_criterion: BaseExpr | None = None, **kwargs: Any, ) -> "TemporalMinCount": """ @@ -776,7 +822,7 @@ def dict(self, include_id: bool = False) -> dict: Get a dictionary representation of the object. """ data = super().dict(include_id=include_id) - data.update({"threshold": self.count_min}) + data["data"].update({"threshold": self.count_min}) return data @@ -788,11 +834,11 @@ class TemporalMaxCount(TemporalCount): def __new__( cls, *args: Any, - threshold: int | None, - start_time: time | None, - end_time: time | None, - interval_type: TimeIntervalType | None, - interval_criterion: BaseExpr | None, + threshold: int, + start_time: time | None = None, + end_time: time | None = None, + interval_type: TimeIntervalType | None = None, + interval_criterion: BaseExpr | None = None, **kwargs: Any, ) -> "TemporalMaxCount": """ @@ -818,7 +864,7 @@ def dict(self, include_id: bool = False) -> dict: Get a dictionary representation of the object. """ data = super().dict(include_id=include_id) - data.update({"threshold": self.count_max}) + data["data"].update({"threshold": self.count_max}) return data @@ -830,11 +876,11 @@ class TemporalExactCount(TemporalCount): def __new__( cls, *args: Any, - threshold: int | None, - start_time: time | None, - end_time: time | None, - interval_type: TimeIntervalType | None, - interval_criterion: BaseExpr | None, + threshold: int, + start_time: time | None = None, + end_time: time | None = None, + interval_type: TimeIntervalType | None = None, + interval_criterion: BaseExpr | None = None, **kwargs: Any, ) -> "TemporalExactCount": """ @@ -860,7 +906,7 @@ def dict(self, include_id: bool = False) -> dict: Get a dictionary representation of the object. """ data = super().dict(include_id=include_id) - data.update({"threshold": self.count_min}) + data["data"].update({"threshold": self.count_min}) return data diff --git a/execution_engine/util/serializable.py b/execution_engine/util/serializable.py index c655ef7d..768fe13f 100644 --- a/execution_engine/util/serializable.py +++ b/execution_engine/util/serializable.py @@ -105,6 +105,24 @@ def __new__(mcs, name: str, bases: tuple[type, ...], attrs: dict[str, Any]) -> S return new_class +def immutable_setattr(self: Self, key: str, value: Any) -> None: + """ + Prevent setting attributes on an immutable object. + + This function is assigned to an instance's __setattr__ method in order to enforce + immutability after object creation. Any attempt to set an attribute on the instance + after initialization will raise an AttributeError. + + :param self: The instance on which the attribute assignment was attempted. + :param key: The name of the attribute being set. + :param value: The value being assigned. + :raises AttributeError: Always, to enforce immutability. + """ + raise AttributeError( + f"Cannot set attribute {key} on immutable object {self.__class__.__name__}" + ) + + class Serializable(metaclass=RegisteredPostInitMeta): """ Base class for making objects serializable. @@ -137,11 +155,6 @@ def __post_init__(self) -> None: """ self.rehash() - def immutable_setattr(self: Self, key: str, value: Any) -> None: - raise AttributeError( - f"Cannot set attribute {key} on immutable object {self.__class__.__name__}" - ) - self.__setattr__ = immutable_setattr # type: ignore[assignment] def set_id(self, value: int, overwrite: bool = False) -> None: diff --git a/tests/execution_engine/omop/criterion/combination/test_temporal_combination.py b/tests/execution_engine/omop/criterion/combination/test_temporal_combination.py index 4bd4d485..9190b680 100644 --- a/tests/execution_engine/omop/criterion/combination/test_temporal_combination.py +++ b/tests/execution_engine/omop/criterion/combination/test_temporal_combination.py @@ -143,11 +143,11 @@ def test_repr(self, mock_criteria): " MockCriterion(\n" " name='c1'\n" " ),\n" - " threshold=1,\n" " start_time=None,\n" " end_time=None,\n" " interval_type=TimeIntervalType.MORNING_SHIFT,\n" - " interval_criterion=None\n" + " interval_criterion=None,\n" + " threshold=1\n" ")" ) @@ -163,11 +163,11 @@ def test_repr(self, mock_criteria): " MockCriterion(\n" " name='c1'\n" " ),\n" - " threshold=1,\n" " start_time='08:00:00',\n" " end_time='16:00:00',\n" " interval_type=None,\n" - " interval_criterion=None\n" + " interval_criterion=None,\n" + " threshold=1\n" ")" ) @@ -178,7 +178,7 @@ def test_expr_contains_criteria(self, mock_criteria): ): expr = temporal_logic_util.MinCount(*mock_criteria) - expr = logic.TemporalCount(*mock_criteria, threshold=1) + expr = logic.TemporalMinCount(*mock_criteria, threshold=1) assert len(expr.args) == len(mock_criteria)