diff --git a/.flake8 b/.flake8 index 0314f154..21e3b235 100644 --- a/.flake8 +++ b/.flake8 @@ -1,5 +1,5 @@ [flake8] -ignore = E203, E266, E501, W503, F403, F401, E402, C901, F405 +ignore = E203, E266, E501, W503, F403, F401, E402, C901, F405, E731 max-line-length = 88 max-complexity = 18 select = B,C,E,F,W,T4,B9 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b4596d3c..0e450f6a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,7 +3,7 @@ default_language_version: exclude: '^docs/' repos: - repo: https://github.com/pycqa/isort - rev: 5.13.2 + rev: 6.0.1 hooks: - id: isort name: isort (python) @@ -17,13 +17,13 @@ repos: - id: check-json - id: pretty-format-json args: ['--autofix', '--no-sort-keys'] -- repo: https://github.com/ambv/black - rev: 24.10.0 +- repo: https://github.com/psf/black + rev: 25.1.0 hooks: - id: black language_version: python3.11 - repo: https://github.com/pre-commit/mirrors-mypy - rev: 'v1.13.0' + rev: 'v1.15.0' hooks: - id: mypy name: mypy @@ -36,12 +36,12 @@ repos: exclude: tests/ args: [--select, "D101,D102,D103,D105,D106"] - repo: https://github.com/PyCQA/bandit - rev: '1.8.0' + rev: '1.8.3' hooks: - id: bandit args: [--skip, "B101,B303,B110,B311"] - repo: https://github.com/PyCQA/flake8 - rev: '7.1.1' + rev: '7.1.2' hooks: - id: flake8 - repo: https://github.com/myint/autoflake diff --git a/apps/graph/static/script.js b/apps/graph/static/script.js index 7ce35453..a6995c98 100644 --- a/apps/graph/static/script.js +++ b/apps/graph/static/script.js @@ -129,15 +129,15 @@ async function loadGraph(recommendationId) { return nodeColors[ele.data('category')] || '#666'; // Assign color based on 'type', with a default }, 'shape': function(ele) { - return nodeShapes[ele.data('type')] || 'star'; // Assign color based on 'type', with a default + return nodeShapes[ele.data("is_atom") ? "Symbol" : ele.data('type')] || 'star'; // Assign color based on 'type', with a default }, 'text-valign': 'center', 'color': '#000000', 'width': function(ele) { - return ele.data('type') === 'Symbol' ? '120px': '40px'; + return ele.data('is_atom') ? '120px': '40px'; }, 'height': function(ele) { - return ele.data('type') === 'Symbol' ? '80px': '40px'; + return ele.data('is_atom') ? '80px': '40px'; }, 'font-size': '10px', 'text-wrap': 'wrap', diff --git a/apps/rest_api/app/routers/recommendation.py b/apps/rest_api/app/routers/recommendation.py index 25c7bb07..4aa48051 100644 --- a/apps/rest_api/app/routers/recommendation.py +++ b/apps/rest_api/app/routers/recommendation.py @@ -46,7 +46,7 @@ async def recommendation_criteria( data = [] - for c in recommendation.flatten(): + for c in recommendation.atoms(): data.append( { "description": c.description(), diff --git a/execution_engine/builder.py b/execution_engine/builder.py index 8251f35c..7999ced7 100644 --- a/execution_engine/builder.py +++ b/execution_engine/builder.py @@ -28,7 +28,8 @@ from execution_engine.converter.goal.ventilator_management import ( VentilatorManagementGoal, ) -from execution_engine.converter.time_from_event.abstract import TemporalIndicator +from execution_engine.converter.relative_time.abstract import RelativeTime +from execution_engine.converter.time_from_event.abstract import TimeFromEvent if TYPE_CHECKING: from execution_engine.execution_engine import ExecutionEngine @@ -42,7 +43,8 @@ class CriterionConverterType(TypedDict): characteristic: list[type[CriterionConverter]] action: list[type[CriterionConverter]] goal: list[type[CriterionConverter]] - time_from_event: list[type[TemporalIndicator]] + time_from_event: list[type[TimeFromEvent]] + relative_time: list[type[RelativeTime]] _default_converters: CriterionConverterType = { @@ -63,6 +65,7 @@ class CriterionConverterType(TypedDict): ], "goal": [LaboratoryValueGoal, VentilatorManagementGoal, AssessmentScaleGoal], "time_from_event": [], + "relative_time": [], } @@ -76,6 +79,7 @@ def default_execution_engine_builder() -> "ExecutionEngineBuilder": builder.set_action_converters(_default_converters["action"]) builder.set_goal_converters(_default_converters["goal"]) builder.set_time_from_event_converters(_default_converters["time_from_event"]) + builder.set_relative_time_converters(_default_converters["relative_time"]) return builder @@ -92,7 +96,8 @@ def __init__(self) -> None: self.characteristic_converters: list[type[CriterionConverter]] = [] self.action_converters: list[type[CriterionConverter]] = [] self.goal_converters: list[type[CriterionConverter]] = [] - self.time_from_event_converters: list[type[TemporalIndicator]] = [] + self.time_from_event_converters: list[type[TimeFromEvent]] = [] + self.relative_time_converters: list[type[RelativeTime]] = [] def set_characteristic_converters( self, converters: list[type[CriterionConverter]] @@ -128,7 +133,7 @@ def set_goal_converters( return self def set_time_from_event_converters( - self, converters: list[type[TemporalIndicator]] + self, converters: list[type[TimeFromEvent]] ) -> "ExecutionEngineBuilder": """ Sets (overwrites) the time from event converters for this builder. @@ -140,6 +145,19 @@ def set_time_from_event_converters( return self + def set_relative_time_converters( + self, converters: list[type[RelativeTime]] + ) -> "ExecutionEngineBuilder": + """ + Sets (overwrites) the time from event converters for this builder. + """ + self.relative_time_converters.clear() + + for converter_type in converters: + self.append_relative_time_converter(converter_type) + + return self + def append_characteristic_converter( self, converter_type: type[CriterionConverter] ) -> "ExecutionEngineBuilder": @@ -207,27 +225,49 @@ def prepend_goal_converter( return self def append_time_from_event_converter( - self, converter_type: type[TemporalIndicator] + self, converter_type: type[TimeFromEvent] ) -> "ExecutionEngineBuilder": """ Appends a single time_from_event converter at the end of the list. """ - if not issubclass(converter_type, TemporalIndicator): + if not issubclass(converter_type, TimeFromEvent): raise ValueError(f"Invalid TimeFromEvent converter type: {converter_type}") self.time_from_event_converters.append(converter_type) return self def prepend_time_from_event_converter( - self, converter_type: type[TemporalIndicator] + self, converter_type: type[TimeFromEvent] ) -> "ExecutionEngineBuilder": """ Inserts a single time_from_event converter at the front of the list. """ - if not issubclass(converter_type, TemporalIndicator): + if not issubclass(converter_type, TimeFromEvent): raise ValueError(f"Invalid TimeFromEvent converter type: {converter_type}") self.time_from_event_converters.insert(0, converter_type) return self + def append_relative_time_converter( + self, converter_type: type[RelativeTime] + ) -> "ExecutionEngineBuilder": + """ + Appends a single relative_time converter at the end of the list. + """ + if not issubclass(converter_type, RelativeTime): + raise ValueError(f"Invalid TimeFromEvent converter type: {converter_type}") + self.relative_time_converters.append(converter_type) + return self + + def prepend_relative_time_converter( + self, converter_type: type[RelativeTime] + ) -> "ExecutionEngineBuilder": + """ + Inserts a single relative_time converter at the front of the list. + """ + if not issubclass(converter_type, RelativeTime): + raise ValueError(f"Invalid TimeFromEvent converter type: {converter_type}") + self.relative_time_converters.insert(0, converter_type) + return self + def build(self, verbose: bool = False) -> "ExecutionEngine": """ Builds an ExecutionEngine with the specified converters. diff --git a/execution_engine/constants.py b/execution_engine/constants.py index 2deec5e9..eed6654c 100644 --- a/execution_engine/constants.py +++ b/execution_engine/constants.py @@ -26,6 +26,7 @@ EXT_DOSAGE_CONDITION = "https://www.netzwerk-universitaetsmedizin.de/fhir/cpg-on-ebm-on-fhir/StructureDefinition/ext-dosage-condition" EXT_ACTION_COMBINATION_METHOD = "https://www.netzwerk-universitaetsmedizin.de/fhir/cpg-on-ebm-on-fhir/StructureDefinition/ext-action-combination-method" EXT_CPG_PARTOF = "http://hl7.org/fhir/uv/cpg/StructureDefinition/cpg-partOf" +EXT_RELATIVE_TIME = "https://www.netzwerk-universitaetsmedizin.de/fhir/cpg-on-ebm-on-fhir/StructureDefinition/relative-time" CS_ACTION_COMBINATION_METHOD = "https://www.netzwerk-universitaetsmedizin.de/fhir/cpg-on-ebm-on-fhir/CodeSystem/cs-action-combination-method" diff --git a/execution_engine/converter/action/abstract.py b/execution_engine/converter/action/abstract.py index b9466b8b..cd91c000 100644 --- a/execution_engine/converter/action/abstract.py +++ b/execution_engine/converter/action/abstract.py @@ -3,16 +3,14 @@ from fhir.resources.timing import Timing as FHIRTiming +from execution_engine import constants from execution_engine.converter.criterion import CriterionConverter, parse_code from execution_engine.converter.goal.abstract import Goal from execution_engine.fhir.recommendation import RecommendationPlan -from execution_engine.fhir.util import get_coding +from execution_engine.fhir.util import get_coding, get_extensions from execution_engine.omop.criterion.abstract import Criterion -from execution_engine.omop.criterion.combination.logical import ( - LogicalCriterionCombination, -) from execution_engine.omop.vocabulary import AbstractVocabulary -from execution_engine.util import AbstractPrivateMethods +from execution_engine.util import AbstractPrivateMethods, logic from execution_engine.util.types import Timing from execution_engine.util.value.time import ValueCount, ValueDuration, ValuePeriod @@ -139,21 +137,29 @@ def process_timing(cls, timing: FHIRTiming) -> Timing: if rep.offset is not None: raise NotImplementedError("offset has not been implemented") + relative_time = get_extensions(timing, constants.EXT_RELATIVE_TIME) + + if relative_time: + raise NotImplementedError( + "RelativeTime processing within AbstractAction not implemented - " + "should be performed in the parser" + ) + return Timing( count=count, duration=duration, frequency=frequency, interval=interval ) @abstractmethod - def _to_criterion(self) -> Criterion | LogicalCriterionCombination | None: + def _to_expression(self) -> logic.BaseExpr | None: """Converts this action to a Criterion.""" raise NotImplementedError() @final - def to_positive_criterion(self) -> Criterion | LogicalCriterionCombination: + def to_positive_expression(self) -> logic.BaseExpr: """ Converts this action to a criterion. """ - action = self._to_criterion() + action = self._to_expression() if action is None: assert ( @@ -161,10 +167,10 @@ def to_positive_criterion(self) -> Criterion | LogicalCriterionCombination: ), "Action without explicit criterion must have at least one goal" if self.goals: - criteria = [goal.to_criterion() for goal in self.goals] + criteria = [goal.to_expression() for goal in self.goals] if action is not None: criteria.append(action) - return LogicalCriterionCombination.And(*criteria) + return logic.And(*criteria) else: return action # type: ignore diff --git a/execution_engine/converter/action/body_positioning.py b/execution_engine/converter/action/body_positioning.py index e029b1ea..0fd80ef8 100644 --- a/execution_engine/converter/action/body_positioning.py +++ b/execution_engine/converter/action/body_positioning.py @@ -4,12 +4,9 @@ from execution_engine.converter.criterion import parse_code from execution_engine.fhir.recommendation import RecommendationPlan from execution_engine.omop.concepts import Concept -from execution_engine.omop.criterion.abstract import Criterion -from execution_engine.omop.criterion.combination.logical import ( - LogicalCriterionCombination, -) from execution_engine.omop.criterion.procedure_occurrence import ProcedureOccurrence from execution_engine.omop.vocabulary import SNOMEDCT +from execution_engine.util import logic from execution_engine.util.types import Timing @@ -48,7 +45,7 @@ def from_fhir(cls, action_def: RecommendationPlan.Action) -> Self: return cls(exclude=exclude, code=code, timing=timing) - def _to_criterion(self) -> Criterion | LogicalCriterionCombination | None: + def _to_expression(self) -> logic.Symbol: """Converts this characteristic to a Criterion.""" return ProcedureOccurrence( diff --git a/execution_engine/converter/action/drug_administration.py b/execution_engine/converter/action/drug_administration.py index b25011aa..8e072b0e 100644 --- a/execution_engine/converter/action/drug_administration.py +++ b/execution_engine/converter/action/drug_administration.py @@ -15,12 +15,6 @@ ) from execution_engine.fhir.recommendation import RecommendationPlan from execution_engine.omop.concepts import Concept -from execution_engine.omop.criterion.abstract import Criterion -from execution_engine.omop.criterion.combination.combination import CriterionCombination -from execution_engine.omop.criterion.combination.logical import ( - LogicalCriterionCombination, - NonCommutativeLogicalCriterionCombination, -) from execution_engine.omop.criterion.drug_exposure import DrugExposure from execution_engine.omop.criterion.point_in_time import PointInTimeCriterion from execution_engine.omop.vocabulary import ( @@ -28,6 +22,7 @@ VocabularyFactory, standard_vocabulary, ) +from execution_engine.util import logic from execution_engine.util.types import Dosage from execution_engine.util.value import Value, ValueNumber @@ -278,12 +273,12 @@ def filter_same_unit(cls, df: pd.DataFrame, unit: Concept) -> pd.DataFrame: return df_filtered - def _to_criterion(self) -> Criterion | LogicalCriterionCombination | None: + def _to_expression(self) -> logic.BaseExpr: """ Returns a criterion that represents this action. """ - drug_actions: list[Criterion | LogicalCriterionCombination] = [] + drug_actions: list[logic.BaseExpr] = [] if not self._dosages: # no dosages, just return the drug exposure @@ -325,19 +320,18 @@ def _to_criterion(self) -> Criterion | LogicalCriterionCombination | None: # rational: "conditional" extensions are some conditions for dosage, such as body weight ranges. # Thus, the actual drug administration (drug_action, "right") must only be fulfilled if the # condition (ext_criterion, "left") is fulfilled. Thus, we here add this conditional filter. - comb = NonCommutativeLogicalCriterionCombination.ConditionalFilter( + comb = logic.ConditionalFilter( left=ext_criterion, right=drug_action, ) drug_actions.append(comb) - result: Criterion | CriterionCombination + result: logic.BaseExpr + if len(drug_actions) == 1: result = drug_actions[0] else: - result = LogicalCriterionCombination( - operator=LogicalCriterionCombination.Operator("OR"), - ) - result.add_all(drug_actions) + result = logic.Or(*drug_actions) + return result diff --git a/execution_engine/converter/action/procedure.py b/execution_engine/converter/action/procedure.py index 7fb827d5..129fad22 100644 --- a/execution_engine/converter/action/procedure.py +++ b/execution_engine/converter/action/procedure.py @@ -3,13 +3,13 @@ from execution_engine.fhir.recommendation import RecommendationPlan from execution_engine.omop.concepts import Concept from execution_engine.omop.criterion.abstract import Criterion -from execution_engine.omop.criterion.combination.logical import ( - LogicalCriterionCombination, -) +from execution_engine.omop.criterion.condition_occurrence import ConditionOccurrence +from execution_engine.omop.criterion.device_exposure import DeviceExposure from execution_engine.omop.criterion.measurement import Measurement from execution_engine.omop.criterion.observation import Observation from execution_engine.omop.criterion.procedure_occurrence import ProcedureOccurrence from execution_engine.omop.vocabulary import SNOMEDCT +from execution_engine.util import logic from execution_engine.util.types import Timing @@ -17,7 +17,7 @@ class ProcedureAction(AbstractAction): """ An ProcedureAction is an action that describes a procedure to be performed - This action tests whether the procedure has been performed by determining whether it is + This action tests whether the procedure has been performed by determining whether it is present in the respective OMOP CDM table. """ @@ -57,7 +57,7 @@ def from_fhir(cls, action_def: RecommendationPlan.Action) -> AbstractAction: return cls(exclude=exclude, code=code, timing=timing) - def _to_criterion(self) -> Criterion | LogicalCriterionCombination | None: + def _to_expression(self) -> logic.Symbol: """Converts this characteristic to a Criterion.""" criterion: Criterion @@ -73,7 +73,7 @@ def _to_criterion(self) -> Criterion | LogicalCriterionCombination | None: # as Observation and Measurement normally expect a value. criterion = Measurement( concept=self._code, - override_value_required=False, + value_required=False, timing=self._timing, ) case "Observation": @@ -81,7 +81,19 @@ def _to_criterion(self) -> Criterion | LogicalCriterionCombination | None: # as Observation and Measurement normally expect a value. criterion = Observation( concept=self._code, - override_value_required=False, + value_required=False, + timing=self._timing, + ) + case "Device": + criterion = DeviceExposure( + concept=self._code, + value_required=False, + timing=self._timing, + ) + case "Condition": + criterion = ConditionOccurrence( + concept=self._code, + value_required=False, timing=self._timing, ) case _: diff --git a/execution_engine/converter/action/ventilator_management.py b/execution_engine/converter/action/ventilator_management.py index cebc7051..1fa6585d 100644 --- a/execution_engine/converter/action/ventilator_management.py +++ b/execution_engine/converter/action/ventilator_management.py @@ -2,10 +2,6 @@ from execution_engine.converter.action.abstract import AbstractAction from execution_engine.fhir.recommendation import RecommendationPlan -from execution_engine.omop.criterion.abstract import Criterion -from execution_engine.omop.criterion.combination.logical import ( - LogicalCriterionCombination, -) from execution_engine.omop.vocabulary import SNOMEDCT @@ -29,6 +25,6 @@ def from_fhir(cls, action_def: RecommendationPlan.Action) -> Self: exclude=False, ) # fixme: no way to exclude goals (e.g. "do not ventilate") - def _to_criterion(self) -> Criterion | LogicalCriterionCombination | None: + def _to_expression(self) -> None: """Converts this characteristic to a Criterion.""" return None diff --git a/execution_engine/converter/characteristic/abstract.py b/execution_engine/converter/characteristic/abstract.py index 899522ce..dd4c5594 100644 --- a/execution_engine/converter/characteristic/abstract.py +++ b/execution_engine/converter/characteristic/abstract.py @@ -8,6 +8,7 @@ from execution_engine.omop.concepts import Concept from execution_engine.omop.criterion.abstract import Criterion from execution_engine.omop.vocabulary import standard_vocabulary +from execution_engine.util import logic as logic from execution_engine.util.value import Value @@ -30,7 +31,7 @@ class AbstractCharacteristic(CriterionConverter, ABC): Subclasses must define the following methods: - valid: returns True if the supplied characteristic falls within the scope of the subclass - from_fhir: creates a new instance of the subclass from a FHIR EvidenceVariable.characteristic element - - to_criterion(): converts the characteristic to a Criterion + - to_expression(): converts the characteristic to a Criterion """ _criterion_class: Type[Criterion] @@ -60,9 +61,9 @@ def type(self) -> Concept: return self._type @type.setter - def type(self, type: Concept) -> None: + def type(self, type_: Concept) -> None: """Sets the type of this characteristic.""" - self._type = type + self._type = type_ @property def value(self) -> Any: @@ -82,6 +83,13 @@ def valid( """Checks if the given FHIR EvidenceVariable is a valid characteristic.""" raise NotImplementedError() + @abstractmethod + def to_positive_expression(self) -> logic.BaseExpr: + """ + Converts this characteristic to a positive expression (i.e. neglecting the exlude flag). + """ + raise NotImplementedError() + @staticmethod def get_standard_concept(cc: Coding) -> Concept: """ @@ -95,13 +103,3 @@ def get_concept(cc: Coding, standard: bool = True) -> Concept: Get the OMOP Standard Vocabulary standard concept for the given code in the given vocabulary. """ return standard_vocabulary.get_concept(cc.system, cc.code, standard=standard) - - @abstractmethod - def to_positive_criterion(self) -> Criterion: - """ - Converts this characteristic to a "Positive" Criterion. - - Positive criterion means that a possible excluded flag is disregarded. Instead, the exclusion - is later introduced (in the to_criterion() method) via a LogicalCriterionCombination.Not). - """ - raise NotImplementedError() diff --git a/execution_engine/converter/characteristic/codeable_concept.py b/execution_engine/converter/characteristic/codeable_concept.py index 1dac1b58..18c96eaa 100644 --- a/execution_engine/converter/characteristic/codeable_concept.py +++ b/execution_engine/converter/characteristic/codeable_concept.py @@ -6,9 +6,9 @@ from execution_engine.converter.characteristic.abstract import AbstractCharacteristic from execution_engine.fhir.util import get_coding from execution_engine.omop.concepts import Concept -from execution_engine.omop.criterion.abstract import Criterion from execution_engine.omop.criterion.concept import ConceptCriterion from execution_engine.omop.vocabulary import AbstractVocabulary +from execution_engine.util import logic class AbstractCodeableConceptCharacteristic(AbstractCharacteristic): @@ -78,7 +78,7 @@ def from_fhir( return c - def to_positive_criterion(self) -> Criterion: + def to_positive_expression(self) -> logic.BaseExpr: """Converts this characteristic to a Criterion.""" return self._criterion_class( concept=self.value, diff --git a/execution_engine/converter/characteristic/value.py b/execution_engine/converter/characteristic/value.py index 59e8b121..51fb91a0 100644 --- a/execution_engine/converter/characteristic/value.py +++ b/execution_engine/converter/characteristic/value.py @@ -6,6 +6,7 @@ from execution_engine.converter.characteristic.abstract import AbstractCharacteristic from execution_engine.converter.criterion import parse_code, parse_value from execution_engine.omop.criterion.concept import ConceptCriterion +from execution_engine.util import logic class AbstractValueCharacteristic(AbstractCharacteristic, ABC): @@ -21,7 +22,12 @@ def from_fhir( """Creates a new Characteristic instance from a FHIR EvidenceVariable.characteristic.""" assert cls.valid(characteristic), "Invalid characteristic definition" - type_omop_concept = parse_code(characteristic.definitionByTypeAndValue.type) + try: + type_omop_concept = parse_code(characteristic.definitionByTypeAndValue.type) + except ValueError: + type_omop_concept = parse_code( + characteristic.definitionByTypeAndValue.type, standard=False + ) value = parse_value( value_parent=characteristic.definitionByTypeAndValue, value_prefix="value" ) @@ -32,7 +38,7 @@ def from_fhir( return c - def to_positive_criterion(self) -> ConceptCriterion: + def to_positive_expression(self) -> logic.Symbol: """Converts this characteristic to a Criterion.""" return self._criterion_class( concept=self.type, diff --git a/execution_engine/converter/criterion.py b/execution_engine/converter/criterion.py index b70b7b98..9cb43e57 100644 --- a/execution_engine/converter/criterion.py +++ b/execution_engine/converter/criterion.py @@ -1,6 +1,6 @@ import json from abc import ABC, abstractmethod -from typing import Tuple, Type +from typing import Tuple, Type, final from fhir.resources.codeableconcept import CodeableConcept from fhir.resources.element import Element @@ -11,11 +11,8 @@ from execution_engine.fhir.util import get_coding from execution_engine.omop.concepts import Concept -from execution_engine.omop.criterion.abstract import Criterion -from execution_engine.omop.criterion.combination.logical import ( - LogicalCriterionCombination, -) from execution_engine.omop.vocabulary import standard_vocabulary +from execution_engine.util import logic as logic from execution_engine.util.value import Value, ValueConcept, ValueNumber from execution_engine.util.value.value import ValueScalar @@ -99,10 +96,14 @@ def parse_value( eps = float(value_numeric) / 1e5 match value.comparator: case "<=" | "<": - value_max = value_numeric - (eps if value.comparator == "<" else 0) + value_max = float(value_numeric) - ( + eps if value.comparator == "<" else 0 + ) value_numeric = None case ">=" | ">": - value_min = value_numeric + (eps if value.comparator == "<" else 0) + value_min = float(value_numeric) + ( + eps if value.comparator == "<" else 0 + ) value_numeric = None case _: raise ValueError(f'Unknown quantity operator: "{value.comparator}"') @@ -182,22 +183,23 @@ def valid(cls, fhir_definition: Element) -> bool: raise NotImplementedError() @abstractmethod - def to_positive_criterion(self) -> Criterion | LogicalCriterionCombination: + def to_positive_expression(self) -> logic.BaseExpr: """Converts this characteristic to a Criterion or a combination of criteria but no negation.""" raise NotImplementedError() - def to_criterion(self) -> Criterion | LogicalCriterionCombination: + @final + def to_expression(self) -> logic.BaseExpr: """ - Converts this characteristic to a Criterion or a - combination of criteria. The result may be a "negative" - criterion, that is the result of to_positive_criterion wrapped - in a LogicalCriterionCombination with operator NOT. + Converts this characteristic to an expression. The result may be a "negative" + criterion, that is the result of to_positive_expression wrapped + in a logic.Not. """ - positive_criterion = self.to_positive_criterion() + positive_expression = self.to_positive_expression() + if self._exclude: - return LogicalCriterionCombination.Not(positive_criterion) + return logic.Not(positive_expression) else: - return positive_criterion + return positive_expression class CriterionConverterFactory: diff --git a/execution_engine/converter/goal/assessment_scale.py b/execution_engine/converter/goal/assessment_scale.py index dd8bfb92..28e05bff 100644 --- a/execution_engine/converter/goal/assessment_scale.py +++ b/execution_engine/converter/goal/assessment_scale.py @@ -4,9 +4,9 @@ from execution_engine.converter.criterion import parse_code_value from execution_engine.converter.goal.abstract import Goal from execution_engine.omop.concepts import Concept -from execution_engine.omop.criterion.abstract import Criterion from execution_engine.omop.criterion.measurement import Measurement from execution_engine.omop.vocabulary import SNOMEDCT +from execution_engine.util import logic from execution_engine.util.value import Value @@ -47,7 +47,7 @@ def from_fhir(cls, goal: PlanDefinitionGoal) -> "AssessmentScaleGoal": return cls(code.concept_name, exclude=False, code=code, value=value) - def to_positive_criterion(self) -> Criterion: + def to_positive_expression(self) -> logic.Symbol: """ Converts the goal to a criterion. """ diff --git a/execution_engine/converter/goal/laboratory_value.py b/execution_engine/converter/goal/laboratory_value.py index a1fd8119..ec10b8f9 100644 --- a/execution_engine/converter/goal/laboratory_value.py +++ b/execution_engine/converter/goal/laboratory_value.py @@ -4,9 +4,9 @@ from execution_engine.converter.criterion import parse_code_value from execution_engine.converter.goal.abstract import Goal from execution_engine.omop.concepts import Concept -from execution_engine.omop.criterion.abstract import Criterion from execution_engine.omop.criterion.measurement import Measurement from execution_engine.omop.vocabulary import SNOMEDCT +from execution_engine.util import logic from execution_engine.util.value import Value @@ -46,7 +46,7 @@ def from_fhir(cls, goal: PlanDefinitionGoal) -> "LaboratoryValueGoal": return cls(exclude=False, code=code, value=value) - def to_positive_criterion(self) -> Criterion: + def to_positive_expression(self) -> logic.Symbol: """ Converts the goal to a criterion. """ diff --git a/execution_engine/converter/goal/ventilator_management.py b/execution_engine/converter/goal/ventilator_management.py index 25b08e11..da8630d6 100644 --- a/execution_engine/converter/goal/ventilator_management.py +++ b/execution_engine/converter/goal/ventilator_management.py @@ -6,12 +6,12 @@ from execution_engine.converter.criterion import parse_code, parse_value from execution_engine.converter.goal.abstract import Goal from execution_engine.omop.concepts import Concept -from execution_engine.omop.criterion.abstract import Criterion from execution_engine.omop.criterion.custom.tidal_volume import ( TidalVolumePerIdealBodyWeight, ) from execution_engine.omop.criterion.measurement import Measurement from execution_engine.omop.vocabulary import CODEXCELIDA, SNOMEDCT +from execution_engine.util import logic from execution_engine.util.value import Value CUSTOM_GOALS: dict[Concept, Type] = { @@ -56,7 +56,7 @@ def from_fhir(cls, goal: PlanDefinitionGoal) -> "VentilatorManagementGoal": return cls(exclude=False, code=code, value=value) - def to_positive_criterion(self) -> Criterion: + def to_positive_expression(self) -> logic.Symbol: """ Converts the goal to a criterion. """ diff --git a/execution_engine/converter/parser/base.py b/execution_engine/converter/parser/base.py index d8ac4a7e..06c9c302 100644 --- a/execution_engine/converter/parser/base.py +++ b/execution_engine/converter/parser/base.py @@ -1,11 +1,18 @@ from abc import ABC, abstractmethod +from typing import Callable, Type -from fhir.resources.evidencevariable import EvidenceVariable +from fhir.resources.evidencevariable import ( + EvidenceVariable, + EvidenceVariableCharacteristic, + EvidenceVariableCharacteristicTimeFromEvent, +) +from fhir.resources.extension import Extension +from fhir.resources.plandefinition import PlanDefinition, PlanDefinitionAction from execution_engine import fhir from execution_engine.converter.criterion import CriterionConverterFactory from execution_engine.converter.temporal import TemporalIndicatorConverterFactory -from execution_engine.omop.criterion.combination.combination import CriterionCombination +from execution_engine.util import logic as logic class FhirRecommendationParserInterface(ABC): @@ -17,17 +24,18 @@ def __init__( action_converters: CriterionConverterFactory, goal_converters: CriterionConverterFactory, time_from_event_converters: TemporalIndicatorConverterFactory, + relative_time_converters: TemporalIndicatorConverterFactory, ): self.characteristics_converters = characteristic_converters self.action_converters = action_converters self.goal_converters = goal_converters self.time_from_event_converters = time_from_event_converters + self.relative_time_converters = relative_time_converters @abstractmethod - def parse_characteristics(self, ev: EvidenceVariable) -> CriterionCombination: + def parse_characteristics(self, ev: EvidenceVariable) -> logic.BooleanFunction: """ - Parses the EvidenceVariable characteristics and returns either a single Criterion - or a LogicalCriterionCombination + Parses the EvidenceVariable characteristics and returns a BooleanFunction """ raise NotImplementedError() @@ -36,9 +44,47 @@ def parse_actions( self, actions_def: list[fhir.RecommendationPlan.Action], rec_plan: fhir.RecommendationPlan, - ) -> CriterionCombination: + ) -> logic.BooleanFunction: """ Parses the actions of a Recommendation (PlanDefinition) and returns a list of Action objects and the corresponding action selection behavior. """ raise NotImplementedError() + + @abstractmethod + def parse_action_combination_method( + self, action_parent: PlanDefinition | PlanDefinitionAction + ) -> Type[logic.BooleanFunction] | Callable: + """ + Parses the action combination method of a Recommendation (PlanDefinition) and returns the corresponding + combination method in form of a logical expression. + """ + raise NotImplementedError() + + def parse_time_from_event( + self, + tfes: list[EvidenceVariableCharacteristicTimeFromEvent], + ) -> list[logic.BaseExpr]: + """ + Parses `timeFromEvent` elements and converts them into interval-based logical criteria. + """ + raise NotImplementedError() + + @abstractmethod + def parse_relative_time( + self, + relative_time: list[Extension], + ) -> list[logic.BaseExpr]: + """ + Parses `extension[relativeTime]` elements and converts them into interval-based logical criteria. + """ + raise NotImplementedError() + + def parse_timing( + self, characteristic: EvidenceVariableCharacteristic, expr: logic.BaseExpr + ) -> logic.BaseExpr: + """ + Applies temporal constraints to a given criterion expression based on `timeFromEvent` and + the relativeTime extension elements. + """ + raise NotImplementedError() diff --git a/execution_engine/converter/parser/factory.py b/execution_engine/converter/parser/factory.py index 7a0196d4..8e0ed3be 100644 --- a/execution_engine/converter/parser/factory.py +++ b/execution_engine/converter/parser/factory.py @@ -17,6 +17,7 @@ def __init__(self, builder: ExecutionEngineBuilder | None = None): self.action_converters = CriterionConverterFactory() self.goal_converters = CriterionConverterFactory() self.time_from_event_converters = TemporalIndicatorConverterFactory() + self.relative_time_converters = TemporalIndicatorConverterFactory() if builder is None: builder = default_execution_engine_builder() @@ -33,6 +34,9 @@ def __init__(self, builder: ExecutionEngineBuilder | None = None): for temporal_converter in builder.time_from_event_converters: self.time_from_event_converters.register(temporal_converter) + for relative_time_converter in builder.relative_time_converters: + self.relative_time_converters.register(relative_time_converter) + def get_parser(self, parser_version: int) -> FhirRecommendationParserInterface: """ Return the correct FhirParser based on the version string. @@ -52,4 +56,5 @@ def get_parser(self, parser_version: int) -> FhirRecommendationParserInterface: action_converters=self.action_converters, goal_converters=self.goal_converters, time_from_event_converters=self.time_from_event_converters, + relative_time_converters=self.relative_time_converters, ) diff --git a/execution_engine/converter/parser/fhir_parser_v1.py b/execution_engine/converter/parser/fhir_parser_v1.py index 40b21d7d..ab6ab62a 100644 --- a/execution_engine/converter/parser/fhir_parser_v1.py +++ b/execution_engine/converter/parser/fhir_parser_v1.py @@ -1,44 +1,52 @@ -from typing import Union, cast +from typing import Callable, Type, cast from fhir.resources.evidencevariable import ( EvidenceVariable, EvidenceVariableCharacteristic, EvidenceVariableCharacteristicTimeFromEvent, ) +from fhir.resources.extension import Extension from fhir.resources.plandefinition import PlanDefinition, PlanDefinitionAction from execution_engine import fhir +from execution_engine.constants import EXT_RELATIVE_TIME from execution_engine.converter.action.abstract import AbstractAction from execution_engine.converter.characteristic.abstract import AbstractCharacteristic from execution_engine.converter.goal.abstract import Goal from execution_engine.converter.parser.base import FhirRecommendationParserInterface -from execution_engine.omop.criterion.abstract import AbstractCriterion, Criterion -from execution_engine.omop.criterion.combination.combination import CriterionCombination -from execution_engine.omop.criterion.combination.logical import ( - LogicalCriterionCombination, -) +from execution_engine.converter.parser.util import wrap_criteria_with_temporal_indicator +from execution_engine.fhir.util import get_extensions, pop_extensions +from execution_engine.util import logic as logic -def characteristic_code_to_criterion_combination_operator( +def characteristic_code_to_expression( code: str, threshold: int | None = None -) -> LogicalCriterionCombination.Operator: +) -> Type[logic.BooleanFunction] | Callable: """ Convert a characteristic combination code (from FHIR) to a criterion combination operator (for OMOP). """ - mapping = { - "all-of": LogicalCriterionCombination.Operator.AND, - "any-of": LogicalCriterionCombination.Operator.OR, - "at-least": LogicalCriterionCombination.Operator.AT_LEAST, - "at-most": LogicalCriterionCombination.Operator.AT_MOST, - "exactly": LogicalCriterionCombination.Operator.EXACTLY, - "all-or-none": LogicalCriterionCombination.Operator.ALL_OR_NONE, + simple_ops = { + "all-of": logic.And, + "any-of": logic.Or, + "all-or-none": logic.AllOrNone, + } + + count_ops = { + "at-least": logic.MinCount, + "at-most": logic.MaxCount, + "exactly": logic.ExactCount, + "all-or-none": logic.AllOrNone, } - if code not in mapping: - raise NotImplementedError(f"Unknown combination code: {code}") - return LogicalCriterionCombination.Operator( - operator=mapping[code], threshold=threshold - ) + if code in simple_ops: + return simple_ops[code] + + if code in count_ops: + if threshold is None: + raise ValueError(f"Threshold must be set for operator {code}") + return lambda *args, category: count_ops[code](*args, threshold=threshold) + + raise NotImplementedError(f'Combination "{code}" not implemented') class FhirRecommendationParserV1(FhirRecommendationParserInterface): @@ -55,52 +63,204 @@ class FhirRecommendationParserV1(FhirRecommendationParserInterface): def parse_time_from_event( self, tfes: list[EvidenceVariableCharacteristicTimeFromEvent], - combo: CriterionCombination, - ) -> CriterionCombination: + ) -> list[logic.BaseExpr]: """ - Parses the timeFromEvent elements and updates the CriterionCombination. + Parses `timeFromEvent` elements and converts them into interval-based logical criteria. - Root element timeFromEvent specifies the time within which the population is valid. - Non-root elements act as filters for the criteria they are attached to, meaning only criteria within the timeFromEvent timeframe will be observed to determine if the patient is part of the population. + The `timeFromEvent` elements specify time intervals that constrain when the associated + criteria must occur. These intervals serve as temporal filters, ensuring that only criteria + occurring within the defined time window contribute to determining population membership. + + This function processes each `timeFromEvent` element by converting it into an interval-based + logical criterion. These criteria are later used to enforce temporal constraints on individual + criteria within a characteristic. + + Args: + tfes (list[EvidenceVariableCharacteristicTimeFromEvent]): + A list of `timeFromEvent` elements defining temporal constraints. + + Returns: + list[logic.BaseExpr]: + A list of interval-based logical expressions representing the extracted temporal constraints. + + Raises: + ValueError: + If any `timeFromEvent` element does not yield a valid `logic.BaseExpr`. + + Notes: + - Each `timeFromEvent` element is processed using a registered converter that transforms + it into an interval-based logical expression. + - These interval-based criteria are meant to be combined later (typically with AND) and + applied to individual criteria rather than an entire combination. + """ + interval_criteria = [] + + for tfe in tfes: + converter = self.time_from_event_converters.get(tfe) + + interval_criterion = converter.to_interval_criterion() + + if not isinstance(interval_criterion, logic.BaseExpr): + raise ValueError( + f"Expected instance of BaseExpr, got {type(interval_criterion)}" + ) + + interval_criteria.append(interval_criterion) + + return interval_criteria + + def parse_relative_time( + self, + relative_time: list[Extension], + ) -> list[logic.BaseExpr]: + """ + Parses `extension[relativeTime]` elements and converts them into interval-based logical criteria. + + The `extension[relativeTime]` elements specify time intervals that constrain when the associated + criteria must occur. These intervals serve as temporal filters, ensuring that only criteria + occurring within the defined time window contribute to determining population membership. + + This function processes each `extension[relativeTime]` element by converting it into an interval-based + logical criterion. These criteria are later used to enforce temporal constraints on individual + criteria within a characteristic. + + Args: + relative_time (list[Extension]): + A list of `extension[relativeTime]` elements defining temporal constraints. + + Returns: + list[logic.BaseExpr]: + A list of interval-based logical expressions representing the extracted temporal constraints. + + Raises: + ValueError: + If any `extension[relativeTime]` element does not yield a valid `logic.BaseExpr`. + + Notes: + - Each `extension[relativeTime]` element is processed using a registered converter that transforms + it into an interval-based logical expression. + - These interval-based criteria are meant to be combined later (typically with AND) and + applied to individual criteria rather than an entire combination. + """ + interval_criteria = [] + + for ext in relative_time: + + converter = self.relative_time_converters.get(ext) + + interval_criterion = converter.to_interval_criterion() + + if not isinstance(interval_criterion, logic.BaseExpr): + raise ValueError( + f"Expected instance of BaseExpr, got {type(interval_criterion)}" + ) + + interval_criteria.append(interval_criterion) + + return interval_criteria + + def parse_timing( + self, characteristic: EvidenceVariableCharacteristic, expr: logic.BaseExpr + ) -> logic.BaseExpr: + """ + Applies temporal constraints to a given criterion expression based on `timeFromEvent` and + the relativeTime extension elements. + + This function aggregates all applicable temporal constraints associated with a characteristic, + including `timeFromEvent` elements and relative timing extensions. These constraints define + the time intervals within which the given criterion must be evaluated. + + The extracted interval criteria are AND-combined and used to wrap **each individual criterion** + rather than the entire criterion combination. This prevents unintended constraints that could + arise if multiple criteria were required to occur simultaneously. Args: - tfes (list[EvidenceVariableCharacteristicTimeFromEvent]): List of timeFromEvent elements. - combo (CriterionCombination): The criterion combination to update. + characteristic (EvidenceVariableCharacteristic): + The characteristic whose timing constraints should be applied. + expr (logic.BaseExpr): + The logical expression representing the criterion to be updated. Returns: - TemporalIndicatorCombination: Updated criterion combination. + logic.BaseExpr: + The criterion expression with the applied temporal constraints. + + Notes: + - If `timeFromEvent` elements are present, they are processed using `parse_time_from_event`. + - If a relative timing extension is present, it is processed separately. + - All extracted interval criteria are AND-combined and used to wrap each individual + criterion within the logical expression to ensure correct temporal evaluation. """ - if len(tfes) != 1: - raise ValueError(f"Expected exactly 1 timeFromEvent, got {len(tfes)}") + interval_criteria: list[logic.BaseExpr] = [] + + if characteristic.timeFromEvent is not None: + interval_criteria.extend( + self.parse_time_from_event(characteristic.timeFromEvent) + ) + + relative_time = get_extensions(characteristic, EXT_RELATIVE_TIME) + + if relative_time: + interval_criteria.extend(self.parse_relative_time(relative_time)) + + if interval_criteria: + interval_criterion_combo = logic.And(*interval_criteria) - tfe = tfes[0] + expr = wrap_criteria_with_temporal_indicator(expr, interval_criterion_combo) + + return expr + + def process_action_relative_time( + self, action_def: fhir.RecommendationPlan.Action + ) -> list[logic.BaseExpr]: + """ + Processes the relativeTime extension in the given Action and removes it from + the ActivityDefinition's timing field so that subsequent calls (e.g. process_timing) + do not re-process it. + + Args: + action_def (fhir.RecommendationPlan.Action): + The FHIR action definition that potentially contains relativeTime extensions + in its ActivityDefinition. + + Returns: + list[logic.BaseExpr]: A list of expressions derived from the relativeTime extensions, + or an empty list if none were found. + """ + if ( + action_def.activity_definition_fhir is None + or action_def.activity_definition_fhir.timingTiming is None + ): + return [] - converter = self.time_from_event_converters.get(tfe) + relative_time = pop_extensions( + action_def.activity_definition_fhir.timingTiming, EXT_RELATIVE_TIME + ) - new_combo = converter.to_temporal_combination(combo) + if not relative_time: + return [] - return new_combo + return self.parse_relative_time(relative_time) - def parse_characteristics(self, ev: EvidenceVariable) -> CriterionCombination: + def parse_characteristics(self, ev: EvidenceVariable) -> logic.BooleanFunction: """ - Parses the EvidenceVariable characteristics and returns either a single Criterion - or a CriterionCombination. + Parses the EvidenceVariable characteristics and returns either a BooleanFunction. Root element timeFromEvent specifies the time within which the population is valid. - Non-root elements act as filters for the criteria they are attached to, meaning only criteria within the timeFromEvent timeframe will be observed to determine if the patient is part of the population. + Non-root elements act as filters for the criteria they are attached to, meaning only criteria within the + timeFromEvent timeframe will be observed to determine if the patient is part of the population. Args: ev (EvidenceVariable): The evidence variable to parse. Returns: - CriterionCombination: The parsed criterion combination. + BooleanFunction: The parsed criterion combination. """ def build_criterion( characteristic: EvidenceVariableCharacteristic, is_root: bool - ) -> Union[Criterion, CriterionCombination]: + ) -> logic.BaseExpr: """ - Recursively build Criterion or CriterionCombination from a single + Recursively build Symbol or BooleanFunction from a single EvidenceVariableCharacteristic. Args: @@ -108,43 +268,41 @@ def build_criterion( is_root (bool): Indicates if the characteristic is the root element. Returns: - Union[Criterion, CriterionCombination]: The built criterion or criterion combination. + Union[Symbol, BooleanFunction]: The built criterion or criterion combination. """ - combo: CriterionCombination + combo: logic.BaseExpr # If this characteristic is itself a combination if characteristic.definitionByCombination is not None: - operator = characteristic_code_to_criterion_combination_operator( + expr = characteristic_code_to_expression( characteristic.definitionByCombination.code, threshold=None, # or parse an actual threshold if needed ) - combo = LogicalCriterionCombination( - operator=operator, - ) + + children = [] for sub_char in characteristic.definitionByCombination.characteristic: - combo.add(build_criterion(sub_char, is_root=False)) + children.append(build_criterion(sub_char, is_root=False)) + + combo = expr(*children) if characteristic.exclude: - combo = LogicalCriterionCombination.Not( + combo = logic.Not( combo, ) - if characteristic.timeFromEvent is not None: - combo = self.parse_time_from_event( - characteristic.timeFromEvent, combo - ) + combo = self.parse_timing(characteristic, combo) return combo # Else it's a single characteristic converter = self.characteristics_converters.get(characteristic) converter = cast(AbstractCharacteristic, converter) - crit = converter.to_criterion() + crit = converter.to_expression() - if not isinstance(crit, AbstractCriterion): - raise ValueError(f"Expected AbstractCriterion, got {type(crit)}") + if not isinstance(crit, logic.BaseExpr): + raise ValueError(f"Expected BaseExpr, got {type(crit)}") return crit @@ -155,17 +313,16 @@ def build_criterion( and ev.characteristic[0].definitionByCombination is not None ): combo = build_criterion(ev.characteristic[0], is_root=True) - assert isinstance(combo, CriterionCombination) + assert isinstance(combo, logic.BooleanFunction) return combo # Otherwise, gather them under an ALL_OF (AND) combination by default: - combo = LogicalCriterionCombination( - operator=LogicalCriterionCombination.Operator( - LogicalCriterionCombination.Operator.AND - ), - ) + children = [] + for c in ev.characteristic: - combo.add(build_criterion(c, is_root=True)) + children.append(build_criterion(c, is_root=True)) + + combo = logic.And(*children) return combo @@ -173,7 +330,7 @@ def parse_actions( self, actions_def: list[fhir.RecommendationPlan.Action], rec_plan: fhir.RecommendationPlan, - ) -> CriterionCombination: + ) -> logic.BooleanFunction: """ Parses the actions of a Recommendation (PlanDefinition) and returns a list of Action objects and the corresponding action selection behavior. @@ -182,19 +339,24 @@ def parse_actions( def action_to_combination( actions_def: list[fhir.RecommendationPlan.Action], parent: fhir.RecommendationPlan | fhir.RecommendationPlan.Action, - ) -> CriterionCombination: + ) -> logic.BooleanFunction: # loop through PlanDefinition.action elements and find the corresponding Action object (by action.code) - actions: list[Criterion | CriterionCombination] = [] + actions: list[logic.BaseExpr] = [] for action_def in actions_def: # check if is combination of actions if action_def.nested_actions: - # todo: make sure code is used correctly and we don't have a definition? action_combination = action_to_combination( action_def.nested_actions, action_def ) actions.append(action_combination) else: + + # first process a possible relativeTime extension in timingTiming + # if there is one, a corresponding logic.Presence (=logic.TemporalMinCount(*args, threshold=1)) + # expression is constructed and later wrapped around the actual criterion + interval_criteria = self.process_action_relative_time(action_def) + action_conv = self.action_converters.get(action_def) action_conv = cast( AbstractAction, action_conv @@ -205,23 +367,32 @@ def action_to_combination( goal = cast(Goal, goal) action_conv.goals.append(goal) - actions.append(action_conv.to_criterion()) + expr = action_conv.to_expression() + + if interval_criteria: + expr = wrap_criteria_with_temporal_indicator( + expr, logic.And(*interval_criteria) + ) - action_combination = self.parse_action_combination_method(parent.fhir()) + actions.append(expr) + + action_combination_expr = self.parse_action_combination_method( + parent.fhir() + ) for action_criterion in actions: - if not isinstance(action_criterion, (CriterionCombination, Criterion)): + if not isinstance( + action_criterion, (logic.Symbol, logic.BooleanFunction) + ): raise ValueError(f"Invalid action type: {type(action_criterion)}") - action_combination.add(action_criterion) - - return action_combination + return action_combination_expr(*actions) return action_to_combination(actions_def, rec_plan) def parse_action_combination_method( self, action_parent: PlanDefinition | PlanDefinitionAction - ) -> CriterionCombination: + ) -> Type[logic.BooleanFunction] | Callable: """ Get the correct action combination based on the action selection behavior. """ @@ -234,22 +405,22 @@ def parse_action_combination_method( selection_behavior = selection_behaviors[0] + expr: Type[logic.BooleanFunction] | Callable + match selection_behavior: case "any" | "one-or-more": - operator = LogicalCriterionCombination.Operator("OR") + expr = logic.Or case "all": - operator = LogicalCriterionCombination.Operator("AND") + expr = logic.And case "all-or-none": - operator = LogicalCriterionCombination.Operator("ALL_OR_NONE") + expr = logic.AllOrNone case "exactly-one": - operator = LogicalCriterionCombination.Operator("EXACTLY", threshold=1) + expr = lambda *args: logic.ExactCount(*args, threshold=1) case "at-most-one": - operator = LogicalCriterionCombination.Operator("AT_MOST", threshold=1) + expr = lambda *args: logic.MaxCount(*args, threshold=1) case _: raise NotImplementedError( f"Selection behavior {selection_behavior} not implemented." ) - return LogicalCriterionCombination( - operator=operator, - ) + return expr diff --git a/execution_engine/converter/parser/fhir_parser_v2.py b/execution_engine/converter/parser/fhir_parser_v2.py index 83d9c41c..86eb0b61 100644 --- a/execution_engine/converter/parser/fhir_parser_v2.py +++ b/execution_engine/converter/parser/fhir_parser_v2.py @@ -1,3 +1,5 @@ +from typing import Callable, Type + from fhir.resources.codeableconcept import CodeableConcept from fhir.resources.plandefinition import PlanDefinition, PlanDefinitionAction @@ -5,9 +7,7 @@ from execution_engine.converter.criterion import get_extension_by_url from execution_engine.converter.parser.fhir_parser_v1 import FhirRecommendationParserV1 from execution_engine.fhir.util import get_coding -from execution_engine.omop.criterion.combination.logical import ( - LogicalCriterionCombination, -) +from execution_engine.util import logic as logic class FhirRecommendationParserV2(FhirRecommendationParserV1): @@ -20,7 +20,7 @@ class FhirRecommendationParserV2(FhirRecommendationParserV1): def parse_action_combination_method( self, action_parent: PlanDefinition | PlanDefinitionAction - ) -> LogicalCriterionCombination: + ) -> Type[logic.BooleanFunction] | Callable: """ Parses the action combination method from an extension to a PlanDefinition or PlanDefinitionAction. """ @@ -41,28 +41,22 @@ def parse_action_combination_method( except ValueError: threshold = None + expr: Type[logic.BooleanFunction] | Callable + match method_code: case "all": - operator = LogicalCriterionCombination.Operator("AND") + expr = logic.And case "any": - operator = LogicalCriterionCombination.Operator("OR") + expr = logic.Or case "at-most": - operator = LogicalCriterionCombination.Operator( - "AT_MOST", threshold=threshold - ) + expr = lambda *args: logic.MaxCount(*args, threshold=threshold) case "exactly": - operator = LogicalCriterionCombination.Operator( - "EXACTLY", threshold=threshold - ) + expr = lambda *args: logic.ExactCount(*args, threshold=threshold) case "at-least": - operator = LogicalCriterionCombination.Operator( - "AT_LEAST", threshold=threshold - ) + expr = lambda *args: logic.MinCount(*args, threshold=threshold) case "one-or-more": - operator = LogicalCriterionCombination.Operator("AT_LEAST", threshold=1) + expr = lambda *args: logic.MinCount(*args, threshold=1) case _: raise ValueError(f"Invalid action combination method: {method_code}") - return LogicalCriterionCombination( - operator=operator, - ) + return expr diff --git a/execution_engine/converter/parser/util.py b/execution_engine/converter/parser/util.py new file mode 100644 index 00000000..c9fa4113 --- /dev/null +++ b/execution_engine/converter/parser/util.py @@ -0,0 +1,91 @@ +import logging +from typing import Callable, cast + +from execution_engine.omop.criterion.abstract import Criterion +from execution_engine.omop.criterion.procedure_occurrence import ProcedureOccurrence +from execution_engine.omop.vocabulary import OMOP_SURGICAL_PROCEDURE +from execution_engine.util import logic +from execution_engine.util.temporal_logic_util import Presence + + +def _wrap_criteria_with_factory( + expr: logic.BaseExpr, + factory: Callable[[logic.BaseExpr], logic.TemporalCount], +) -> logic.Expr: + """ + Recursively wraps all Criterion instances within a combination using the specified factory. + + :param expr: A single Criterion or an expression to be processed. + :param factory: A callable that takes a Criterion or expression and returns a TemporalCount. + :return: A new TemporalCount where all Criterion instances have been wrapped using the factory. + :raises ValueError: If an unexpected element type is encountered. + """ + + new_expr: logic.Expr + + if isinstance(expr, Criterion): + new_expr = factory(expr) + elif isinstance(expr, logic.Expr): + + # Create a new combination of the same type with the same operator + args = [] + + interval_criterion = ( + expr.interval_criterion if hasattr(expr, "interval_criterion") else None + ) + + # Loop through all elements + for element in expr.args: + + if element == interval_criterion: + # interval_criterion must not be wrapped! + args.append(element) + if isinstance(element, logic.Expr): + # Recursively wrap nested combinations + args.append(_wrap_criteria_with_factory(element, factory)) + elif isinstance(element, Criterion): + # Wrap individual criteria with the factory + + if ( + isinstance(element, ProcedureOccurrence) + and element.concept.concept_id == OMOP_SURGICAL_PROCEDURE + and element.concept.vocabulary_id == "SNOMED" + ): + logging.warning( + "Removing Surgical Procedure Criterion in TimeFromEvent-SurgicalOperationDate" + ) + continue + + args.append(factory(element)) + + else: + raise ValueError(f"Unexpected element type: {type(element)}") + + new_expr = expr.__class__(*args) + else: + raise ValueError(f"Unexpected element type: {type(expr)}") + + return new_expr + + +def wrap_criteria_with_temporal_indicator( + expr: logic.BaseExpr, + interval_criterion: logic.BaseExpr, +) -> logic.TemporalMinCount: + """ + Wraps all Criterion instances in a combination with a TemporalCount (with interval_criterion). + + :param expr: A single Criterion or an expression to be wrapped. + :param interval_criterion: A Criterion or expression that defines the temporal interval. + :return: A new expression where all Criterion instances are wrapped with a TemporalCount (with interval_criterion). + """ + temporal_combo_factory = lambda criterion: Presence( + criterion=criterion, interval_criterion=interval_criterion + ) + + new_combo = cast( + logic.TemporalMinCount, + _wrap_criteria_with_factory(expr, temporal_combo_factory), + ) + + return new_combo diff --git a/execution_engine/converter/recommendation_factory.py b/execution_engine/converter/recommendation_factory.py index 15f95372..64f57978 100644 --- a/execution_engine/converter/recommendation_factory.py +++ b/execution_engine/converter/recommendation_factory.py @@ -1,10 +1,18 @@ from execution_engine import fhir from execution_engine.builder import ExecutionEngineBuilder +from execution_engine.converter.parser.base import FhirRecommendationParserInterface from execution_engine.converter.parser.factory import FhirRecommendationParserFactory from execution_engine.fhir.client import FHIRClient +from execution_engine.fhir.recommendation import ( + RecommendationPlan, + RecommendationPlanCollection, +) from execution_engine.omop import cohort -from execution_engine.omop.cohort import PopulationInterventionPair +from execution_engine.omop.cohort.population_intervention_pair import ( + PopulationInterventionPairExpr, +) from execution_engine.omop.criterion.visit_occurrence import PatientsActiveDuringPeriod +from execution_engine.util import logic as logic class FhirToRecommendationFactory: @@ -57,30 +65,12 @@ def parse_recommendation_from_url( fhir_connector=fhir_client, ) - pi_pairs: list[PopulationInterventionPair] = [] - - base_criterion = PatientsActiveDuringPeriod() - - for rec_plan in rec.plans(): - pi_pair = PopulationInterventionPair( - name=rec_plan.name, - url=rec_plan.url, - base_criterion=base_criterion, - ) - - # parse population and create criteria - population_criteria = parser.parse_characteristics(rec_plan.population) - pi_pair.set_population(population_criteria) - - # parse intervention and create criteria - actions = parser.parse_actions(rec_plan.actions, rec_plan) - pi_pair.add_intervention(actions) - - pi_pairs.append(pi_pair) + # Recursively build a single expression from the nested plans/collections: + expr = self._parse_collection(rec.plans(), parser) recommendation = cohort.Recommendation( - pi_pairs, - base_criterion=base_criterion, + expr=expr, + base_criterion=PatientsActiveDuringPeriod(), url=rec.url, name=rec.name, title=rec.title, @@ -90,3 +80,46 @@ def parse_recommendation_from_url( ) return recommendation + + def _parse_collection( + self, + plan_or_collection: RecommendationPlanCollection | RecommendationPlan, + parser: FhirRecommendationParserInterface, + ) -> logic.Expr: + """ + Recursively parse a single RecommendationPlan or a nested RecommendationPlanCollection + into a logic.Expr. If it's a collection, gather each sub-item's expression and combine + them using parser.parse_action_combination_method(...). + """ + if isinstance(plan_or_collection, RecommendationPlan): + # Parse a leaf plan: build a population + interventions expression (e.g. PopulationInterventionPairExpr) + population_criteria = parser.parse_characteristics( + plan_or_collection.population + ) + intervention_criteria = parser.parse_actions( + plan_or_collection.actions, plan_or_collection + ) + return PopulationInterventionPairExpr( + population_expr=population_criteria, + intervention_expr=intervention_criteria, + name=plan_or_collection.name, + url=plan_or_collection.url, + base_criterion=PatientsActiveDuringPeriod(), + ) + + elif isinstance(plan_or_collection, RecommendationPlanCollection): + # Recursively parse all sub-items + sub_exprs = [ + self._parse_collection(sub_item, parser) + for sub_item in plan_or_collection.plans + ] + + combination_op = parser.parse_action_combination_method( + plan_or_collection.fhir + ) + + # Combine all sub-expressions with the appropriate operator + return combination_op(*sub_exprs) + + else: + raise TypeError("Unknown plan_or_collection type.") diff --git a/execution_engine/omop/criterion/combination/__init__.py b/execution_engine/converter/relative_time/__init__.py similarity index 100% rename from execution_engine/omop/criterion/combination/__init__.py rename to execution_engine/converter/relative_time/__init__.py diff --git a/execution_engine/converter/relative_time/abstract.py b/execution_engine/converter/relative_time/abstract.py new file mode 100644 index 00000000..0a6c96d4 --- /dev/null +++ b/execution_engine/converter/relative_time/abstract.py @@ -0,0 +1,90 @@ +from abc import abstractmethod +from typing import cast + +from fhir.resources.extension import Extension + +from execution_engine.converter.criterion import parse_value +from execution_engine.converter.temporal_indicator import TemporalIndicator +from execution_engine.fhir.util import get_coding, get_extension +from execution_engine.omop.vocabulary import AbstractVocabulary +from execution_engine.util import logic +from execution_engine.util.value import ValueNumeric + + +class RelativeTime(TemporalIndicator): + """ + extension[relativeTime] in the context of CPG-on-EBM-on-FHIR. + """ + + _event_vocabulary: type[AbstractVocabulary] + _event_code: str + _value: ValueNumeric | None + + def __init__( + self, + value: ValueNumeric | None, + ) -> None: + """ + Initialize the drug administration action. + """ + super().__init__() + self._value = value + + @classmethod + def valid(cls, fhir: Extension) -> bool: + """Checks if the given FHIR definition is a valid TemporalIndicator in the context of CPG-on-EBM-on-FHIR.""" + + assert isinstance( + fhir, Extension + ), f"Expected Extension type, got {fhir.__class__.__name__}" + + context = get_extension(fhir, "contextCode") + + if not context: + raise ValueError("Required relativeTime:contextCode not found") + + cc = get_coding(context.valueCodeableConcept) + + return cls._event_vocabulary.is_system(cc.system) and cc.code == cls._event_code + + @classmethod + def from_fhir(cls, fhir: Extension) -> "TemporalIndicator": + """ + Creates a new TemporalIndicator from a FHIR PlanDefinition. + """ + assert isinstance( + fhir, Extension + ), f"Expected Extension type, got {fhir.__class__.__name__}" + + value = None + + offset = get_extension(fhir, "offset") + + if offset: + value = cast(ValueNumeric, parse_value(offset.valueDuration)) + + return cls(value) + + @abstractmethod + def to_interval_criterion(self) -> logic.BaseExpr: + """ + Returns the criterion that returns the intervals during the enclosed criterion/combination is evaluated. + + This criterion comes from a "timeFromEvent" field in EvidenceVariable.characteristic and therefore + specifies some time window (a.k.a. interval) during which the actual characteristic is supposed to happen. + For example, the characteristic could be some kind of measurement to be performed, and the timeFromEvent + could be "post surgical". + + The interval criterion returned by the TimeFromEvent class is later AND-combined (if there are more than one + timeFromEvent requirements defined - they are always supposed to be simultaneously fulfilled, i.e. AND-combined) + - and a logic.Presence TemporalIndicator is instantiated, with the AND-combination of interval criteria, and + each single criterion contained in the characteristic to which this timeFromEvent belongs is wrapped with that + logic.Presence( *args, interval_criterion=interval_criterion). + + + Note that it is not the (potential) combination of single criteria in characteristic (if characteristic + contains a definitionByCombination element) that is wrapped with logic.Presence, but the single criteria, + because e.g. AND-combining single measurements to be performed would likely result in no positive intervals, + because measurements are not performed simultaneously. + """ + raise NotImplementedError("must be implemented by class") diff --git a/execution_engine/converter/temporal.py b/execution_engine/converter/temporal.py index f06741a5..3a654cd0 100644 --- a/execution_engine/converter/temporal.py +++ b/execution_engine/converter/temporal.py @@ -3,7 +3,7 @@ from fhir.resources.element import Element -from execution_engine.converter.time_from_event.abstract import TemporalIndicator +from execution_engine.converter.temporal_indicator import TemporalIndicator class TemporalIndicatorConverterFactory: diff --git a/execution_engine/converter/temporal_indicator.py b/execution_engine/converter/temporal_indicator.py new file mode 100644 index 00000000..9b747786 --- /dev/null +++ b/execution_engine/converter/temporal_indicator.py @@ -0,0 +1,46 @@ +from abc import ABC, abstractmethod + +from fhir.resources.element import Element + +from execution_engine.util import logic + + +class TemporalIndicator(ABC): + """ + EvidenceVariable.characteristic.timeFromEvent in the context of CPG-on-EBM-on-FHIR. + """ + + @classmethod + @abstractmethod + def valid(cls, fhir: Element) -> bool: + """Checks if the given FHIR definition is a valid TemporalIndicator in the context of CPG-on-EBM-on-FHIR.""" + raise NotImplementedError("must be implemented by class") + + @classmethod + @abstractmethod + def from_fhir(cls, fhir: Element) -> "TemporalIndicator": + """ + Creates a new TemporalIndicator from a FHIR PlanDefinition. + """ + raise NotImplementedError("must be implemented by class") + + @abstractmethod + def to_interval_criterion(self) -> logic.BaseExpr: + """ + Returns the criterion that returns the intervals during the enclosed criterion/combination is evaluated. + + This criterion (in FHIR) specifies some time window (a.k.a. interval) during which the actual criterion or + combination of criteria are supposed to happen. For example, the criterion could be some kind of measurement + to be performed, and the temporal indicator could be "post surgical". + + The interval criterion returned by this class is later AND-combined (if there are more than one + temporal requirements defined - they are always supposed to be simultaneously fulfilled, i.e. AND-combined) + - and a logic.Presence TemporalIndicator is instantiated, with the AND-combination of interval criteria, and + each single criterion contained in the characteristic to which this timeFromEvent belongs is wrapped with that + logic.Presence( *args, interval_criterion=interval_criterion). + + Note that it is not the (potential) combination of single criteria that is wrapped with logic.Presence, but the + single criteria, because e.g. AND-combining single measurements to be performed would likely result in no + positive intervals, because measurements are not performed simultaneously. + """ + raise NotImplementedError("must be implemented by class") diff --git a/execution_engine/converter/time_from_event/abstract.py b/execution_engine/converter/time_from_event/abstract.py index 566f4f7c..fcba5392 100644 --- a/execution_engine/converter/time_from_event/abstract.py +++ b/execution_engine/converter/time_from_event/abstract.py @@ -1,74 +1,38 @@ -from abc import ABC, abstractmethod +from abc import abstractmethod from typing import Callable, cast -from fhir.resources.element import Element from fhir.resources.evidencevariable import EvidenceVariableCharacteristicTimeFromEvent from execution_engine.converter.criterion import parse_value +from execution_engine.converter.temporal_indicator import TemporalIndicator from execution_engine.fhir.util import get_coding from execution_engine.omop.criterion.abstract import Criterion -from execution_engine.omop.criterion.combination.combination import CriterionCombination -from execution_engine.omop.criterion.combination.logical import ( - LogicalCriterionCombination, -) -from execution_engine.omop.criterion.combination.temporal import ( - TemporalIndicatorCombination, -) from execution_engine.omop.vocabulary import AbstractVocabulary +from execution_engine.util import logic from execution_engine.util.value import ValueNumeric def _wrap_criteria_with_factory( - combo: CriterionCombination, - factory: Callable[[Criterion | CriterionCombination], TemporalIndicatorCombination], -) -> CriterionCombination: + combo: logic.BooleanFunction, + factory: Callable[[logic.BaseExpr], logic.TemporalCount], +) -> logic.BooleanFunction: """ Recursively wraps all Criterion instances within a combination using the factory. """ - # Create a new combination of the same type with the same operator - new_combo = combo.__class__(operator=combo.operator) - + children = [] # Loop through all elements - for element in combo: - if isinstance(element, LogicalCriterionCombination): + for element in combo.args: + if isinstance(element, logic.BooleanFunction): # Recursively wrap nested combinations - new_combo.add(_wrap_criteria_with_factory(element, factory)) + children.append(_wrap_criteria_with_factory(element, factory)) elif isinstance(element, Criterion): # Wrap individual criteria with the factory - new_combo.add(factory(element)) + children.append(factory(element)) else: raise ValueError(f"Unexpected element type: {type(element)}") - return new_combo - - -class TemporalIndicator(ABC): - """ - EvidenceVariable.characteristic.timeFromEvent in the context of CPG-on-EBM-on-FHIR. - """ - - @classmethod - @abstractmethod - def from_fhir(cls, fhir: Element) -> "TemporalIndicator": - """ - Creates a new TemporalIndicator from a FHIR PlanDefinition. - """ - raise NotImplementedError("must be implemented by class") - - @classmethod - @abstractmethod - def valid(cls, fhir: Element) -> bool: - """Checks if the given FHIR definition is a valid TemporalIndicator in the context of CPG-on-EBM-on-FHIR.""" - raise NotImplementedError("must be implemented by class") - - @abstractmethod - def to_temporal_combination( - self, combo: Criterion | CriterionCombination - ) -> CriterionCombination: - """ - Wraps Criterion/CriterionCombinaion with a TemporalIndicatorCombination - """ - raise NotImplementedError("must be implemented by class") + # Create a new combination of the same type with the same operator + return combo.__class__(children) class TimeFromEvent(TemporalIndicator): @@ -91,7 +55,21 @@ def __init__( self._value = value @classmethod - def from_fhir(cls, fhir: Element) -> "TemporalIndicator": + def valid(cls, fhir: EvidenceVariableCharacteristicTimeFromEvent) -> bool: + """Checks if the given FHIR definition is a valid TemporalIndicator in the context of CPG-on-EBM-on-FHIR.""" + + assert isinstance( + fhir, EvidenceVariableCharacteristicTimeFromEvent + ), f"Expected timeFromEvent type, got {fhir.__class__.__name__}" + + cc = get_coding(fhir.eventCodeableConcept) + + return cls._event_vocabulary.is_system(cc.system) and cc.code == cls._event_code + + @classmethod + def from_fhir( + cls, fhir: EvidenceVariableCharacteristicTimeFromEvent + ) -> "TemporalIndicator": """ Creates a new TemporalIndicator from a FHIR PlanDefinition. """ @@ -99,42 +77,41 @@ def from_fhir(cls, fhir: Element) -> "TemporalIndicator": fhir, EvidenceVariableCharacteristicTimeFromEvent ), f"Expected timeFromEvent type, got {fhir.__class__.__name__}" - tfe: EvidenceVariableCharacteristicTimeFromEvent = fhir - value = None - if tfe.range is not None and tfe.quantity is not None: + if fhir.range is not None and fhir.quantity is not None: raise ValueError( "Must specify either Range or Quantity in characteristic.timeFromEvent, not both." ) - if tfe.range is not None: - value = cast(ValueNumeric, parse_value(tfe.range)) + if fhir.range is not None: + value = cast(ValueNumeric, parse_value(fhir.range)) - if tfe.quantity is not None: - value = cast(ValueNumeric, parse_value(tfe.quantity)) + if fhir.quantity is not None: + value = cast(ValueNumeric, parse_value(fhir.quantity)) return cls(value) - @classmethod - def valid(cls, fhir: Element) -> bool: - """Checks if the given FHIR definition is a valid TemporalIndicator in the context of CPG-on-EBM-on-FHIR.""" - - assert isinstance( - fhir, EvidenceVariableCharacteristicTimeFromEvent - ), f"Expected timeFromEvent type, got {fhir.__class__.__name__}" + @abstractmethod + def to_interval_criterion(self) -> logic.BaseExpr: + """ + Returns the criterion that returns the intervals during the enclosed criterion/combination is evaluated. - tfe: EvidenceVariableCharacteristicTimeFromEvent = fhir + This criterion comes from a "timeFromEvent" field in EvidenceVariable.characteristic and therefore + specifies some time window (a.k.a. interval) during which the actual characteristic is supposed to happen. + For example, the characteristic could be some kind of measurement to be performed, and the timeFromEvent + could be "post surgical". - cc = get_coding(tfe.eventCodeableConcept) + The interval criterion returned by the TimeFromEvent class is later AND-combined (if there are more than one + timeFromEvent requirements defined - they are always supposed to be simultaneously fulfilled, i.e. AND-combined) + - and a logic.Presence TemporalIndicator is instantiated, with the AND-combination of interval criteria, and + each single criterion contained in the characteristic to which this timeFromEvent belongs is wrapped with that + logic.Presence( *args, interval_criterion=interval_criterion). - return cls._event_vocabulary.is_system(cc.system) and cc.code == cls._event_code - @abstractmethod - def to_temporal_combination( - self, combo: Criterion | CriterionCombination - ) -> CriterionCombination: - """ - Wraps Criterion/CriterionCombinaion with a TemporalIndicatorCombination + Note that it is not the (potential) combination of single criteria in characteristic (if characteristic + contains a definitionByCombination element) that is wrapped with logic.Presence, but the single criteria, + because e.g. AND-combining single measurements to be performed would likely result in no positive intervals, + because measurements are not performed simultaneously. """ raise NotImplementedError("must be implemented by class") diff --git a/execution_engine/execution_engine.py b/execution_engine/execution_engine.py index 5af0d781..9bef1e7f 100644 --- a/execution_engine/execution_engine.py +++ b/execution_engine/execution_engine.py @@ -14,11 +14,11 @@ FhirToRecommendationFactory, ) from execution_engine.omop import cohort -from execution_engine.omop.cohort import PopulationInterventionPair +from execution_engine.omop.cohort import PopulationInterventionPairExpr from execution_engine.omop.criterion.abstract import Criterion from execution_engine.omop.db.celida import tables as result_db -from execution_engine.omop.serializable import Serializable from execution_engine.task import runner +from execution_engine.util.serializable import Serializable class ExecutionEngine: @@ -198,29 +198,7 @@ def load_recommendation_from_database( # we'll run registration again to set all ids recommendation.set_id(rec_db.recommendation_id) - assert ( - recommendation.base_criterion is not None - ), "Base Criterion must be set" - self.register_criterion(recommendation.base_criterion) - - for pi_pair in recommendation.population_intervention_pairs(): - self.register_population_intervention_pair( - pi_pair, rec_db.recommendation_id - ) - - for criterion in pi_pair.flatten(): - self.register_criterion(criterion) - - # All objects in the deserialized object graph must have an id. - assert recommendation.id is not None - assert recommendation.base_criterion is not None - assert recommendation.base_criterion.id is not None - - for pi_pair in recommendation.population_intervention_pairs(): - assert pi_pair.id is not None - - for criterion in pi_pair.flatten(): - assert criterion.id is not None + self.register_children(recommendation) return recommendation @@ -280,31 +258,7 @@ def register_recommendation(self, recommendation: cohort.Recommendation) -> None result = con.execute(query) recommendation.set_id(result.fetchone().recommendation_id) - # Register all child objects. After that, the recommendation - # and all child objects have valid ids (either restored or - # fresh). - - if recommendation.base_criterion is None: - raise ValueError("Base Criterion must be set when storing recommendation") - - self.register_criterion(recommendation.base_criterion) - - for pi_pair in recommendation.population_intervention_pairs(): - self.register_population_intervention_pair( - pi_pair, recommendation_id=recommendation.id - ) - for criterion in pi_pair.flatten(): - self.register_criterion(criterion) - - assert recommendation.id is not None - assert recommendation.base_criterion is not None - assert recommendation.base_criterion.id is not None - - for pi_pair in recommendation.population_intervention_pairs(): - assert pi_pair.id is not None - - for criterion in pi_pair.flatten(): - assert criterion.id is not None + self.register_children(recommendation) # Update the recommendation in the database with the final # JSON representation and execution graph (now that @@ -328,8 +282,40 @@ def register_recommendation(self, recommendation: cohort.Recommendation) -> None con.execute(update_query) + def register_children(self, recommendation: cohort.Recommendation) -> None: + """ + Registers all child objects of the recommendation in the result database. + + :param recommendation: The Recommendation object. + :type recommendation: cohort.Recommendation + :raises ValueError: If the base criterion is not set. + :raises AssertionError: If the recommendation or any of its child objects do not have a valid ID. + """ + if recommendation.base_criterion is None: + raise ValueError("Base Criterion must be set when storing recommendation") + + self.register_criterion(recommendation.base_criterion) + + for pi_pair in recommendation.population_intervention_pairs(): + self.register_population_intervention_pair( + pi_pair, recommendation_id=recommendation.id + ) + + for criterion in recommendation.atoms(): + self.register_criterion(criterion) + + assert recommendation.id is not None + assert recommendation.base_criterion is not None + assert recommendation.base_criterion.id is not None + + for pi_pair in recommendation.population_intervention_pairs(): + assert pi_pair.id is not None + + for criterion in recommendation.atoms(): + assert criterion.id is not None + def register_population_intervention_pair( - self, pi_pair: PopulationInterventionPair, recommendation_id: int + self, pi_pair: PopulationInterventionPairExpr, recommendation_id: int ) -> None: """ Registers the Population/Intervention Pair in the result database. diff --git a/execution_engine/execution_graph/graph.py b/execution_engine/execution_graph/graph.py index c012a767..792a034c 100644 --- a/execution_engine/execution_graph/graph.py +++ b/execution_engine/execution_graph/graph.py @@ -1,20 +1,10 @@ -from typing import Any, Callable, Type +from typing import Any, cast import networkx as nx -import execution_engine.util.cohort_logic as logic +import execution_engine.util.logic as logic from execution_engine.constants import CohortCategory from execution_engine.omop.criterion.abstract import Criterion -from execution_engine.omop.criterion.combination.combination import CriterionCombination -from execution_engine.omop.criterion.combination.logical import ( - LogicalCriterionCombination, - NonCommutativeLogicalCriterionCombination, -) -from execution_engine.omop.criterion.combination.temporal import ( - FixedWindowTemporalIndicatorCombination, - PersonalWindowTemporalIndicatorCombination, - TemporalIndicatorCombination, -) class ExecutionGraph(nx.DiGraph): @@ -28,30 +18,6 @@ def __add__(self, other: "ExecutionGraph") -> "ExecutionGraph": """ return nx.compose(self, other) - @classmethod - def from_criterion_combination( - cls, - population: CriterionCombination, - intervention: CriterionCombination, - base_criterion: Criterion, - ) -> "ExecutionGraph": - """ - Create a graph from a population and intervention criterion combination. - """ - from execution_engine.omop.cohort import PopulationInterventionPair - - p = cls.combination_to_expression(population, CohortCategory.POPULATION) - i = cls.combination_to_expression(intervention, CohortCategory.INTERVENTION) - i_filtered = PopulationInterventionPair.filter_symbols(i, p) - - pi = logic.LeftDependentToggle( - p, - i_filtered, - category=CohortCategory.POPULATION_INTERVENTION, - ) - - return cls.from_expression(pi, base_criterion) - def is_sink_of_category( self, expr: logic.Expr, graph: "ExecutionGraph", category: CohortCategory ) -> bool: @@ -75,43 +41,71 @@ def is_sink(self, expr: logic.Expr) -> bool: @classmethod def from_expression( - cls, expr: logic.Expr, base_criterion: Criterion + cls, expr: logic.Expr, base_criterion: Criterion, category: CohortCategory ) -> "ExecutionGraph": """ Create a graph from a cohort query expression. """ - def expression_to_graph( + from execution_engine.omop.cohort import PopulationInterventionPairExpr + + expr_hash = hash(expr) + + graph = cls() + base_node = base_criterion + + graph.add_node( + base_node, + category=CohortCategory.BASE, + store_result=True, + ) + + def traverse( expr: logic.Expr, - graph: ExecutionGraph, parent: logic.Expr | None = None, - ) -> ExecutionGraph: - graph.add_node(expr, category=expr.category, store_result=False) + category: CohortCategory = category, + ) -> None: - if expr.is_Atom: - graph.nodes[expr]["store_result"] = True - graph.add_edge(base_node, expr) + graph.add_node(expr, category=category, store_result=False) if parent is not None: + assert expr in graph.nodes + assert parent in graph.nodes graph.add_edge(expr, parent) - for child in expr.args: - expression_to_graph(child, graph, expr) + if isinstance(expr, PopulationInterventionPairExpr): + # special case for PopulationInterventionPairExpr: + # we need explicitly set the category of the population and intervention nodes - return graph + p, i = expr.left, expr.right - graph = cls() - base_node = logic.Symbol( - criterion=base_criterion, - category=CohortCategory.BASE, - ) - graph.add_node( - base_node, - category=CohortCategory.BASE, - store_result=True, - ) + traverse(i, parent=expr, category=CohortCategory.INTERVENTION) + traverse(p, parent=expr, category=CohortCategory.POPULATION) + + # create a subgraph for the pair in order to determine the sink nodes (i.e. the nodes that have no + # outgoing edges) for POPULATION and POPULATION_INTERVENTION and mark them for storing their result + # in the database + subgraph = cast( + ExecutionGraph, graph.subgraph(nx.ancestors(graph, expr) | {expr}) + ) + subgraph.set_sink_nodes_store(bind_params=dict(pi_pair_id=expr.id)) + + elif expr == base_node: + # don't need to do anything - only non-base criteria are connected to the base criterion, + # otherwise we get a cyclic graph + pass + elif expr.is_Atom: + assert expr in graph.nodes, "Node not found in graph" + graph.nodes[expr]["store_result"] = True + graph.add_edge(base_node, expr) + else: + for child in expr.args: + traverse(child, parent=expr, category=category) + + traverse(expr, category=category) - expression_to_graph(expr, graph=graph) + if hash(expr) != expr_hash: + raise ValueError("Expression has been modified during traversal") return graph @@ -149,7 +143,7 @@ def combine_from(cls, *graphs: "ExecutionGraph") -> "ExecutionGraph": return combined_graph - def sink_node(self, category: CohortCategory | None = None) -> logic.Expr: + def sink_nodes(self, category: CohortCategory | None = None) -> list[logic.Expr]: """ Get the sink node of the graph. @@ -165,12 +159,7 @@ def sink_node(self, category: CohortCategory | None = None) -> logic.Expr: if self.is_sink_of_category(node, self, category) ] - if len(sink_nodes) != 1: - raise ValueError( - f"There must be exactly one sink node, but there are {len(sink_nodes)}" - ) - - return sink_nodes[0] + return sink_nodes def to_cytoscape_dict(self) -> dict: """ @@ -187,7 +176,7 @@ def to_cytoscape_dict(self) -> dict: "id": id(node), "label": str(node), "class": ( - node.criterion.__class__.__name__ + node.__class__.__name__ if isinstance(node, logic.Symbol) else node.__class__.__name__ ), @@ -204,29 +193,31 @@ def to_cytoscape_dict(self) -> dict: self.nodes[node]["store_result"] ), # Convert to string if necessary "is_sink": self.is_sink(node), + "is_atom": node.is_Atom, "bind_params": self.nodes[node]["bind_params"], } } if isinstance(node, logic.Symbol): - node_data["data"]["criterion_id"] = node.criterion._id + assert isinstance( + node, Criterion + ), f"Expected Criterion, got {type(node)}" + + node_data["data"]["criterion_id"] = node.id def criterion_attr(attr: str) -> str | None: - if ( - hasattr(node.criterion, attr) - and getattr(node.criterion, attr) is not None - ): - return str(getattr(node.criterion, attr)) + if hasattr(node, attr) and getattr(node, attr) is not None: + return str(getattr(node, attr)) return None try: - if node.criterion.concept is not None: + if node.concept is not None: node_data["data"].update( { "concept": ( - node.criterion.concept.model_dump() - if node.criterion.concept is not None + node.concept.model_dump() + if node.concept is not None else None ), "value": criterion_attr("value"), @@ -251,7 +242,7 @@ def criterion_attr(attr: str) -> str | None: if self.nodes[node]["category"] == CohortCategory.BASE: node_data["data"]["base_criterion"] = str( - node.criterion + node ) # Ensure this is serializable nodes.append(node_data) @@ -295,7 +286,7 @@ def plot(self) -> None: }[self.nodes[node]["category"]] if node.is_Atom: - label = node.criterion.description() + label = node.description() else: label = node.__class__.__name__ symbols = { @@ -305,8 +296,6 @@ def plot(self) -> None: "LeftDependentToggle": "=>", "NonSimplifiableOr": "!|", "NonSimplifiableAnd": "!&", - "NoDataPreservingAnd": "NDP-&", - "NoDataPreservingOr": "NPD-|", "MinCount": "Min", "MaxCount": "Max", "ExactCount": "Exact", @@ -357,7 +346,7 @@ def set_predecessors_store( if hops_remaining <= 0: return for predecessor in graph.predecessors(expr): - if predecessor.category == expr.category: + if self.nodes[predecessor]["category"] == self.nodes[expr]["category"]: set_predecessors_store(predecessor, graph, hops_remaining - 1) if desired_category is not None: @@ -382,153 +371,6 @@ def set_predecessors_store( for sink_node in sink_nodes_of_desired_category: set_predecessors_store(sink_node, self, hops) - @staticmethod - def combination_to_expression( - comb: CriterionCombination, category: CohortCategory - ) -> logic.Expr: - """ - Convert the CriterionCombination into an expression of And, Not, Or objects (and possibly more operators). - - :param comb: The criterion combination. - :param category: The CohortCategory of the expression. - :return: The expression. - """ - - def conjunction_from_combination( - comb: CriterionCombination, - ) -> Type[logic.BooleanFunction] | Callable: - """ - Convert the criterion's operator into a logical conjunction (Not or And or Or) - """ - if isinstance(comb, LogicalCriterionCombination): - if comb.is_root(): - # This is a hack to make the root node a non-simplifiable And node - otherwise, using the - # logic.And, the root node would be simplified to the criterion if there is only one criterion. - # The problem is that we need a non-criterion sink node of the intervention and population in order - # to store the results to the database without the criterion_id (as the result of the whole - # intervention or population of this population/intervention pair). - - if ( - comb.operator.operator - != LogicalCriterionCombination.Operator.AND - ): - raise AssertionError( - f"Invalid operator {comb.operator} for root node. Expected AND." - ) - return logic.NonSimplifiableAnd - - # Handle non-commutative combinations. - if isinstance(comb, NonCommutativeLogicalCriterionCombination): - return logic.ConditionalFilter - - op = comb.operator.operator - - # Mapping of simple logical operators. - simple_ops = { - LogicalCriterionCombination.Operator.NOT: logic.Not, - LogicalCriterionCombination.Operator.AND: logic.And, - LogicalCriterionCombination.Operator.OR: logic.Or, - LogicalCriterionCombination.Operator.ALL_OR_NONE: logic.AllOrNone, - } - if op in simple_ops: - return simple_ops[op] - - # Mapping of count-based operators. - count_ops = { - LogicalCriterionCombination.Operator.AT_LEAST: logic.MinCount, - LogicalCriterionCombination.Operator.AT_MOST: logic.MaxCount, - LogicalCriterionCombination.Operator.EXACTLY: logic.ExactCount, - LogicalCriterionCombination.Operator.CAPPED_AT_LEAST: logic.CappedMinCount, - } - if op in count_ops: - if comb.operator.threshold is None: - raise ValueError( - f"Threshold must be set for operator {comb.operator.operator}" - ) - return lambda *args, category: count_ops[op]( - *args, threshold=comb.operator.threshold, category=category - ) - - raise NotImplementedError(f'Logical combination operator "{comb.operator}" not implemented') - - ################################################################################### - elif isinstance(comb, TemporalIndicatorCombination): - - interval_criterion: logic.BaseExpr | None = None - start_time = None - end_time = None - interval_type = None - - if isinstance(comb, PersonalWindowTemporalIndicatorCombination): - - if isinstance(comb.interval_criterion, CriterionCombination): - interval_criterion = _traverse(comb.interval_criterion) - elif isinstance(comb.interval_criterion, Criterion): - interval_criterion = logic.Symbol( - comb.interval_criterion, category=category - ) - else: - raise ValueError( - f"Invalid interval criterion type: {type(comb.interval_criterion)}" - ) - - elif isinstance(comb, FixedWindowTemporalIndicatorCombination): - start_time = comb.start_time - end_time = comb.end_time - interval_type = comb.interval_type - - # Ensure a threshold is set. - if comb.operator.threshold is None: - raise ValueError( - f"Threshold must be set for operator {comb.operator.operator}" - ) - - # Map the operator to the corresponding logic function. - op_map = { - TemporalIndicatorCombination.Operator.AT_LEAST: logic.TemporalMinCount, - TemporalIndicatorCombination.Operator.AT_MOST: logic.TemporalMaxCount, - TemporalIndicatorCombination.Operator.EXACTLY: logic.TemporalExactCount, - } - op_func = op_map.get(comb.operator.operator, None) - if op_func is None: - raise NotImplementedError( - f'Temporal combination operator "{str(comb.operator)}" not implemented' - ) - - return lambda *args, category: op_func( - *args, - threshold=comb.operator.threshold, - category=category, - start_time=start_time, - end_time=end_time, - interval_type=interval_type, - interval_criterion=interval_criterion, - ) - - else: - raise ValueError(f"Invalid combination type: {type(comb)}") - - def _traverse(comb: CriterionCombination) -> logic.Expr: - """ - Traverse the criterion combination and creates a collection of logical conjunctions from it. - """ - conjunction = conjunction_from_combination(comb) - components: list[logic.Expr | logic.Symbol] = [] - - for entry in comb: - if isinstance(entry, CriterionCombination): - components.append(_traverse(entry)) - elif isinstance(entry, Criterion): - components.append(logic.Symbol(entry, category=category)) - else: - raise ValueError(f"Invalid entry type: {type(entry)}") - - return conjunction(*components, category=category) - - expression = _traverse(comb) - - return expression - def __eq__(self, other: Any) -> bool: """ Check if two graphs are equal. diff --git a/execution_engine/fhir/recommendation.py b/execution_engine/fhir/recommendation.py index 98f6d650..a45e9da7 100644 --- a/execution_engine/fhir/recommendation.py +++ b/execution_engine/fhir/recommendation.py @@ -1,4 +1,5 @@ import logging +from typing import Tuple, Union, cast import fhir from fhir.resources.activitydefinition import ActivityDefinition @@ -23,7 +24,9 @@ def __init__( self.fhir = fhir_connector self._recommendation: PlanDefinition | None = None - self._recommendation_plans: list[RecommendationPlan] = [] + self._recommendation_plans: RecommendationPlanCollection = ( + RecommendationPlanCollection(fhir=None) + ) self.load(url) @@ -96,7 +99,7 @@ def description(self) -> str: return self._recommendation.description - def plans(self) -> list["RecommendationPlan"]: + def plans(self) -> "RecommendationPlanCollection": """ Return the recommendation plans of this recommendation (i.e. the individual, non-overlapping parts or steps of the recommendation). @@ -107,19 +110,86 @@ def load(self, url: str) -> None: """ Load the recommendation from the FHIR server. """ - self._recommendation = self.fetch_recommendation(url) + self._recommendation, self._recommendation_plans = self.fetch_recommendation( + url + ) + logging.info("Recommendation loaded.") + + def _build_population_intervention_pair_collection( + self, plan_def: PlanDefinition + ) -> "RecommendationPlanCollection": + """ + Build and return a RecommendationPlanCollection for the given PlanDefinition. + + If the PlanDefinition is of type "eca-rule", we create a single RecommendationPlan + and add it as the sole entry in the collection. If it is of type "workflow-definition", + we recursively process any sub-actions referencing other PlanDefinitions and add + their collections/plans as entries. + + :param plan_def: The loaded PlanDefinition resource. + :return: A RecommendationPlanCollection containing one or more RecommendationPlans (or nested collections). + """ + + cc = get_coding(plan_def.type, CS_PLAN_DEFINITION_TYPE) + + collection = RecommendationPlanCollection(fhir=plan_def) - for rec_action in self._recommendation.action: - rec_plan = RecommendationPlan( - rec_action.definitionCanonical, + if cc.code == "eca-rule": + plan = RecommendationPlan( + plan_def.url, package_version=self._package_version, fhir_connector=self.fhir, ) - self._recommendation_plans.append(rec_plan) + collection.add_plan(plan) + + elif cc.code == "workflow-definition": + # Recursively build sub-items for each referenced PlanDefinition + if plan_def.action: + for sub_action in plan_def.action: + if sub_action.definitionCanonical: + sub_plan_def = cast( + PlanDefinition, + self.fhir.fetch_resource( + "PlanDefinition", + sub_action.definitionCanonical, + self._package_version, + ), + ) + sub_item = self._build_population_intervention_pair_collection( + sub_plan_def + ) + collection.add_plan(sub_item) + else: + raise ValueError(f"Unknown PlanDefinition type: {cc.code}") + + return collection + + def fetch_recommendation( + self, canonical_url: str + ) -> Tuple[PlanDefinition, "RecommendationPlanCollection"]: + """ + Fetches the PlanDefinition, then checks if it's an 'eca-rule' or 'workflow-definition'. + If it's 'eca-rule', this is effectively a single RecommendationPlan; + if it's 'workflow-definition', it's a RecommendationPlanCollection. + + (Only changes shown below: how you might branch to a new or refactored _build_plan_or_collection.) + """ + plan_def = cast( + PlanDefinition, + self.fhir.fetch_resource( + "PlanDefinition", canonical_url, self._package_version + ), + ) + cc = get_coding(plan_def.type, CS_PLAN_DEFINITION_TYPE) - logging.info("Recommendation loaded.") + plan_collection = self._build_population_intervention_pair_collection(plan_def) + + if cc.code not in ("eca-rule", "workflow-definition"): + raise ValueError(f"Unknown recommendation type: {cc.code}") - def fetch_recommendation(self, canonical_url: str) -> PlanDefinition: + return plan_def, plan_collection + + def fetch_recommendation_DELETEME(self, canonical_url: str) -> PlanDefinition: """ Fetch the recommendation specified by the canonical URL from the FHIR server. @@ -286,8 +356,8 @@ def fetch_recommendation_plan(self, canonical_url: str) -> PlanDefinition: This method checks if the PlanDefinition resource referenced by the canonical URL is a recommendation (i.e. PlanDefinition.type = #workflow-definition). If - it is an recommendation-plan instead (i.e. PlanDefinition.type = #eca-rule), - it will fetch the PlanDefinition that is referenced by the extension[partOf]. + it is a recommendation-plan instead (i.e. PlanDefinition.type = #eca-rule), + it will recursively fetch the PlanDefinitions. :param canonical_url: Canonical URL of the recommendation :return: FHIR PlanDefinition @@ -364,3 +434,50 @@ def is_combination_definition( characteristics defined using EvidenceVariableCharacteristic.definitionByCombination. """ return characteristic.definitionByCombination is not None + + +class RecommendationPlanCollection: + """ + Represents a collection of RecommendationPlan objects or nested RecommendationPlanCollection objects. + + This class is used to aggregate multiple recommendation plans into a single structure, + enabling the recursive composition of recommendation workflows. + """ + + def __init__(self, fhir: PlanDefinition | None) -> None: + """ + Create a collection of RecommendationPlans or nested RecommendationPlanCollection objects. + + :param combination_method: A string indicating how these plans should be combined + (e.g. "AND", "OR", "SEQUENCE"). + """ + self._fhir = fhir + self._plans: list[RecommendationPlan | RecommendationPlanCollection] = [] + + @property + def fhir(self) -> PlanDefinition | None: + """ + Retrieve the FHIR PlanDefinition resource associated with this collection. + + :return: The FHIR PlanDefinition if set; otherwise, None. + """ + return self._fhir + + @property + def plans(self) -> list["RecommendationPlan | RecommendationPlanCollection"]: + """ + Retrieve the list of recommendation plans or nested recommendation plan collections. + + :return: A list containing RecommendationPlan objects or nested RecommendationPlanCollection objects. + """ + return self._plans + + def add_plan( + self, plan: Union[RecommendationPlan, "RecommendationPlanCollection"] + ) -> None: + """ + Add a recommendation plan or a nested recommendation plan collection to this collection. + + :param plan: The RecommendationPlan or RecommendationPlanCollection to be added. + """ + self._plans.append(plan) diff --git a/execution_engine/fhir/terminology.py b/execution_engine/fhir/terminology.py index 1f3d86a9..180d5394 100644 --- a/execution_engine/fhir/terminology.py +++ b/execution_engine/fhir/terminology.py @@ -1,3 +1,6 @@ +import logging +from json import JSONDecodeError + import requests # fixme use caching? @@ -121,7 +124,9 @@ def code_in_valueset( :param system: The coding system the code belongs to :return: True if the code is valid within the ValueSet, False otherwise """ - # Prepare the request data + + logging.debug(f"Validating code {code} from system {system} in ValueSet") + data = { "resourceType": "Parameters", "parameter": [ @@ -145,15 +150,23 @@ def code_in_valueset( timeout=30, ) - json_response = response.json() - if response.status_code == 200: - # Look for a parameter named 'result' and check if its value is True + + json_response = response.json() + for param in json_response.get("parameter", []): if param.get("name") == "result": return param.get("valueBoolean", False) return False else: + + try: + json_response = response.json() + except JSONDecodeError: + raise FHIRTerminologyServerException( + f"Error validating code: HTTP Status {response.status_code}" + ) + issues = [ issue["severity"] + ": " + issue["details"]["text"] for issue in json_response["issue"] diff --git a/execution_engine/fhir/util.py b/execution_engine/fhir/util.py index d1a32981..0c9699cb 100644 --- a/execution_engine/fhir/util.py +++ b/execution_engine/fhir/util.py @@ -1,6 +1,7 @@ from fhir.resources.codeableconcept import CodeableConcept from fhir.resources.coding import Coding from fhir.resources.element import Element +from fhir.resources.extension import Extension def get_coding(cc: CodeableConcept, system_uri: str | None = None) -> Coding: @@ -20,15 +21,71 @@ def get_coding(cc: CodeableConcept, system_uri: str | None = None) -> Coding: return coding[0] -def get_extension(base: Element, extension_url: str) -> Element | None: +def get_extensions(base: Element, extension_url: str) -> list[Extension]: """ - Get the extension with the given URL from the given element. + Retrieves all extensions with the given URL from the specified element. + + This function returns a list of extensions that match the given URL, + allowing for multiple occurrences. + + Args: + base (Element): The element containing extensions. + extension_url (str): The URL of the extensions to retrieve. + + Returns: + list[Extension]: A list of matching extensions (empty if none are found). """ if base.extension is None: - return None + return [] + + return [ext for ext in base.extension if ext.url == extension_url] + + +def pop_extensions(base: Element, extension_url: str) -> list[Extension]: + """ + Retrieves and removes all extensions with the given URL from the specified element, + returning them as a list in the order they appeared. + + Args: + base (Element): The element containing extensions. + extension_url (str): The URL of the extensions to retrieve and remove. + + Returns: + list[Extension]: A list of matching extensions (empty if none are found). + """ + if not base.extension: + return [] + + matches = [ext for ext in base.extension if ext.url == extension_url] + keepers = [ext for ext in base.extension if ext.url != extension_url] + + base.extension = keepers - for ext in base.extension: - if ext.url == extension_url: - return ext + return matches + + +def get_extension(base: Element, extension_url: str) -> Extension | None: + """ + Retrieves a single extension with the given URL from the specified element. + + If multiple extensions with the same URL exist, an error is raised to ensure + that only unique extensions are retrieved. + + Args: + base (Element): The element containing extensions. + extension_url (str): The URL of the extension to retrieve. + + Returns: + Extension | None: The matching extension if found, otherwise None. + + Raises: + ValueError: If multiple extensions with the same URL exist. + """ + matching_extensions = get_extensions(base, extension_url) + + if len(matching_extensions) > 1: + raise ValueError( + f"Multiple extensions found with URL '{extension_url}', but only one was expected." + ) - return None + return matching_extensions[0] if matching_extensions else None diff --git a/execution_engine/omop/cohort/__init__.py b/execution_engine/omop/cohort/__init__.py index 81706fdd..bc7fcb5a 100644 --- a/execution_engine/omop/cohort/__init__.py +++ b/execution_engine/omop/cohort/__init__.py @@ -1,2 +1,2 @@ -from .population_intervention_pair import PopulationInterventionPair +from .population_intervention_pair import PopulationInterventionPairExpr from .recommendation import Recommendation diff --git a/execution_engine/omop/cohort/graph_builder.py b/execution_engine/omop/cohort/graph_builder.py new file mode 100644 index 00000000..9329bd26 --- /dev/null +++ b/execution_engine/omop/cohort/graph_builder.py @@ -0,0 +1,151 @@ +import copy + +import networkx as nx + +import execution_engine.util.logic as logic +from execution_engine.constants import CohortCategory +from execution_engine.execution_graph import ExecutionGraph +from execution_engine.omop.cohort.population_intervention_pair import ( + PopulationInterventionPairExpr, +) +from execution_engine.omop.criterion.abstract import Criterion + + +class RecommendationGraphBuilder: + """ + A builder class for constructing ExecutionGraph objects based on + population/intervention expressions. It provides utility methods to filter + intervention criteria by population constraints, and then converts the + filtered expression into an ExecutionGraph ready for execution and storage. + """ + + @classmethod + def filter_symbols(cls, node: logic.Expr, filter_: logic.Expr) -> logic.Expr: + """ + Filter (=AND-combine) all symbols by the applied filter function + + Used to filter all intervention criteria (symbols) by the population output in order to exclude + all intervention events outside the population intervals, which may otherwise interfere with corrected + determination of temporal combination, i.e. the presence of an intervention event during some time window. + + :param node: The expression node to be filtered. + :type node: logic.Expr + :param filter_: The filter expression to AND-combine with symbols in the node. + :type filter_: logic.Expr + :return: A new expression in which all symbols are constrained by the filter expression. + :rtype: logic.Expr + """ + + node = copy.copy(node) + + if isinstance(node, logic.Symbol): + return logic.LeftDependentToggle(left=filter_, right=node) + elif isinstance(node, logic.Expr): + if hasattr(node, "interval_criterion") and node.interval_criterion: + # we must not wrap the interval_criterion + interval_criterion = node.interval_criterion + converted_args = [ + cls.filter_symbols(a, filter_) + for a in node.args + if not a == interval_criterion + ] + [interval_criterion] + else: + converted_args = [cls.filter_symbols(a, filter_) for a in node.args] + + if any(a is not b for a, b in zip(node.args, converted_args)): + node.update_args(*converted_args) + + return node + + @classmethod + def filter_intervention_criteria_by_population(cls, expr: logic.Expr) -> logic.Expr: + """ + Filter all intervention criteria in a given expression by the population part of the expression. + + :param expr: The expression that may contain population and intervention parts. + :type expr: logic.Expr + :return: A new expression where all intervention symbols are constrained by the population intervals. + :rtype: logic.Expr + """ + + from execution_engine.omop.cohort import PopulationInterventionPairExpr + + # we might make changes to the expression (e.g. filtering), so we must preserve + # the original expression from the caller + expr = copy.deepcopy(expr) + + def traverse( + expr: logic.Expr, + ) -> None: + if isinstance(expr, PopulationInterventionPairExpr): + p, i = expr.left, expr.right + + # filter all intervention criteria by the output of the population - this is performed to filter out + # intervention events that outside of the population intervals (i.e. the time windows during which + # patients are part of the population) as otherwise events outside of the population time may be picked up + # by Temporal criteria that determine the presence of some event or condition during a specific time window. + i = cls.filter_symbols(i, filter_=p) + + expr.update_args(p, i) + + traverse(i) + traverse(p) + + elif not expr.is_Atom: + for child in expr.args: + traverse(child) + + traverse(expr) + + # we need to rehash as the structure has been changed due to insertion of additional nodes + expr.rehash(recursive=True) + + return expr + + @classmethod + def build(cls, expr: logic.Expr, base_criterion: Criterion) -> ExecutionGraph: + """ + Build an ExecutionGraph for a population/intervention expression. + + If the expression is a PopulationInterventionPairExpr, it is wrapped in a + NonSimplifiableAnd to ensure a top-level result entry is generated in the database. + Then the expression is filtered and converted into an ExecutionGraph with the + appropriate sink nodes and edges. + + :param expr: The population/intervention expression to build the graph from. + :type expr: logic.Expr + :param base_criterion: The base criterion used to label the execution graph. + :type base_criterion: Criterion + :return: The constructed ExecutionGraph for the given expression. + :rtype: ExecutionGraph + """ + if isinstance(expr, PopulationInterventionPairExpr): + expr = logic.NonSimplifiableAnd(expr) + + # Make sure the expr is filtered + expr_filtered = cls.filter_intervention_criteria_by_population(expr) + + graph = ExecutionGraph.from_expression( + expr_filtered, + base_criterion=base_criterion, + category=CohortCategory.POPULATION_INTERVENTION, + ) + + p_sink_nodes = graph.sink_nodes(CohortCategory.POPULATION) + graph.set_sink_nodes_store( + bind_params={}, desired_category=CohortCategory.POPULATION_INTERVENTION + ) + + p_combination_node = logic.NonSimplifiableOr(*p_sink_nodes) + graph.add_node( + p_combination_node, store_result=True, category=CohortCategory.POPULATION + ) + graph.add_edges_from((src, p_combination_node) for src in p_sink_nodes) + + if graph.in_degree(base_criterion) != 0: + raise AssertionError("Base criterion must not have incoming edges") + + if not nx.is_directed_acyclic_graph(graph): + raise AssertionError("Graph is not acyclic") + + return graph diff --git a/execution_engine/omop/cohort/population_intervention_pair.py b/execution_engine/omop/cohort/population_intervention_pair.py index cab375f0..457ec6f5 100644 --- a/execution_engine/omop/cohort/population_intervention_pair.py +++ b/execution_engine/omop/cohort/population_intervention_pair.py @@ -1,32 +1,14 @@ -from typing import Any, Dict, cast +from typing import Callable, cast -from sqlalchemy.sql import ( - Alias, - CompoundSelect, - Insert, - Join, - Select, - Subquery, - TableClause, -) -from sqlalchemy.sql.elements import BinaryExpression, BooleanClauseList -from sqlalchemy.sql.selectable import CTE - -import execution_engine.util.cohort_logic as logic -from execution_engine.constants import CohortCategory -from execution_engine.execution_graph import ExecutionGraph +import execution_engine.util.logic as logic from execution_engine.omop.criterion.abstract import Criterion -from execution_engine.omop.criterion.combination.combination import CriterionCombination -from execution_engine.omop.criterion.combination.logical import ( - LogicalCriterionCombination, -) -from execution_engine.omop.criterion.factory import criterion_factory -from execution_engine.omop.serializable import Serializable -from execution_engine.util.sql import SelectInto -class PopulationInterventionPair(Serializable): +class PopulationInterventionPairExpr(logic.LeftDependentToggle): """ + A logical expression that ties together a population expression and an intervention expression, + plus any extra info like name, url, base_criterion, etc. + A population/intervention pair in OMOP as a collection of separate criteria. A population/intervention pair represents an individual recommendation plan (i.e. one part of a single recommendation), @@ -39,293 +21,90 @@ class PopulationInterventionPair(Serializable): """ _name: str - _population: CriterionCombination - _intervention: CriterionCombination - - def __init__( - self, + _url: str + _base_criterion: Criterion + + def __new__( + cls, + population_expr: logic.BaseExpr, + intervention_expr: logic.BaseExpr, + *, name: str, url: str, base_criterion: Criterion, - population: CriterionCombination | None = None, - intervention: CriterionCombination | None = None, - ) -> None: - self._name = name - self._url = url - self._base_criterion = base_criterion - - self.set_criteria(CohortCategory.POPULATION, population) - self.set_criteria(CohortCategory.INTERVENTION, intervention) - - def __repr__(self) -> str: + **kwargs: dict, + ) -> "PopulationInterventionPairExpr": """ - Get the string representation of the population/intervention pair. + Create a new PopulationInterventionPairExpr object. """ - return ( - f"{self.__class__.__name__}(\n" - f" name={repr(self._name)},\n" - f" url={repr(self._url)},\n" - f" base_criterion={repr(self._base_criterion)},\n" - f" population={self._population._repr_pretty(level=1).strip()},\n" - f" intervention={self._intervention._repr_pretty(level=1).strip()}\n" - f")" - ) - - def set_criteria( - self, category: CohortCategory, criteria: CriterionCombination | None - ) -> None: - """ - Set the criteria (either population or intervention) of the population/intervention pair. - - :param category: The category of the criteria. - :param criteria: The criteria. - """ - - root_combination = LogicalCriterionCombination( - operator=LogicalCriterionCombination.Operator("AND"), - root_combination=True, + self = cast( + PopulationInterventionPairExpr, + super().__new__( + cls, left=population_expr, right=intervention_expr, **kwargs + ), ) - if criteria is not None: - root_combination.add(criteria) + self._name = name + self._url = url + self._base_criterion = base_criterion - if category == CohortCategory.POPULATION: - self._population = root_combination - elif category == CohortCategory.INTERVENTION: - self._intervention = root_combination - else: - raise ValueError(f"Invalid category {category}") + return self @property def name(self) -> str: """ - Get the name of the population/intervention pair. + The name of the population/intervention pair. """ return self._name @property def url(self) -> str: """ - Get the url of the population/intervention pair. + The URL of the population/intervention pair. """ return self._url - @classmethod - def filter_symbols(cls, node: logic.Expr, filter_: logic.Expr) -> logic.Expr: - """ - Filter (=AND-combine) all symbols by the applied filter function - - Used to filter all intervention criteria (symbols) by the population output in order to exclude - all intervention events outside the population intervals, which may otherwise interfere with corrected - determination of temporal combination, i.e. the presence of an intervention event during some time window. - """ - - if isinstance(node, logic.Symbol): - return logic.LeftDependentToggle( - left=filter_, right=node, category=CohortCategory.INTERVENTION - ) - - if hasattr(node, "args") and isinstance(node.args, tuple): - converted_args = [cls.filter_symbols(a, filter_) for a in node.args] - - if any(a is not b for a, b in zip(node.args, converted_args)): - node.args = tuple(converted_args) - - return node - - def execution_graph(self) -> ExecutionGraph: - """ - Get the execution graph for the population/intervention pair. - """ - - p = ExecutionGraph.combination_to_expression( - self._population, category=CohortCategory.POPULATION - ) - i = ExecutionGraph.combination_to_expression( - self._intervention, category=CohortCategory.INTERVENTION - ) - - # filter all intervention criteria by the output of the population - this is performed to filter out - # intervention events that outside of the population intervals (i.e. the time windows during which - # patients are part of the population) as otherwise events outside of the population time may be picked up - # by Temporal criteria that determine the presence of some event or condition during a specific time window. - i = self.filter_symbols(i, filter_=p) - - pi = logic.LeftDependentToggle( - p, - i, - category=CohortCategory.POPULATION_INTERVENTION, - ) - pi_graph = ExecutionGraph.from_expression(pi, self._base_criterion) - - if self._id is None: - raise ValueError("Population/intervention pair ID not set") - - # todo: should we supply self instead of self._id? - pi_graph.set_sink_nodes_store(bind_params=dict(pi_pair_id=self._id)) - - return pi_graph - - def set_population(self, combination: CriterionCombination) -> None: - """ - Set the population criteria. - """ - combination.set_root() - self._population = combination - - def add_population(self, criterion: Criterion | CriterionCombination) -> None: - """ - Add a criterion to the population of the population/intervention pair. - """ - self._population.add(criterion) - - def add_intervention(self, criterion: Criterion | CriterionCombination) -> None: - """ - Add a criterion to the intervention of the population/intervention pair. - """ - self._intervention.add(criterion) - - def criteria(self) -> CriterionCombination: + @property + def base_criterion(self) -> Criterion: """ - Get the criteria of the population/intervention pair. + The base criterion of the population/intervention pair. """ - """return self._criteria""" - raise NotImplementedError() + return self._base_criterion - def flatten(self) -> list[Criterion]: + def __reduce__(self) -> tuple[Callable, tuple]: """ - Retrieve all criteria in a flat list (i.e. no nested criterion combinations). + Reduce the expression to its arguments and category. - Includes the base criterion, population and intervention. + Required for pickling (e.g. when using multiprocessing). - :return: A list of individual criteria. + :return: Tuple of the class, arguments, and category. """ - - def _traverse(comb: CriterionCombination) -> list[Criterion]: - criteria = [] - for element in comb: - if isinstance(element, CriterionCombination): - criteria += _traverse(element) - else: - criteria.append(element) - return criteria - return ( - [self._base_criterion] - + _traverse(self._population) - + _traverse(self._intervention) + self._recreate, + ( + self.args, + { + "name": self.name, + "url": self.url, + "base_criterion": self.base_criterion, + } + | {"_id": self._id}, + ), ) - @staticmethod - def _assert_base_table_in_select( - sql: CompoundSelect | Select | SelectInto, base_table_out: str - ) -> None: - """ - Assert that the base table is used in the select statement. - - Joining the base table ensures that always just a subset of patients is selected, - not all. - """ - if isinstance(sql, SelectInto) or isinstance(sql, Insert): - sql = sql.select - - def _base_table_in_select(sql_select: Join | Select | Alias) -> bool: - """ - Check if the base table is used in the select statement. - """ - if isinstance(sql_select, Join): - return _base_table_in_select(sql_select.left) or _base_table_in_select( - sql_select.right - ) - elif isinstance(sql_select, Select): - return any( - _base_table_in_select(f) for f in sql_select.get_final_froms() - ) or any(_base_table_in_select(w) for w in sql_select.whereclause) - elif isinstance(sql_select, Alias): - return sql_select.original.name == base_table_out - elif isinstance(sql_select, TableClause): - return sql_select.name == base_table_out - elif isinstance(sql_select, CTE): - return _base_table_in_select(sql_select.original) - elif isinstance(sql_select, BooleanClauseList): - return any( - w.right.element.froms[0].name == base_table_out for w in sql_select - ) - elif isinstance(sql_select, BinaryExpression): - return ( - sql_select.right.element.get_final_froms()[0].name == base_table_out - ) - elif isinstance(sql_select, Subquery): - if isinstance(sql_select.original, CompoundSelect): - return all( - _base_table_in_select(s) for s in sql_select.original.selects - ) - else: - return any( - _base_table_in_select(f) - for f in sql_select.original.get_final_froms() - ) - else: - raise NotImplementedError(f"Unknown type {type(sql_select)}") - - if isinstance(sql, CompoundSelect): - assert all( - _base_table_in_select(s) for s in sql.selects - ), "Base table not used in all selects of compound select" - elif isinstance(sql, Select): - assert _base_table_in_select( - sql - ), f"Base table {base_table_out} not found in select" - else: - raise NotImplementedError(f"Unknown type {type(sql)}") - - def dict(self) -> dict[str, Any]: - """ - Get a dictionary representation of the population/intervention pair. - """ - base_criterion = self._base_criterion - population = self._population - intervention = self._intervention - return { - "name": self.name, - "url": self.url, - "base_criterion": { - "class_name": base_criterion.__class__.__name__, - "data": base_criterion.dict(), - }, - "population": { - "class_name": population.__class__.__name__, - "data": population.dict(), - }, - "intervention": { - "class_name": intervention.__class__.__name__, - "data": intervention.dict(), - }, - } - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "PopulationInterventionPair": - """ - Create a population/intervention pair from a dictionary. - """ - - base_criterion = criterion_factory(**data["base_criterion"]) - assert isinstance( - base_criterion, Criterion - ), "Base criterion must be a Criterion" - population = cast(CriterionCombination, criterion_factory(**data["population"])) - intervention = cast( - CriterionCombination, criterion_factory(**data["intervention"]) - ) - object = cls( - name=data["name"], - url=data["url"], - base_criterion=base_criterion, + def dict(self, include_id: bool = False) -> dict: + """ + Get a dictionary representation of the object. + """ + data = super().dict(include_id=include_id) + del data["data"]["left"] + del data["data"]["right"] + data["data"].update( + { + "population_expr": self.left.dict(include_id=include_id), + "intervention_expr": self.right.dict(include_id=include_id), + "name": self.name, + "url": self.url, + "base_criterion": self.base_criterion.dict(include_id=include_id), + } ) - # The constructor initializes the population and intervention - # slots in a particular way, but we want to use whatever we - # have deserialized instead. This is a bit inefficient because - # we discard the values that were assigned to the two slots in - # the constructor. - object._population = population - object._intervention = intervention - return object + return data diff --git a/execution_engine/omop/cohort/recommendation.py b/execution_engine/omop/cohort/recommendation.py index 4f4b88fc..695805b6 100644 --- a/execution_engine/omop/cohort/recommendation.py +++ b/execution_engine/omop/cohort/recommendation.py @@ -1,8 +1,6 @@ -import itertools import re -from typing import Any, Dict, Iterator, Self +from typing import Iterator -import networkx as nx from sqlalchemy import ( Column, Date, @@ -15,23 +13,19 @@ select, ) -import execution_engine.util.cohort_logic as logic +import execution_engine.util.logic as logic from execution_engine.constants import CohortCategory from execution_engine.execution_graph import ExecutionGraph -from execution_engine.omop import cohort - -# ) -from execution_engine.omop.criterion.abstract import Criterion -from execution_engine.omop.criterion.combination.combination import CriterionCombination -from execution_engine.omop.criterion.combination.logical import ( - LogicalCriterionCombination, +from execution_engine.omop.cohort.graph_builder import RecommendationGraphBuilder +from execution_engine.omop.cohort.population_intervention_pair import ( + PopulationInterventionPairExpr, ) -from execution_engine.omop.criterion.factory import criterion_factory +from execution_engine.omop.criterion.abstract import Criterion from execution_engine.omop.db.celida.tables import ResultInterval -from execution_engine.omop.serializable import Serializable +from execution_engine.util.serializable import SerializableDataClass -class Recommendation(Serializable): +class Recommendation(SerializableDataClass): """ A recommendation OMOP as a collection of separate population/intervention pairs. @@ -44,7 +38,7 @@ class Recommendation(Serializable): def __init__( self, - pi_pairs: list[cohort.PopulationInterventionPair], + expr: logic.Expr, base_criterion: Criterion, name: str, title: str, @@ -53,7 +47,7 @@ def __init__( description: str, package_version: str | None = None, ) -> None: - self._pi_pairs: list[cohort.PopulationInterventionPair] = pi_pairs + self._expr: logic.Expr = expr self._base_criterion: Criterion = base_criterion self._name: str = name self._title: str = title @@ -66,15 +60,9 @@ def __repr__(self) -> str: """ Get the string representation of the recommendation. """ - pi_repr = "\n".join( - [(" " + line) for line in repr(self._pi_pairs).split("\n")] - ).strip() - pi_repr = ( - pi_repr[0] + "\n " + pi_repr[1:-2] + pi_repr[-2] + "\n " + pi_repr[-1] - ) return ( f"{self.__class__.__name__}(\n" - f" pi_pairs={pi_repr},\n" + f" expr={repr(self._expr)},\n" f" base_criterion={repr(self._base_criterion)},\n" f" name={repr(self._name)},\n" f" title={repr(self._title)},\n" @@ -111,7 +99,7 @@ def version(self) -> str: """ Get the version of the recommendation. """ - return self._version # + return self._version @property def package_version(self) -> str | None: @@ -142,68 +130,28 @@ def execution_graph(self) -> ExecutionGraph: execution maps of the individual population/intervention pairs of the recommendation. """ - p_nodes = [] - pi_nodes = [] - pi_graphs = [] - - for pi_pair in self._pi_pairs: - pi_graph = pi_pair.execution_graph() - - p_nodes.append(pi_graph.sink_node(CohortCategory.POPULATION)) - pi_nodes.append(pi_graph.sink_node(CohortCategory.POPULATION_INTERVENTION)) - pi_graphs.append(pi_graph) - - p_combination_node = logic.NoDataPreservingOr( - *p_nodes, category=CohortCategory.POPULATION - ) - pi_combination_node = logic.NoDataPreservingAnd( - *pi_nodes, category=CohortCategory.POPULATION_INTERVENTION - ) + return RecommendationGraphBuilder.build(self._expr, self._base_criterion) - common_graph = nx.compose_all(pi_graphs) - - common_graph.add_node( - p_combination_node, store_result=True, category=CohortCategory.POPULATION - ) - - common_graph.add_node( - pi_combination_node, - store_result=True, - category=CohortCategory.POPULATION_INTERVENTION, - ) - - common_graph.add_edges_from((src, p_combination_node) for src in p_nodes) - common_graph.add_edges_from((src, pi_combination_node) for src in pi_nodes) - - return common_graph - - def criteria(self) -> CriterionCombination: - """ - Get the criteria of the recommendation. - """ - criteria = LogicalCriterionCombination( - operator=LogicalCriterionCombination.Operator("OR"), - root_combination=True, - ) - - for pi_pair in self._pi_pairs: - criteria.add(pi_pair.criteria()) - - return criteria - - def flatten(self) -> list[Criterion]: + def atoms(self) -> Iterator[Criterion]: """ Retrieve all criteria in a flat list """ - return list(itertools.chain(*[pi_pair.flatten() for pi_pair in self._pi_pairs])) + yield self._base_criterion + yield from self._expr.atoms() - def population_intervention_pairs( - self, - ) -> Iterator[cohort.PopulationInterventionPair]: + def population_intervention_pairs(self) -> Iterator[PopulationInterventionPairExpr]: """ - Iterate over the population/intervention pairs. + Iterate over all PopulationInterventionPairExpr in the expression tree. """ - yield from self._pi_pairs + + def traverse(expr: logic.BaseExpr) -> Iterator[PopulationInterventionPairExpr]: + if isinstance(expr, PopulationInterventionPairExpr): + yield expr + else: + for sub_expr in expr.args: + yield from traverse(sub_expr) + + yield from traverse(self._expr) def __str__(self) -> str: """ @@ -211,18 +159,6 @@ def __str__(self) -> str: """ return f"Recommendation(name='{self._name}', description='{self.description}')" - def __len__(self) -> int: - """ - Get the number of population/intervention pairs. - """ - return len(self._pi_pairs) - - def __getitem__(self, index: int) -> cohort.PopulationInterventionPair: - """ - Get the population/intervention pair at the given index. - """ - return self._pi_pairs[index] - @staticmethod def to_table(name: str) -> Table: """ @@ -276,50 +212,8 @@ def reset_state(self) -> None: """ self._id = None - for pi_pair in self._pi_pairs: + for pi_pair in self.population_intervention_pairs(): pi_pair._id = None - for criterion in pi_pair.flatten(): - criterion._id = None - - def dict(self) -> dict: - """ - Get the combination as a dictionary. - """ - base_criterion = self._base_criterion - return { - "population_intervention_pairs": [c.dict() for c in self._pi_pairs], - "base_criterion": { - "class_name": base_criterion.__class__.__name__, - "data": base_criterion.dict(), - }, - "recommendation_name": self._name, - "recommendation_title": self._title, - "recommendation_url": self._url, - "recommendation_version": self._version, - "recommendation_package_version": self._package_version, - "recommendation_description": self._description, - } - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> Self: - """ - Create a combination from a dictionary. - """ - base_criterion = criterion_factory(**data["base_criterion"]) - assert isinstance( - base_criterion, Criterion - ), "Base criterion must be a Criterion" - - return cls( - pi_pairs=[ - cohort.PopulationInterventionPair.from_dict(c) - for c in data["population_intervention_pairs"] - ], - base_criterion=base_criterion, - name=data["recommendation_name"], - title=data["recommendation_title"], - url=data["recommendation_url"], - version=data["recommendation_version"], - description=data["recommendation_description"], - package_version=data["recommendation_package_version"], - ) + + for criterion in self.atoms(): + criterion._id = None diff --git a/execution_engine/omop/concepts.py b/execution_engine/omop/concepts.py index 57c9ec37..96d135bd 100644 --- a/execution_engine/omop/concepts.py +++ b/execution_engine/omop/concepts.py @@ -1,7 +1,10 @@ import pandas as pd from pydantic import BaseModel +from execution_engine.util import serializable + +@serializable.register_class class Concept(BaseModel, frozen=True): # type: ignore """Represents an OMOP Standard Vocabulary concept.""" @@ -35,23 +38,14 @@ def is_custom(self) -> bool: return self.concept_id < 0 -class CustomConcept(Concept): +@serializable.register_class +class CustomConcept(Concept, frozen=True): """Represents a custom concept.""" - def __init__( - self, name: str, concept_code: str, domain_id: str, vocabulary_id: str - ) -> None: - """Creates a custom concept.""" - super().__init__( - concept_id=-1, - concept_name=name, - concept_code=concept_code, - domain_id=domain_id, - vocabulary_id=vocabulary_id, - concept_class_id="Custom", - standard_concept=None, - invalid_reason=None, - ) + concept_id: int = -1 + concept_class_id: str = "Custom" + standard_concept: str | None = None + invalid_reason: str | None = None @property def id(self) -> int: diff --git a/execution_engine/omop/criterion/abstract.py b/execution_engine/omop/criterion/abstract.py index 16627385..8ceaece8 100644 --- a/execution_engine/omop/criterion/abstract.py +++ b/execution_engine/omop/criterion/abstract.py @@ -1,6 +1,5 @@ -import copy -from abc import ABC, abstractmethod -from typing import Any, Dict, Self, Type, TypedDict, cast +from abc import abstractmethod +from typing import Type, TypedDict, cast import sqlalchemy from sqlalchemy import CTE, Alias, ColumnElement, Date, Integer @@ -11,7 +10,6 @@ from execution_engine.constants import CohortCategory from execution_engine.omop.concepts import Concept -from execution_engine.omop.criterion.meta import SignatureReprABCMeta from execution_engine.omop.db.base import DateTimeWithTimeZone from execution_engine.omop.db.celida.tables import IntervalTypeEnum, ResultInterval from execution_engine.omop.db.omop.tables import ( @@ -25,12 +23,21 @@ VisitDetail, VisitOccurrence, ) -from execution_engine.omop.serializable import Serializable +from execution_engine.util import logic from execution_engine.util.interval import IntervalType +from execution_engine.util.serializable import SerializableDataClassABC from execution_engine.util.sql import SelectInto, select_into from execution_engine.util.types import PersonIntervals, TimeRange -__all__ = ["AbstractCriterion", "Criterion"] +__all__ = [ + "Criterion", + "column_interval_type", + "create_conditional_interval_column", + "SQL_ONE_SECOND", + "observation_start_datetime", + "observation_end_datetime", + "run_id", +] Domain = TypedDict( "Domain", @@ -89,53 +96,7 @@ def create_conditional_interval_column(condition: ColumnElement) -> ColumnElemen ) -class AbstractCriterion(Serializable, ABC, metaclass=SignatureReprABCMeta): - """ - Abstract base class for Criterion and CriterionCombination. - """ - - _base: bool = False - - def is_base(self) -> bool: - """ - Check if this criterion is the base criterion. - """ - return self._base - - def set_base(self, value: bool = True) -> None: - """ - Set the base criterion. - """ - self._base = value - - @property - def type(self) -> str: - """ - Get the type of the criterion. - """ - return self.__class__.__name__ - - def copy(self) -> "AbstractCriterion": - """ - Copy the criterion. - """ - return copy.copy(self) - - @abstractmethod - def description(self) -> str: - """ - Return a description of the criterion. - """ - raise NotImplementedError() - - def __str__(self) -> str: - """ - Get the name of the criterion. - """ - return self.description() - - -class Criterion(AbstractCriterion): +class Criterion(SerializableDataClassABC, logic.Symbol): """A criterion in a recommendation.""" _OMOP_TABLE: Type[Base] @@ -164,6 +125,11 @@ class Criterion(AbstractCriterion): during the recommendation period). """ + _base: bool = False + """ + Specifies whether this criterion is the base criterion (i.e. the criterion that selects the initial cohort). + """ + DOMAINS: dict[str, Domain] = { "condition": { "table": ConditionOccurrence, @@ -209,6 +175,34 @@ class Criterion(AbstractCriterion): Flag to indicate whether the filter_datetime function has been called. """ + def __init__(self) -> None: + super().__init__() + + def is_base(self) -> bool: + """ + Check if this criterion is the base criterion. + """ + return self._base + + def set_base(self, value: bool = True) -> None: + """ + Set the base criterion. + """ + self._base = value + + @abstractmethod + def description(self) -> str: + """ + Return a description of the criterion. + """ + raise NotImplementedError() + + def __str__(self) -> str: + """ + Get the name of the criterion. + """ + return self.description() + def _set_omop_variables_from_domain(self, domain_id: str) -> None: """ Set the OMOP table and column prefix based on the domain ID. @@ -226,6 +220,19 @@ def _set_omop_variables_from_domain(self, domain_id: str) -> None: self._static = cast(bool, domain["static"]) self._table = cast(Base, domain["table"]).__table__.alias(self.table_alias) + # @property + # def type(self) -> str: + # """ + # Get the type of the criterion. + # """ + # return self.__class__.__name__ + # + # def copy(self) -> "AbstractCriterion": + # """ + # Copy the criterion. + # """ + # return copy.copy(self) + @property def table_alias(self) -> str: """ @@ -284,9 +291,12 @@ def create_query(self) -> Select: ), "Query must select 4 columns: person_id, interval_start, interval_end, interval_type" # assert that the output columns are person_id, interval_start, interval_end, type - assert set([c.name for c in query.selected_columns]) == set( - ["person_id", "interval_start", "interval_end", "interval_type"] - ), "Query must select 4 columns: person_id, interval_start, interval_end, interval_type" + assert set([c.name for c in query.selected_columns]) == { + "person_id", + "interval_start", + "interval_end", + "interval_type", + }, "Query must select 4 columns: person_id, interval_start, interval_end, interval_type" return query @@ -614,14 +624,6 @@ def sql_insert_into_result_table(query: Select) -> SelectInto: return query - @classmethod - @abstractmethod - def from_dict(cls, data: Dict[str, Any]) -> Self: - """ - Create a criterion from a JSON object. - """ - raise NotImplementedError() - def cte_interval_starts( self, query: Select, diff --git a/execution_engine/omop/criterion/combination/combination.py b/execution_engine/omop/criterion/combination/combination.py deleted file mode 100644 index 8ef990cc..00000000 --- a/execution_engine/omop/criterion/combination/combination.py +++ /dev/null @@ -1,296 +0,0 @@ -from abc import ABCMeta -from typing import Any, Dict, Iterator, Sequence, Union, cast - -from execution_engine.omop.criterion.abstract import AbstractCriterion, Criterion - -__all__ = ["CriterionCombination"] - - -def snake_to_camel(s: str) -> str: - return "".join(word.capitalize() for word in s.lower().split("_")) - - -class CriterionCombination(AbstractCriterion, metaclass=ABCMeta): - """ - Base class for a combination of criteria (temporal or logical). - """ - - class Operator: - """ - Operators for criterion combinations. - """ - - def __init__(self, operator: str, threshold: int | None = None): - self.operator = operator - self.threshold = threshold - - def __str__(self) -> str: - """ - Get the string representation of the operator. - """ - if self.operator in ["AT_LEAST", "AT_MOST", "EXACTLY"]: - return f"{self.operator}(threshold={self.threshold})" - else: - return f"{self.operator}" - - def __repr__(self) -> str: - """ - Get the string representation of the operator. - """ - if self.operator in ["AT_LEAST", "AT_MOST", "EXACTLY"]: - return f'{self.__class__.__qualname__}(operator="{self.operator}", threshold={self.threshold})' - else: - return f'{self.__class__.__qualname__}(operator="{self.operator}")' - - def __eq__(self, other: object) -> bool: - """ - Check if the operator is equal to another operator. - """ - if not isinstance(other, CriterionCombination.Operator): - return NotImplemented - return self.operator == other.operator and self.threshold == other.threshold - - def __init__( - self, - operator: Operator, - criteria: Sequence[Union[Criterion, "CriterionCombination"]] | None = None, - root_combination: bool = False, - ): - """ - Initialize the criterion combination. - """ - super().__init__() - self._operator = operator - - self._criteria: list[Union[Criterion, CriterionCombination]] - self._root = root_combination - - if criteria is None: - self._criteria = [] - else: - self._criteria = cast( - list[Union[Criterion, "CriterionCombination"]], criteria - ) - - def add(self, criterion: Union[Criterion, "CriterionCombination"]) -> None: - """ - Add a criterion to the combination. - """ - self._criteria.append(criterion) - - def add_all( - self, criteria: Sequence[Union[Criterion, "CriterionCombination"]] - ) -> None: - """ - Add multiple criteria to the combination. - """ - self._criteria.extend(criteria) - - def __str__(self) -> str: - """ - Get the name of the criterion combination. - """ - return f"{self.__class__.__name__}({self.operator})" - - @property - def operator(self) -> "CriterionCombination.Operator": - """ - Get the operator of the criterion combination (i.e. the type of combination, e.g. AND, OR, AT_LEAST, etc.). - """ - return self._operator - - def set_root(self, value: bool = True) -> None: - """ - Sets whether this criterion combination is at the root of a tree of criteria / combinations. - """ - self._root = value - - def is_root(self) -> bool: - """ - Returns whether this criterion combination is at the root of a tree of criteria / combinations. - """ - return self._root - - def __iter__(self) -> Iterator[Union[Criterion, "CriterionCombination"]]: - """ - Iterate over the criteria in the combination. - """ - for criterion in self._criteria: - yield criterion - - def __len__(self) -> int: - """ - Get the number of criteria in the combination. - """ - return len(self._criteria) - - def __getitem__(self, index: int) -> Union[Criterion, "CriterionCombination"]: - """ - Get the criterion at the specified index. - """ - return self._criteria[index] - - def description(self) -> str: - """ - Description of this combination. - """ - return str(self) - - def dict(self) -> dict[str, Any]: - """ - Get the dictionary representation of the criterion combination. - """ - return { - "operator": self._operator.operator, - "threshold": self._operator.threshold, - "criteria": [ - { - "class_name": criterion.__class__.__name__, - "data": criterion.dict(), - } - for criterion in self._criteria - ], - "root": self._root, - } - - def __invert__(self) -> AbstractCriterion: - """ - Invert the criterion combination. - """ - # Would be cycle if imported at top-level. - from execution_engine.omop.criterion.combination.logical import ( - LogicalCriterionCombination, - ) - - if ( - isinstance(self, LogicalCriterionCombination) - and self.operator.operator == LogicalCriterionCombination.Operator.NOT - ): - return self._criteria[0] - else: - copy = self.__class__( - operator=self._operator, - criteria=self._criteria, - ) - return LogicalCriterionCombination.Not(copy) - - def invert(self) -> AbstractCriterion: - """ - Invert the criterion combination. - """ - return ~self - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "CriterionCombination": - """ - Create a criterion combination from a dictionary. - """ - - from execution_engine.omop.criterion.factory import ( - criterion_factory, # needs to be here to avoid circular import - ) - - operator = cls.Operator(data["operator"], data["threshold"]) - - combination = cls( - operator=operator, - root_combination=data["root"], - ) - - for criterion in data["criteria"]: - combination.add(criterion_factory(**criterion)) - - return combination - - def _build_repr( - self, - children: Sequence[tuple[str | None, Any]], - params: list[tuple[str, Any]], - level: int = 0, - children_param_name: str = "criteria", - children_is_sequence: bool = True, - ) -> str: - """ - Builds a multi-line string for this criterion combination, - properly indenting each level but avoiding double-indenting if - a child already handles indentation in its own _repr_pretty. - """ - indent = " " * (2 * level) - child_indent = " " * (2 * (level + 1)) - - op = snake_to_camel(self.operator.operator) - - lines: list[str] = [] - kw_lines: list[str] = [] - criteria_lines: list[str] = [] - - method_defined = hasattr(self, op) and callable(getattr(self, op)) - - if self._root: - params.append(("root_combination", self._root)) - - # if there is a specific method for this operator, use it - if method_defined: - lines.append(f"{indent}{self.__class__.__name__}.{op}(") - if self.operator.threshold is not None: - params.append(("threshold", self.operator.threshold)) - else: - lines.append(f"{indent}{self.__class__.__name__}(") - params.append(("operator", self.operator)) - - # Each child on its own line - for key, child in children: - if hasattr(child, "_repr_pretty"): - # The child already handles its own indentation and multi-line formatting. - child_repr = child._repr_pretty(level + 1) - # We'll put a comma on the last line that the child produces. - child_lines = child_repr.split("\n") - child_lines[-1] += "," # add trailing comma to the last line - if key is None: - criteria_lines.extend(child_lines) - else: - # If you have a key like "left=" or "right=", prepend that to the first line - # or handle it similarly. One approach is: - child_lines[0] = f"{child_indent}{key}={child_lines[0].lstrip()}" - criteria_lines.extend(child_lines) - else: - # Fallback to normal repr, which we indent at this level - child_repr = repr(child) - if key is None: - criteria_lines.append(child_indent + child_repr + ",") - else: - criteria_lines.append(f"{child_indent}{key}={child_repr},") - - for key, value in params: - kw_lines.append(f"{child_indent}{key}={repr(value)},") - - if method_defined: - lines.extend(criteria_lines) - lines.extend(kw_lines) - elif children_is_sequence: - lines.extend(kw_lines) - lines.append(f"{child_indent}{children_param_name}=[") - lines.extend(criteria_lines) - lines.append(f"{child_indent}],") - else: - assert len(criteria_lines) <= 1 - lines.extend(kw_lines) - if criteria_lines: - lines.append( - f"{child_indent}{children_param_name}={criteria_lines[0].strip()}" - ) - - lines.append(f"{indent})") - - return "\n".join(lines) - - def _repr_pretty(self, level: int = 0) -> str: - children = [(None, c) for c in self._criteria] - - return self._build_repr(children, params=[], level=level) - - def __repr__(self) -> str: - """ - Get the string representation of the criterion combination. - """ - return self._repr_pretty(0) diff --git a/execution_engine/omop/criterion/combination/logical.py b/execution_engine/omop/criterion/combination/logical.py deleted file mode 100644 index bb624a66..00000000 --- a/execution_engine/omop/criterion/combination/logical.py +++ /dev/null @@ -1,292 +0,0 @@ -from typing import Any, Dict, Iterator, Union - -from execution_engine.omop.criterion.abstract import Criterion -from execution_engine.omop.criterion.combination.combination import CriterionCombination - - -class LogicalCriterionCombination(CriterionCombination): - """ - A combination of criteria. - """ - - class Operator(CriterionCombination.Operator): - """Operators for criterion combinations.""" - - NOT = "NOT" - AND = "AND" - OR = "OR" - AT_LEAST = "AT_LEAST" - CAPPED_AT_LEAST = "CAPPED_AT_LEAST" - AT_MOST = "AT_MOST" - EXACTLY = "EXACTLY" - ALL_OR_NONE = "ALL_OR_NONE" - - def __init__(self, operator: str, threshold: int | None = None): - assert operator in [ - "NOT", - "AND", - "OR", - "AT_LEAST", - "CAPPED_AT_LEAST", - "AT_MOST", - "EXACTLY", - "ALL_OR_NONE", - ], f"Invalid operator {operator}" - - self.operator = operator - if operator in ["AT_LEAST", "CAPPED_AT_LEAST", "AT_MOST", "EXACTLY"]: - assert ( - threshold is not None - ), f"Threshold must be set for operator {operator}" - self.threshold = threshold - - @classmethod - def Not( - cls, - criterion: Union[Criterion, "CriterionCombination"], - ) -> "LogicalCriterionCombination": - """ - Create a NOT "combination" of a single criterion. - """ - return cls( - operator=cls.Operator(cls.Operator.NOT), - criteria=[criterion], - ) - - @classmethod - def And( - cls, - *criteria: Union[Criterion, "CriterionCombination"], - ) -> "LogicalCriterionCombination": - """ - Create an AND combination of criteria. - """ - return cls( - operator=cls.Operator(cls.Operator.AND), - criteria=criteria, - ) - - @classmethod - def Or( - cls, - *criteria: Union[Criterion, "CriterionCombination"], - ) -> "LogicalCriterionCombination": - """ - Create an OR combination of criteria. - """ - return cls( - operator=cls.Operator(cls.Operator.OR), - criteria=criteria, - ) - - @classmethod - def AtLeast( - cls, - *criteria: Union[Criterion, "CriterionCombination"], - threshold: int, - ) -> "LogicalCriterionCombination": - """ - Create an AT_LEAST combination of criteria. - """ - return cls( - operator=cls.Operator(cls.Operator.AT_LEAST, threshold=threshold), - criteria=criteria, - ) - - @classmethod - def CappedAtLeast( - cls, - *criteria: Union[Criterion, "CriterionCombination"], - threshold: int, - ) -> "LogicalCriterionCombination": - """ - Create an CAPPED_AT_LEAST combination of criteria. - """ - return cls( - operator=cls.Operator(cls.Operator.CAPPED_AT_LEAST, threshold=threshold), - criteria=criteria, - ) - - @classmethod - def AtMost( - cls, - *criteria: Union[Criterion, "CriterionCombination"], - threshold: int, - ) -> "LogicalCriterionCombination": - """ - Create an AT_MOST combination of criteria. - """ - return cls( - operator=cls.Operator(cls.Operator.AT_MOST, threshold=threshold), - criteria=criteria, - ) - - @classmethod - def Exactly( - cls, - *criteria: Union[Criterion, "LogicalCriterionCombination"], - threshold: int, - ) -> "LogicalCriterionCombination": - """ - Create an EXACTLY combination of criteria. - """ - return cls( - operator=cls.Operator(cls.Operator.EXACTLY, threshold=threshold), - criteria=criteria, - ) - - @classmethod - def AllOrNone( - cls, - *criteria: Union[Criterion, "LogicalCriterionCombination"], - ) -> "LogicalCriterionCombination": - """ - Create an ALL_OR_NONE combination of criteria. - """ - return cls( - operator=cls.Operator(cls.Operator.ALL_OR_NONE), - criteria=criteria, - ) - - -class NonCommutativeLogicalCriterionCombination(LogicalCriterionCombination): - """ - A combination of criteria that is not commutative. - """ - - _left: Union[Criterion, CriterionCombination] - _right: Union[Criterion, CriterionCombination] - - class Operator(CriterionCombination.Operator): - """Operators for criterion combinations.""" - - CONDITIONAL_FILTER = "CONDITIONAL_FILTER" - - def __init__(self, operator: str, threshold: None = None): - assert operator in [ - "CONDITIONAL_FILTER", - ], f"Invalid operator {operator}" - assert threshold is None - self.operator = operator - self.threshold = threshold - - def __init__( - self, - operator: "NonCommutativeLogicalCriterionCombination.Operator", - left: Union[Criterion, CriterionCombination] | None = None, - right: Union[Criterion, CriterionCombination] | None = None, - root_combination: bool = False, - ): - """ - Initialize the criterion combination. - """ - super().__init__(operator=operator) - - self._criteria = [] - if left is not None: - self._left = left - if right is not None: - self._right = right - - self._root = root_combination - - @property - def left(self) -> Union[Criterion, CriterionCombination]: - """ - Get the left criterion. - """ - return self._left - - @property - def right(self) -> Union[Criterion, CriterionCombination]: - """ - Get the right criterion. - """ - return self._right - - def __str__(self) -> str: - """ - Get the string representation of the criterion combination. - """ - return f"{self.operator}({', '.join(str(c) for c in self._criteria)})" - - def __eq__(self, other: object) -> bool: - """ - Check if the criterion combination is equal to another criterion combination. - """ - if not isinstance(other, NonCommutativeLogicalCriterionCombination): - return NotImplemented - return ( - self.operator == other.operator - and self._left == other._left - and self._right == other._right - ) - - def __iter__(self) -> Iterator[Union[Criterion, "CriterionCombination"]]: - """ - Iterate over the criteria in the combination. - """ - yield self._left - yield self._right - - def dict(self) -> dict: - """ - Get the dictionary representation of the criterion combination. - """ - left = self._left - right = self._right - return { - "operator": self._operator.operator, - "left": {"class_name": left.__class__.__name__, "data": left.dict()}, - "right": {"class_name": right.__class__.__name__, "data": right.dict()}, - } - - def _repr_pretty(self, level: int = 0) -> str: - children = [ - ("left", self._left), - ("right", self._right), - ] - return self._build_repr(children, params=[], level=level) - - @classmethod - def from_dict( - cls, data: Dict[str, Any] - ) -> "NonCommutativeLogicalCriterionCombination": - """ - Create a criterion combination from a dictionary. - """ - - from execution_engine.omop.criterion.factory import ( - criterion_factory, # needs to be here to avoid circular import - ) - - return cls( - operator=cls.Operator(data["operator"]), - left=criterion_factory(**data["left"]), - right=criterion_factory(**data["right"]), - ) - - @classmethod - def ConditionalFilter( - cls, - left: Union[Criterion, "CriterionCombination"], - right: Union[Criterion, "CriterionCombination"], - ) -> "LogicalCriterionCombination": - """ - Create a CONDITIONAL_FILTER combination of criteria. - - A conditional filter returns `right` iff `left` is POSITIVE, otherwise NEGATIVE. - - | left | right | Result | - |----------|----------|----------| - | NEGATIVE | * | NEGATIVE | - | NO_DATA | * | NEGATIVE | - | POSITIVE | POSITIVE | POSITIVE | - | POSITIVE | NEGATIVE | NEGATIVE | - | POSITIVE | NO_DATA | NO_DATA | - """ - return cls( - operator=cls.Operator(cls.Operator.CONDITIONAL_FILTER), - left=left, - right=right, - ) diff --git a/execution_engine/omop/criterion/combination/temporal.py b/execution_engine/omop/criterion/combination/temporal.py deleted file mode 100644 index ca9f6930..00000000 --- a/execution_engine/omop/criterion/combination/temporal.py +++ /dev/null @@ -1,511 +0,0 @@ -from abc import ABC -from datetime import time -from enum import StrEnum -from typing import Any, Dict, Iterator, Union - -from execution_engine.omop.criterion.abstract import Criterion -from execution_engine.omop.criterion.combination.combination import CriterionCombination - - -class TimeIntervalType(StrEnum): - """ - Types of time intervals to aggregate criteria over. - """ - - MORNING_SHIFT = "morning_shift" - AFTERNOON_SHIFT = "afternoon_shift" - NIGHT_SHIFT = "night_shift" - NIGHT_SHIFT_BEFORE_MIDNIGHT = "night_shift_before_midnight" - NIGHT_SHIFT_AFTER_MIDNIGHT = "night_shift_after_midnight" - DAY = "day" - ANY_TIME = "any_time" - - def __repr__(self) -> str: - """ - Get the string representation of the time interval type. - """ - return f"{self.__class__.__name__}.{self.name}" - - -class TemporalIndicatorCombination(CriterionCombination, ABC): - """ - TemporalIndicatorCombination is an abstract base class for constructing temporal indicator - combinations used to evaluate patient data over time. It encapsulates the common logic such as - operator definitions (e.g., AT_LEAST, AT_MOST, EXACTLY) and the management of one or more criteria. - This class serves as the foundation for more specialized implementations that define how the time - windows for evaluation are determined. - """ - - class Operator(CriterionCombination.Operator): - """Operators for criterion combinations.""" - - AT_LEAST = "AT_LEAST" - AT_MOST = "AT_MOST" - EXACTLY = "EXACTLY" - - def __init__(self, operator: str, threshold: int | None = None): - assert operator in [ - "AT_LEAST", - "AT_MOST", - "EXACTLY", - ], f"Invalid operator {operator}" - - self.operator = operator - if operator in ["AT_LEAST", "AT_MOST", "EXACTLY"]: - assert ( - threshold is not None - ), f"Threshold must be set for operator {operator}" - self.threshold = threshold - - def __init__( - self, - operator: Operator, - criterion: Union[Criterion, "CriterionCombination"] | None = None, - ): - if criterion is not None: - if not isinstance(criterion, (Criterion, CriterionCombination)): - raise ValueError( - f"Invalid criterion - expected Criterion or CriterionCombination, got {type(criterion)}" - ) - criteria = [criterion] - else: - criteria = None - - super().__init__(operator=operator, criteria=criteria) - - def _repr_pretty(self, level: int = 0) -> str: - children = [(None, c) for c in self._criteria] - - return self._build_repr( - children, - params=[], - level=level, - children_param_name="criterion", - children_is_sequence=False, - ) - - -class FixedWindowTemporalIndicatorCombination(TemporalIndicatorCombination): - """ - FixedWindowTemporalIndicatorCombination implements a temporal indicator combination that relies on - fixed time window specifications. It supports two mutually exclusive methods for defining these windows: - either via a pre-defined TimeIntervalType (e.g., morning, afternoon, or night shifts) or through explicit - start_time and end_time values. This class is intended for scenarios where the same evaluation window - applies uniformly across all patients, and it enforces validation to ensure only one method of window - specification is used. - """ - - interval_type: TimeIntervalType | None = None - start_time: time | None = None - end_time: time | None = None - - def __init__( - self, - operator: TemporalIndicatorCombination.Operator, - criterion: Union[Criterion, "CriterionCombination"] | None = None, - interval_type: TimeIntervalType | None = None, - start_time: time | None = None, - end_time: time | None = None, - ): - super().__init__(operator=operator, criterion=criterion) - - if interval_type: - if start_time is not None or end_time is not None: - raise ValueError( - "start_time/end_time cannot be used together with interval_type" - ) - # Validate the interval_type if needed - self.interval_type = interval_type - self.start_time = None - self.end_time = None - else: - # Must have start_time and end_time - if start_time is None or end_time is None: - raise ValueError( - "Either interval_type OR both start_time & end_time must be provided" - ) - if start_time >= end_time: - raise ValueError("start_time must be less than end_time") - - self.interval_type = interval_type - self.start_time = start_time - self.end_time = end_time - - def __str__(self) -> str: - """ - Get the string representation of the criterion combination. - """ - if self.interval_type: - return f"{super().__str__()} for {self.interval_type.value}" - elif self.start_time and self.end_time: - return f"{super().__str__()} from {self.start_time.strftime('%H:%M:%S')} to {self.end_time.strftime('%H:%M:%S')}" - else: - return super().__str__() - - def _repr_pretty(self, level: int = 0) -> str: - children = [(None, c) for c in self._criteria] - params = [ - ("interval_type", self.interval_type), - ("start_time", self.start_time), - ("end_time", self.end_time), - ] - return self._build_repr( - children, - params=params, - level=level, - children_param_name="criterion", - children_is_sequence=False, - ) - - def dict(self) -> Dict: - """ - Get the dictionary representation of the criterion combination. - """ - - d = super().dict() - d["start_time"] = self.start_time.isoformat() if self.start_time else None - d["end_time"] = self.end_time.isoformat() if self.end_time else None - d["interval_type"] = self.interval_type - - return d - - @classmethod - def from_dict( - cls, data: Dict[str, Any] - ) -> "FixedWindowTemporalIndicatorCombination": - """ - Create a criterion combination from a dictionary. - """ - - from execution_engine.omop.criterion.factory import ( - criterion_factory, # needs to be here to avoid circular import - ) - - operator = cls.Operator(data["operator"], data["threshold"]) - - combination = cls( - operator=operator, - interval_type=data["interval_type"], - # start_time and end_time is in isoformat ! - start_time=( - time.fromisoformat(data["start_time"]) if data["start_time"] else None - ), - end_time=time.fromisoformat(data["end_time"]) if data["end_time"] else None, - ) - - for criterion in data["criteria"]: - combination.add(criterion_factory(**criterion)) - - return combination - - @classmethod - def Presence( - cls, - criterion: Union[Criterion, "CriterionCombination"], - interval_type: TimeIntervalType | None = None, - start_time: time | None = None, - end_time: time | None = None, - ) -> "FixedWindowTemporalIndicatorCombination": - """ - Create an AT_LEAST combination of criteria. - """ - return cls( - operator=cls.Operator(cls.Operator.AT_LEAST, threshold=1), - criterion=criterion, - interval_type=interval_type, - start_time=start_time, - end_time=end_time, - ) - - @classmethod - def MinCount( - cls, - criterion: Union[Criterion, "CriterionCombination"], - threshold: int, - interval_type: TimeIntervalType | None = None, - start_time: time | None = None, - end_time: time | None = None, - ) -> "FixedWindowTemporalIndicatorCombination": - """ - Create an AT_LEAST combination of criteria. - """ - return cls( - operator=cls.Operator(cls.Operator.AT_LEAST, threshold=threshold), - criterion=criterion, - interval_type=interval_type, - start_time=start_time, - end_time=end_time, - ) - - @classmethod - def MaxCount( - cls, - criterion: Union[Criterion, "CriterionCombination"], - threshold: int, - interval_type: TimeIntervalType | None = None, - start_time: time | None = None, - end_time: time | None = None, - interval_criterion: Criterion | CriterionCombination | None = None, - ) -> "FixedWindowTemporalIndicatorCombination": - """ - Create an AT_MOST combination of criteria. - """ - return cls( - operator=cls.Operator(cls.Operator.AT_MOST, threshold=threshold), - criterion=criterion, - interval_type=interval_type, - start_time=start_time, - end_time=end_time, - ) - - @classmethod - def ExactCount( - cls, - criterion: Union[Criterion, "CriterionCombination"], - threshold: int, - interval_type: TimeIntervalType | None = None, - start_time: time | None = None, - end_time: time | None = None, - ) -> "FixedWindowTemporalIndicatorCombination": - """ - Create an EXACTLY combination of criteria. - """ - return cls( - operator=cls.Operator(cls.Operator.EXACTLY, threshold=threshold), - criterion=criterion, - interval_type=interval_type, - start_time=start_time, - end_time=end_time, - ) - - @classmethod - def MorningShift( - cls, - criterion: Union[Criterion, "CriterionCombination"], - ) -> "FixedWindowTemporalIndicatorCombination": - """ - Create a MorningShift combination of criteria. - """ - return cls.Presence(criterion, TimeIntervalType.MORNING_SHIFT) - - @classmethod - def AfternoonShift( - cls, - criterion: Union[Criterion, "CriterionCombination"], - ) -> "FixedWindowTemporalIndicatorCombination": - """ - Create a AfternoonShift combination of criteria. - """ - return cls.Presence(criterion, TimeIntervalType.AFTERNOON_SHIFT) - - @classmethod - def NightShift( - cls, - criterion: Union[Criterion, "CriterionCombination"], - ) -> "FixedWindowTemporalIndicatorCombination": - """ - Create a NightShift combination of criteria. - """ - return cls.Presence(criterion, TimeIntervalType.NIGHT_SHIFT) - - @classmethod - def NightShiftBeforeMidnight( - cls, - criterion: Union[Criterion, "CriterionCombination"], - ) -> "FixedWindowTemporalIndicatorCombination": - """ - Create a NightShiftBeforeMidnight combination of criteria. - """ - return cls.Presence(criterion, TimeIntervalType.NIGHT_SHIFT_BEFORE_MIDNIGHT) - - @classmethod - def NightShiftAfterMidnight( - cls, - criterion: Union[Criterion, "CriterionCombination"], - ) -> "FixedWindowTemporalIndicatorCombination": - """ - Create a NightShiftAfterMidnight combination of criteria. - """ - return cls.Presence(criterion, TimeIntervalType.NIGHT_SHIFT_AFTER_MIDNIGHT) - - @classmethod - def Day( - cls, - criterion: Union[Criterion, "CriterionCombination"], - ) -> "FixedWindowTemporalIndicatorCombination": - """ - Create a Day combination of criteria. - """ - return cls.Presence(criterion, TimeIntervalType.DAY) - - @classmethod - def AnyTime( - cls, - criterion: Union[Criterion, "CriterionCombination"], - ) -> "FixedWindowTemporalIndicatorCombination": - """ - Create a AnyTime combination of criteria. - """ - return cls.Presence(criterion, TimeIntervalType.ANY_TIME) - - -class PersonalWindowTemporalIndicatorCombination(TemporalIndicatorCombination): - """ - PersonalWindowTemporalIndicatorCombination implements a temporal indicator combination based on - person-specific time windows. Instead of using fixed start/end times or a global TimeIntervalType, this - class leverages an interval_criterion to dynamically generate evaluation windows tailored to each - patient. This design is ideal for situations where the timing of events (such as post-operative periods) - varies between patients, enabling more personalized temporal assessments. - """ - - _interval_criterion: Criterion | CriterionCombination - - def __init__( - self, - operator: TemporalIndicatorCombination.Operator, - criterion: Union[Criterion, "CriterionCombination"] | None, - interval_criterion: Criterion | CriterionCombination, - ): - super().__init__(operator=operator, criterion=criterion) - - if not isinstance(interval_criterion, (Criterion, CriterionCombination)): - raise ValueError( - f"Invalid criterion - expected Criterion or CriterionCombination, got {type(interval_criterion)}" - ) - - self._interval_criterion = interval_criterion - - def __iter__(self) -> Iterator[Union[Criterion, "CriterionCombination"]]: - """ - Iterate over the criteria in the combination - first criteria, then interval criterion if present. - """ - yield from super().__iter__() - yield self._interval_criterion - - def __str__(self) -> str: - """ - Get the string representation of the criterion combination. - """ - base_str = super().__str__() - return f"{base_str} [Personal Windows via: {self.interval_criterion}]" - - def _repr_pretty(self, level: int = 0) -> str: - children = [(None, c) for c in self._criteria] - params = [ - ("interval_criterion", self.interval_criterion), - ] - return self._build_repr( - children, - params=params, - level=level, - children_param_name="criterion", - children_is_sequence=False, - ) - - @property - def interval_criterion(self) -> Criterion | CriterionCombination: - """ - Get the interval criterion. - """ - return self._interval_criterion - - def dict(self) -> Dict: - """ - Get the dictionary representation of the criterion combination. - """ - - d = super().dict() - d["interval_criterion"] = { - "class_name": self.interval_criterion.__class__.__name__, - "data": self.interval_criterion.dict(), - } - - return d - - @classmethod - def from_dict( - cls, data: Dict[str, Any] - ) -> "PersonalWindowTemporalIndicatorCombination": - """ - Create a criterion combination from a dictionary. - """ - - from execution_engine.omop.criterion.factory import ( - criterion_factory, # needs to be here to avoid circular import - ) - - operator = cls.Operator(data["operator"], data["threshold"]) - - combination = cls( - operator=operator, - criterion=None, - interval_criterion=criterion_factory(**data["interval_criterion"]), - ) - - for criterion in data["criteria"]: - combination.add(criterion_factory(**criterion)) - - return combination - - @classmethod - def Presence( - cls, - criterion: Union[Criterion, "CriterionCombination"], - interval_criterion: Criterion | CriterionCombination, - ) -> "PersonalWindowTemporalIndicatorCombination": - """ - Create an AT_LEAST combination of criteria. - """ - return cls( - operator=cls.Operator(cls.Operator.AT_LEAST, threshold=1), - criterion=criterion, - interval_criterion=interval_criterion, - ) - - @classmethod - def MinCount( - cls, - criterion: Union[Criterion, "CriterionCombination"], - threshold: int, - interval_criterion: Criterion | CriterionCombination, - ) -> "TemporalIndicatorCombination": - """ - Create an AT_LEAST combination of criteria. - """ - return cls( - operator=cls.Operator(cls.Operator.AT_LEAST, threshold=threshold), - criterion=criterion, - interval_criterion=interval_criterion, - ) - - @classmethod - def MaxCount( - cls, - criterion: Union[Criterion, "CriterionCombination"], - threshold: int, - interval_criterion: Criterion | CriterionCombination, - ) -> "TemporalIndicatorCombination": - """ - Create an AT_MOST combination of criteria. - """ - return cls( - operator=cls.Operator(cls.Operator.AT_MOST, threshold=threshold), - criterion=criterion, - interval_criterion=interval_criterion, - ) - - @classmethod - def ExactCount( - cls, - criterion: Union[Criterion, "CriterionCombination"], - threshold: int, - interval_criterion: Criterion | CriterionCombination, - ) -> "TemporalIndicatorCombination": - """ - Create an EXACTLY combination of criteria. - """ - return cls( - operator=cls.Operator(cls.Operator.EXACTLY, threshold=threshold), - criterion=criterion, - interval_criterion=interval_criterion, - ) diff --git a/execution_engine/omop/criterion/concept.py b/execution_engine/omop/criterion/concept.py index c8ec3faf..ac33ad2c 100644 --- a/execution_engine/omop/criterion/concept.py +++ b/execution_engine/omop/criterion/concept.py @@ -1,14 +1,12 @@ -from typing import Any, Dict, cast +from abc import ABC from sqlalchemy.sql import Select from execution_engine.constants import OMOPConcepts from execution_engine.omop.concepts import Concept from execution_engine.omop.criterion.abstract import Criterion -from execution_engine.omop.criterion.meta import SignatureReprMeta from execution_engine.util.types import Timing from execution_engine.util.value import Value -from execution_engine.util.value.factory import value_factory __all__ = ["ConceptCriterion"] @@ -25,7 +23,7 @@ # TODO: Only use weight etc from the current encounter/visit! -class ConceptCriterion(Criterion, metaclass=SignatureReprMeta): +class ConceptCriterion(Criterion, ABC): """ Abstract class for a criterion based on an OMOP concept and optional value. @@ -45,7 +43,7 @@ def __init__( value: Value | None = None, static: bool | None = None, timing: Timing | None = None, - override_value_required: bool | None = None, + value_required: bool | None = None, ): super().__init__() @@ -67,10 +65,8 @@ def __init__( # if static is None, then the criterion is static if the concept is in the STATIC_CLINICAL_CONCEPTS list self._static = concept.concept_id in STATIC_CLINICAL_CONCEPTS - if override_value_required is not None and isinstance( - override_value_required, bool - ): - self._value_required = override_value_required + if value_required is not None and isinstance(value_required, bool): + self._value_required = value_required @property def concept(self) -> Concept: @@ -132,39 +128,3 @@ def description(self) -> str: desc += "]" return desc - - def dict(self) -> dict[str, Any]: - """ - Get a JSON representation of the criterion. - """ - return { - "concept": self._concept.model_dump(), - "value": ( - self._value.model_dump(include_meta=True) - if self._value is not None - else None - ), - "static": self._static, - "timing": ( - self._timing.model_dump(include_meta=True) - if self._timing is not None - else None - ), - } - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "ConceptCriterion": - """ - Create a criterion from a JSON representation. - """ - - return cls( - concept=Concept(**data["concept"]), - value=( - cast(Value, value_factory(**data["value"])) - if data["value"] is not None - else None - ), - static=data["static"], - timing=Timing(**data["timing"]) if data["timing"] is not None else None, - ) diff --git a/execution_engine/omop/criterion/custom/tidal_volume.py b/execution_engine/omop/criterion/custom/tidal_volume.py index d0ef3801..39480698 100644 --- a/execution_engine/omop/criterion/custom/tidal_volume.py +++ b/execution_engine/omop/criterion/custom/tidal_volume.py @@ -54,7 +54,7 @@ def __init__( value: Value | None = None, static: bool | None = None, timing: Timing | None = None, - override_value_required: bool | None = None, + value_required: bool | None = None, forward_fill: bool = True, ): super().__init__( @@ -62,7 +62,7 @@ def __init__( value=value, static=static, timing=timing, - override_value_required=override_value_required, + value_required=value_required, ) self._table = self._cte() diff --git a/execution_engine/omop/criterion/device_exposure.py b/execution_engine/omop/criterion/device_exposure.py new file mode 100644 index 00000000..7bc55222 --- /dev/null +++ b/execution_engine/omop/criterion/device_exposure.py @@ -0,0 +1,7 @@ +__all__ = ["DeviceExposure"] + +from execution_engine.omop.criterion.continuous import ContinuousCriterion + + +class DeviceExposure(ContinuousCriterion): + """A device_exposure criterion in a recommendation.""" diff --git a/execution_engine/omop/criterion/drug_exposure.py b/execution_engine/omop/criterion/drug_exposure.py index fd08b032..90c1668c 100644 --- a/execution_engine/omop/criterion/drug_exposure.py +++ b/execution_engine/omop/criterion/drug_exposure.py @@ -1,5 +1,4 @@ import logging -from typing import Any, Dict from sqlalchemy import Column, and_, case, func, select, true from sqlalchemy.sql import Select @@ -17,7 +16,6 @@ from execution_engine.util.interval import IntervalType from execution_engine.util.sql import SelectInto from execution_engine.util.types import Dosage -from execution_engine.util.value.factory import value_factory __all__ = ["DrugExposure"] @@ -349,33 +347,3 @@ def description(self) -> str: parts.append(f"route={route.concept_name}") return f"{self.__class__.__name__}[" + ", ".join(parts) + "]" - - def dict(self) -> dict[str, Any]: - """ - Return a dictionary representation of the criterion. - """ - return { - "ingredient_concept": self._ingredient_concept.model_dump(), - "dose": ( - self._dose.model_dump(include_meta=True) - if self._dose is not None - else None - ), - "route": self._route.model_dump() if self._route is not None else None, - } - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "DrugExposure": - """ - Create a drug exposure criterion from a dictionary representation. - """ - - dose = value_factory(**data["dose"]) if data["dose"] is not None else None - - assert dose is None or isinstance(dose, Dosage), "Dose must be a Dosage or None" - - return cls( - ingredient_concept=Concept(**data["ingredient_concept"]), - dose=dose, - route=Concept(**data["route"]) if data["route"] is not None else None, - ) diff --git a/execution_engine/omop/criterion/factory.py b/execution_engine/omop/criterion/factory.py deleted file mode 100644 index 706e7669..00000000 --- a/execution_engine/omop/criterion/factory.py +++ /dev/null @@ -1,76 +0,0 @@ -from typing import Type - -from execution_engine.omop.criterion.abstract import Criterion -from execution_engine.omop.criterion.combination.combination import CriterionCombination -from execution_engine.omop.criterion.combination.logical import ( - LogicalCriterionCombination, - NonCommutativeLogicalCriterionCombination, -) -from execution_engine.omop.criterion.combination.temporal import ( - PersonalWindowTemporalIndicatorCombination, - TemporalIndicatorCombination, -) -from execution_engine.omop.criterion.concept import ConceptCriterion -from execution_engine.omop.criterion.condition_occurrence import ConditionOccurrence -from execution_engine.omop.criterion.custom import TidalVolumePerIdealBodyWeight -from execution_engine.omop.criterion.drug_exposure import DrugExposure -from execution_engine.omop.criterion.measurement import Measurement -from execution_engine.omop.criterion.observation import Observation -from execution_engine.omop.criterion.point_in_time import PointInTimeCriterion -from execution_engine.omop.criterion.procedure_occurrence import ProcedureOccurrence -from execution_engine.omop.criterion.visit_detail import VisitDetail -from execution_engine.omop.criterion.visit_occurrence import ( - ActivePatients, - PatientsActiveDuringPeriod, - VisitOccurrence, -) - -__all__ = ["criterion_factory", "register_criterion_class"] - -class_map: dict[str, Type[Criterion] | Type[CriterionCombination]] = { - "ConceptCriterion": ConceptCriterion, - "LogicalCriterionCombination": LogicalCriterionCombination, - "TemporalCriterionCombination": TemporalIndicatorCombination, - "NonCommutativeLogicalCriterionCombination": NonCommutativeLogicalCriterionCombination, - "ConditionOccurrence": ConditionOccurrence, - "DrugExposure": DrugExposure, - "Measurement": Measurement, - "Observation": Observation, - "ProcedureOccurrence": ProcedureOccurrence, - "VisitOccurrence": VisitOccurrence, - "ActivePatients": ActivePatients, - "PatientsActiveDuringPeriod": PatientsActiveDuringPeriod, - "TidalVolumePerIdealBodyWeight": TidalVolumePerIdealBodyWeight, - "VisitDetail": VisitDetail, - "PointInTimeCriterion": PointInTimeCriterion, - "PersonalWindowTemporalIndicatorCombination": PersonalWindowTemporalIndicatorCombination, -} - - -def register_criterion_class( - class_name: str, - criterion_class: Type[Criterion] | Type[CriterionCombination], -) -> None: - """ - Register a criterion class. - - :param class_name: The name of the criterion class. - :param criterion_class: The criterion class. - """ - class_map[class_name] = criterion_class - - -def criterion_factory(class_name: str, data: dict) -> Criterion | CriterionCombination: - """ - Create a criterion from a dictionary representation. - - :param class_name: The name of the criterion class. - :param data: The dictionary representation of the criterion. - :return: The criterion. - :raises ValueError: If the class name is not recognized. - """ - - if class_name not in class_map: - raise ValueError(f"Unknown criterion class {class_name}") - - return class_map[class_name].from_dict(data) diff --git a/execution_engine/omop/criterion/meta.py b/execution_engine/omop/criterion/meta.py deleted file mode 100644 index 9012e45f..00000000 --- a/execution_engine/omop/criterion/meta.py +++ /dev/null @@ -1,150 +0,0 @@ -import abc -import inspect -from typing import Any, Type, TypeVar - -T = TypeVar("T", bound="SignatureReprMeta") - - -class SignatureReprMeta(type): - """ - A metaclass that automatically captures constructor arguments and generates a __repr__ method. - - This metaclass wraps the `__init__` method of a class to store the arguments passed at instantiation, - allowing `__repr__` to dynamically generate an informative string representation, omitting default values. - - Features: - - Captures constructor arguments at instantiation time (`self._init_args`). - - Automatically generates a `__repr__` if the class does not define one. - - Ensures that `__repr__` only displays arguments that differ from the default values. - - Prevents redundant re-capturing when a parent class’s `__init__` is invoked via `super()`. - - Retains the original `__init__` function signature for accurate introspection. - - Usage: - ```python - class MyClass(metaclass=SignatureReprMeta): - def __init__(self, x, y=10, z=None): - self.x = x - self.y = y - self.z = z - - obj = MyClass(5, z="test") - print(obj) # Output: MyClass(x=5, z='test') (y is omitted since it uses the default 10) - ``` - - """ - - def __new__( - mcs: Type[T], name: str, bases: tuple[type, ...], namespace: dict[str, Any] - ) -> T: - """ - Wrap the __init__ method and attach a default __repr__ if not defined. - """ - original_init = namespace.get("__init__") - - if not original_init: - # Try to find an __init__ in the bases - for base in reversed(bases): - if base.__init__ is not object.__init__: - original_init = base.__init__ - break - - # We'll define a new __init__ only if there's an actual one to wrap - if original_init: - - def __init__(self: object, *args: Any, **kwargs: Any) -> None: - if type(self) is cls: - assert original_init is not None, "No valid __init__ found!" - sig = inspect.signature(original_init) - bound = sig.bind(self, *args, **kwargs) - bound.apply_defaults() - all_args = dict(bound.arguments) - all_args.pop("self", None) - self._init_args = all_args # type: ignore[attr-defined] - - original_init(self, *args, **kwargs) - - # The rest is basically the same - # But we must create the class first so we can set new_init.__signature__ = ... - cls = super().__new__(mcs, name, bases, namespace) - - # If we actually did define new_init, attach it - if original_init: - # Replace/attach the new __init__ to cls - setattr(cls, "__init__", __init__) - - # Manually override the function's signature - init_sig = inspect.signature(original_init) - __init__.__signature__ = init_sig # type: ignore[attr-defined] - - original_repr = namespace.get("__repr__") - - if not original_repr: - # Try to find an __repr__ in the bases - for base in reversed(bases): - if base.__repr__ is not object.__repr__: - original_repr = base.__repr__ - break - - # If no user-defined __repr__, attach a default - if not original_repr or getattr( - original_repr, "__signature_repr_generated__", False - ): - - def __repr__(self: object) -> str: - # Case 1: No real __init__ found at all => just return ClassName() - if original_init is None: - return f"{name}()" - - # Case 2: If the class or its parents do define an __init__, we check - # for _init_args. If the class isn't wrapping init, your class would - # have to set _init_args itself (or you'll just get a normal object repr). - if not hasattr(self, "_init_args"): - return super(type(self), self).__repr__() - - # Build param=value only if they differ from default - sig = inspect.signature(original_init) - parts = [] - for param_name, param in sig.parameters.items(): - if param_name == "self": - continue - default = param.default - if ( - param_name in self._init_args - and self._init_args[param_name] != default - ): - parts.append( - f"{param_name}={repr(self._init_args[param_name])}" - ) - return f"{name}({', '.join(parts)})" - - # Tag it so children know it's auto-generated - __repr__.__signature_repr_generated__ = True # type: ignore[attr-defined] - - setattr(cls, "__repr__", __repr__) - - return cls - - -class SignatureReprABCMeta(SignatureReprMeta, abc.ABCMeta): - """ - A metaclass combining `SignatureReprMeta` and `ABCMeta`. - - This metaclass extends `SignatureReprMeta`, allowing abstract base classes (`ABC`) to inherit - automatic argument capturing and dynamic `__repr__` generation. - - Usage: - ```python - class AbstractExample(metaclass=SignatureReprABCMeta): - @abc.abstractmethod - def some_method(self): - pass - - class ConcreteExample(AbstractExample): - def __init__(self, value, flag=True): - self.value = value - self.flag = flag - - obj = ConcreteExample(42) - print(obj) # Output: ConcreteExample(value=42) (flag is omitted since it uses the default True) - ``` - """ diff --git a/execution_engine/omop/criterion/noop.py b/execution_engine/omop/criterion/noop.py new file mode 100644 index 00000000..2b01ed8c --- /dev/null +++ b/execution_engine/omop/criterion/noop.py @@ -0,0 +1,41 @@ +from sqlalchemy import Select, select + +from execution_engine.omop.criterion.abstract import ( + Criterion, + column_interval_type, + observation_end_datetime, + observation_start_datetime, +) +from execution_engine.util.interval import IntervalType + + +class NoopCriterion(Criterion): + """ + Select patients who are post-surgical in the timeframe between the day of the surgery and 6 days after the surgery. + """ + + _static = True + + def _create_query(self) -> Select: + """ + Get the SQL Select query for data required by this criterion. + """ + subquery = self.base_query().subquery() + + query = select( + subquery.c.person_id, + column_interval_type(IntervalType.POSITIVE), + observation_start_datetime.label("interval_start"), + observation_end_datetime.label("interval_end"), + ) + + query = self._filter_base_persons(query, c_person_id=subquery.c.person_id) + query = self._filter_datetime(query) + + return query + + def description(self) -> str: + """ + Get a description of the criterion. + """ + return self.__class__.__name__ diff --git a/execution_engine/omop/criterion/point_in_time.py b/execution_engine/omop/criterion/point_in_time.py index 2fdb6a0d..d79342a0 100644 --- a/execution_engine/omop/criterion/point_in_time.py +++ b/execution_engine/omop/criterion/point_in_time.py @@ -1,5 +1,3 @@ -from typing import Any, Dict, cast - from sqlalchemy import CTE, ColumnElement, Select, select from execution_engine.omop.concepts import Concept @@ -27,7 +25,7 @@ def __init__( value: Value | None = None, static: bool | None = None, timing: Timing | None = None, - override_value_required: bool | None = None, + value_required: bool | None = None, forward_fill: bool = True, ): super().__init__( @@ -35,7 +33,7 @@ def __init__( value=value, static=static, timing=timing, - override_value_required=override_value_required, + value_required=value_required, ) self._forward_fill = forward_fill @@ -103,22 +101,6 @@ def _create_query(self) -> Select: return query - def dict(self) -> dict[str, Any]: - """ - Get a JSON representation of the criterion. - """ - from_super = super().dict() - return from_super | {"forward_fill": self._forward_fill} - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "PointInTimeCriterion": - """ - Create a criterion from a JSON representation. - """ - object = cast("PointInTimeCriterion", super().from_dict(data)) - object._forward_fill = data.get("forward_fill", True) # Backward compat - return object - def process_data( self, data: PersonIntervals, diff --git a/execution_engine/omop/criterion/procedure_occurrence.py b/execution_engine/omop/criterion/procedure_occurrence.py index 0084afe5..c2f853f9 100644 --- a/execution_engine/omop/criterion/procedure_occurrence.py +++ b/execution_engine/omop/criterion/procedure_occurrence.py @@ -1,5 +1,3 @@ -from typing import Any, Dict, cast - from sqlalchemy import ColumnElement, case, func, select from sqlalchemy.sql import Select @@ -13,7 +11,6 @@ from execution_engine.util.interval import IntervalType from execution_engine.util.types import Timing from execution_engine.util.value import ValueNumber -from execution_engine.util.value.factory import value_factory from execution_engine.util.value.time import ValueCount __all__ = ["ProcedureOccurrence"] @@ -160,45 +157,3 @@ def description(self) -> str: parts.append(f"dose={str(self._timing)}") return f"{self.__class__.__name__}[" + ", ".join(parts) + "]" - - def dict(self) -> dict[str, Any]: - """ - Return a dictionary representation of the criterion. - """ - assert self._concept is not None, "Concept must be set" - - return { - "concept": self._concept.model_dump(), - "value": ( - self._value.model_dump(include_meta=True) - if self._value is not None - else None - ), - "timing": ( - self._timing.model_dump(include_meta=True) - if self._timing is not None - else None - ), - } - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "ProcedureOccurrence": - """ - Create a procedure occurrence criterion from a dictionary representation. - """ - - value = value_factory(**data["value"]) if data["value"] is not None else None - timing = value_factory(**data["timing"]) if data["timing"] is not None else None - - assert ( - isinstance(value, ValueNumber) or value is None - ), "value must be a ValueNumber" - assert ( - isinstance(timing, ValueNumber | Timing) or timing is None - ), "timing must be a ValueNumber" - - return cls( - concept=Concept(**data["concept"]), - value=value, - timing=cast(Timing, timing), - ) diff --git a/execution_engine/omop/criterion/visit_detail.py b/execution_engine/omop/criterion/visit_detail.py index de5f74d2..66bacdde 100644 --- a/execution_engine/omop/criterion/visit_detail.py +++ b/execution_engine/omop/criterion/visit_detail.py @@ -1,4 +1,6 @@ -from typing import Any +from execution_engine.omop.concepts import Concept +from execution_engine.util.types import Timing +from execution_engine.util.value import Value __all__ = ["VisitDetail"] @@ -13,8 +15,21 @@ class VisitDetail(ContinuousCriterion): visit details may be transfers between units of a hospital or a change of bed. We d """ - def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, **kwargs) + def __init__( + self, + concept: Concept, + value: Value | None = None, + static: bool | None = None, + timing: Timing | None = None, + value_required: bool | None = None, + ): + super().__init__( + concept=concept, + value=value, + static=static, + timing=timing, + value_required=value_required, + ) # visit concepts are mapped automatically to the visit_occurrence table, so map visit_detail explicitly here self._set_omop_variables_from_domain("visit_detail") diff --git a/execution_engine/omop/criterion/visit_occurrence.py b/execution_engine/omop/criterion/visit_occurrence.py index 34b830ca..f6bab235 100644 --- a/execution_engine/omop/criterion/visit_occurrence.py +++ b/execution_engine/omop/criterion/visit_occurrence.py @@ -1,5 +1,3 @@ -from typing import Any, Dict - from sqlalchemy import select from sqlalchemy.sql import Select @@ -83,19 +81,6 @@ def description(self) -> str: """ return f"{self.__class__.__name__}[]" - def dict(self) -> dict[str, Any]: - """ - Get a JSON representation of the criterion. - """ - return {} - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "ActivePatients": - """ - Create a criterion from a JSON representation. - """ - return cls() - class PatientsActiveDuringPeriod(ActivePatients): """ diff --git a/execution_engine/omop/serializable.py b/execution_engine/omop/serializable.py deleted file mode 100644 index 4376c5d7..00000000 --- a/execution_engine/omop/serializable.py +++ /dev/null @@ -1,96 +0,0 @@ -import json -from abc import ABC, abstractmethod -from typing import Any, Dict, Self - - -class Serializable(ABC): - """ - Base class for serializable objects. - """ - - _id: int | None = None - """ - The id is used in the database tables. - """ - - def set_id(self, value: int) -> None: - """ - Assigns the database ID to the object. This can only be done once. - - This ID corresponds to the primary key in the database and is set - when the object is persisted. - - :param value: The database ID assigned to the object. - :raises ValueError: If the ID has already been set. - """ - if self._id is not None: - raise ValueError("Database ID has already been set!") - self._id = value - - @property - def id(self) -> int: - """ - Retrieves the database ID of the object. - - This ID is only available after the object has been stored in the database. - - :return: The database ID, or None if the object has not been stored yet. - """ - if self._id is None: - raise ValueError("Database ID has not been set yet!") - return self._id - - def is_persisted(self) -> bool: - """ - Returns True if the object has been stored in the database. - """ - return self._id is not None - - @abstractmethod - def dict(self) -> dict: - """ - Get a dictionary representation of the object. - """ - raise NotImplementedError() - - @classmethod - @abstractmethod - def from_dict(cls, data: Dict[str, Any]) -> Self: - """ - Create an object from a dictionary. - """ - raise NotImplementedError() - - def json(self) -> bytes: - """ - Get a JSON representation of the object. - - The json excludes the id, as this is auto-inserted by the database - and not known during the creation of the object. - """ - - s_json = self.dict() - - return json.dumps(s_json, sort_keys=True).encode() - - @classmethod - def from_json(cls, data: str | bytes) -> Self: - """ - Create a combination from a JSON string. - """ - return cls.from_dict(json.loads(data)) - - def __eq__(self, other: Any) -> bool: - """ - Check if two objects are equal. - """ - if not isinstance(other, self.__class__): - return False - - return self.dict() == other.dict() - - def __hash__(self) -> int: - """ - Get the hash of the object. - """ - return hash(self.__class__.__name__.encode() + self.json()) diff --git a/execution_engine/omop/sqlclient.py b/execution_engine/omop/sqlclient.py index eace72ae..b5488029 100644 --- a/execution_engine/omop/sqlclient.py +++ b/execution_engine/omop/sqlclient.py @@ -4,7 +4,7 @@ import pandas as pd import sqlalchemy -from sqlalchemy import and_, bindparam, event, func, select, text +from sqlalchemy import NullPool, and_, bindparam, event, func, select, text from sqlalchemy.engine.interfaces import DBAPIConnection from sqlalchemy.pool import ConnectionPoolEntry from sqlalchemy.sql import Insert, Select @@ -83,6 +83,7 @@ def __init__( result_schema: str, timezone: str = "Europe/Berlin", disable_triggers: bool = False, + null_pool: bool = False, ) -> None: """Initialize the OMOP SQL client.""" @@ -97,6 +98,9 @@ def __init__( connection_string, connect_args={"options": "-csearch_path={}".format(self._data_schema)}, future=True, + poolclass=( + NullPool if null_pool else None + ), # <--- ensures no persistent pool ) if disable_triggers: diff --git a/execution_engine/omop/vocabulary.py b/execution_engine/omop/vocabulary.py index 8c212ee1..baa7d1b3 100644 --- a/execution_engine/omop/vocabulary.py +++ b/execution_engine/omop/vocabulary.py @@ -7,6 +7,7 @@ OMOP_INTENSIVE_CARE = 32037 OMOP_INPATIENT_VISIT = 9201 OMOP_OUTPATIENT_VISIT = 9202 +OMOP_SURGICAL_PROCEDURE = 4301351 # OMOP surgical procedure class VocabularyNotFoundError(Exception): @@ -110,6 +111,15 @@ class ICD10GM(AbstractStandardVocabulary): omop_vocab_name = "ICD10GM" +class ICD10CM(AbstractStandardVocabulary): + """ + ICD10 Clinical Modification + """ + + system_uri = "http://hl7.org/fhir/sid/icd-10-cm" + omop_vocab_name = "ICD10CM" + + class UCUM(AbstractStandardVocabulary): """ UCUM vocabulary. @@ -192,7 +202,7 @@ class CODEXCELIDA(AbstractVocabulary): vocab_id = "CODEX_CELIDA" map = { "tvpibw": CustomConcept( - name="Tidal volume / ideal body weight (ARDSnet)", + concept_name="Tidal volume / ideal body weight (ARDSnet)", concept_code="tvpibw", domain_id="Measurement", vocabulary_id=vocab_id, @@ -232,6 +242,7 @@ def init(self) -> None: self.register(UCUM) self.register(ATCDE) self.register(ICD10GM) + self.register(ICD10CM) self.register(CODEXCELIDA) def register(self, vocabulary: Type[AbstractVocabulary]) -> None: diff --git a/execution_engine/task/creator.py b/execution_engine/task/creator.py index 1da50cda..72a1f223 100644 --- a/execution_engine/task/creator.py +++ b/execution_engine/task/creator.py @@ -1,11 +1,96 @@ -from typing import cast +import pickle # nosec import networkx as nx +from typing_extensions import Any -import execution_engine.util.cohort_logic as logic +import execution_engine.util.logic as logic from execution_engine.task.task import Task +def assert_pickle_roundtrip(obj: logic.BaseExpr) -> None: + """ + Serializes 'obj' via pickle (the same method multiprocessing would use), + then deserializes it, and finally compares the original object to the result. + + :param obj: The object to serialize/deserialize. + :raises AssertionError: If the object does not match its clone. + :return: The deserialized clone (for further inspection if needed). + """ + pickled = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL) # nosec + clone = pickle.loads(pickled) # nosec + + if isinstance(obj, logic.CountOperator): + if obj.dict()["data"]["threshold"] is None: + raise AssertionError("Threshold must be set") + + if obj == clone: + return # They are considered equal, so nothing to do. + + # If they're unequal, compare their dict() representations to find differences. + d1 = obj.dict() + d2 = clone.dict() + + if d1 == d2: + # If they're unequal but dicts are the same, there's some internal difference + # not visible via .dict(). Just alert that we can't show details. + raise AssertionError( + f"Objects differ in __eq__, but their dict() representations are identical.\n" + f"obj: {obj}\n" + f"clone: {clone}" + ) + + # Otherwise, gather all leaf-level differences in d1 vs d2. + diffs = _compare_dicts_leaf_level(d1, d2) + diff_msg = "\n".join(diffs) + + raise AssertionError( + f"Object does not match its clone after round-trip!\n\n" + f"Differences at leaf level in .dict() representations:\n{diff_msg}" + ) + + +def _compare_dicts_leaf_level(d1: Any, d2: Any, path: str = "") -> list[str]: + """ + Recursively compare two dict/list/tuple/scalar structures and return + a list of strings describing differences at the leaf level. + + :param d1, d2: Potentially nested structures (dict, list, tuple, scalar). + :param path: Path string to locate the current point in the structure. + :return: List of difference descriptions. + """ + differences = [] + + # If both are dicts, recurse into matching keys + if isinstance(d1, dict) and isinstance(d2, dict): + all_keys = set(d1.keys()) | set(d2.keys()) + for key in sorted(all_keys): + sub_path = f"{path}.{key}" if path else str(key) + if key not in d1: + differences.append(f"[MISSING IN ORIGINAL] {sub_path} => {d2[key]!r}") + elif key not in d2: + differences.append(f"[MISSING IN CLONE] {sub_path} => {d1[key]!r}") + else: + differences.extend( + _compare_dicts_leaf_level(d1[key], d2[key], sub_path) + ) + + # If both are lists/tuples, compare element by element + elif isinstance(d1, (list, tuple)) and isinstance(d2, (list, tuple)): + if len(d1) != len(d2): + differences.append(f"[LEN MISMATCH] {path} => {len(d1)} vs {len(d2)}") + else: + for i, (item1, item2) in enumerate(zip(d1, d2)): + sub_path = f"{path}[{i}]" + differences.extend(_compare_dicts_leaf_level(item1, item2, sub_path)) + + # Otherwise, treat them as leaf values and compare directly + else: + if d1 != d2: + differences.append(f"[VALUE MISMATCH] {path} => {d1!r} vs {d2!r}") + + return differences + + class TaskCreator: """ A TaskCreator object creates a Task tree for an expression and its dependencies. @@ -23,13 +108,12 @@ def create_tasks_and_dependencies(graph: nx.DiGraph) -> list[Task]: """ def node_to_task(expr: logic.Expr, attr: dict) -> Task: - criterion = cast(logic.Symbol, expr).criterion if expr.is_Atom else None store_result = attr.get("store_result", False) - bind_params = attr.get("bind_params", {}) + bind_params = attr.get("bind_params", {}).copy() + bind_params["category"] = attr["category"] task = Task( expr=expr, - criterion=criterion, bind_params=bind_params, store_result=store_result, ) @@ -46,6 +130,12 @@ def node_to_task(expr: logic.Expr, attr: dict) -> Task: flattened_tasks = list(tasks.values()) + # we will make sure all tasks are depickled correctly [commented out for performance reasons] + # from tqdm import tqdm + # + # for node in tqdm(tasks): + # assert_pickle_roundtrip(node) + assert ( len(set(flattened_tasks)) == len(flattened_tasks) diff --git a/execution_engine/task/process/__init__.py b/execution_engine/task/process/__init__.py index 72b4589c..b9e4444e 100644 --- a/execution_engine/task/process/__init__.py +++ b/execution_engine/task/process/__init__.py @@ -15,6 +15,7 @@ def get_processing_module( - rectangle (faster, using rectangles intersection/union) :param name: name of the processing module + :param version: version of the processing module """ if name == "rectangle": @@ -38,4 +39,6 @@ def get_processing_module( Interval = namedtuple("Interval", ["lower", "upper", "type"]) IntervalWithCount = namedtuple("IntervalWithCount", ["lower", "upper", "type", "count"]) -IntervalWithTypeCounts = namedtuple("IntervalWithTypeCounts", ["lower", "upper", "counts"]) +IntervalWithTypeCounts = namedtuple( + "IntervalWithTypeCounts", ["lower", "upper", "counts"] +) diff --git a/execution_engine/task/process/rectangle.py b/execution_engine/task/process/rectangle.py index b3546180..91bef1e8 100644 --- a/execution_engine/task/process/rectangle.py +++ b/execution_engine/task/process/rectangle.py @@ -479,6 +479,7 @@ def filter_count_intervals( min_count: int | None, max_count: int | None, keep_no_data: bool = True, + keep_not_applicable: bool = True, ) -> PersonIntervals: """ Filters the intervals per dict key in the list by count. @@ -487,14 +488,18 @@ def filter_count_intervals( :param min_count: The minimum count of the intervals. :param max_count: The maximum count of the intervals. :param keep_no_data: Whether to keep NO_DATA intervals (irrespective of the count). + :param keep_not_applicable: Whether to keep NOT_APPLICABLE intervals (irrespective of the count). :return: A dict with the unioned intervals. """ result: PersonIntervals = {} interval_filter = [] + if keep_no_data: interval_filter.append(IntervalType.NO_DATA) + if keep_not_applicable: + interval_filter.append(IntervalType.NOT_APPLICABLE) if min_count is None and max_count is None: raise ValueError("min_count and max_count cannot both be None") @@ -668,6 +673,7 @@ def create_time_intervals( # Prepare to collect intervals intervals = [] previous_end = None + def add_interval(interval_start, interval_end, interval_type): nonlocal previous_end effective_start = max(interval_start, start_datetime) @@ -681,11 +687,13 @@ def add_interval(interval_start, interval_end, interval_type): # touching intervals. if previous_end is not None: assert previous_end < effective_start - intervals.append(Interval( - lower=effective_start.timestamp(), - upper=effective_end.timestamp(), - type=interval_type, - )) + intervals.append( + Interval( + lower=effective_start.timestamp(), + upper=effective_end.timestamp(), + type=interval_type, + ) + ) previous_end = effective_end # Current date to process @@ -714,25 +722,22 @@ def add_interval(interval_start, interval_end, interval_type): # overlaps the main datetime range, otherwise fill the day # with an interval of type "not applicable". # TODO: what about intervals "before" the main datetime range? - if end_interval < start_datetime: # completely before datetime range + if end_interval < start_datetime: # completely before datetime range day_start = timezone.localize( - datetime.datetime.combine( - current_date, datetime.time(0, 0, 0) - )) + datetime.datetime.combine(current_date, datetime.time(0, 0, 0)) + ) day_end = timezone.localize( - datetime.datetime.combine( - current_date, datetime.time(23, 59, 59) - )) + datetime.datetime.combine(current_date, datetime.time(23, 59, 59)) + ) if (previous_end is not None) and day_start <= previous_end: start = previous_end + datetime.timedelta(seconds=1) else: start = day_start add_interval(start, day_end, IntervalType.NOT_APPLICABLE) - elif end_datetime < start_interval: # completely after datetime range + elif end_datetime < start_interval: # completely after datetime range day_start = timezone.localize( - datetime.datetime.combine( - current_date, datetime.time(0, 0, 0) - )) + datetime.datetime.combine(current_date, datetime.time(0, 0, 0)) + ) if (previous_end is not None) and day_start <= previous_end: start = previous_end + datetime.timedelta(seconds=1) else: @@ -799,10 +804,15 @@ def find_overlapping_personal_windows( return result + def find_rectangles_with_count(data: list[PersonIntervals]) -> PersonIntervals: if len(data) == 0: return {} else: keys = data[0].keys() - return {key: _impl.find_rectangles_with_count([ intervals[key] for intervals in data ]) - for key in keys} + return { + key: _impl.find_rectangles_with_count( + [intervals[key] for intervals in data] + ) + for key in keys + } diff --git a/execution_engine/task/runner.py b/execution_engine/task/runner.py index 9379fb9a..af0ea25d 100644 --- a/execution_engine/task/runner.py +++ b/execution_engine/task/runner.py @@ -191,7 +191,6 @@ def run(self, bind_params: dict) -> None: try: while len(self.completed_tasks) < len(self.tasks): self.enqueue_ready_tasks() - logging.info(f"{len(self.completed_tasks)}/{len(self.tasks)} tasks") if self.queue.empty() and not any( task.status == TaskStatus.RUNNING for task in self.tasks @@ -311,13 +310,14 @@ def task_executor_worker() -> None: self.start_workers(task_executor_worker) + task_names = {task.name() for task in self.tasks} + try: while len(self.completed_tasks) < len(self.tasks): if self.stop_event.is_set(): break self.enqueue_ready_tasks() - logging.info(f"{len(self.completed_tasks)}/{len(self.tasks)} tasks") if self.completed_tasks == self.enqueued_tasks and len( self.completed_tasks @@ -331,7 +331,17 @@ def task_executor_worker() -> None: with self.lock: # Update the set of completed tasks + n_completed = len(self.completed_tasks) self.completed_tasks = set(self.shared_results.keys()) + if len(self.completed_tasks) > n_completed: + logging.info( + f"Completed {len(self.completed_tasks)} of {len(self.tasks)} tasks" + ) + if not all(task in task_names for task in self.completed_tasks): + raise TaskError( + "Completed tasks differ from actual tasks " + "- problem with pickling/unpickling in multiprocessing?" + ) except Exception as e: logging.error(f"An error occurred: {e}") diff --git a/execution_engine/task/task.py b/execution_engine/task/task.py index 0968fbda..daf28fa1 100644 --- a/execution_engine/task/task.py +++ b/execution_engine/task/task.py @@ -6,14 +6,14 @@ from sqlalchemy.exc import DBAPIError, IntegrityError, ProgrammingError, SQLAlchemyError -import execution_engine.util.cohort_logic as logic +import execution_engine.util.logic as logic from execution_engine.constants import CohortCategory from execution_engine.omop.criterion.abstract import Criterion -from execution_engine.omop.criterion.combination.temporal import TimeIntervalType from execution_engine.omop.db.celida.tables import ResultInterval from execution_engine.omop.sqlclient import OMOPSQLClient from execution_engine.settings import get_config from execution_engine.task.process import Interval, get_processing_module +from execution_engine.util.enum import TimeIntervalType from execution_engine.util.interval import IntervalType from execution_engine.util.types import PersonIntervals, TimeRange @@ -27,6 +27,7 @@ def get_engine() -> OMOPSQLClient: return OMOPSQLClient( **get_config().omop.model_dump(by_alias=True), timezone=get_config().timezone, + null_pool=True, ) @@ -54,13 +55,11 @@ class Task: def __init__( self, - expr: logic.Expr, - criterion: Criterion | None, + expr: logic.BaseExpr, bind_params: dict | None, store_result: bool = False, ) -> None: self.expr = expr - self.criterion = criterion self.dependencies: list[Task] = [] self.status = TaskStatus.PENDING self.bind_params = bind_params if bind_params is not None else {} @@ -71,7 +70,7 @@ def category(self) -> CohortCategory: """ Returns the category of the task. """ - return self.expr.category + return self.bind_params["category"] def get_base_task(self) -> "Task": """ @@ -94,6 +93,32 @@ def find_base_task(task: Task) -> Task: return find_base_task(self) + def select_predecessor_result( + self, expr: logic.BaseExpr, data: list[PersonIntervals] + ) -> PersonIntervals: + """ + Select the result results of the predecessor task from the given expression. + + This is required in expressions where order is important, e.g. in BinaryNonCommutativeOperator. + As the nx.DiGraph (and by inheritance, ExecutionGraph) does not store the order of the predecessors, + we need to find the predecessor task by its expression and select the result from the data. + + :param expr: The expression of the predecessor task. + :param data: The input data. + :return: The result of the predecessor task. + """ + if len(self.dependencies) == 0: + raise ValueError("Task has no dependencies.") + + idx = next((i for i, t in enumerate(self.dependencies) if t.expr == expr), None) + + if idx is None: + raise ValueError( + f"Task with expression '{str(expr)}' not found in dependencies." + ) + + return data[idx] + def run( self, data: list[PersonIntervals], @@ -125,13 +150,12 @@ def run( if len(self.dependencies) == 0 or self.expr.is_Atom: # atomic expressions (i.e. criterion) - assert ( - self.criterion is not None - ), "criterion shall not be None for atomic expression" - logging.debug(f"Running criterion - '{self.name()}'") + + assert isinstance(self.expr, Criterion), "Dependency is not a Criterion" + result = self.handle_criterion( - self.criterion, bind_params, base_data, observation_window + self.expr, bind_params, base_data, observation_window ) logging.debug(f"Storing results - '{self.name()}'") @@ -149,41 +173,28 @@ def run( result = self.handle_unary_logical_operator( data, base_data, observation_window ) + elif isinstance(self.expr, logic.TemporalCount): + result = self.handle_temporal_operator(data, observation_window) elif isinstance( self.expr, - ( - logic.And, - logic.Or, - logic.NonSimplifiableAnd, - logic.Count, - logic.CappedCount, - logic.AllOrNone, - ), + (logic.CommutativeOperator), ): result = self.handle_binary_logical_operator(data) - elif isinstance( - self.expr, (logic.LeftDependentToggle, logic.ConditionalFilter) - ): + elif isinstance(self.expr, logic.BinaryNonCommutativeOperator): result = self.handle_left_dependent_toggle( - left=data[0], - right=data[1], + left=self.select_predecessor_result(self.expr.left, data), + right=self.select_predecessor_result(self.expr.right, data), base_data=base_data, observation_window=observation_window, ) - elif isinstance(self.expr, logic.NoDataPreservingAnd): - result = self.handle_no_data_preserving_operator( - data, base_data, observation_window - ) - elif isinstance(self.expr, logic.NoDataPreservingOr): - result = self.handle_no_data_preserving_operator( - data, base_data, observation_window - ) - elif isinstance(self.expr, logic.TemporalCount): - result = self.handle_temporal_operator(data, observation_window) else: raise ValueError(f"Unsupported expression type: {type(self.expr)}") if self.store_result: + if not self.expr.is_Atom: + result = self.insert_negative_intervals( + result, base_data, observation_window + ) logging.debug(f"Storing results - '{self.name()}'") self.store_result_in_db(result, base_data, bind_params) @@ -269,7 +280,7 @@ def handle_binary_logical_operator( self, data: list[PersonIntervals] ) -> PersonIntervals: """ - Handles a binary logical operator (And or Or) by merging or intersecting the intervals of the + Handles a binary logical operator by using the appropriate processing function. :param data: The input data. :return: A DataFrame with the merged or intersected intervals. @@ -284,7 +295,7 @@ def handle_binary_logical_operator( if isinstance(self.expr, (logic.And, logic.NonSimplifiableAnd)): result = process.intersect_intervals(data) - elif isinstance(self.expr, logic.Or): + elif isinstance(self.expr, (logic.Or, logic.NonSimplifiableOr)): result = process.union_intervals(data) elif isinstance(self.expr, logic.Count): result = process.count_intervals(data) @@ -326,48 +337,6 @@ def handle_binary_logical_operator( return result - def handle_no_data_preserving_operator( - self, - data: list[PersonIntervals], - base_data: PersonIntervals, - observation_window: TimeRange, - ) -> PersonIntervals: - """ - Handles a NoDataPreservingAnd/Or operator. - - These are used to combine POPULATION, INTERVENTION and POPULATION/INTERVENTION results from different - population/intervention pairs into a single result (i.e. the full recommendation's POPULATION etc.). - - The POSITIVE intervals are intersected (And) or merged (Or), the NO_DATA intervals are intersected and the - remaining intervals are set to NEGATIVE. - - :param data: The input data. - :param base_data: The result of the base criterion. - :param observation_window: The observation window. - :return: A DataFrame with the merged intervals. - """ - assert isinstance( - self.expr, (logic.NoDataPreservingAnd, logic.NoDataPreservingOr) - ), "Dependency is not a NoDataPreservingAnd / NoDataPreservingOr expression." - - if isinstance(self.expr, logic.NoDataPreservingAnd): - result = process.intersect_intervals(data) - elif isinstance(self.expr, logic.NoDataPreservingOr): - result = process.union_intervals(data) - - # todo: the only difference between this function and handle_binary_logical_operator is the following lines - # - can we merge? - result_negative = process.complementary_intervals( - result, - reference=base_data, - observation_window=observation_window, - interval_type=IntervalType.NEGATIVE, - ) - - result = process.concat_intervals([result, result_negative]) - - return result - def handle_left_dependent_toggle( self, left: PersonIntervals, @@ -496,7 +465,11 @@ def get_start_end_from_interval_type( assert isinstance(self.expr, logic.TemporalCount), "Invalid expression type" if self.expr.interval_criterion is not None: + # last element is the indicator windows + assert ( + len(data) >= 2 + ), "TemporalCount with indicator criterion requires at least two inputs" data, indicator_personal_windows = data[:-1], data[-1] result = process.find_overlapping_personal_windows( @@ -537,6 +510,35 @@ def get_start_end_from_interval_type( return result + def insert_negative_intervals( + self, + data: PersonIntervals, + base_data: PersonIntervals, + observation_window: TimeRange, + ) -> PersonIntervals: + """ + Inserts negative intervals into the result. + + Usually, negative intervals are implicit. This functions fills all gaps between other intervals with negative + intervals. + + :param data: The input data. + :param base_data: The result of the base criterion. + :param observation_window: The observation window. + :return: A DataFrame with the merged intervals. + """ + + data_negative = process.complementary_intervals( + data, + reference=base_data, + observation_window=observation_window, + interval_type=IntervalType.NEGATIVE, + ) + + result = process.concat_intervals([data, data_negative]) + + return result + def store_result_in_db( self, result: PersonIntervals, @@ -565,7 +567,7 @@ def store_result_in_db( return pi_pair_id = bind_params.get("pi_pair_id", None) - criterion_id = self.criterion.id if self.expr.is_Atom else None # type: ignore # when expr.is_Atom, criterion is not None + criterion_id = self.expr.id if self.expr.is_Atom else None # type: ignore # when expr.is_Atom, criterion is not None if self.expr.is_Atom: assert pi_pair_id is None, "pi_pair_id shall be None for criterion" @@ -625,13 +627,16 @@ def name(self) -> str: Uniqueness is guaranteed by prepending the base64-encoded hash of the Task object. """ - return f"[{self.id()}] {str(self)}" + if self.expr.is_Atom: + return f"[{self.id()}] {str(self)}" + else: + return f"[{self.id()}] {self.expr.__class__.__name__}()" def id(self) -> str: """ Returns the id of the Task object. """ - hash_value = hash((str(self.expr), json.dumps(self.bind_params))) + hash_value = hash((self.expr, json.dumps(self.bind_params))) # Determine the number of bytes needed. Python's hash returns a value based on the platform's pointer size. # It's 8 bytes for 64-bit systems and 4 bytes for 32-bit systems. @@ -653,7 +658,4 @@ def __repr__(self) -> str: """ Returns a string representation of the Task object. """ - if self.expr.is_Atom: - return f"Task(criterion={self.expr}, category={self.expr.category})" - else: - return f"Task({self.expr}), category={self.expr.category})" + return f"Task({self.expr}), category={self.category})" diff --git a/execution_engine/util/__init__.py b/execution_engine/util/__init__.py index 67cf1f92..92a8cbae 100644 --- a/execution_engine/util/__init__.py +++ b/execution_engine/util/__init__.py @@ -1,7 +1,20 @@ +import datetime from abc import ABCMeta from typing import Any +def datetime_converter(obj: Any) -> Any: + """ + Convert datetime objects to ISO format strings. + + Used in json.dumps() to serialize datetime objects. + """ + if isinstance(obj, (datetime.datetime, datetime.date, datetime.time)): + return obj.isoformat() + + return obj + + class AbstractPrivateMethods(ABCMeta): """ A metaclass that prevents overriding of methods decorated with @typing.final. diff --git a/execution_engine/util/cohort_logic.py b/execution_engine/util/cohort_logic.py deleted file mode 100644 index d07e784c..00000000 --- a/execution_engine/util/cohort_logic.py +++ /dev/null @@ -1,863 +0,0 @@ -from abc import ABC, abstractmethod -from datetime import time -from typing import Any, Callable, cast - -from execution_engine.constants import CohortCategory -from execution_engine.omop.criterion.abstract import Criterion -from execution_engine.omop.criterion.combination.temporal import TimeIntervalType - - -class BaseExpr(ABC): - """ - Base class for expressions and symbols, defining common properties. - """ - - args: tuple - - @classmethod - def _recreate(cls, args: Any, kwargs: dict) -> "Expr": - """ - Recreate an expression from its arguments and category. - """ - return cast(Expr, cls(*args, **kwargs)) - - def __new__(cls, *args: Any, **kwargs: Any) -> "BaseExpr": - """ - Initialize a new instance of the class. - """ - new_self = super().__new__(cls) - - # we must not allow the __init__ function because of possible infinite recursion when using the __new__ function - # (see https://pdarragh.github.io/blog/2017/05/22/oddities-in-pythons-new-method/) - if "__init__" in cls.__dict__: - raise AttributeError( - f"__init__ is not allowed in subclass {cls.__name__} of BaseExpr" - ) - - return new_self - - @abstractmethod - def __eq__(self, other: Any) -> bool: - """ - Check if this expression is equal to another expression. - """ - raise NotImplementedError("__eq__ must be implemented by subclasses") - - @abstractmethod - def __hash__(self) -> int: - """ - Get the hash of this expression. - """ - raise NotImplementedError("__hash__ must be implemented by subclasses") - - @property - def is_Atom(self) -> bool: - """ - Check if the object is an atom (not divisible into smaller parts). - - :return: True if atom, False otherwise. To be overridden in subclasses. - """ - raise NotImplementedError("is_Atom must be implemented by subclasses") - - @property - def is_Not(self) -> bool: - """ - Check if the object is a Not type. - - :return: True if Not type, False otherwise. To be overridden in subclasses. - """ - raise NotImplementedError("is_Not must be implemented by subclasses") - - def __reduce__(self) -> tuple[Callable, tuple]: - """ - Reduce the expression to its arguments and category. - - Required for pickling (e.g. when using multiprocessing). - - :return: Tuple of the class, arguments, and category. - """ - return self._recreate, (self.args, self.get_instance_variables()) - - def get_instance_variables(self, immutable: bool = False) -> dict | tuple: - """ - Return all instance variables of the object. - - If immutable is True, return as an immutable tuple of key-value pairs. - If immutable is False, return as a mutable dictionary. - """ - instance_vars = { - key: value - for key, value in vars(self).items() - if not key.startswith("_") # Exclude private or special attributes - and key != "args" - } - - if immutable: - return tuple(sorted(instance_vars.items())) - else: - return instance_vars - - -class Expr(BaseExpr): - """ - Class for expressions that require a category. - """ - - category: CohortCategory - - def __new__(cls, *args: Any, category: CohortCategory) -> "Expr": - """ - Initialize an expression with given arguments and a mandatory category. - - :param args: Arguments for the expression. - :param category: Mandatory category of the expression. - """ - self = cast(Expr, super().__new__(cls, *args)) - self.args = args - self.category = category - - return self - - def __repr__(self) -> str: - """ - Represent the expression in a readable format. - """ - return f"{self.__class__.__name__}({', '.join(map(repr, self.args))}, category='{self.category}')" - - def __eq__(self, other: Any) -> bool: - """ - Check if this expression is equal to another expression. - - :param other: The other expression. - :return: True if equal, False otherwise. - """ - return isinstance(other, self.__class__) and hash(self) == hash(other) - - def __hash__(self) -> int: - """ - Get the hash of this expression. - - :return: Hash of the expression. - """ - return hash( - (self.__class__, self.args, self.get_instance_variables(immutable=True)) - ) - - @property - def is_Atom(self) -> bool: - """ - Check if the expression is an atom. Returns False for general expressions. - - :return: False for Expr. - """ - return False - - @property - def is_Not(self) -> bool: - """ - Check if the expression is a Not type. - - :return: True if Not type, False otherwise. - """ - return isinstance(self, Not) - - -class Symbol(BaseExpr): - """ - Class representing a symbolic variable. - """ - - criterion: Criterion - category: CohortCategory - - def __new__(cls, criterion: Criterion, category: CohortCategory) -> "Symbol": - """ - Initialize a symbol. - - :param criterion: The criterion of the symbol. - """ - self = cast(Symbol, super().__new__(cls)) - self.args = () - self.criterion = criterion - self.category = category - - return self - - def __eq__(self, other: Any) -> bool: - """ - Check if this symbol is equal to another symbol. - - :param other: The other symbol. - :return: True if equal, False otherwise. - """ - return isinstance(other, Symbol) and self.criterion == other.criterion - - def __hash__(self) -> int: - """ - Get the hash of this symbol. - - :return: Hash of the symbol. - """ - return hash(self.criterion) - - def __repr__(self) -> str: - """ - Represent the symbol. - - :return: Name of the symbol. - """ - return self.criterion.description() - - @property - def is_Atom(self) -> bool: - """ - Check if the Symbol is an atom. Always returns True for Symbol. - - :return: True as Symbol is always an atom. - """ - return True - - @property - def is_Not(self) -> bool: - """ - Check if the Symbol is a Not type. Always returns False for Symbol. - - :return: False as Symbol is never a Not type. - """ - return False - - -class BooleanFunction(Expr): - """ - Base class for boolean functions like OR, AND, and NOT. - """ - - _repr_join_str: str | None = None - - def __eq__(self, other: Any) -> bool: - """ - Check if this operator is equal to another operator. - - :param other: The other operator. - :return: True if equal, False otherwise. - """ - return ( - isinstance(other, self.__class__) - and self.args == other.args - and self.get_instance_variables(immutable=True) - == other.get_instance_variables(immutable=True) - ) - - # Needs to be defined again (although it is the same as in Expr) because we define __eq__ here - def __hash__(self) -> int: - """ - Get the hash of this operator. - - :return: Hash of the operator. - """ - return super().__hash__() - - @property - def is_Atom(self) -> bool: - """ - Boolean functions are not atoms. - - :return: False - """ - return False - - @property - def is_Not(self) -> bool: - """ - Check if the BooleanFunction is a Not type. - - :return: True if Not type, False otherwise. - """ - return isinstance(self, Not) - - def __repr__(self) -> str: - """ - Represent the BooleanFunction in a readable format. - """ - if self._repr_join_str is not None: - return "(" + f" {self._repr_join_str} ".join(map(repr, self.args)) + ")" - else: - return super().__repr__() - - -class Or(BooleanFunction): - """ - Class representing a logical OR operation. - """ - - _repr_join_str = "|" - - def __new__(cls, *args: Any, **kwargs: Any) -> BaseExpr: - """ - Create a new Or object. - """ - if len(args) == 1 and isinstance(args[0], BaseExpr): - return args[0] - - return super().__new__(cls, *args, **kwargs) - - -class And(BooleanFunction): - """ - Class representing a logical AND operation. - """ - - _repr_join_str = "&" - - def __new__(cls, *args: Any, **kwargs: Any) -> BaseExpr: - """ - Create a new And object. - """ - if len(args) == 1 and isinstance(args[0], BaseExpr): - return args[0] - - return super().__new__(cls, *args, **kwargs) - - -class Not(BooleanFunction): - """ - Class representing a logical NOT operation. - """ - - def __repr__(self) -> str: - """ - Represent the NOT operation as a string. - """ - return f"~{self.args[0]}" - - def __new__(cls, *args: Any, **kwargs: Any) -> "Not": - """ - Create a new Or object. - """ - if len(args) > 1: - raise ValueError("Not can only have one argument") - - return cast(Not, super().__new__(cls, *args, **kwargs)) - - -class Count(BooleanFunction, ABC): - """ - Class representing a logical COUNT operation. - - Adds a "threshold" parameter of type int. - - This class should not be instantiated directly, but rather through one of its subclasses. - """ - - count_min: int | None = None - count_max: int | None = None - - -class MinCount(Count): - """ - Class representing a logical MIN_COUNT operation. - """ - - def __new__(cls, *args: Any, threshold: int | None, **kwargs: Any) -> "MinCount": - """ - Create a new MinCount object. - """ - self = cast(MinCount, super().__new__(cls, *args, **kwargs)) - self.count_min = threshold - return self - - def __reduce__(self) -> tuple[Callable, tuple]: - """ - Reduce the expression to its arguments and category. - - Required for pickling (e.g. when using multiprocessing). - - :return: Tuple of the class, arguments, and category. - """ - return ( - self._recreate, - (self.args, {"category": self.category, "threshold": self.count_min}), - ) - - def __repr__(self) -> str: - """ - Represent the expression in a readable format. - """ - return f"{self.__class__.__name__}(threshold={self.count_min}; {', '.join(map(repr, self.args))}, category='{self.category}')" - - -class MaxCount(Count): - """ - Class representing a logical MAX_COUNT operation. - """ - - def __new__(cls, *args: Any, threshold: int | None, **kwargs: Any) -> "MaxCount": - """ - Create a new MaxCount object. - """ - self = cast(MaxCount, super().__new__(cls, *args, **kwargs)) - self.count_max = threshold - return self - - def __reduce__(self) -> tuple[Callable, tuple]: - """ - Reduce the expression to its arguments and category. - - Required for pickling (e.g. when using multiprocessing). - - :return: Tuple of the class, arguments, and category. - """ - return ( - self._recreate, - (self.args, {"category": self.category, "threshold": self.count_max}), - ) - - def __repr__(self) -> str: - """ - Represent the expression in a readable format. - """ - return f"{self.__class__.__name__}(threshold={self.count_max}; {', '.join(map(repr, self.args))}, category='{self.category}')" - - -class ExactCount(Count): - """ - Class representing a logical EXACT_COUNT operation. - """ - - def __new__(cls, *args: Any, threshold: int | None, **kwargs: Any) -> "ExactCount": - """ - Create a new ExactCount object. - """ - self = cast(ExactCount, super().__new__(cls, *args, **kwargs)) - self.count_min = threshold - self.count_max = threshold - return self - - def __reduce__(self) -> tuple[Callable, tuple]: - """ - Reduce the expression to its arguments and category. - - Required for pickling (e.g. when using multiprocessing). - - :return: Tuple of the class, arguments, and category. - """ - return ( - self._recreate, - (self.args, {"category": self.category, "threshold": self.count_min}), - ) - - def __repr__(self) -> str: - """ - Represent the expression in a readable format. - """ - return f"{self.__class__.__name__}(threshold={self.count_min}; {', '.join(map(repr, self.args))}, category='{self.category}')" - - -class CappedCount(BooleanFunction, ABC): - """ - Base class representing a COUNT operation with an upper cap. - - This class distinguishes COUNT operations that are subject to an implicit - maximum constraint, ensuring that they do not exceed what is achievable - given external limitations. - - Unlike regular COUNT operations, the threshold in this class is not assumed - to be unbounded. However, no explicit handling of the maximum occurs here; - it is enforced externally. - - This class should not be instantiated directly but used as a base for specific - capped count operations like CappedMinCount. - """ - - count_min: int | None = None - count_max: int | None = None - - -class CappedMinCount(CappedCount): - """ - Class representing a MIN_COUNT operation with an implicit upper cap. - - This behaves like MinCount but acknowledges that the minimum required count - is subject to an external upper constraint. If the requested threshold exceeds - what is achievable, the actual threshold will be limited to the maximum possible - count, which is determined externally. - - The enforcement of this cap does not occur within this class; rather, it is - expected to be handled by the surrounding logic. - - The threshold parameter defines the minimum number of overlapping intervals - required, but in practice, it will not exceed the externally imposed cap. - """ - - def __new__( - cls, *args: Any, threshold: int | None, **kwargs: Any - ) -> "CappedMinCount": - """ - Create a new CappedMinCount object. - """ - self = cast(CappedMinCount, super().__new__(cls, *args, **kwargs)) - self.count_min = threshold - return self - - def __reduce__(self) -> tuple[Callable, tuple]: - """ - Reduce the expression to its arguments and category. - - Required for pickling (e.g., when using multiprocessing). - - :return: Tuple of the class, arguments, and category. - """ - return ( - self._recreate, - (self.args, {"category": self.category, "threshold": self.count_min}), - ) - - def __repr__(self) -> str: - """ - Represent the expression in a readable format. - """ - return f"{self.__class__.__name__}(threshold={self.count_min}; {', '.join(map(repr, self.args))}, category='{self.category}')" - - -class AllOrNone(BooleanFunction): - """ - Class representing a logical ALL_OR_NONE operation. - """ - - -class TemporalCount(BooleanFunction, ABC): - """ - Class representing a logical COUNT operation. - - Adds a "threshold" parameter of type int. - - This class should not be instantiated directly, but rather through one of its subclasses. - """ - - count_min: int | None = None - count_max: int | None = None - start_time: time | None = None - end_time: time | None = None - interval_type: TimeIntervalType | None = None - interval_criterion: BaseExpr | None = None - - -class TemporalMinCount(TemporalCount): - """ - Class representing a logical temporal MIN_COUNT operation. - """ - - def __new__( - cls, - *args: Any, - threshold: int | None, - start_time: time | None, - end_time: time | None, - interval_type: TimeIntervalType | None, - interval_criterion: BaseExpr | None, - **kwargs: Any, - ) -> "TemporalMinCount": - """ - Create a new MinCount object. - """ - self = cast(TemporalMinCount, super().__new__(cls, *args, **kwargs)) - self.count_min = threshold - self.start_time = start_time - self.end_time = end_time - self.interval_type = interval_type - self.interval_criterion = interval_criterion - - return self - - def __reduce__(self) -> tuple[Callable, tuple]: - """ - Reduce the expression to its arguments and category. - - Required for pickling (e.g. when using multiprocessing). - - :return: Tuple of the class, arguments, and category. - """ - return ( - self._recreate, - ( - self.args, - { - "category": self.category, - "threshold": self.count_min, - "start_time": self.start_time, - "end_time": self.end_time, - "interval_type": self.interval_type, - "interval_criterion": self.interval_criterion, - }, - ), - ) - - def __repr__(self) -> str: - """ - Represent the expression in a readable format. - """ - - if self.start_time is not None and self.end_time is not None: - interval = f"{self.start_time} - {self.end_time}" - elif self.interval_type is not None: - interval = self.interval_type.name - elif self.interval_criterion is not None: - interval = repr(self.interval_criterion) - else: - interval = "None" - - return f"{self.__class__.__name__}(interval={interval}; threshold={self.count_min}; {', '.join(map(repr, self.args))}, category='{self.category}')" - - -class TemporalMaxCount(TemporalCount): - """ - Class representing a logical MAX_COUNT operation. - """ - - def __new__( - cls, - *args: Any, - threshold: int | None, - start_time: time | None, - end_time: time | None, - interval_type: TimeIntervalType | None, - interval_criterion: BaseExpr | None, - **kwargs: Any, - ) -> "TemporalMaxCount": - """ - Create a new MaxCount object. - """ - self = cast(TemporalMaxCount, super().__new__(cls, *args, **kwargs)) - self.count_max = threshold - self.start_time = start_time - self.end_time = end_time - self.interval_type = interval_type - self.interval_criterion = interval_criterion - - return self - - def __reduce__(self) -> tuple[Callable, tuple]: - """ - Reduce the expression to its arguments and category. - - Required for pickling (e.g. when using multiprocessing). - - :return: Tuple of the class, arguments, and category. - """ - return ( - self._recreate, - ( - self.args, - { - "category": self.category, - "threshold": self.count_max, - "start_time": self.start_time, - "end_time": self.end_time, - "interval_type": self.interval_type, - "interval_criterion": self.interval_criterion, - }, - ), - ) - - def __repr__(self) -> str: - """ - Represent the expression in a readable format. - """ - - if self.start_time is not None and self.end_time is not None: - interval = f"{self.start_time} - {self.end_time}" - elif self.interval_type is not None: - interval = self.interval_type.name - elif self.interval_criterion is not None: - interval = repr(self.interval_criterion) - else: - interval = "None" - - return f"{self.__class__.__name__}(interval={interval}; threshold={self.count_max}; {', '.join(map(repr, self.args))}, category='{self.category}')" - - -class TemporalExactCount(TemporalCount): - """ - Class representing a logical EXACT_COUNT operation. - """ - - def __new__( - cls, - *args: Any, - threshold: int | None, - start_time: time | None, - end_time: time | None, - interval_type: TimeIntervalType | None, - interval_criterion: BaseExpr | None, - **kwargs: Any, - ) -> "TemporalExactCount": - """ - Create a new ExactCount object. - """ - self = cast(TemporalExactCount, super().__new__(cls, *args, **kwargs)) - self.count_min = threshold - self.count_max = threshold - self.start_time = start_time - self.end_time = end_time - self.interval_type = interval_type - self.interval_criterion = interval_criterion - - return self - - def __reduce__(self) -> tuple[Callable, tuple]: - """ - Reduce the expression to its arguments and category. - - Required for pickling (e.g. when using multiprocessing). - - :return: Tuple of the class, arguments, and category. - """ - return ( - self._recreate, - ( - self.args, - { - "category": self.category, - "threshold": self.count_min, - "start_time": self.start_time, - "end_time": self.end_time, - "interval_type": self.interval_type, - "interval_criterion": self.interval_criterion, - }, - ), - ) - - def __repr__(self) -> str: - """ - Represent the expression in a readable format. - """ - - if self.start_time is not None and self.end_time is not None: - interval = f"{self.start_time} - {self.end_time}" - elif self.interval_type is not None: - interval = self.interval_type.name - elif self.interval_criterion is not None: - interval = repr(self.interval_criterion) - else: - interval = "None" - - return f"{self.__class__.__name__}(interval={interval}; threshold={self.count_min}; {', '.join(map(repr, self.args))}, category='{self.category}')" - - -class NonSimplifiableAnd(BooleanFunction): - """ - A NonSimplifiableAnd object represents a logical AND operation that cannot be simplified. - - Simplified here means that when this operator is used on a single argument, still this operator is returned - instead of the argument itself, as is the case with the sympy.And operator. - - The reason for this operator is that if there is a single population or intervention criterion, the And/Or - operators would simplify to the criterion itself. In that case, the _whole_ population or intervention expression of - the respective population/intervention pair, which should be written to the database with criterion_id = None, would - be lost (i.e. not written, because there is no graph execution node that would perform this. This operator prevents - that. - """ - - def __eq__(self, other: Any) -> bool: - """ - Check if this operator is equal to another operator. - - This always yields false to prevent combination of two NonSimplifiableAnd operators when merging a graph. - """ - return False - - def __hash__(self) -> int: - """ - Get the hash of this operator. - - Required because __eq__ yields always False -- and we need distinct hashes for distinct objects, as this - operator should not be merged. - """ - return id(self) - - def __new__(cls, *args: Any, **kwargs: Any) -> "NonSimplifiableAnd": - """ - Create a new NonSimplifiableAnd object. - """ - return cast(NonSimplifiableAnd, super().__new__(cls, *args, **kwargs)) - - -# todo: can we rename to more meaningful name? -class NoDataPreservingAnd(BooleanFunction): - """ - A And object represents a logical AND operation. - - See Task.handle_no_data_preserving_operator for the rules. Currently, the only difference between this operator - and the And operator is that during handling of this operator, the negative intervals are added explicitly. - """ - - def __new__(cls, *args: Any, **kwargs: Any) -> "NoDataPreservingAnd": - """ - Create a new NoDataPreservingAnd object. - """ - return cast(NoDataPreservingAnd, super().__new__(cls, *args, **kwargs)) - - -class NoDataPreservingOr(BooleanFunction): - """ - A Or object represents a logical OR operation. - - See Task.handle_no_data_preserving_operator for the rules. Currently, the only difference between this operator - and the And operator is that during handling of this operator, the negative intervals are added explicitly. - """ - - def __new__(cls, *args: Any, **kwargs: Any) -> "NoDataPreservingOr": - """ - Create a new NoDataPreservingOr object. - """ - return cast(NoDataPreservingOr, super().__new__(cls, *args, **kwargs)) - - -class LeftDependentToggle(BooleanFunction): - """ - A LeftDependentToggle object represents a logical AND operation if the left operand is positive, - otherwise it returns NOT_APPLICABLE. - """ - - def __new__( - cls, left: BaseExpr, right: BaseExpr, **kwargs: Any - ) -> "LeftDependentToggle": - """ - Initialize a LeftDependentToggle object. - """ - return cast(LeftDependentToggle, super().__new__(cls, left, right, **kwargs)) - - @property - def left(self) -> Expr: - """Returns the left operand""" - return self.args[0] - - @property - def right(self) -> Expr: - """Returns the right operand""" - return self.args[1] - - -class ConditionalFilter(BooleanFunction): - """ - A ConditionalFilter object returns the right operand if the left operand is POSITIVE, - and NEGATIVE otherwise - """ - - def __new__( - cls, left: BaseExpr, right: BaseExpr, **kwargs: Any - ) -> "ConditionalFilter": - """ - Initialize a ConditionalFilter object. - """ - return cast(ConditionalFilter, super().__new__(cls, left, right, **kwargs)) - - @property - def left(self) -> Expr: - """Returns the left operand""" - return self.args[0] - - @property - def right(self) -> Expr: - """Returns the right operand""" - return self.args[1] diff --git a/execution_engine/util/enum.py b/execution_engine/util/enum.py index a9653fa5..574ed87c 100644 --- a/execution_engine/util/enum.py +++ b/execution_engine/util/enum.py @@ -23,9 +23,15 @@ class TimeUnit(StrEnum): MONTH = "mo" YEAR = "a" + def __repr__(self) -> str: + """ + Get the string representation of the category. + """ + return f"{self.__class__.__name__}.{self.name}" + def __str__(self) -> str: """ - Returns the string representation of the TimeUnit. + Get the string representation of the category. """ return self.name @@ -58,3 +64,29 @@ def to_sql_interval_length_seconds(self) -> ColumnElement: return func.cast(func.extract("EPOCH", self.to_sql_interval()), NUMERIC).label( "duration_seconds" ) + + +class TimeIntervalType(StrEnum): + """ + Types of time intervals to aggregate criteria over. + """ + + MORNING_SHIFT = "morning_shift" + AFTERNOON_SHIFT = "afternoon_shift" + NIGHT_SHIFT = "night_shift" + NIGHT_SHIFT_BEFORE_MIDNIGHT = "night_shift_before_midnight" + NIGHT_SHIFT_AFTER_MIDNIGHT = "night_shift_after_midnight" + DAY = "day" + ANY_TIME = "any_time" + + def __repr__(self) -> str: + """ + Get the string representation of the category. + """ + return f"{self.__class__.__name__}.{self.name}" + + def __str__(self) -> str: + """ + Get the string representation of the category. + """ + return self.name diff --git a/execution_engine/util/interval/typed_interval.py b/execution_engine/util/interval/typed_interval.py index 3c2e4eac..aa35696f 100644 --- a/execution_engine/util/interval/typed_interval.py +++ b/execution_engine/util/interval/typed_interval.py @@ -1418,6 +1418,7 @@ def interval_datetime( :param lower: The lower bound. :param upper: The upper bound. + :param type_: The type of the interval. :return: The new datetime interval. """ return DateTimeInterval.from_atomic(Bound.CLOSED, lower, upper, Bound.CLOSED, type_) # type: ignore # mypy expects "IntervalT", not sure why diff --git a/execution_engine/util/logic.py b/execution_engine/util/logic.py new file mode 100644 index 00000000..0d11f587 --- /dev/null +++ b/execution_engine/util/logic.py @@ -0,0 +1,1032 @@ +from datetime import time +from typing import Any, Callable, Dict, Iterator, Self, cast + +from execution_engine.util.enum import TimeIntervalType +from execution_engine.util.serializable import Serializable, SerializableABC + + +def arg_to_dict(arg: Any, include_id: bool) -> dict: + """ + Convert an argument to a dictionary representation. + + :param arg: The argument to convert. + :param include_id: Whether to include the ID in the dictionary. + :return: Dictionary representation of the argument. + """ + return arg.dict(include_id=include_id) if isinstance(arg, BaseExpr) else arg + + +class BaseExpr(Serializable): + """ + Base class for expressions and symbols, defining common properties. + """ + + args: tuple + + @property + def is_Atom(self) -> bool: + """ + Check if the object is an atom (not divisible into smaller parts). + + :return: True if atom, False otherwise. To be overridden in subclasses. + """ + raise NotImplementedError("is_Atom must be implemented by subclasses") + + @property + def is_Not(self) -> bool: + """ + Check if the object is a Not type. + + :return: True if Not type, False otherwise. To be overridden in subclasses. + """ + raise NotImplementedError("is_Not must be implemented by subclasses") + + +class Expr(BaseExpr): + """ + Class for expressions that are not Symbols + """ + + @classmethod + def _recreate(cls, args: Any, kwargs: dict) -> "Expr": + """ + Recreate an expression from its arguments and category. + """ + _id = kwargs.pop("_id") + + self = cast(Expr, cls(*args, **kwargs)) + self.set_id(_id) + return self + + def _reduce_helper( + self, ivars_map: dict | None = None, args: tuple | None = None + ) -> tuple[Callable, tuple]: + """ + Return a picklable tuple that calls self._recreate and passes in + (self.args, combined_ivars). + + :param ivars_map: A dictionary for renaming keys in the instance variables. + For each old_key -> new_key in ivars_map, if old_key exists + in data, it will be removed and stored under new_key instead. + :param args: Optionally, the args parameter can be specified. If it is None, self.args is used. + """ + data = self.get_instance_variables() + data["_id"] = self._id # type: ignore[index] + + # Apply key renaming if ivars_map is provided + if ivars_map: + for old_key, new_key in ivars_map.items(): + if old_key in data: + data[new_key] = data.pop(old_key) # type: ignore[index,union-attr] + + if args is None: + args = self.args + + return (self._recreate, (args, data)) + + def __reduce__(self) -> tuple[Callable, tuple]: + """ + Reduce the expression to its arguments and category. + + Required for pickling (e.g. when using multiprocessing). + + :return: Tuple of the class, arguments, and category. + """ + # return self._recreate, (self.args, self.get_instance_variables() | {"_id": self._id}) + return self._reduce_helper() + + def get_instance_variables(self, immutable: bool = False) -> dict | tuple: + """ + Return all instance variables of the object. + + If immutable is True, return as an immutable tuple of key-value pairs. + If immutable is False, return as a mutable dictionary. + """ + instance_vars = { + key: value + for key, value in vars(self).items() + if not key.startswith("_") # Exclude private or special attributes + and key != "args" + } + + if immutable: + return tuple(sorted(instance_vars.items())) + else: + return instance_vars + + def __setattr__(self, name: str, value: Any) -> None: + """ + Set an attribute on the object. + + This is overridden to prevent setting attributes on the object. + """ + if name in self.__dict__ and name not in ["args", "_hash"]: + raise AttributeError( + f"Cannot update attributes on {self.__class__.__name__}" + ) + super().__setattr__(name, value) + + def update_args(self, *args: Any) -> None: + """ + Update the arguments of the expression. + + :param args: The new arguments. + """ + self.args = args + self.rehash() + + def __new__(cls, *args: Any, **kwargs: Any) -> "Expr": + """ + Initialize an expression with given arguments. + + :param args: Arguments for the expression. + """ + self = cast(Expr, super().__new__(cls)) + + # we must not allow the __init__ function because of possible infinite recursion when using the __new__ function + # (see https://pdarragh.github.io/blog/2017/05/22/oddities-in-pythons-new-method/) + if "__init__" in cls.__dict__: + raise AttributeError( + f"__init__ is not allowed in subclass {cls.__name__} of BaseExpr" + ) + + self.args = args + + return self + + def __str__(self) -> str: + """ + Represent the expression in a readable format. + """ + return f"{self.__class__.__name__}({', '.join(map(str, self.args))})" + + def rehash(self, recursive: bool = False) -> None: + """ + Recalculate the hash of the object. + """ + + if recursive: + for arg in self.args: + if isinstance(arg, Expr): + arg.rehash(recursive=True) + else: + arg.rehash() + + self._hash = hash( + (self.__class__, self.args, self.get_instance_variables(immutable=True)) + ) + + @property + def is_Atom(self) -> bool: + """ + Check if the expression is an atom. Returns False for general expressions. + + :return: False for Expr. + """ + return False + + @property + def is_Not(self) -> bool: + """ + Check if the expression is a Not type. + + :return: True if Not type, False otherwise. + """ + return isinstance(self, Not) + + def atoms(self) -> Iterator["Symbol"]: + """ + Get all symbols in the expression. + """ + + def traverse(expr: BaseExpr) -> Iterator[Symbol]: + if expr.is_Atom: + assert isinstance(expr, Symbol), f"Expected Symbol, got {expr}" + yield expr + + for sub_expr in expr.args: + yield from traverse(sub_expr) + + yield from traverse(self) + + def dict(self, include_id: bool = False) -> dict: + """ + Get a dictionary representation of the object. + """ + data: Dict[str, Any] = { + "type": self.__class__.__name__, + "data": { + "args": [arg_to_dict(arg, include_id=include_id) for arg in self.args], + }, + } + + if include_id and self._id is not None: + data["data"]["_id"] = self._id + + return data + + +class Symbol(BaseExpr): + """ + Class representing a symbolic variable. + """ + + def __new__(cls, *args: Any, **kwargs: Any) -> "Symbol": + """ + Initialize a symbol. + """ + self = cast(Symbol, super().__new__(cls)) + self.args = () + + return self + + @property + def is_Atom(self) -> bool: + """ + Check if the Symbol is an atom. Always returns True for Symbol. + + :return: True as Symbol is always an atom. + """ + return True + + @property + def is_Not(self) -> bool: + """ + Check if the Symbol is a Not type. Always returns False for Symbol. + + :return: False as Symbol is never a Not type. + """ + return False + + def get_instance_variables(self, immutable: bool = False) -> Dict[str, Any] | tuple: + """ + Return all instance variables of the object. + + If immutable is True, return as an immutable tuple of key-value pairs. + If immutable is False, return as a mutable dictionary. + """ + instance_vars = { + key: value + for key, value in vars(self).items() + if not key.startswith("_") # Exclude private or special attributes + and key != "args" + } + + if immutable: + return tuple(sorted(instance_vars.items())) + else: + return instance_vars + + +class BooleanFunction(Expr): + """ + Base class for boolean functions like OR, AND, and NOT. + """ + + _repr_join_str: str | None = None + + @property + def is_Atom(self) -> bool: + """ + Boolean functions are not atoms. + + :return: False + """ + return False + + @property + def is_Not(self) -> bool: + """ + Check if the BooleanFunction is a Not type. + + :return: True if Not type, False otherwise. + """ + return isinstance(self, Not) + + +class UnaryOperator(BooleanFunction): + """ + Base class for unary operators. + """ + + def __new__(cls, *args: Any, **kwargs: Any) -> "UnaryOperator": + """ + Create a new UnaryOperator object. + """ + if len(args) > 1: + raise ValueError(f"{cls.__name__} can only have one argument") + + return cast(UnaryOperator, super().__new__(cls, *args, **kwargs)) + + +class CommutativeOperator(BooleanFunction, SerializableABC): + """ + Base class for commutative operators. + """ + + def __new__(cls, *args: Any, **kwargs: Any) -> "CommutativeOperator": + """ + Create a new CommutativeOperator object. + """ + return cast(CommutativeOperator, super().__new__(cls, *args, **kwargs)) + + +class Or(CommutativeOperator): + """ + Class representing a logical OR operation. + """ + + _repr_join_str = "|" + + def __new__(cls, *args: Any, **kwargs: Any) -> BaseExpr: + """ + Create a new Or object. + """ + if len(args) == 1 and isinstance(args[0], BaseExpr): + return args[0] + + return super().__new__(cls, *args, **kwargs) + + +class And(CommutativeOperator): + """ + Class representing a logical AND operation. + """ + + _repr_join_str = "&" + + def __new__(cls, *args: Any, **kwargs: Any) -> BaseExpr: + """ + Create a new And object. + """ + if len(args) == 1 and isinstance(args[0], BaseExpr): + return args[0] + + return super().__new__(cls, *args, **kwargs) + + +class Not(UnaryOperator): + """ + Class representing a logical NOT operation. + """ + + def __str__(self) -> str: + """ + Represent the NOT operation as a string. + """ + return f"~{self.args[0]}" + + def __new__(cls, *args: Any, **kwargs: Any) -> "Not": + """ + Create a new Or object. + """ + if len(args) > 1: + raise ValueError("Not can only have one argument") + + return cast(Not, super().__new__(cls, *args, **kwargs)) + + +class CountOperator(CommutativeOperator, SerializableABC): + """ + Base class for count operators + + This is the BaseClass for Count, TemporalCount and CappedCount - while these three classes + may not define any additional code, we need them to be able to use isinstance on the different subclasses and + distinguish between subclasses from Count, TemporalCount or CappedCount + """ + + count_min: int | None + count_max: int | None + + def __new__( + cls, *args: Any, min_count: int | None, max_count: int | None, **kwargs: Any + ) -> "CountOperator": + """ + Create a new MinCount object. + """ + self = cast(MinCount, super().__new__(cls, *args, **kwargs)) + self.count_min = min_count + self.count_max = max_count + + return self + + def _replace_map(self) -> dict: + replace = {} + + if self.count_min and self.count_max: + assert self.count_min == self.count_max + replace["count_min"] = "threshold" + elif self.count_min: + replace["count_min"] = "threshold" + elif self.count_max: + replace["count_max"] = "threshold" + else: + raise AttributeError("At least one of count_min or count_max must be set") + + return replace + + def __reduce__(self) -> tuple[Callable, tuple]: + """ + Reduce the expression to its arguments and category. + + Required for pickling (e.g. when using multiprocessing). + + :return: Tuple of the class, argument. + """ + return self._reduce_helper(self._replace_map()) + + +class Count(CountOperator): + """ + Class representing a logical COUNT operation. + """ + + +class MinCount(Count): + """ + Class representing a logical MIN_COUNT operation. + """ + + def __new__(cls, *args: Any, threshold: int | None, **kwargs: Any) -> "MinCount": + """ + Create a new MinCount object. + """ + self = cast( + MinCount, + super().__new__(cls, *args, min_count=threshold, max_count=None, **kwargs), + ) + return self + + def dict(self, include_id: bool = False) -> dict: + """ + Get a dictionary representation of the object. + """ + + data = super().dict(include_id=include_id) + data["data"].update({"threshold": self.count_min}) + + return data + + def __str__(self) -> str: + """ + Represent the expression in a readable format. + """ + return f"{self.__class__.__name__}(threshold={self.count_min})" + + +class MaxCount(Count): + """ + Class representing a logical MAX_COUNT operation. + """ + + def __new__(cls, *args: Any, threshold: int | None, **kwargs: Any) -> "MaxCount": + """ + Create a new MaxCount object. + """ + self = cast( + MaxCount, + super().__new__(cls, *args, min_count=None, max_count=threshold, **kwargs), + ) + return self + + def dict(self, include_id: bool = False) -> dict: + """ + Get a dictionary representation of the object. + """ + data = super().dict(include_id=include_id) + data["data"].update({"threshold": self.count_max}) + return data + + def __str__(self) -> str: + """ + Represent the expression in a readable format. + """ + return f"{self.__class__.__name__}(threshold={self.count_max})" + + +class ExactCount(Count): + """ + Class representing a logical EXACT_COUNT operation. + """ + + def __new__(cls, *args: Any, threshold: int | None, **kwargs: Any) -> "ExactCount": + """ + Create a new ExactCount object. + """ + self = cast( + ExactCount, + super().__new__( + cls, *args, min_count=threshold, max_count=threshold, **kwargs + ), + ) + return self + + def dict(self, include_id: bool = False) -> dict: + """ + Get a dictionary representation of the object. + """ + data = super().dict(include_id=include_id) + data["data"].update({"threshold": self.count_min}) + return data + + def __str__(self) -> str: + """ + Represent the expression in a readable format. + """ + return f"{self.__class__.__name__}(threshold={self.count_min})" + + +class CappedCount(CountOperator, SerializableABC): + """ + Base class representing a COUNT operation with an upper cap. + + This class distinguishes COUNT operations that are subject to an implicit + maximum constraint, ensuring that they do not exceed what is achievable + given external limitations. + + Unlike regular COUNT operations, the threshold in this class is not assumed + to be unbounded. However, no explicit handling of the maximum occurs here; + it is enforced externally. + + This class should not be instantiated directly but used as a base for specific + capped count operations like CappedMinCount. + """ + + +class CappedMinCount(CappedCount): + """ + Class representing a MIN_COUNT operation with an implicit upper cap. + + This behaves like MinCount but acknowledges that the minimum required count + is subject to an external upper constraint. If the requested threshold exceeds + what is achievable, the actual threshold will be limited to the maximum possible + count, which is determined externally. + + The enforcement of this cap does not occur within this class; rather, it is + expected to be handled by the surrounding logic. + + The threshold parameter defines the minimum number of overlapping intervals + required, but in practice, it will not exceed the externally imposed cap. + """ + + def __new__( + cls, *args: Any, threshold: int | None, **kwargs: Any + ) -> "CappedMinCount": + """ + Create a new CappedMinCount object. + """ + return cast( + CappedMinCount, + super().__new__(cls, *args, min_count=threshold, max_count=None, **kwargs), + ) + + def dict(self, include_id: bool = False) -> dict: + """ + Get a dictionary representation of the object. + """ + data = super().dict(include_id=include_id) + data["data"].update({"threshold": self.count_min}) + return data + + def __str__(self) -> str: + """ + Represent the expression in a readable format. + """ + return f"{self.__class__.__name__}(threshold={self.count_min})" + + +class AllOrNone(CommutativeOperator): + """ + Class representing a logical ALL_OR_NONE operation. + """ + + +class TemporalCount(CountOperator, SerializableABC): + """ + Class representing a logical COUNT operation. + + Adds a "threshold" parameter of type int. + + This class should not be instantiated directly, but rather through one of its subclasses. + """ + + count_min: int | None = None + count_max: int | None = None + start_time: time | None = None + end_time: time | None = None + interval_type: TimeIntervalType | None = None + interval_criterion: BaseExpr | None = None + + def __new__( + cls, + *args: Any, + min_count: int | None, + max_count: int | None, + start_time: time | None = None, + end_time: time | None = None, + interval_type: TimeIntervalType | None = None, + interval_criterion: BaseExpr | None = None, + **kwargs: Any, + ) -> Self: + """ + Create a new TemporalCount object. + """ + + TemporalCount._validate_time_inputs( + start_time, end_time, interval_type, interval_criterion + ) + + if interval_criterion: + # we need to add the interval_criterion to the list of arguments of this criterion in order to have + # it properly processed + args += (interval_criterion,) + + self = cast( + Self, + super().__new__( + cls, *args, min_count=min_count, max_count=max_count, **kwargs + ), + ) + + self.start_time = ( + time.fromisoformat(start_time) # type: ignore[arg-type] + if isinstance(start_time, str) + else start_time + ) + self.end_time = ( + time.fromisoformat(end_time) # type: ignore[arg-type] + if isinstance(end_time, str) + else end_time + ) + self.interval_type = interval_type + self.interval_criterion = interval_criterion + + return self + + @classmethod + def _validate_time_inputs( + self, + start_time: time | None, + end_time: time | None, + interval_type: TimeIntervalType | None, + interval_criterion: BaseExpr | None, + ) -> None: + + if interval_type: + if start_time is not None or end_time is not None: + raise ValueError( + "start_time/end_time cannot be used together with interval_type" + ) + if interval_criterion is not None: + raise ValueError( + "interval_criterion cannot be used together with interval_type" + ) + # Validate the interval_type if needed + self.interval_type = interval_type + self.start_time = None + self.end_time = None + + elif start_time or end_time: + # Must have start_time and end_time + if start_time is None or end_time is None: + raise ValueError( + "Either interval_type or interval_criterion or both start_time & end_time must be provided" + ) + if interval_criterion is not None: + raise ValueError( + "interval_criterion cannot be used together with start_time/end_time" + ) + if start_time >= end_time: + raise ValueError("start_time must be less than end_time") + + elif interval_criterion and not isinstance(interval_criterion, BaseExpr): + raise ValueError( + f"Invalid criterion - expected Criterion or CriterionCombination, got {type(interval_criterion)}" + ) + + def __reduce__(self) -> tuple[Callable, tuple]: + """ + Reduce the expression to its arguments and category. + + Required for pickling (e.g. when using multiprocessing). + + :return: Tuple of the class, argument. + """ + if self.interval_criterion: + + if len(self.args) <= 1: + raise ValueError( + "More than one argument required if interval_criterion is set" + ) + + args, pop = self.args[:-1], self.args[-1] + + if pop != self.interval_criterion: + raise ValueError( + f"Expected last argument to be the interval_criterion, got {str(pop)}" + ) + return self._reduce_helper(self._replace_map(), args=args) + + return super().__reduce__() + + def dict(self, include_id: bool = False) -> dict: + """ + Get a dictionary representation of the object. + """ + data = super().dict(include_id=include_id) + + if self.interval_criterion: + + if len(self.args) <= 1: + raise ValueError( + "More than one argument required if interval_criterion is set" + ) + + args, pop = self.args[:-1], self.args[-1] + + if pop != self.interval_criterion: + raise ValueError( + f"Expected last argument to be the interval_criterion, got {str(pop)}" + ) + + data["data"]["args"] = [ + arg_to_dict(arg, include_id=include_id) for arg in args + ] + + data["data"].update( + { + "start_time": self.start_time.isoformat() if self.start_time else None, + "end_time": self.end_time.isoformat() if self.end_time else None, + "interval_type": self.interval_type, + "interval_criterion": ( + self.interval_criterion.dict(include_id=include_id) + if self.interval_criterion + else None + ), + } + ) + return data + + def __str__(self) -> str: + """ + Represent the expression in a readable format. + """ + + if self.start_time is not None and self.end_time is not None: + interval = f"{self.start_time} - {self.end_time}" + elif self.interval_type is not None: + interval = self.interval_type.name + elif self.interval_criterion is not None: + interval = repr(self.interval_criterion) + else: + interval = "None" + + return f"{self.__class__.__name__}(interval={interval}; threshold={self.count_min})" + + +class TemporalMinCount(TemporalCount): + """ + Class representing a logical temporal MIN_COUNT operation. + """ + + def __new__( + cls, + *args: Any, + threshold: int, + start_time: time | None = None, + end_time: time | None = None, + interval_type: TimeIntervalType | None = None, + interval_criterion: BaseExpr | None = None, + **kwargs: Any, + ) -> "TemporalMinCount": + """ + Create a new MinCount object. + """ + self = cast( + TemporalMinCount, + super().__new__( + cls, + *args, + min_count=threshold, + max_count=None, + start_time=start_time, + end_time=end_time, + interval_type=interval_type, + interval_criterion=interval_criterion, + ), + ) + return self + + def dict(self, include_id: bool = False) -> dict: + """ + Get a dictionary representation of the object. + """ + data = super().dict(include_id=include_id) + data["data"].update({"threshold": self.count_min}) + return data + + +class TemporalMaxCount(TemporalCount): + """ + Class representing a logical MAX_COUNT operation. + """ + + def __new__( + cls, + *args: Any, + threshold: int, + start_time: time | None = None, + end_time: time | None = None, + interval_type: TimeIntervalType | None = None, + interval_criterion: BaseExpr | None = None, + **kwargs: Any, + ) -> "TemporalMaxCount": + """ + Create a new MaxCount object. + """ + self = cast( + TemporalMaxCount, + super().__new__( + cls, + *args, + min_count=None, + max_count=threshold, + start_time=start_time, + end_time=end_time, + interval_type=interval_type, + interval_criterion=interval_criterion, + ), + ) + return self + + def dict(self, include_id: bool = False) -> dict: + """ + Get a dictionary representation of the object. + """ + data = super().dict(include_id=include_id) + data["data"].update({"threshold": self.count_max}) + return data + + +class TemporalExactCount(TemporalCount): + """ + Class representing a logical EXACT_COUNT operation. + """ + + def __new__( + cls, + *args: Any, + threshold: int, + start_time: time | None = None, + end_time: time | None = None, + interval_type: TimeIntervalType | None = None, + interval_criterion: BaseExpr | None = None, + **kwargs: Any, + ) -> "TemporalExactCount": + """ + Create a new ExactCount object. + """ + self = cast( + TemporalExactCount, + super().__new__( + cls, + *args, + min_count=threshold, + max_count=threshold, + start_time=start_time, + end_time=end_time, + interval_type=interval_type, + interval_criterion=interval_criterion, + ), + ) + return self + + def dict(self, include_id: bool = False) -> dict: + """ + Get a dictionary representation of the object. + """ + data = super().dict(include_id=include_id) + data["data"].update({"threshold": self.count_min}) + return data + + +class NonSimplifiableAnd(CommutativeOperator): + """ + A NonSimplifiableAnd object represents a logical AND operation that cannot be simplified. + + Simplified here means that when this operator is used on a single argument, still this operator is returned + instead of the argument itself, as is the case with the sympy.And operator. + + The reason for this operator is that if there is a single population or intervention criterion, the And/Or + operators would simplify to the criterion itself. In that case, the _whole_ population or intervention expression of + the respective population/intervention pair, which should be written to the database with criterion_id = None, would + be lost (i.e. not written, because there is no graph execution node that would perform this. This operator prevents + that. + """ + + def __new__(cls, *args: Any, **kwargs: Any) -> "NonSimplifiableAnd": + """ + Create a new NonSimplifiableAnd object. + """ + return cast(NonSimplifiableAnd, super().__new__(cls, *args, **kwargs)) + + +class NonSimplifiableOr(CommutativeOperator): + """ + A NonSimplifiableOr object represents a logical Or operation that cannot be simplified. + + Simplified here means that when this operator is used on a single argument, still this operator is returned + instead of the argument itself, as is the case with the sympy.Or operator. + + The reason for this operator is that if there is a single population or intervention criterion, the And/Or + operators would simplify to the criterion itself. In that case, the _whole_ population or intervention expression of + the respective population/intervention pair, which should be written to the database with criterion_id = None, would + be lost (i.e. not written, because there is no graph execution node that would perform this. This operator prevents + that. + """ + + def __new__(cls, *args: Any, **kwargs: Any) -> "NonSimplifiableOr": + """ + Create a new NonSimplifiableOr object. + """ + return cast(NonSimplifiableOr, super().__new__(cls, *args, **kwargs)) + + +class BinaryNonCommutativeOperator(BooleanFunction, SerializableABC): + """ + Base class for binary non-commutative operators. + + This class should not be instantiated directly but used as a base for specific + binary non-commutative operators like LeftDependentToggle. + """ + + def update_args(self, *args: Any) -> None: + """ + Update the arguments of the expression. + + :param args: The new arguments. + """ + if len(args) != 2: + raise ValueError( + f"{self.__class__.__name__} requires exactly two arguments" + ) + super().update_args(*args) + + def __new__( + cls, left: BaseExpr, right: BaseExpr, **kwargs: Any + ) -> "BinaryNonCommutativeOperator": + """ + Create a new BinaryNonCommutativeOperator object. + """ + return cast( + BinaryNonCommutativeOperator, super().__new__(cls, left, right, **kwargs) + ) + + @property + def left(self) -> Expr: + """Returns the left operand""" + return self.args[0] + + @property + def right(self) -> Expr: + """Returns the right operand""" + return self.args[1] + + def dict(self, include_id: bool = False) -> dict: + """ + Get a dictionary representation of the object. + """ + data = super().dict(include_id=include_id) + del data["data"]["args"] + data["data"].update( + { + "left": arg_to_dict(self.left, include_id=include_id), + "right": arg_to_dict(self.right, include_id=include_id), + } + ) + return data + + +class LeftDependentToggle(BinaryNonCommutativeOperator): + """ + A LeftDependentToggle object represents a logical AND operation if the left operand is positive, + otherwise it returns NOT_APPLICABLE. + """ + + +class ConditionalFilter(BinaryNonCommutativeOperator): + """ + A ConditionalFilter object returns the right operand if the left operand is POSITIVE, + and NEGATIVE otherwise + + + A conditional filter returns `right` iff `left` is POSITIVE, otherwise NEGATIVE. + + | left | right | Result | + |----------|----------|----------| + | NEGATIVE | * | NEGATIVE | + | NO_DATA | * | NEGATIVE | + | POSITIVE | POSITIVE | POSITIVE | + | POSITIVE | NEGATIVE | NEGATIVE | + | POSITIVE | NO_DATA | NO_DATA | + """ diff --git a/execution_engine/util/serializable.py b/execution_engine/util/serializable.py new file mode 100644 index 00000000..768fe13f --- /dev/null +++ b/execution_engine/util/serializable.py @@ -0,0 +1,547 @@ +import abc +import inspect +import json +from typing import Any, Dict, Self, final + +from pydantic import BaseModel + +from execution_engine.util import datetime_converter + +__class_registry: dict[str, type] = {} +""" +Registry of classes that can be serialized. +""" + + +def register_class(cls: type) -> type: + """ + Register a class for serialization. + + This function may be used as a decorator to register a class in the serialization registry. + + :param cls: The class to register. + :return: The same class (for decorator chaining). + :raises ValueError: If the class name is already registered. + """ + if cls.__name__ in __class_registry: + raise ValueError(f"Class {cls.__name__} is already registered.") + + __class_registry[cls.__name__] = cls + return cls + + +def resolve_class(name: str) -> type: + """ + Resolve a registered class by name. + + :param name: The name of the class to retrieve. + :return: The corresponding class object. + :raises ValueError: If the class is not found in the registry. + """ + cls = __class_registry.get(name) + + if cls is None: + raise ValueError(f"Class {name} is not registered.") + + return cls + + +def is_class_registered(name: str) -> bool: + """ + Check if a class is registered under the given name. + + :param name: The name of the class to check. + :return: True if the class is registered, False otherwise. + """ + return name in __class_registry + + +class RegisteredPostInitMeta(type): + """ + Metaclass that automatically registers a class for serialization and calls + a custom __post_init__ method (if defined) once after regular object initialization. + """ + + def __call__(cls, *args: Any, do_post_init: bool = True, **kwargs: Any) -> Self: + """ + Create and return a new instance of the class, then call __post_init__ if defined. + + This overrides the default object construction process to perform any + custom post-initialization logic defined in __post_init__. If __post_init__ + exists, it is called exactly once, after the instance is created. + + :param args: Positional arguments used during object creation. + :param kwargs: Keyword arguments used during object creation. + :return: The newly created instance of the class. + """ + instance = super().__call__(*args, **kwargs) + + if ( + do_post_init + and hasattr(instance, "__post_init__") + and not getattr(instance, "_post_initialized", False) + ): + instance.__post_init__() + instance._post_initialized = True + + return instance + + def __new__(mcs, name: str, bases: tuple[type, ...], attrs: dict[str, Any]) -> Self: + """ + Create and return a new class, registering it in the serialization registry. + + This method intercepts the class creation process itself. It registers the + newly defined class to __class_registry, allowing instances to be properly + deserialized later. + + :param name: The name of the newly created class. + :param bases: A tuple of base classes. + :param attrs: A dictionary of attributes/methods for the new class. + :return: The newly created class object. + :raises ValueError: If the class is already registered. + """ + new_class = super().__new__(mcs, name, bases, attrs) + register_class(new_class) + return new_class + + +def immutable_setattr(self: Self, key: str, value: Any) -> None: + """ + Prevent setting attributes on an immutable object. + + This function is assigned to an instance's __setattr__ method in order to enforce + immutability after object creation. Any attempt to set an attribute on the instance + after initialization will raise an AttributeError. + + :param self: The instance on which the attribute assignment was attempted. + :param key: The name of the attribute being set. + :param value: The value being assigned. + :raises AttributeError: Always, to enforce immutability. + """ + raise AttributeError( + f"Cannot set attribute {key} on immutable object {self.__class__.__name__}" + ) + + +class Serializable(metaclass=RegisteredPostInitMeta): + """ + Base class for making objects serializable. + + Stores construction arguments, manages an optional database ID, and provides serialization + and deserialization to and from dictionaries or JSON. + + Note that Serializable classes are immutable. This means that once an object is created, + its attributes cannot be changed. This is enforced by overriding the __setattr__ method in + the __post_init__ method. + The rationale behind this is to provide a fixed hash value for the object, which is + calculated only once during the object's lifetime. This is important for caching and + serialization purposes. + """ + + _id: int | None = None + """ + The id is used in the database tables. + """ + + _hash: int + """ + The hash of the object. This is calculated based on the class name and the JSON representation + of the object. It is used to ensure that the object is immutable. + """ + + def __post_init__(self) -> None: + """ + Create a new instance of the object. + """ + self.rehash() + + self.__setattr__ = immutable_setattr # type: ignore[assignment] + + def set_id(self, value: int, overwrite: bool = False) -> None: + """ + Assigns the database ID to the object. This can only be done once. + + This ID corresponds to the primary key in the database and is set + when the object is persisted. + + :param value: The database ID assigned to the object. + :param overwrite: If True, allows overwriting an existing ID. + :raises ValueError: If the ID has already been set. + """ + if self._id is not None and not overwrite: + raise ValueError("Database ID has already been set!") + self._id = value + + @property + def id(self) -> int: + """ + Retrieves the database ID of the object. + + This ID is only available after the object has been stored in the database. + + :return: The database ID, or None if the object has not been stored yet. + """ + if self._id is None: + raise ValueError("Database ID has not been set yet!") + return self._id + + def is_persisted(self) -> bool: + """ + Returns True if the object has been stored in the database. + """ + return self._id is not None + + def dict(self, include_id: bool = False) -> dict: + """ + Get a dictionary representation of the object. + """ + data = {} + + def serialize(val: Any) -> Any: + if isinstance(val, Serializable): + return val.dict(include_id=include_id) + elif isinstance(val, BaseModel): + return { + "type": val.__class__.__name__, + "data": val.model_dump(), + } + elif isinstance(val, (list, tuple)): + return type(val)(serialize(item) for item in val) + return val + + for key, val in self.get_instance_variables().items(): # type: ignore[union-attr] + data[key] = serialize(val) + + if include_id and self._id is not None: + data["_id"] = self._id + + return {"type": self.__class__.__name__, "data": data} + + def get_instance_variables(self, immutable: bool = False) -> Dict[str, Any] | tuple: + """ + Get the instance variables of the object. + + This is only required if the subclass doesn't provide an own implementation of dict(). + """ + raise NotImplementedError( + "Method get_instance_variables must be implemented in subclasses." + ) + + @final + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> Self: + """ + Create an object from a dictionary produced by dict(). + + Expected format: + { + "type": , + "data": { + "_id": ... , (optional) + "args": [...], (optional, positional arguments) + ... any additional keys ... + } + } + """ + class_name = data["type"] + content = data["data"] + + # Look up the target class from our registry + sub_cls = resolve_class(class_name) + + # Extract _id if present + db_id = content.pop("_id", None) + + pos_args = content.pop("args", []) + var_kwargs = content + + def deserialize_item(item: Any) -> Any: + if isinstance(item, dict) and "type" in item: + return cls.from_dict(item) + return item + + pos_args = [deserialize_item(arg) for arg in pos_args] + var_kwargs = {k: deserialize_item(v) for k, v in var_kwargs.items()} + + obj = sub_cls(*pos_args, **var_kwargs) + + if db_id is not None: + obj.set_id(db_id) + + return obj + + def json(self) -> bytes: + """ + Get a JSON representation of the object. + + The json excludes the id, as this is auto-inserted by the database + and not known during the creation of the object. + """ + s_json = self.dict() + + return json.dumps(s_json, default=datetime_converter, sort_keys=True).encode() + + @classmethod + def from_json(cls, data: str | bytes) -> Self: + """ + Create a combination from a JSON string. + """ + return cls.from_dict(json.loads(data)) + + def __eq__(self, other: Any) -> bool: + """ + Check if two objects are equal. + """ + if not isinstance(other, self.__class__): + return False + + return hash(self) == hash(other) + + def rehash(self) -> None: + """ + Recalculate the hash of the object. + """ + self._hash = hash(self.__class__.__name__.encode() + self.json()) + + def __hash__(self) -> int: + """ + Get the hash of the object. + """ + return self._hash + + # def __reduce__(self) -> Tuple[Callable, tuple]: + # """ + # Support pickling of the object. + # """ + # return self.__class__.from_dict, (self.dict(include_id=True),) + + @final + def __repr__(self) -> str: + """ + Get a string representation of the object. + """ + rep_dict = self.dict(include_id=False) + data = rep_dict["data"].copy() + + return self.build_repr(rep_dict["type"], data, indent=0) + + def build_repr(self, cls: str, data: Dict, indent: int) -> str: + """ + Build a string representation of the object. + """ + indent_unit = " " + + def pformat_value(val: Any, indent: int) -> str: + current_indent = indent_unit * indent + next_indent = indent_unit * (indent + 1) + # Check for a serialized object (dict with "type" and "data") + if isinstance(val, dict) and "type" in val and "data" in val: + + if is_class_registered(val["type"]) and issubclass( + resolve_class(val["type"]), BaseModel + ): + flat_data = ", ".join( + f"{k}={repr(v)}" for k, v in val["data"].items() + ) + return f"{val['type']}({flat_data})" + + return self.build_repr(val["type"], val["data"], indent + 1) + elif isinstance(val, list): + if not val: + return "[]" + items = [pformat_value(item, indent + 1) for item in val] + return ( + "[\n" + + ",\n".join(next_indent + item for item in items) + + "\n" + + current_indent + + "]" + ) + elif isinstance(val, tuple): + if not val: + return "()" + # Special case for single-element tuples + if len(val) == 1: + item = pformat_value(val[0], indent + 1) + return "(\n" + next_indent + item + ",\n" + current_indent + ")" + else: + items = [pformat_value(item, indent + 1) for item in val] + return ( + "(\n" + + ",\n".join(next_indent + item for item in items) + + "\n" + + current_indent + + ")" + ) + elif isinstance(val, dict): + if not val: + return "{}" + items = [] + for k, v in val.items(): + formatted_v = pformat_value(v, indent + 1) + items.append(f"{next_indent}{repr(k)}: {formatted_v}") + return "{\n" + ",\n".join(items) + "\n" + current_indent + "}" + else: + return repr(val) + + # Extract positional arguments (if any) + pos_args = data.pop("args", []) + pos_args_str = [pformat_value(arg, indent=indent + 1) for arg in pos_args] + + # Format remaining keyword arguments + kw_args_str = [ + f"{key}={pformat_value(value, indent=indent+1)}" + for key, value in data.items() + ] + + # Combine both sets of arguments + all_args = pos_args_str + kw_args_str + outer_indent = indent_unit * indent + inner_indent = indent_unit * (indent + 1) + + if all_args: + args_joined = ",\n".join(inner_indent + arg for arg in all_args) + return f"{cls}(\n{args_joined}\n{outer_indent})" + else: + return f"{cls}()" + + def __str__(self) -> str: + """ + Get a string representation of the object. + """ + data = {} + + for key, val in self.get_instance_variables().items(): # type: ignore[union-attr] + data[key] = str(val) + + return f"{self.__class__.__name__}({data})" + + +def get_constructor_vars(cls: type) -> set[str]: + """ + Get the variables that are passed to the constructor of a class. + + :param cls: The class. + :return: A set of variable names. + """ + self_var = "self" + original_init = cls.__init__ + if cls.__init__ == object.__init__ and cls.__new__ != object.__new__: + original_init = cls.__new__ + self_var = "cls" + + sig = inspect.signature(original_init) + + return {arg for arg in sig.parameters if arg != self_var} + + +class SerializableDataClassMeta(RegisteredPostInitMeta): + """ + Base class for making objects serializable. Stores construction + arguments, manages an optional database ID, and provides serialization + and deserialization to and from dictionaries or JSON. + """ + + def __call__(cls, *args: Any, **kwargs: Any) -> Self: + """ + Creates and returns a new instance of the class, ensuring that arguments + are assigned to protected attributes. Uses the function signature to bind + positional and keyword arguments and raises a ValueError if protected + attributes are missing. + + :param args: Positional arguments for object creation. + :param kwargs: Keyword arguments for object creation. + :return: Newly created instance of the class. + """ + instance = super().__call__(*args, do_post_init=False, **kwargs) + + self_var = "self" + original_init = cls.__init__ + if cls.__init__ == object.__init__ and cls.__new__ != object.__new__: + original_init = cls.__new__ + self_var = "cls" + + sig = inspect.signature(original_init) + if self_var == "cls": + # For __new__, the first argument should be the class (instance's class) + bound = sig.bind(cls, *args, **kwargs) + else: + # For __init__, the first argument is the instance + bound = sig.bind(instance, *args, **kwargs) + + bound.arguments.pop(self_var, None) # Remove 'self' or 'cls' + + protected_instance_vars = set( + arg + for arg in vars(instance) + if arg.startswith("_") and not arg.startswith("__") + ) + + # Check if the set of arguments passed to __init__ is a subset of the protected instance variables + if not set(f"_{arg}" for arg in bound.arguments) <= protected_instance_vars: + raise AttributeError( + "All arguments passed to __init__ must be assigned to the instance as protected attributes" + " in a SerializableDataClass." + ) + + if hasattr(instance, "__post_init__") and not getattr( + instance, "_post_initialized", False + ): + instance.__post_init__() + instance._post_initialized = True # type: ignore[attr-defined] + + return instance + + +class SerializableABCMeta(RegisteredPostInitMeta, abc.ABCMeta): + """ + Metaclass that combines SerializableDataClassMeta logic with ABCMeta + to allow abstract methods in serializable classes. + """ + + +class SerializableABC(metaclass=SerializableABCMeta): + """ + Abstract base class for serializable objects. This class allows + the use of abstract methods in serializable classes. + """ + + +class SerializableDataClassABCMeta(SerializableDataClassMeta, abc.ABCMeta): + """ + Metaclass that combines SerializableDataClassMeta logic with ABCMeta + to allow abstract methods in serializable data classes. + """ + + +class SerializableDataClass(Serializable, metaclass=SerializableDataClassMeta): + """ + Serializable data class that ensures arguments passed to __init__ + are protected attributes. Automatically registers subclasses and + supports dict() and JSON exports. + """ + + def get_instance_variables(self, immutable: bool = False) -> dict | tuple: + """ + Get a dictionary representation of the criterion. + """ + + if immutable: + raise NotImplementedError( + "get_instance_variables() must be implemented in subclasses." + ) + + return { + var: getattr(self, f"_{var}") + for var in get_constructor_vars(self.__class__) + } + + +class SerializableDataClassABC( + SerializableDataClass, metaclass=SerializableDataClassABCMeta +): + """ + Abstract variant of a serializable data class. Requires subclasses + to implement abstract methods, while keeping the serialization + features and protected attribute checks. + """ diff --git a/execution_engine/util/temporal_logic_util.py b/execution_engine/util/temporal_logic_util.py new file mode 100644 index 00000000..42c06e51 --- /dev/null +++ b/execution_engine/util/temporal_logic_util.py @@ -0,0 +1,156 @@ +from datetime import time + +from execution_engine.util import logic +from execution_engine.util.enum import TimeIntervalType + + +def Presence( + criterion: logic.BaseExpr, + *, + interval_type: TimeIntervalType | None = None, + start_time: time | None = None, + end_time: time | None = None, + interval_criterion: logic.BaseExpr | None = None, +) -> logic.TemporalMinCount: + """ + Create a presence combination of criteria. + """ + return logic.TemporalMinCount( + criterion, + threshold=1, + interval_type=interval_type, + start_time=start_time, + end_time=end_time, + interval_criterion=interval_criterion, + ) + + +def MinCount( + criterion: logic.BaseExpr, + *, + threshold: int, + interval_type: TimeIntervalType | None = None, + start_time: time | None = None, + end_time: time | None = None, + interval_criterion: logic.BaseExpr | None = None, +) -> logic.TemporalMinCount: + """ + Create a minimum count combination of criteria. + """ + return logic.TemporalMinCount( + criterion, + threshold=threshold, + interval_type=interval_type, + start_time=start_time, + end_time=end_time, + interval_criterion=interval_criterion, + ) + + +def MaxCount( + criterion: logic.BaseExpr, + *, + threshold: int, + interval_type: TimeIntervalType | None = None, + start_time: time | None = None, + end_time: time | None = None, + interval_criterion: logic.BaseExpr | None = None, +) -> logic.TemporalMaxCount: + """ + Create a maximum count combination of criteria. + """ + return logic.TemporalMaxCount( + criterion, + threshold=threshold, + interval_type=interval_type, + start_time=start_time, + end_time=end_time, + interval_criterion=interval_criterion, + ) + + +def ExactCount( + criterion: logic.BaseExpr, + *, + threshold: int, + interval_type: TimeIntervalType | None = None, + start_time: time | None = None, + end_time: time | None = None, + interval_criterion: logic.BaseExpr | None = None, +) -> logic.TemporalExactCount: + """ + Create an exact count combination of criteria. + """ + return logic.TemporalExactCount( + criterion, + threshold=threshold, + interval_type=interval_type, + start_time=start_time, + end_time=end_time, + interval_criterion=interval_criterion, + ) + + +def MorningShift( + criterion: logic.BaseExpr, +) -> logic.TemporalMinCount: + """ + Create a morning shift combination of criteria. + """ + return Presence(criterion, interval_type=TimeIntervalType.MORNING_SHIFT) + + +def AfternoonShift( + criterion: logic.BaseExpr, +) -> logic.TemporalMinCount: + """ + Create an afternoon shift combination of criteria. + """ + return Presence(criterion, interval_type=TimeIntervalType.AFTERNOON_SHIFT) + + +def NightShift( + criterion: logic.BaseExpr, +) -> logic.TemporalMinCount: + """ + Create a night shift combination of criteria. + """ + return Presence(criterion, interval_type=TimeIntervalType.NIGHT_SHIFT) + + +def NightShiftBeforeMidnight( + criterion: logic.BaseExpr, +) -> logic.TemporalMinCount: + """ + Create a night shift before midnight combination of criteria. + """ + return Presence( + criterion, interval_type=TimeIntervalType.NIGHT_SHIFT_BEFORE_MIDNIGHT + ) + + +def NightShiftAfterMidnight( + criterion: logic.BaseExpr, +) -> logic.TemporalMinCount: + """ + Create a night shift after midnight combination of criteria. + """ + return Presence( + criterion, interval_type=TimeIntervalType.NIGHT_SHIFT_AFTER_MIDNIGHT + ) + + +def Day( + criterion: logic.BaseExpr, +) -> logic.TemporalMinCount: + """ + Create a day combination of criteria. + """ + return Presence(criterion, interval_type=TimeIntervalType.DAY) + + +def AnyTime(criterion: logic.BaseExpr) -> logic.TemporalMinCount: + """ + Any time overlap + """ + return Presence(criterion, interval_type=TimeIntervalType.ANY_TIME) diff --git a/execution_engine/util/types.py b/execution_engine/util/types.py index cf099700..7ff9982a 100644 --- a/execution_engine/util/types.py +++ b/execution_engine/util/types.py @@ -5,6 +5,7 @@ import pytz from pydantic import BaseModel, ConfigDict, field_validator, model_validator +from execution_engine.util import serializable from execution_engine.util.enum import TimeUnit from execution_engine.util.interval import ( DateTimeInterval, @@ -17,6 +18,7 @@ PersonIntervals = dict[int, Any] +@serializable.register_class class TimeRange(BaseModel): """ A time range. @@ -91,6 +93,7 @@ def model_dump(self, *args: Any, **kwargs: Any) -> dict[str, datetime]: } +@serializable.register_class class Timing(BaseModel): """ The timing of a criterion. @@ -212,6 +215,7 @@ def model_dump( return data +@serializable.register_class class Dosage(Timing): """ A dosage consisting of a dose in addition to the Timing fields. diff --git a/execution_engine/util/value/time.py b/execution_engine/util/value/time.py index da2de370..8d905187 100644 --- a/execution_engine/util/value/time.py +++ b/execution_engine/util/value/time.py @@ -5,6 +5,7 @@ from sqlalchemy import TableClause from sqlalchemy.sql.functions import concat, func +from execution_engine.util import serializable from execution_engine.util.enum import TimeUnit from execution_engine.util.value import ucum_to_postgres from execution_engine.util.value.value import ( @@ -16,6 +17,7 @@ ) +@serializable.register_class class ValuePeriod(ValueNumeric[NonNegativeInt, TimeUnit]): """ A non-negative integer value with a unit of type TimeUnit, where value_min and value_max are not allowed. @@ -52,6 +54,7 @@ def to_sql_interval(self) -> ColumnElement: ) +@serializable.register_class class ValueCount(ValueNumeric[NonNegativeInt, None]): """ A non-negative integer value without a unit. @@ -71,6 +74,7 @@ def supports_units(self) -> bool: return False +@serializable.register_class class ValueDuration(ValueNumeric[float, TimeUnit]): """ A float value with a unit of type TimeUnit. @@ -104,6 +108,7 @@ def to_sql( return super().to_sql(table, c / c_duration_seconds, with_unit) +@serializable.register_class class ValueFrequency(Value): """ A non-negative integer value with no unit, with a Period. diff --git a/execution_engine/util/value/value.py b/execution_engine/util/value/value.py index bc6c8f93..fcda3dc6 100644 --- a/execution_engine/util/value/value.py +++ b/execution_engine/util/value/value.py @@ -18,9 +18,14 @@ "ValueNumeric", "ValueNumber", "ValueConcept", + "ValueScalar", "check_int", + "check_unit_none", + "check_value_min_max_none", ] +from execution_engine.util import serializable + ValueT = TypeVar("ValueT") UnitT = TypeVar("UnitT") ValueNumericClassT = TypeVar("ValueNumericClassT", bound="ValueNumeric") @@ -90,6 +95,7 @@ def get_precision(value: float | int) -> int: ) # the number of digits after the decimal point is the precision +@serializable.register_class class Value(BaseModel, ABC): """A value in a criterion.""" @@ -145,6 +151,7 @@ def supports_units(self) -> bool: return hasattr(self, "unit") +@serializable.register_class class ValueNumeric(Value, Generic[ValueT, UnitT]): """ A value of type number. @@ -294,6 +301,7 @@ def eps(number: ValueT) -> float: return and_(*clauses) +@serializable.register_class class ValueNumber(ValueNumeric[float, Concept]): """ A float value with a unit of type Concept. @@ -307,6 +315,7 @@ def _get_unit_clause(self, table: TableClause | None) -> ColumnClause: return c_unit == self.unit.concept_id +@serializable.register_class class ValueScalar(ValueNumeric[float, None]): """ A numeric value without a unit. @@ -314,9 +323,6 @@ class ValueScalar(ValueNumeric[float, None]): unit: None = None - _validate_value = field_validator("value", mode="before")(check_int) - _validate_value_min = field_validator("value_min", mode="before")(check_int) - _validate_value_max = field_validator("value_max", mode="before")(check_int) _validate_no_unit = field_validator("unit", mode="before")(check_unit_none) def supports_units(self) -> bool: @@ -326,6 +332,7 @@ def supports_units(self) -> bool: return False +@serializable.register_class class ValueConcept(Value): """ A value of type concept. diff --git a/scripts/execute.py b/scripts/execute.py index 724fbb80..c17ae152 100644 --- a/scripts/execute.py +++ b/scripts/execute.py @@ -93,12 +93,12 @@ urls = [ "covid19-inpatient-therapy/recommendation/no-therapeutic-anticoagulation", - "sepsis/recommendation/ventilation-plan-ards-tidal-volume", "covid19-inpatient-therapy/recommendation/ventilation-plan-ards-tidal-volume", "covid19-inpatient-therapy/recommendation/covid19-ventilation-plan-peep", "covid19-inpatient-therapy/recommendation/prophylactic-anticoagulation", "covid19-inpatient-therapy/recommendation/therapeutic-anticoagulation", "covid19-inpatient-therapy/recommendation/covid19-abdominal-positioning-ards", + "sepsis/recommendation/ventilation-plan-ards-tidal-volume", ] start_datetime = pendulum.parse("2020-01-01 00:00:00+01:00") diff --git a/tests/_fixtures/omop_fixture.py b/tests/_fixtures/omop_fixture.py index 174dece3..c682bc55 100644 --- a/tests/_fixtures/omop_fixture.py +++ b/tests/_fixtures/omop_fixture.py @@ -177,7 +177,7 @@ def celida_recommendation( recommendation_id=recommendation_id, pi_pair_url="https://example.com", pi_pair_name="my_pair", - pi_pair_hash=hash("my_pair"), + pi_pair_hash=str(hash("my_pair")), ) db_session.add(pi_pair) db_session.commit() @@ -185,7 +185,7 @@ def celida_recommendation( criterion = Criterion( criterion_id=criterion_id, criterion_description="my_description", - criterion_hash=hash("my_criterion"), + criterion_hash=str(hash("my_criterion")), ) db_session.add(criterion) db_session.commit() diff --git a/tests/execution_engine/__init__.py b/tests/execution_engine/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/execution_engine/converter/action/test_assessment.py b/tests/execution_engine/converter/action/test_assessment.py index 821dc736..ad2a09b8 100644 --- a/tests/execution_engine/converter/action/test_assessment.py +++ b/tests/execution_engine/converter/action/test_assessment.py @@ -33,7 +33,7 @@ class TestAssessmentAction: def test_assessment_action(self, code, timing, criterion_class): action = AssessmentAction(exclude=False, code=code, timing=timing) - criterion = action.to_criterion() + criterion = action.to_positive_expression() assert isinstance(criterion, criterion_class) assert criterion._concept == code diff --git a/tests/execution_engine/converter/action/test_drug_administration.py b/tests/execution_engine/converter/action/test_drug_administration.py index c58557ba..2b101a85 100644 --- a/tests/execution_engine/converter/action/test_drug_administration.py +++ b/tests/execution_engine/converter/action/test_drug_administration.py @@ -26,7 +26,7 @@ def test_single_dose(self): dosages=[dosage_def], ) - criterion = action.to_criterion() + criterion = action.to_expression() assert isinstance(criterion, DrugExposure) assert criterion._dose == dosage @@ -64,8 +64,8 @@ def test_multiple_doses(self): ], ) - comb = action.to_criterion() - criteria = list(comb) + expr = action.to_expression() + criteria = list(expr.args) assert len(criteria) == 3 assert all(isinstance(c, DrugExposure) for c in criteria) diff --git a/tests/execution_engine/converter/test_converter.py b/tests/execution_engine/converter/test_converter.py index e7d1b260..fc2ed131 100644 --- a/tests/execution_engine/converter/test_converter.py +++ b/tests/execution_engine/converter/test_converter.py @@ -16,10 +16,7 @@ parse_value, select_value, ) -from execution_engine.omop.criterion.abstract import Criterion -from execution_engine.omop.criterion.combination.logical import ( - LogicalCriterionCombination, -) +from execution_engine.util import logic from execution_engine.util.value import ValueConcept, ValueNumber @@ -149,15 +146,16 @@ class TestCriterionConverter: # Continuing the test classes for CriterionConverter and CriterionConverterFactory class MockCriterionConverter(CriterionConverter): - @classmethod - def from_fhir(cls, fhir_definition: Element) -> "CriterionConverter": - return cls(exclude=False) @classmethod def valid(cls, fhir_definition: Element) -> bool: return fhir_definition.id == "valid" - def to_positive_criterion(self) -> Criterion | LogicalCriterionCombination: + @classmethod + def from_fhir(cls, fhir_definition: Element) -> "CriterionConverter": + return cls(exclude=False) + + def to_positive_expression(self) -> logic.Symbol: raise NotImplementedError() def test_criterion_converter_factory_register(self): diff --git a/tests/execution_engine/fhir/test_recommendation.py b/tests/execution_engine/fhir/test_recommendation.py index 0aa496ca..b69a0e8f 100644 --- a/tests/execution_engine/fhir/test_recommendation.py +++ b/tests/execution_engine/fhir/test_recommendation.py @@ -65,7 +65,7 @@ def mock_fetch_recommendation(self, test_class): {"status": "draft", "action": []} ) with patch.object( - test_class, "fetch_recommendation", return_value=plan_definition + test_class, "fetch_recommendation", return_value=(plan_definition, None) ) as _fixture: yield _fixture @@ -125,24 +125,7 @@ def test_recommendation_load_with_unknown_type(self, mock_fetch_resource_unknown # Test with pytest.raises( - ValueError, match=r"Unknown recommendation type: unknown-type" - ): - _ = Recommendation( - canonical_url, - package_version="latest", - fhir_connector=FHIRClient("http://fhir.example.com"), - ) - - def test_recommendation_fetch_with_no_partof_extension( - self, mock_fetch_resource_no_partOf - ): - # Setup - canonical_url = "http://test.com/PlanDefinition/123" - - # Test - with pytest.raises( - ValueError, - match=r"No partOf extension found in PlanDefinition, can't fetch recommendation.", + ValueError, match=r"Unknown PlanDefinition type: unknown-type" ): _ = Recommendation( canonical_url, diff --git a/tests/execution_engine/omop/__init__.py b/tests/execution_engine/omop/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/execution_engine/omop/cohort/test_cohort_recommendation.py b/tests/execution_engine/omop/cohort/test_cohort_recommendation.py index 750e285e..37f8ab0a 100644 --- a/tests/execution_engine/omop/cohort/test_cohort_recommendation.py +++ b/tests/execution_engine/omop/cohort/test_cohort_recommendation.py @@ -1,9 +1,10 @@ import pytest -from execution_engine.omop.cohort import PopulationInterventionPair +from execution_engine.omop.cohort import PopulationInterventionPairExpr from execution_engine.omop.cohort.recommendation import Recommendation from execution_engine.omop.concepts import Concept from execution_engine.omop.criterion.visit_occurrence import ActivePatients +from execution_engine.util import logic from tests.mocks.criterion import MockCriterion @@ -11,12 +12,9 @@ class TestRecommendation: def test_serialization(self): # Register the mock criterion class - from execution_engine.omop.criterion import factory - - factory.register_criterion_class("MockCriterion", MockCriterion) original = Recommendation( - pi_pairs=[], + expr=logic.BooleanFunction(), base_criterion=MockCriterion("c"), name="foo", title="bar", @@ -47,7 +45,7 @@ def test_serialization_with_active_patients(self, concept): # specifically because there used to be a problem with the # serialization of that combination. original = Recommendation( - pi_pairs=[], + expr=logic.BooleanFunction(), base_criterion=ActivePatients(), name="foo", title="bar", @@ -60,11 +58,12 @@ def test_serialization_with_active_patients(self, concept): deserialized = Recommendation.from_json(json) assert original == deserialized - def test_database_serialization(self, concept): + def test_database_serialization( + self, + concept, + db_session, # use_db_session fixture for a proper database clean up after this test + ): from execution_engine.execution_engine import ExecutionEngine - from execution_engine.omop.criterion.factory import register_criterion_class - - register_criterion_class("MockCriterion", MockCriterion) e = ExecutionEngine(verbose=False) @@ -74,18 +73,18 @@ def test_database_serialization(self, concept): population_criterion = MockCriterion("p") intervention_criterion = MockCriterion("i)") - pi_pair = PopulationInterventionPair( + pi_pair = PopulationInterventionPairExpr( + population_expr=population_criterion, + intervention_expr=intervention_criterion, name="foo", url="foo", base_criterion=base_criterion, - population=population_criterion, - intervention=intervention_criterion, ) pi_pairs = [pi_pair] recommendation = Recommendation( - pi_pairs=pi_pairs, + expr=pi_pair, base_criterion=ActivePatients(), name="foo", title="bar", @@ -105,33 +104,33 @@ def test_database_serialization(self, concept): with pytest.raises(ValueError, match=r"Database ID has not been set yet!"): assert pi_pair.id is None - for criterion in pi_pair.flatten(): - with pytest.raises( - ValueError, match=r"Database ID has not been set yet!" - ): - assert criterion.id is None + for criterion in recommendation.atoms(): + with pytest.raises(ValueError, match=r"Database ID has not been set yet!"): + assert criterion.id is None e.register_recommendation(recommendation) assert recommendation.id is not None assert recommendation.base_criterion.id is not None + for pi_pair in recommendation.population_intervention_pairs(): assert pi_pair.id is not None - for criterion in pi_pair.flatten(): - assert criterion.id is not None + + for criterion in recommendation.atoms(): + assert criterion.id is not None rec_loaded = e.load_recommendation_from_database(url) assert recommendation == rec_loaded assert rec_loaded.id == recommendation.id - assert len(rec_loaded._pi_pairs) == len(pi_pairs) + assert len(list(rec_loaded.population_intervention_pairs())) == 1 - for pi_pair_loaded, pi_pair in zip(rec_loaded._pi_pairs, pi_pairs): + for pi_pair_loaded, pi_pair in zip( + list(rec_loaded.population_intervention_pairs()), pi_pairs + ): assert pi_pair_loaded.id == pi_pair.id - assert len(pi_pair_loaded.flatten()) == len(pi_pair.flatten()) - - for criterion_loaded, criterion in zip( - pi_pair_loaded.flatten(), pi_pair.flatten() - ): - assert criterion.id == criterion_loaded.id + for criterion_loaded, criterion in zip( + rec_loaded.atoms(), recommendation.atoms() + ): + assert criterion.id == criterion_loaded.id diff --git a/tests/execution_engine/omop/cohort/test_population_intervention_pair.py b/tests/execution_engine/omop/cohort/test_population_intervention_pair.py index 12910d41..2993cbfc 100644 --- a/tests/execution_engine/omop/cohort/test_population_intervention_pair.py +++ b/tests/execution_engine/omop/cohort/test_population_intervention_pair.py @@ -1,6 +1,7 @@ from execution_engine.omop.cohort.population_intervention_pair import ( - PopulationInterventionPair, + PopulationInterventionPairExpr, ) +from execution_engine.util import logic from tests.mocks.criterion import MockCriterion @@ -8,14 +9,14 @@ class TestPopulationInterventionPair: def test_serialization(self): # Register the mock criterion class - from execution_engine.omop.criterion import factory - - factory.register_criterion_class("MockCriterion", MockCriterion) - - original = PopulationInterventionPair( - name="foo", url="bar", base_criterion=MockCriterion("c") + original = PopulationInterventionPairExpr( + population_expr=logic.NonSimplifiableAnd(MockCriterion("population")), + intervention_expr=logic.NonSimplifiableAnd(MockCriterion("intervention")), + name="foo", + url="bar", + base_criterion=MockCriterion("base"), ) json = original.json() - deserialized = PopulationInterventionPair.from_json(json) + deserialized = PopulationInterventionPairExpr.from_json(json) assert original == deserialized diff --git a/tests/execution_engine/omop/criterion/__init__.py b/tests/execution_engine/omop/criterion/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/execution_engine/omop/criterion/combination/__init__.py b/tests/execution_engine/omop/criterion/combination/__init__.py index 5829cf14..e69de29b 100644 --- a/tests/execution_engine/omop/criterion/combination/__init__.py +++ b/tests/execution_engine/omop/criterion/combination/__init__.py @@ -1,56 +0,0 @@ -from typing import Any, Dict, Self - -from sqlalchemy import Select, select - -from execution_engine.omop.criterion.abstract import ( - Criterion, - column_interval_type, - observation_end_datetime, - observation_start_datetime, -) -from execution_engine.util.interval import IntervalType - - -class NoopCriterion(Criterion): - """ - Select patients who are post-surgical in the timeframe between the day of the surgery and 6 days after the surgery. - """ - - _static = True - - def _create_query(self) -> Select: - """ - Get the SQL Select query for data required by this criterion. - """ - subquery = self.base_query().subquery() - - query = select( - subquery.c.person_id, - column_interval_type(IntervalType.POSITIVE), - observation_start_datetime.label("interval_start"), - observation_end_datetime.label("interval_end"), - ) - - query = self._filter_base_persons(query, c_person_id=subquery.c.person_id) - query = self._filter_datetime(query) - - return query - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> Self: - """ - Create an object from a dictionary. - """ - return cls() - - def description(self) -> str: - """ - Get a description of the criterion. - """ - return self.__class__.__name__ - - def dict(self) -> dict: - """ - Get a dictionary representation of the object. - """ - return {} diff --git a/tests/execution_engine/omop/criterion/combination/test_logical_combination.py b/tests/execution_engine/omop/criterion/combination/test_logical_combination.py index 62f9faee..6c83d4b2 100644 --- a/tests/execution_engine/omop/criterion/combination/test_logical_combination.py +++ b/tests/execution_engine/omop/criterion/combination/test_logical_combination.py @@ -6,15 +6,13 @@ import sympy from execution_engine.constants import CohortCategory -from execution_engine.omop.criterion.combination.logical import ( - LogicalCriterionCombination, - NonCommutativeLogicalCriterionCombination, -) from execution_engine.omop.criterion.condition_occurrence import ConditionOccurrence from execution_engine.omop.criterion.drug_exposure import DrugExposure from execution_engine.omop.criterion.measurement import Measurement +from execution_engine.omop.criterion.noop import NoopCriterion from execution_engine.omop.criterion.procedure_occurrence import ProcedureOccurrence from execution_engine.task.process import get_processing_module +from execution_engine.util import logic from execution_engine.util.types import Dosage, TimeRange from execution_engine.util.value import ValueNumber from tests._fixtures.concept import ( @@ -28,7 +26,6 @@ concept_unit_ml, ) from tests._testdata import concepts -from tests.execution_engine.omop.criterion.combination import NoopCriterion from tests.execution_engine.omop.criterion.test_criterion import TestCriterion, date_set from tests.functions import ( create_condition, @@ -51,196 +48,83 @@ def intervals_to_df(result, by=None): return df -class TestCriterionCombination: +class TestExpr: """ - Test class for testing criterion combinations (without database). + Test class for testing Expr """ @pytest.fixture def mock_criteria(self): return [MockCriterion(f"c{i}") for i in range(1, 6)] - def test_criterion_combination_init(self, mock_criteria): - operator = LogicalCriterionCombination.Operator( - LogicalCriterionCombination.Operator.AND - ) - combination = LogicalCriterionCombination( - operator=operator, - ) - - assert combination.operator == operator - assert len(combination) == 0 - - def test_criterion_combination_add(self, mock_criteria): - operator = LogicalCriterionCombination.Operator( - LogicalCriterionCombination.Operator.AND - ) - combination = LogicalCriterionCombination( - operator=operator, - ) - - for criterion in mock_criteria: - combination.add(criterion) - - assert len(combination) == len(mock_criteria) - - for idx, criterion in enumerate(combination): - assert criterion == mock_criteria[idx] + def test_expr_dict(self, mock_criteria): - def test_criterion_combination_dict(self, mock_criteria): - operator = LogicalCriterionCombination.Operator( - LogicalCriterionCombination.Operator.AND - ) - combination = LogicalCriterionCombination( - operator=operator, - ) + expr = logic.And(*mock_criteria) - for criterion in mock_criteria: - combination.add(criterion) + combination_dict = expr.dict() - combination_dict = combination.dict() assert combination_dict == { - "operator": "AND", - "threshold": None, - "criteria": [ - {"class_name": "MockCriterion", "data": criterion.dict()} - for criterion in mock_criteria - ], - "root": False, + "type": "And", + "data": {"args": [criterion.dict() for criterion in mock_criteria]}, } - def test_criterion_combination_from_dict(self, mock_criteria): - operator = LogicalCriterionCombination.Operator( - LogicalCriterionCombination.Operator.AND - ) - combination_data = { - "operator": "AND", - "threshold": None, - "criteria": [ - {"class_name": "MockCriterion", "data": criterion.dict()} - for criterion in mock_criteria - ], - "root": False, + def test_expr_from_dict(self, mock_criteria): + expr_data = { + "type": "And", + "data": {"args": [criterion.dict() for criterion in mock_criteria]}, } - # Register the mock criterion class - from execution_engine.omop.criterion import factory - - factory.register_criterion_class("MockCriterion", MockCriterion) + expr = logic.Expr.from_dict(expr_data) - combination = LogicalCriterionCombination.from_dict(combination_data) + assert len(expr.args) == len(mock_criteria) - assert combination.operator == operator - assert len(combination) == len(mock_criteria) - - for idx, criterion in enumerate(combination): + for idx, criterion in enumerate(expr.args): assert str(criterion) == str(mock_criteria[idx]) @pytest.mark.parametrize( - "operator", + "expr_class", [ - LogicalCriterionCombination.Operator( - LogicalCriterionCombination.Operator.AND, None - ), - LogicalCriterionCombination.Operator( - LogicalCriterionCombination.Operator.AT_LEAST, 10 - ), + logic.And, + lambda *args: logic.MinCount(*args, threshold=10), ], ) - def test_criterion_combination_serialization(self, operator, mock_criteria): + def test_criterion_combination_serialization(self, expr_class, mock_criteria): # Register the mock criterion class - from execution_engine.omop.criterion import factory - factory.register_criterion_class("MockCriterion", MockCriterion) + expr = expr_class(*mock_criteria) - combination = LogicalCriterionCombination( - operator=operator, - ) - for criterion in mock_criteria: - combination.add(criterion) + json = expr.json() + deserialized = logic.Expr.from_json(json) - json = combination.json() - deserialized = LogicalCriterionCombination.from_json(json) - assert combination == deserialized + assert expr == deserialized def test_noncommutative_logical_criterion_combination_serialization( self, mock_criteria ): - # Register the mock criterion class - from execution_engine.omop.criterion import factory - - factory.register_criterion_class("MockCriterion", MockCriterion) - - operator = NonCommutativeLogicalCriterionCombination.Operator( - NonCommutativeLogicalCriterionCombination.Operator.CONDITIONAL_FILTER - ) - combination = NonCommutativeLogicalCriterionCombination( - operator=operator, + expr = logic.ConditionalFilter( left=mock_criteria[0], right=mock_criteria[1], ) - json = combination.json() - deserialized = NonCommutativeLogicalCriterionCombination.from_json(json) - assert combination == deserialized + json = expr.json() + deserialized = logic.Expr.from_json(json) - @pytest.mark.parametrize("operator", ["AT_LEAST", "AT_MOST", "EXACTLY"]) - def test_operator_with_threshold(self, operator): - with pytest.raises( - AssertionError, match=f"Threshold must be set for operator {operator}" - ): - LogicalCriterionCombination.Operator(operator) + assert expr == deserialized - def test_operator(self): - with pytest.raises(AssertionError, match=""): - LogicalCriterionCombination.Operator("invalid") - - @pytest.mark.parametrize( - "operator, threshold", - [("AND", None), ("OR", None), ("AT_LEAST", 1), ("AT_MOST", 1), ("EXACTLY", 1)], + @pytest.mark.skip( + reason="repr does not have a fixed argument order, therefore test fails randomly" ) - def test_operator_str(self, operator, threshold): - op = LogicalCriterionCombination.Operator(operator, threshold) - - if operator in ["AT_LEAST", "AT_MOST", "EXACTLY"]: - assert ( - repr(op) - == f'LogicalCriterionCombination.Operator(operator="{operator}", threshold={threshold})' - ) - assert str(op) == f"{operator}(threshold={threshold})" - else: - assert ( - repr(op) - == f'LogicalCriterionCombination.Operator(operator="{operator}")' - ) - assert str(op) == f"{operator}" - - def test_repr(self): - operator = LogicalCriterionCombination.Operator( - LogicalCriterionCombination.Operator.AND - ) - combination = LogicalCriterionCombination( - operator=operator, - ) - - assert repr(combination) == ("LogicalCriterionCombination.And(\n" ")") - - def test_add_all(self): - operator = LogicalCriterionCombination.Operator( - LogicalCriterionCombination.Operator.AND - ) - combination = LogicalCriterionCombination( - operator=operator, - ) - - assert len(combination) == 0 + def test_repr(self, mock_criteria): + expr = logic.And(*mock_criteria) - combination.add_all([MockCriterion("c1"), MockCriterion("c2")]) + assert repr(expr) == ("LogicalCriterionCombination.And(\n" ")") - assert len(combination) == 2 + def test_expr_contains_criteria(self, mock_criteria): + expr = logic.And(*mock_criteria) + assert len(expr.args) == len(mock_criteria) - assert str(combination[0]) == str(MockCriterion("c1")) - assert str(combination[1]) == str(MockCriterion("c2")) + for i in range(len(mock_criteria)): + assert expr.args[i] == mock_criteria[i] class TestCriterionCombinationDatabase(TestCriterion): @@ -306,41 +190,43 @@ def run_criteria_test( threshold = None if c.func == sympy.And: - operator = LogicalCriterionCombination.Operator.AND + cls = logic.And elif c.func == sympy.Or: - operator = LogicalCriterionCombination.Operator.OR + cls = logic.Or elif isinstance(c.func, sympy.core.function.UndefinedFunction): if c.func.name in ["MinCount", "MaxCount", "ExactCount"]: assert args[0].is_number, "First argument must be a number (threshold)" - threshold = args[0] + threshold = int(args[0]) args = args[1:] if c.func.name == "MinCount": - operator = LogicalCriterionCombination.Operator.AT_LEAST + cls = lambda *args: logic.MinCount(*args, threshold=threshold) elif c.func.name == "MaxCount": - operator = LogicalCriterionCombination.Operator.AT_MOST + cls = lambda *args: logic.MaxCount(*args, threshold=threshold) elif c.func.name == "ExactCount": - operator = LogicalCriterionCombination.Operator.EXACTLY + cls = lambda *args: logic.ExactCount(*args, threshold=threshold) elif c.func.name == "AllOrNone": - operator = LogicalCriterionCombination.Operator.ALL_OR_NONE + cls = lambda *args: logic.AllOrNone(*args) elif c.func.name == "ConditionalFilter": - operator = None + cls = lambda *args: logic.ConditionalFilter(*args) else: raise ValueError(f"Unknown operator {c.func}") else: raise ValueError(f"Unknown operator {c.func}") - c1, c2, c3 = [ - c.copy() for c in criteria - ] # TODO(jmoringe): copy should no longer be necessary + # c1, c2, c3 = [ + # c for c in criteria + # ] # TODO(jmoringe): copy should no longer be necessary + + c1, c2, c3 = criteria for arg in args: if arg.is_Not: if arg.args[0].name == "c1": - c1 = LogicalCriterionCombination.Not(c1) + c1 = logic.Not(c1) elif arg.args[0].name == "c2": - c2 = LogicalCriterionCombination.Not(c2) + c2 = logic.Not(c2) elif arg.args[0].name == "c3": - c3 = LogicalCriterionCombination.Not(c3) + c3 = logic.Not(c3) else: raise ValueError(f"Unknown criterion {arg.args[0].name}") @@ -349,33 +235,25 @@ def run_criteria_test( if hasattr(c.func, "name") and c.func.name == "ConditionalFilter": assert len(c.args) == 2 - comb = NonCommutativeLogicalCriterionCombination.ConditionalFilter( + comb = logic.ConditionalFilter( left=symbols[str(c.args[0])], right=symbols[str(c.args[1])], ) else: - comb = LogicalCriterionCombination( - operator=LogicalCriterionCombination.Operator( - operator, threshold=threshold - ), + comb = cls( + *[symbols[str(symbol)] for symbol in c.atoms() if not symbol.is_number] ) - for symbol in c.atoms(): - if symbol.is_number: - continue - else: - comb.add(symbols[str(symbol)]) - if exclude: - comb = LogicalCriterionCombination.Not(comb) + comb = logic.Not(comb) noop_criterion = NoopCriterion() noop_criterion.set_id(1005) - noop_intervention = LogicalCriterionCombination.And(noop_criterion) + noop_intervention = logic.NonSimplifiableAnd(noop_criterion) self.register_criterion(noop_criterion, db_session) - self.insert_criterion_combination( + self.insert_expression( db_session, population=comb, intervention=noop_intervention, diff --git a/tests/execution_engine/omop/criterion/combination/test_temporal_combination.py b/tests/execution_engine/omop/criterion/combination/test_temporal_combination.py index 7501c6fa..9190b680 100644 --- a/tests/execution_engine/omop/criterion/combination/test_temporal_combination.py +++ b/tests/execution_engine/omop/criterion/combination/test_temporal_combination.py @@ -1,5 +1,4 @@ import datetime -from typing import Any, Dict, Self import pandas as pd import pendulum @@ -8,20 +7,16 @@ from execution_engine.constants import CohortCategory from execution_engine.omop.criterion.abstract import Criterion, column_interval_type -from execution_engine.omop.criterion.combination.logical import ( - LogicalCriterionCombination, -) -from execution_engine.omop.criterion.combination.temporal import ( - FixedWindowTemporalIndicatorCombination, - PersonalWindowTemporalIndicatorCombination, - TimeIntervalType, -) from execution_engine.omop.criterion.condition_occurrence import ConditionOccurrence from execution_engine.omop.criterion.drug_exposure import DrugExposure from execution_engine.omop.criterion.measurement import Measurement +from execution_engine.omop.criterion.noop import NoopCriterion from execution_engine.omop.criterion.procedure_occurrence import ProcedureOccurrence from execution_engine.omop.db.omop import tables +from execution_engine.omop.vocabulary import OMOP_SURGICAL_PROCEDURE from execution_engine.task.process import get_processing_module +from execution_engine.util import logic, temporal_logic_util +from execution_engine.util.enum import TimeIntervalType from execution_engine.util.interval import IntervalType from execution_engine.util.types import Dosage, TimeRange from execution_engine.util.value import ValueNumber @@ -36,7 +31,6 @@ concept_unit_mg, ) from tests._testdata import concepts -from tests.execution_engine.omop.criterion.combination import NoopCriterion from tests.execution_engine.omop.criterion.test_criterion import TestCriterion from tests.functions import ( create_condition, @@ -50,8 +44,6 @@ process = get_processing_module() -OMOP_SURGICAL_PROCEDURE = 4301351 # OMOP surgical procedure - def intervals_to_df(result, by=None): df = intervals_to_df_orig(result, by, process.normalize_interval) @@ -70,206 +62,128 @@ class TestFixedWindowTemporalIndicatorCombination: def mock_criteria(self): return [MockCriterion(f"c{i}") for i in range(1, 6)] - def test_criterion_combination_init(self, mock_criteria): - operator = FixedWindowTemporalIndicatorCombination.Operator( - FixedWindowTemporalIndicatorCombination.Operator.AT_LEAST, threshold=1 - ) - combination = FixedWindowTemporalIndicatorCombination( - operator=operator, - start_time=datetime.time(8, 0), - end_time=datetime.time(16, 0), - ) - - assert combination.operator == operator - assert len(combination) == 0 - - def test_criterion_combination_add(self, mock_criteria): - operator = FixedWindowTemporalIndicatorCombination.Operator( - FixedWindowTemporalIndicatorCombination.Operator.AT_LEAST, threshold=1 - ) - combination = FixedWindowTemporalIndicatorCombination( - operator=operator, - start_time=datetime.time(8, 0), - end_time=datetime.time(16, 0), - ) - - for criterion in mock_criteria: - combination.add(criterion) - - assert len(combination) == len(mock_criteria) - - for idx, criterion in enumerate(combination): - assert criterion == mock_criteria[idx] - def test_criterion_combination_dict(self, mock_criteria): - operator = FixedWindowTemporalIndicatorCombination.Operator( - FixedWindowTemporalIndicatorCombination.Operator.AT_LEAST, threshold=1 - ) - combination = FixedWindowTemporalIndicatorCombination( - operator=operator, + + expr = logic.TemporalMinCount( + *mock_criteria, + threshold=1, start_time=datetime.time(8, 0), end_time=datetime.time(16, 0), ) - for criterion in mock_criteria: - combination.add(criterion) - - combination_dict = combination.dict() - assert combination_dict == { - "operator": "AT_LEAST", - "threshold": 1, - "start_time": "08:00:00", - "end_time": "16:00:00", - "interval_type": None, - "criteria": [ - {"class_name": "MockCriterion", "data": criterion.dict()} - for criterion in mock_criteria - ], - "root": False, + expr_dict = expr.dict() + assert expr_dict == { + "type": "TemporalMinCount", + "data": { + "threshold": 1, + "start_time": "08:00:00", + "end_time": "16:00:00", + "interval_type": None, + "interval_criterion": None, + "args": [criterion.dict() for criterion in mock_criteria], + }, } def test_criterion_combination_from_dict(self, mock_criteria): - # Register the mock criterion class - from execution_engine.omop.criterion import factory - factory.register_criterion_class("MockCriterion", MockCriterion) - - operator = FixedWindowTemporalIndicatorCombination.Operator( - FixedWindowTemporalIndicatorCombination.Operator.AT_LEAST, threshold=1 - ) - - combination_data = { - "id": None, - "operator": "AT_LEAST", - "threshold": 1, - "start_time": "08:00:00", - "end_time": "16:00:00", - "interval_type": None, - "criteria": [ - {"class_name": "MockCriterion", "data": criterion.dict()} - for criterion in mock_criteria - ], + expr_dict = { + "type": "TemporalMinCount", + "data": { + "threshold": 1, + "start_time": "08:00:00", + "end_time": "16:00:00", + "interval_type": None, + "interval_criterion": None, + "args": [criterion.dict() for criterion in mock_criteria], + }, } - combination = FixedWindowTemporalIndicatorCombination.from_dict( - combination_data - ) + expr = logic.Expr.from_dict(expr_dict) - assert combination.operator == operator - assert len(combination) == len(mock_criteria) - assert combination.start_time == datetime.time(8, 0) - assert combination.end_time == datetime.time(16, 0) - assert combination.interval_type is None + assert len(expr.args) == len(mock_criteria) + assert expr.start_time == datetime.time(8, 0) + assert expr.end_time == datetime.time(16, 0) + assert expr.interval_type is None + assert expr.interval_criterion is None - for idx, criterion in enumerate(combination): + for idx, criterion in enumerate(expr.args): assert str(criterion) == str(mock_criteria[idx]) - combination_data = { - "operator": "AT_LEAST", - "threshold": 1, - "start_time": None, - "end_time": None, - "interval_type": TimeIntervalType.MORNING_SHIFT, - "criteria": [ - {"class_name": "MockCriterion", "data": criterion.dict()} - for criterion in mock_criteria - ], + expr_dict = { + "type": "TemporalMinCount", + "data": { + "threshold": 1, + "start_time": None, + "end_time": None, + "interval_type": TimeIntervalType.MORNING_SHIFT, + "interval_criterion": None, + "args": [criterion.dict() for criterion in mock_criteria], + }, } - combination = FixedWindowTemporalIndicatorCombination.from_dict( - combination_data - ) + expr = logic.Expr.from_dict(expr_dict) - assert combination.operator == operator - assert len(combination) == len(mock_criteria) - assert combination.start_time is None - assert combination.end_time is None - assert combination.interval_type == TimeIntervalType.MORNING_SHIFT + assert len(expr.args) == len(mock_criteria) + assert expr.start_time is None + assert expr.end_time is None + assert expr.interval_type == TimeIntervalType.MORNING_SHIFT + assert expr.interval_criterion is None - for idx, criterion in enumerate(combination): + for idx, criterion in enumerate(expr.args): assert str(criterion) == str(mock_criteria[idx]) - @pytest.mark.parametrize("operator", ["AT_LEAST", "AT_MOST", "EXACTLY"]) - def test_operator_with_threshold(self, operator): - with pytest.raises( - AssertionError, match=f"Threshold must be set for operator {operator}" - ): - FixedWindowTemporalIndicatorCombination.Operator(operator) - - def test_operator(self): - with pytest.raises(AssertionError, match=""): - FixedWindowTemporalIndicatorCombination.Operator("invalid") - - @pytest.mark.parametrize( - "operator, threshold", - [("AT_LEAST", 1), ("AT_MOST", 1), ("EXACTLY", 1)], - ) - def test_operator_str(self, operator, threshold): - op = FixedWindowTemporalIndicatorCombination.Operator(operator, threshold) - - if operator in ["AT_LEAST", "AT_MOST", "EXACTLY"]: - assert ( - repr(op) - == f'TemporalIndicatorCombination.Operator(operator="{operator}", threshold={threshold})' - ) - assert str(op) == f"{operator}(threshold={threshold})" - else: - assert ( - repr(op) - == f'TemporalIndicatorCombination.Operator(operator="{operator}")' - ) - assert str(op) == f"{operator}" - - def test_repr(self): - operator = FixedWindowTemporalIndicatorCombination.Operator( - FixedWindowTemporalIndicatorCombination.Operator.AT_LEAST, threshold=1 - ) - combination = FixedWindowTemporalIndicatorCombination( - operator=operator, - interval_type=TimeIntervalType.MORNING_SHIFT, - ) + # @pytest.mark.skip( + # reason="the repr does not return arguments in a consistent manner" + # ) + def test_repr(self, mock_criteria): + expr = temporal_logic_util.MorningShift(mock_criteria[0]) assert ( - repr(combination) == "FixedWindowTemporalIndicatorCombination(\n" - " interval_type=TimeIntervalType.MORNING_SHIFT,\n" + repr(expr) == "TemporalMinCount(\n" + " MockCriterion(\n" + " name='c1'\n" + " ),\n" " start_time=None,\n" " end_time=None,\n" - ' operator=TemporalIndicatorCombination.Operator(operator="AT_LEAST", threshold=1),\n' + " interval_type=TimeIntervalType.MORNING_SHIFT,\n" + " interval_criterion=None,\n" + " threshold=1\n" ")" ) - combination = FixedWindowTemporalIndicatorCombination( - operator=operator, + expr = logic.TemporalMinCount( + mock_criteria[0], start_time=datetime.time(8, 0), end_time=datetime.time(16, 0), + threshold=1, ) assert ( - repr(combination) == "FixedWindowTemporalIndicatorCombination(\n" + repr(expr) == "TemporalMinCount(\n" + " MockCriterion(\n" + " name='c1'\n" + " ),\n" + " start_time='08:00:00',\n" + " end_time='16:00:00',\n" " interval_type=None,\n" - " start_time=datetime.time(8, 0),\n" - " end_time=datetime.time(16, 0),\n" - ' operator=TemporalIndicatorCombination.Operator(operator="AT_LEAST", threshold=1),\n' + " interval_criterion=None,\n" + " threshold=1\n" ")" ) - def test_add_all(self): - operator = FixedWindowTemporalIndicatorCombination.Operator( - FixedWindowTemporalIndicatorCombination.Operator.AT_MOST, threshold=1 - ) - combination = FixedWindowTemporalIndicatorCombination( - operator=operator, - interval_type=TimeIntervalType.MORNING_SHIFT, - ) - - assert len(combination) == 0 + def test_expr_contains_criteria(self, mock_criteria): + with pytest.raises( + TypeError, + match=r"MinCount\(\) takes 1 positional argument but 5 were given", + ): + expr = temporal_logic_util.MinCount(*mock_criteria) - combination.add_all([MockCriterion("c1"), MockCriterion("c2")]) + expr = logic.TemporalMinCount(*mock_criteria, threshold=1) - assert len(combination) == 2 + assert len(expr.args) == len(mock_criteria) - assert str(combination[0]) == str(MockCriterion("c1")) - assert str(combination[1]) == str(MockCriterion("c2")) + for i in range(len(mock_criteria)): + assert expr.args[i] == mock_criteria[i] c1 = DrugExposure( @@ -334,8 +248,7 @@ def criteria(self, db_session): delir_screening, ] for i, c in enumerate(criteria): - if not c.is_persisted(): - c.set_id(i + 1) + c.set_id(i + 1, overwrite=True) self.register_criterion(c, db_session) return criteria @@ -351,11 +264,11 @@ def run_criteria_test( ): noop_criterion = NoopCriterion() - noop_criterion.set_id(1005) - noop_intervention = LogicalCriterionCombination.And(noop_criterion) + noop_criterion.set_id(1005, overwrite=True) + noop_intervention = logic.And(noop_criterion) self.register_criterion(noop_criterion, db_session) - self.insert_criterion_combination( + self.insert_expression( db_session, population=combination, intervention=noop_intervention, @@ -440,7 +353,7 @@ def patient_events(self, db_session, person_visit): # Full Day #################### ( - FixedWindowTemporalIndicatorCombination.Day( + temporal_logic_util.Day( c1, ), { @@ -455,7 +368,7 @@ def patient_events(self, db_session, person_visit): }, ), ( - FixedWindowTemporalIndicatorCombination.Day( + temporal_logic_util.Day( c2, ), { @@ -470,7 +383,7 @@ def patient_events(self, db_session, person_visit): }, ), ( - FixedWindowTemporalIndicatorCombination.Day( + temporal_logic_util.Day( c3, ), { @@ -488,7 +401,7 @@ def patient_events(self, db_session, person_visit): # Explicit Times #################### ( - FixedWindowTemporalIndicatorCombination.Presence( + temporal_logic_util.Presence( c1, start_time=datetime.time(8, 30), end_time=datetime.time(16, 59), @@ -509,7 +422,7 @@ def patient_events(self, db_session, person_visit): }, ), ( - FixedWindowTemporalIndicatorCombination.Presence( + temporal_logic_util.Presence( c1, start_time=datetime.time(8, 30), end_time=datetime.time(18, 59), @@ -530,7 +443,7 @@ def patient_events(self, db_session, person_visit): }, ), ( - FixedWindowTemporalIndicatorCombination.Presence( + temporal_logic_util.Presence( c2, start_time=datetime.time(8, 30), end_time=datetime.time(16, 59), @@ -547,7 +460,7 @@ def patient_events(self, db_session, person_visit): }, ), ( - FixedWindowTemporalIndicatorCombination.Presence( + temporal_logic_util.Presence( c3, start_time=datetime.time(8, 30), end_time=datetime.time(16, 59), @@ -555,7 +468,7 @@ def patient_events(self, db_session, person_visit): {1: set(), 2: set(), 3: set()}, ), ( - FixedWindowTemporalIndicatorCombination.Presence( + temporal_logic_util.Presence( c3, start_time=datetime.time(17, 30), end_time=datetime.time(22, 00), @@ -575,7 +488,7 @@ def patient_events(self, db_session, person_visit): # Morning Shifts #################### ( - FixedWindowTemporalIndicatorCombination.MorningShift( + temporal_logic_util.MorningShift( c1, ), { @@ -594,7 +507,7 @@ def patient_events(self, db_session, person_visit): }, ), ( - FixedWindowTemporalIndicatorCombination.MorningShift( + temporal_logic_util.MorningShift( c2, ), { @@ -609,7 +522,7 @@ def patient_events(self, db_session, person_visit): }, ), ( - FixedWindowTemporalIndicatorCombination.MorningShift( + temporal_logic_util.MorningShift( c3, ), {1: set(), 2: set(), 3: set()}, @@ -618,7 +531,7 @@ def patient_events(self, db_session, person_visit): # Afternoon Shifts #################### ( - FixedWindowTemporalIndicatorCombination.AfternoonShift( + temporal_logic_util.AfternoonShift( c1, ), { @@ -637,7 +550,7 @@ def patient_events(self, db_session, person_visit): }, ), ( - FixedWindowTemporalIndicatorCombination.AfternoonShift( + temporal_logic_util.AfternoonShift( c2, ), { @@ -656,7 +569,7 @@ def patient_events(self, db_session, person_visit): }, ), ( - FixedWindowTemporalIndicatorCombination.AfternoonShift( + temporal_logic_util.AfternoonShift( c3, ), { @@ -674,7 +587,7 @@ def patient_events(self, db_session, person_visit): # Night Shifts #################### ( - FixedWindowTemporalIndicatorCombination.NightShift( + temporal_logic_util.NightShift( c1, ), { @@ -693,7 +606,7 @@ def patient_events(self, db_session, person_visit): }, ), ( - FixedWindowTemporalIndicatorCombination.NightShift( + temporal_logic_util.NightShift( c2, ), { @@ -708,7 +621,7 @@ def patient_events(self, db_session, person_visit): }, ), ( - FixedWindowTemporalIndicatorCombination.NightShift( + temporal_logic_util.NightShift( c3, ), {1: set(), 2: set(), 3: set()}, @@ -717,7 +630,7 @@ def patient_events(self, db_session, person_visit): # Partial Night Shifts (before midnight) ####################### ( - FixedWindowTemporalIndicatorCombination.NightShiftBeforeMidnight( + temporal_logic_util.NightShiftBeforeMidnight( c1, ), { @@ -736,7 +649,7 @@ def patient_events(self, db_session, person_visit): }, ), ( - FixedWindowTemporalIndicatorCombination.NightShiftBeforeMidnight( + temporal_logic_util.NightShiftBeforeMidnight( c2, ), { @@ -751,7 +664,7 @@ def patient_events(self, db_session, person_visit): }, ), ( - FixedWindowTemporalIndicatorCombination.NightShiftBeforeMidnight( + temporal_logic_util.NightShiftBeforeMidnight( c3, ), {1: set(), 2: set(), 3: set()}, @@ -760,7 +673,7 @@ def patient_events(self, db_session, person_visit): # Partial Night Shifts (after midnight) ####################### ( - FixedWindowTemporalIndicatorCombination.NightShiftAfterMidnight( + temporal_logic_util.NightShiftAfterMidnight( c1, ), { @@ -775,7 +688,7 @@ def patient_events(self, db_session, person_visit): }, ), ( - FixedWindowTemporalIndicatorCombination.NightShiftAfterMidnight( + temporal_logic_util.NightShiftAfterMidnight( c2, ), { @@ -790,7 +703,7 @@ def patient_events(self, db_session, person_visit): }, ), ( - FixedWindowTemporalIndicatorCombination.NightShiftAfterMidnight( + temporal_logic_util.NightShiftAfterMidnight( c3, ), {1: set(), 2: set(), 3: set()}, @@ -873,7 +786,7 @@ def patient_events(self, db_session, visit_occurrence): # Full Day #################### ( - FixedWindowTemporalIndicatorCombination.Day( + temporal_logic_util.Day( c1, ), { @@ -888,7 +801,7 @@ def patient_events(self, db_session, visit_occurrence): }, ), ( - FixedWindowTemporalIndicatorCombination.Day( + temporal_logic_util.Day( c2, ), { @@ -903,7 +816,7 @@ def patient_events(self, db_session, visit_occurrence): }, ), ( - FixedWindowTemporalIndicatorCombination.Day( + temporal_logic_util.Day( c3, ), { @@ -921,7 +834,7 @@ def patient_events(self, db_session, visit_occurrence): # Explicit Times #################### ( - FixedWindowTemporalIndicatorCombination.Presence( + temporal_logic_util.Presence( c1, start_time=datetime.time(8, 30), end_time=datetime.time(16, 59), @@ -970,7 +883,7 @@ def patient_events(self, db_session, visit_occurrence): }, ), ( - FixedWindowTemporalIndicatorCombination.Presence( + temporal_logic_util.Presence( c2, start_time=datetime.time(8, 30), end_time=datetime.time(16, 59), @@ -995,7 +908,7 @@ def patient_events(self, db_session, visit_occurrence): }, ), ( - FixedWindowTemporalIndicatorCombination.Presence( + temporal_logic_util.Presence( c3, start_time=datetime.time(17, 30), end_time=datetime.time(22, 00), @@ -1027,7 +940,7 @@ def patient_events(self, db_session, visit_occurrence): # # Morning Shifts # #################### ( - FixedWindowTemporalIndicatorCombination.MorningShift( + temporal_logic_util.MorningShift( c1, ), { @@ -1074,7 +987,7 @@ def patient_events(self, db_session, visit_occurrence): }, ), ( - FixedWindowTemporalIndicatorCombination.MorningShift( + temporal_logic_util.MorningShift( c2, ), { @@ -1097,7 +1010,7 @@ def patient_events(self, db_session, visit_occurrence): }, ), ( - FixedWindowTemporalIndicatorCombination.MorningShift( + temporal_logic_util.MorningShift( c3, ), { @@ -1127,7 +1040,7 @@ def patient_events(self, db_session, visit_occurrence): # # Afternoon Shifts # #################### ( - FixedWindowTemporalIndicatorCombination.AfternoonShift( + temporal_logic_util.AfternoonShift( c1, ), { @@ -1174,7 +1087,7 @@ def patient_events(self, db_session, visit_occurrence): }, ), ( - FixedWindowTemporalIndicatorCombination.AfternoonShift( + temporal_logic_util.AfternoonShift( c2, ), { @@ -1201,7 +1114,7 @@ def patient_events(self, db_session, visit_occurrence): }, ), ( - FixedWindowTemporalIndicatorCombination.AfternoonShift( + temporal_logic_util.AfternoonShift( c3, ), { @@ -1231,7 +1144,7 @@ def patient_events(self, db_session, visit_occurrence): # # Night Shifts # #################### ( - FixedWindowTemporalIndicatorCombination.NightShift( + temporal_logic_util.NightShift( c1, ), { @@ -1282,7 +1195,7 @@ def patient_events(self, db_session, visit_occurrence): }, ), ( - FixedWindowTemporalIndicatorCombination.NightShift( + temporal_logic_util.NightShift( c2, ), { @@ -1305,7 +1218,7 @@ def patient_events(self, db_session, visit_occurrence): }, ), ( - FixedWindowTemporalIndicatorCombination.NightShift( + temporal_logic_util.NightShift( c3, ), { @@ -1407,7 +1320,7 @@ def patient_events(self, db_session, visit_occurrence): "combination,expected", [ ( - FixedWindowTemporalIndicatorCombination.MorningShift( + temporal_logic_util.MorningShift( bodyweight_measurement_without_forward_fill, ), { @@ -1420,13 +1333,13 @@ def patient_events(self, db_session, visit_occurrence): }, ), ( - FixedWindowTemporalIndicatorCombination.AfternoonShift( + temporal_logic_util.AfternoonShift( bodyweight_measurement_without_forward_fill, ), {1: set()}, ), ( - FixedWindowTemporalIndicatorCombination.MorningShift( + temporal_logic_util.MorningShift( bodyweight_measurement_with_forward_fill, ), { @@ -1447,7 +1360,7 @@ def patient_events(self, db_session, visit_occurrence): }, ), ( - FixedWindowTemporalIndicatorCombination.AfternoonShift( + temporal_logic_util.AfternoonShift( bodyweight_measurement_with_forward_fill, ), { @@ -1518,25 +1431,12 @@ def __init__(self) -> None: super().__init__() self._table = tables.ProcedureOccurrence.__table__.alias("po") - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> Self: - """ - Create an object from a dictionary. - """ - return cls() - def description(self) -> str: """ Get a description of the criterion. """ return self.__class__.__name__ - def dict(self) -> dict: - """ - Get a dictionary representation of the object. - """ - return {} - def _create_query(self) -> sa.Select: """ Get the SQL Select query for data required by this criterion. @@ -1567,17 +1467,13 @@ def _create_query(self) -> sa.Select: class TestPersonalWindowTemporalIndicatorCombination(TestCriterionCombinationDatabase): """ Test class for testing criterion combinations on the database with individual windows, - i.e. windows whose length is dependend on some patient-specific event (here: surgery) + i.e. windows whose length is dependent on some patient-specific event (here: surgery) """ @pytest.fixture def criteria(self, db_session): - # c4.set_id(4) # surgical procedure - - if not c_preop.is_persisted(): - c_preop.set_id(4) - if not bodyweight_measurement_without_forward_fill.is_persisted(): - bodyweight_measurement_without_forward_fill.set_id(5) + c_preop.set_id(4, overwrite=True) + bodyweight_measurement_without_forward_fill.set_id(5, overwrite=True) self.register_criterion(c_preop, db_session) self.register_criterion(bodyweight_measurement_without_forward_fill, db_session) @@ -1622,7 +1518,7 @@ def patient_events(self, db_session, person_visit): # Explicit Times #################### ( - PersonalWindowTemporalIndicatorCombination.Presence( + temporal_logic_util.Presence( bodyweight_measurement_without_forward_fill, interval_criterion=c_preop, ), @@ -1740,17 +1636,17 @@ def patient_events(self, db_session, visit_occurrence): "population,intervention,expected", [ ( - LogicalCriterionCombination.And(c2), - LogicalCriterionCombination.CappedAtLeast( + logic.And(c2), + logic.CappedMinCount( *[ - FixedWindowTemporalIndicatorCombination.Day( + temporal_logic_util.Day( criterion=shift_class(criterion=delir_screening), ) for shift_class in [ - FixedWindowTemporalIndicatorCombination.NightShiftAfterMidnight, - FixedWindowTemporalIndicatorCombination.MorningShift, - FixedWindowTemporalIndicatorCombination.AfternoonShift, - FixedWindowTemporalIndicatorCombination.NightShiftBeforeMidnight, + temporal_logic_util.NightShiftAfterMidnight, + temporal_logic_util.MorningShift, + temporal_logic_util.AfternoonShift, + temporal_logic_util.NightShiftBeforeMidnight, ] ], threshold=2, @@ -1803,7 +1699,7 @@ def test_at_least_combination_on_database( db_session.add_all(vos) db_session.commit() - self.insert_criterion_combination( + self.insert_expression( db_session, population, intervention, base_criterion, observation_window ) @@ -1828,17 +1724,17 @@ def test_at_least_combination_on_database( "population,intervention,expected", [ ( - LogicalCriterionCombination.And(c2), - LogicalCriterionCombination.CappedAtLeast( + logic.And(c2), + logic.CappedMinCount( *[ - FixedWindowTemporalIndicatorCombination.Day( + temporal_logic_util.Day( criterion=shift_class(criterion=delir_screening), ) for shift_class in [ - FixedWindowTemporalIndicatorCombination.NightShiftAfterMidnight, - FixedWindowTemporalIndicatorCombination.MorningShift, - FixedWindowTemporalIndicatorCombination.AfternoonShift, - FixedWindowTemporalIndicatorCombination.NightShiftBeforeMidnight, + temporal_logic_util.NightShiftAfterMidnight, + temporal_logic_util.MorningShift, + temporal_logic_util.AfternoonShift, + temporal_logic_util.NightShiftBeforeMidnight, ] ], threshold=2, @@ -1882,7 +1778,7 @@ def test_at_least_combination_on_database_no_measurements( db_session.add_all([c1]) db_session.commit() - self.insert_criterion_combination( + self.insert_expression( db_session, population, intervention, base_criterion, observation_window ) diff --git a/tests/execution_engine/omop/criterion/custom/test_tidal_volume.py b/tests/execution_engine/omop/criterion/custom/test_tidal_volume.py index f349c331..1996fc60 100644 --- a/tests/execution_engine/omop/criterion/custom/test_tidal_volume.py +++ b/tests/execution_engine/omop/criterion/custom/test_tidal_volume.py @@ -36,7 +36,7 @@ class TestTidalVolumePerIdealBodyWeight(TestCriterion): @pytest.fixture def concept(self): return CustomConcept( - name="Tidal volume / ideal body weight (ARDSnet)", + concept_name="Tidal volume / ideal body weight (ARDSnet)", concept_code="tvpibw", domain_id="Measurement", vocabulary_id="CODEX-CELIDA", diff --git a/tests/execution_engine/omop/criterion/test_criterion.py b/tests/execution_engine/omop/criterion/test_criterion.py index fee7a47b..092e2e9c 100644 --- a/tests/execution_engine/omop/criterion/test_criterion.py +++ b/tests/execution_engine/omop/criterion/test_criterion.py @@ -5,14 +5,15 @@ import pandas as pd import pendulum import pytest -from sqlalchemy import Column, Date, Integer, MetaData, Table, select +from sqlalchemy import Column, Date, Integer, MetaData, Table, select, update import execution_engine.omop.db.celida.tables as celida_tables from execution_engine.constants import CohortCategory from execution_engine.execution_graph import ExecutionGraph +from execution_engine.omop.cohort import PopulationInterventionPairExpr +from execution_engine.omop.cohort.graph_builder import RecommendationGraphBuilder from execution_engine.omop.concepts import Concept from execution_engine.omop.criterion.abstract import Criterion -from execution_engine.omop.criterion.combination.combination import CriterionCombination from execution_engine.omop.criterion.visit_occurrence import PatientsActiveDuringPeriod from execution_engine.omop.db.celida.views import ( full_day_coverage, @@ -25,7 +26,7 @@ task, ) from execution_engine.task.process import get_processing_module -from execution_engine.util import cohort_logic +from execution_engine.util import datetime_converter, logic from execution_engine.util.db import add_result_insert from execution_engine.util.interval import IntervalType from execution_engine.util.types import TimeRange @@ -62,6 +63,27 @@ def date_set(tc: Iterable): return set(pendulum.parse(t).date() for t in tc) +def store_execution_graph(graph: ExecutionGraph, db_session, recommendation_id: int): + import json + + from execution_engine.omop.db.celida import tables as result_db + + rec_graph: bytes = json.dumps( + graph.to_cytoscape_dict(), sort_keys=True, default=datetime_converter + ).encode() + + update_query = ( + update(result_db.Recommendation) + .where(result_db.Recommendation.recommendation_id == recommendation_id) + .values( + recommendation_execution_graph=rec_graph, + ) + ) + + db_session.execute(update_query) + db_session.commit() + + class TestCriterion: result_table = celida_tables.ResultInterval.__table__ run_id = 1234 @@ -187,7 +209,7 @@ def register_criterion(cls, criterion: Criterion, db_session): exists = db_session.query( db_session.query(celida_tables.Criterion) - .filter_by(criterion_id=criterion.id) + .filter_by(criterion_hash=str(hash(criterion))) .exists() ).scalar() @@ -195,7 +217,7 @@ def register_criterion(cls, criterion: Criterion, db_session): new_criterion = celida_tables.Criterion( criterion_id=criterion.id, criterion_description=criterion.description(), - criterion_hash=hash(criterion), + criterion_hash=str(hash(criterion)), ) db_session.add(new_criterion) db_session.commit() @@ -213,14 +235,14 @@ def register_population_intervention_pair(cls, id_, name, db_session): ).scalar() if not exists: - new_criterion = celida_tables.PopulationInterventionPair( + pi_pair = celida_tables.PopulationInterventionPair( pi_pair_id=id_, recommendation_id=-1, pi_pair_url="https://example.com", pi_pair_name=name, pi_pair_hash=hash(name), ) - db_session.add(new_criterion) + db_session.add(pi_pair) db_session.commit() @pytest.fixture @@ -251,7 +273,11 @@ def create_occurrence( ) def insert_criterion(self, db_session, criterion, observation_window: TimeRange): - criterion.set_id(self.criterion_id) + + criterion.set_id( + self.criterion_id + 1 + ) # +1 to avoid collision with the criterion saved in + # omop_fixture.py::celida_recommendation() self.register_criterion(criterion, db_session) query = criterion.create_query() @@ -287,44 +313,30 @@ def insert_criterion(self, db_session, criterion, observation_window: TimeRange) db_session.commit() - def insert_criterion_combination( + def insert_expression( self, db_session, - population: CriterionCombination, - intervention: CriterionCombination, + population: logic.Expr, + intervention: logic.Expr, base_criterion: Criterion, observation_window: TimeRange, ): - graph = ExecutionGraph.from_criterion_combination( - population, intervention, base_criterion - ) - - # we need to add a NoDataPreservingAnd node to insert NEGATIVE intervals - p_sink_node = graph.sink_node(CohortCategory.POPULATION) - pi_sink_node = graph.sink_node(CohortCategory.POPULATION_INTERVENTION) - - p_combination_node = cohort_logic.NoDataPreservingAnd( - p_sink_node, category=CohortCategory.POPULATION + # population_expr is assigned a NonSimplifiableAnd to ensure creation of negative intervals + pi_pair = PopulationInterventionPairExpr( + population_expr=logic.NonSimplifiableAnd(population), + intervention_expr=intervention, + base_criterion=base_criterion, + name="Test", + url="https://example.com", ) + pi_pair.set_id(self.pi_pair_id) - pi_combination_node = cohort_logic.NoDataPreservingAnd( - pi_sink_node, category=CohortCategory.POPULATION_INTERVENTION - ) + graph = RecommendationGraphBuilder.build(pi_pair, base_criterion) - graph.add_node( - p_combination_node, store_result=True, category=CohortCategory.POPULATION - ) - graph.add_node( - pi_combination_node, - store_result=True, - category=CohortCategory.POPULATION_INTERVENTION, + store_execution_graph( + graph=graph, db_session=db_session, recommendation_id=self.recommendation_id ) - graph.add_edge(p_sink_node, p_combination_node) - graph.add_edge(pi_sink_node, pi_combination_node) - - graph.set_sink_nodes_store(bind_params={"pi_pair_id": self.pi_pair_id}) - params = observation_window.model_dump() | {"run_id": self.run_id} task_runner = runner.SequentialTaskRunner(graph) diff --git a/tests/execution_engine/omop/criterion/test_drug_exposure.py b/tests/execution_engine/omop/criterion/test_drug_exposure.py index b405a999..cf87d007 100644 --- a/tests/execution_engine/omop/criterion/test_drug_exposure.py +++ b/tests/execution_engine/omop/criterion/test_drug_exposure.py @@ -4,10 +4,8 @@ from execution_engine.constants import CohortCategory from execution_engine.omop.concepts import Concept -from execution_engine.omop.criterion.combination.logical import ( - LogicalCriterionCombination, -) from execution_engine.omop.criterion.drug_exposure import DrugExposure +from execution_engine.util import logic from execution_engine.util.enum import TimeUnit from execution_engine.util.types import Dosage from execution_engine.util.value import ValueNumber @@ -43,7 +41,7 @@ def _run_drug_exposure( route=route, ) if exclude: - criterion = LogicalCriterionCombination.Not(criterion) + criterion = logic.Not(criterion) self.insert_criterion(db_session, criterion, observation_window) diff --git a/tests/execution_engine/omop/test_concepts.py b/tests/execution_engine/omop/test_concepts.py index 6a933224..283ab505 100644 --- a/tests/execution_engine/omop/test_concepts.py +++ b/tests/execution_engine/omop/test_concepts.py @@ -81,7 +81,7 @@ def test_is_custom(self): class TestCustomConcept: def test_init(self): custom_concept = CustomConcept( - name="Test Custom Concept", + concept_name="Test Custom Concept", concept_code="CC123", domain_id="Test Domain", vocabulary_id="Test Vocabulary", @@ -97,7 +97,7 @@ def test_init(self): def test_id_property(self): custom_concept = CustomConcept( - name="Test Custom Concept", + concept_name="Test Custom Concept", concept_code="CC123", domain_id="Test Domain", vocabulary_id="Test Vocabulary", @@ -107,7 +107,7 @@ def test_id_property(self): def test_str(self): custom_concept = CustomConcept( - name="Test Custom Concept", + concept_name="Test Custom Concept", concept_code="CC123", domain_id="Test Domain", vocabulary_id="Test Vocabulary", diff --git a/tests/execution_engine/util/test_cohort_logic.py b/tests/execution_engine/util/test_logic.py similarity index 57% rename from tests/execution_engine/util/test_cohort_logic.py rename to tests/execution_engine/util/test_logic.py index 1d8b9d54..feef0886 100644 --- a/tests/execution_engine/util/test_cohort_logic.py +++ b/tests/execution_engine/util/test_logic.py @@ -3,9 +3,8 @@ import pytest -from execution_engine.constants import CohortCategory -from execution_engine.omop.criterion.combination.temporal import TimeIntervalType -from execution_engine.util.cohort_logic import ( +from execution_engine.util.enum import TimeIntervalType +from execution_engine.util.logic import ( AllOrNone, And, BooleanFunction, @@ -14,8 +13,6 @@ LeftDependentToggle, MaxCount, MinCount, - NoDataPreservingAnd, - NoDataPreservingOr, NonSimplifiableAnd, Not, Or, @@ -28,34 +25,29 @@ dummy_criterion = MockCriterion( name="dummy_criterion", - exclude=False, ) x, y, z = ( - Symbol(MockCriterion("x", False), category=CohortCategory.POPULATION), - Symbol(MockCriterion("y", False), category=CohortCategory.POPULATION), - Symbol(MockCriterion("z", False), category=CohortCategory.POPULATION), + MockCriterion("x"), + MockCriterion("y"), + MockCriterion("z"), ) # Tests for Expr class TestExpr: - def test_creation_with_category(self): - expr = Expr(category=CohortCategory.POPULATION) - assert expr.category == CohortCategory.POPULATION - def test_is_Atom_false(self): - expr = Expr(category=CohortCategory.POPULATION) + expr = Expr() assert not expr.is_Atom # Tests for Symbol class TestSymbol: def test_symbol_creation(self): - symbol = Symbol(dummy_criterion, category=CohortCategory.POPULATION) - assert str(symbol) == "MockCriterion[dummy_criterion]" - assert symbol.criterion == dummy_criterion + symbol = dummy_criterion + assert repr(symbol) == "MockCriterion(\n name='dummy_criterion'\n)" + assert symbol == dummy_criterion def test_is_Atom_true(self): symbol = x @@ -68,11 +60,11 @@ def test_is_Not_false(self): class TestBooleanFunction: def test_or_creation(self): - or_expr = Or(x, y, category=CohortCategory.POPULATION) + or_expr = Or(x, y) assert isinstance(or_expr, Or) def test_or_is_Atom_false(self): - or_expr = Or(x, y, category=CohortCategory.POPULATION) + or_expr = Or(x, y) assert not or_expr.is_Atom @@ -81,24 +73,35 @@ def test_and_creation_with_multiple_args(self): and_expr = And( x, y, - Not(z, category=CohortCategory.POPULATION), - category=CohortCategory.POPULATION, + Not(z), ) assert isinstance(and_expr, And) assert ( - str(and_expr) == "(MockCriterion[x] & MockCriterion[y] & ~MockCriterion[z])" + repr(and_expr) == "And(\n" + " MockCriterion(\n" + " name='x'\n" + " ),\n" + " MockCriterion(\n" + " name='y'\n" + " ),\n" + " Not(\n" + " MockCriterion(\n" + " name='z'\n" + " )\n" + " )\n" + ")" ) assert and_expr.args[0] == x assert and_expr.args[1] == y - assert and_expr.args[2] == Not(z, category=CohortCategory.POPULATION) + assert and_expr.args[2] == Not(z) assert not and_expr.is_Not assert not and_expr.is_Atom def test_and_creation_with_single_arg(self): single_arg = x - and_expr = And(single_arg, category=CohortCategory.POPULATION) + and_expr = And(single_arg) assert and_expr is single_arg @@ -107,30 +110,41 @@ def test_or_creation_with_multiple_args(self): or_expr = Or( x, y, - Not(z, category=CohortCategory.POPULATION), - category=CohortCategory.POPULATION, + Not(z), ) assert isinstance(or_expr, Or) assert ( - str(or_expr) == "(MockCriterion[x] | MockCriterion[y] | ~MockCriterion[z])" + repr(or_expr) == "Or(\n" + " MockCriterion(\n" + " name='x'\n" + " ),\n" + " MockCriterion(\n" + " name='y'\n" + " ),\n" + " Not(\n" + " MockCriterion(\n" + " name='z'\n" + " )\n" + " )\n" + ")" ) assert or_expr.args[0] == x assert or_expr.args[1] == y - assert or_expr.args[2] == Not(z, category=CohortCategory.POPULATION) + assert or_expr.args[2] == Not(z) assert not or_expr.is_Not assert not or_expr.is_Atom def test_or_creation_with_single_arg(self): single_arg = x - or_expr = Or(single_arg, category=CohortCategory.POPULATION) + or_expr = Or(single_arg) assert or_expr is single_arg class TestNot: def test_not_creation(self): - not_expr = Not(x, category=CohortCategory.POPULATION) + not_expr = Not(x) assert isinstance(not_expr, Not) assert str(not_expr) == "~MockCriterion[x]" assert not_expr.args[0] == x @@ -140,59 +154,33 @@ def test_not_creation(self): def test_not_creation_with_multiple_args(self): with pytest.raises(ValueError): - Not(x, y, category=CohortCategory.POPULATION) + Not(x, y) class TestNonSimplifiableAnd: def test_non_simplifiable_and_creation(self): - non_simp_and = NonSimplifiableAnd(x, y, category=CohortCategory.POPULATION) + non_simp_and = NonSimplifiableAnd(x, y) assert isinstance(non_simp_and, NonSimplifiableAnd) assert not non_simp_and.is_Not assert not non_simp_and.is_Atom def test_non_simplifiable_and_single_arg(self): single_arg = x - non_simp_and = NonSimplifiableAnd( - single_arg, category=CohortCategory.POPULATION - ) + non_simp_and = NonSimplifiableAnd(single_arg) assert isinstance(non_simp_and, NonSimplifiableAnd) assert non_simp_and.args[0] is single_arg def test_non_simplifiable_and_equality(self): - non_simp_and1 = NonSimplifiableAnd(x, category=CohortCategory.POPULATION) - non_simp_and2 = NonSimplifiableAnd(x, category=CohortCategory.POPULATION) - assert non_simp_and1 != non_simp_and2 - - -class TestNoDataPreservingAnd: - def test_no_data_preserving_and_creation(self): - no_data_and = NoDataPreservingAnd(x, y, category=CohortCategory.POPULATION) - assert isinstance(no_data_and, NoDataPreservingAnd) - assert no_data_and.args[0] == x - assert no_data_and.args[1] == y - - assert not no_data_and.is_Not - assert not no_data_and.is_Atom - - -class TestNoDataPreservingOr: - def test_no_data_preserving_or_creation(self): - no_data_or = NoDataPreservingOr(x, y, category=CohortCategory.POPULATION) - assert isinstance(no_data_or, NoDataPreservingOr) - assert no_data_or.args[0] == x - assert no_data_or.args[1] == y - - assert not no_data_or.is_Not - assert not no_data_or.is_Atom + non_simp_and1 = NonSimplifiableAnd(x) + non_simp_and2 = NonSimplifiableAnd(x) + assert non_simp_and1 == non_simp_and2 class TestLeftDependentToggle: def test_left_dependent_toggle_creation(self): left_expr = x right_expr = y - left_toggle = LeftDependentToggle( - left=left_expr, right=right_expr, category=CohortCategory.POPULATION - ) + left_toggle = LeftDependentToggle(left=left_expr, right=right_expr) assert isinstance(left_toggle, LeftDependentToggle) assert left_toggle.left == left_expr == left_toggle.args[0] assert left_toggle.right == right_expr == left_toggle.args[1] @@ -213,20 +201,18 @@ def worker(queue: Queue, symbol: Symbol): class TestSymbolMultiprocessing: @pytest.fixture( params=[ - Expr(1, 2, 3, category=CohortCategory.POPULATION), - Symbol(dummy_criterion, category=CohortCategory.POPULATION), - BooleanFunction(1, 2, 3, category=CohortCategory.POPULATION), - Or(1, 2, 3, category=CohortCategory.POPULATION), - And(1, 2, 3, category=CohortCategory.POPULATION), - Not(1, category=CohortCategory.POPULATION), - MinCount(1, 2, 3, threshold=2, category=CohortCategory.POPULATION), - MaxCount(1, 2, 3, threshold=2, category=CohortCategory.POPULATION), - ExactCount(1, 2, 3, threshold=2, category=CohortCategory.POPULATION), - AllOrNone(1, 2, 3, category=CohortCategory.POPULATION), - NonSimplifiableAnd(1, 2, 3, category=CohortCategory.POPULATION), - NoDataPreservingAnd(1, 2, 3, category=CohortCategory.POPULATION), - NoDataPreservingOr(1, 2, 3, category=CohortCategory.POPULATION), - LeftDependentToggle(left=1, right=2, category=CohortCategory.POPULATION), + Expr(1, 2, 3), + Symbol(dummy_criterion), + BooleanFunction(1, 2, 3), + Or(1, 2, 3), + And(1, 2, 3), + Not(1), + MinCount(1, 2, 3, threshold=2), + MaxCount(1, 2, 3, threshold=2), + ExactCount(1, 2, 3, threshold=2), + AllOrNone(1, 2, 3), + NonSimplifiableAnd(1, 2, 3), + LeftDependentToggle(left=1, right=2), TemporalMinCount( 1, 2, @@ -236,7 +222,6 @@ class TestSymbolMultiprocessing: end_time=None, interval_type=TimeIntervalType.DAY, interval_criterion=None, - category=CohortCategory.POPULATION, ), TemporalMaxCount( 1, @@ -247,7 +232,6 @@ class TestSymbolMultiprocessing: end_time=None, interval_type=TimeIntervalType.MORNING_SHIFT, interval_criterion=None, - category=CohortCategory.POPULATION, ), TemporalExactCount( 1, @@ -258,7 +242,6 @@ class TestSymbolMultiprocessing: end_time=None, interval_type=TimeIntervalType.NIGHT_SHIFT, interval_criterion=None, - category=CohortCategory.POPULATION, ), ], ids=lambda expr: expr.__class__.__name__, diff --git a/tests/execution_engine/util/test_types.py b/tests/execution_engine/util/test_types.py index a15e6686..350fd66a 100644 --- a/tests/execution_engine/util/test_types.py +++ b/tests/execution_engine/util/test_types.py @@ -67,9 +67,6 @@ def test_dosage_creation(self): assert dosage.count.value == 5 assert dosage.frequency.value == 10 - @pytest.mark.skip( - reason="Timing in Dosage is another type that does not use the use use_enum_values flag" - ) def test_enum_values(self): dosage = Dosage( dose=ValueNumber(value=10, unit=concept_unit_mg), @@ -81,7 +78,7 @@ def test_enum_values(self): assert ( dosage.interval.unit == "h" ) # Check if the enum value is used instead of the enum name - assert not isinstance( + assert isinstance( dosage.interval.unit, TimeUnit ) # Ensure it's not returning the enum member diff --git a/tests/mocks/criterion.py b/tests/mocks/criterion.py index 3559a36c..f2ff7c25 100644 --- a/tests/mocks/criterion.py +++ b/tests/mocks/criterion.py @@ -1,5 +1,3 @@ -from typing import Any, Dict, Self - from sqlalchemy import Select from execution_engine.omop.concepts import Concept @@ -13,12 +11,9 @@ def _create_query(self) -> Select: def __init__( self, name: str, - exclude: bool = False, ): self._id = None self._name = name - self._exclude = exclude - assert not exclude def unique_name(self) -> str: return self._name @@ -35,22 +30,6 @@ def _sql_select_data(self, query: Select) -> Select: def description(self) -> str: return f"MockCriterion[{self._name}]" - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> Self: - return cls( - name=data["name"], - exclude=data["exclude"], - ) - - def dict(self) -> dict[str, Any]: - return { - "name": self._name, - "exclude": self._exclude, - } - - def __repr__(self) -> str: - return self.description() - @property def concept(self) -> Concept: # type: ignore pass