From e864f0e465f7c99ee688c082c8ab302190ae9630 Mon Sep 17 00:00:00 2001 From: Jan Moringen Date: Thu, 6 Mar 2025 15:17:28 +0100 Subject: [PATCH 01/43] cosmetic changes --- execution_engine/omop/db/celida/tables.py | 5 ++--- execution_engine/task/process/rectangle.py | 1 - execution_engine/task/process/rectangle_python.py | 4 +++- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/execution_engine/omop/db/celida/tables.py b/execution_engine/omop/db/celida/tables.py index 0db17e26..9f85b407 100644 --- a/execution_engine/omop/db/celida/tables.py +++ b/execution_engine/omop/db/celida/tables.py @@ -27,8 +27,8 @@ from execution_engine.util.interval import IntervalType # Use the "public" schema so that tables in different schemas can -# these enums easily without introducing depedencies between the -# respective schemas. Note that replicate the enum definitions in each +# use these enums easily without introducing dependencies between the +# respective schemas. Note that replicating the enum definitions in each # schema would not work when data must be exchanged between the # schemas because enum definitions in separate schemas, even if # identical in terms of enum values, are considered distinct and @@ -37,7 +37,6 @@ CohortCategoryEnum = Enum(CohortCategory, name="cohort_category", schema="public") - class Recommendation(Base): # noqa: D101 __tablename__ = "recommendation" __table_args__ = {"schema": SCHEMA_NAME} diff --git a/execution_engine/task/process/rectangle.py b/execution_engine/task/process/rectangle.py index b3546180..71464542 100644 --- a/execution_engine/task/process/rectangle.py +++ b/execution_engine/task/process/rectangle.py @@ -713,7 +713,6 @@ def add_interval(interval_start, interval_end, interval_type): # Create the interval with the specified interval_type if it # overlaps the main datetime range, otherwise fill the day # with an interval of type "not applicable". - # TODO: what about intervals "before" the main datetime range? if end_interval < start_datetime: # completely before datetime range day_start = timezone.localize( datetime.datetime.combine( diff --git a/execution_engine/task/process/rectangle_python.py b/execution_engine/task/process/rectangle_python.py index bc0609b9..b87db0a2 100644 --- a/execution_engine/task/process/rectangle_python.py +++ b/execution_engine/task/process/rectangle_python.py @@ -325,7 +325,9 @@ def intersect_interval_lists( :return: The list of intersections. """ return union_rects( - [item for x in left for y in right for item in intersect_rects([x, y])] + [item for x in left + for y in right + for item in intersect_rects([x, y])] ) From a66c3160c23a93a35bcb01131f288069eecaf09b Mon Sep 17 00:00:00 2001 From: Jan Moringen Date: Thu, 13 Mar 2025 14:12:51 +0100 Subject: [PATCH 02/43] fix: in TemporalCount, assert that unimplemented cases are not executed --- execution_engine/task/task.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/execution_engine/task/task.py b/execution_engine/task/task.py index 0968fbda..a9b62cd7 100644 --- a/execution_engine/task/task.py +++ b/execution_engine/task/task.py @@ -494,6 +494,8 @@ def get_start_end_from_interval_type( return cnf.start, cnf.end assert isinstance(self.expr, logic.TemporalCount), "Invalid expression type" + assert self.expr.count_min == 1 + assert self.expr.count_max is None if self.expr.interval_criterion is not None: # last element is the indicator windows From 790f57f8621dd9f04c7bde3567b9a4421c8d862f Mon Sep 17 00:00:00 2001 From: Jan Moringen Date: Wed, 12 Mar 2025 15:42:24 +0100 Subject: [PATCH 03/43] refactor: use just one intervals_to_events function --- execution_engine/task/process/__init__.py | 6 + .../task/process/rectangle_cython.pyx | 137 ++++++++---------- .../task/process/rectangle_python.py | 80 ++++------ 3 files changed, 95 insertions(+), 128 deletions(-) diff --git a/execution_engine/task/process/__init__.py b/execution_engine/task/process/__init__.py index 72b4589c..5f4bf06e 100644 --- a/execution_engine/task/process/__init__.py +++ b/execution_engine/task/process/__init__.py @@ -3,6 +3,7 @@ import sys import types from collections import namedtuple +from typing import TypeVar def get_processing_module( @@ -39,3 +40,8 @@ def get_processing_module( Interval = namedtuple("Interval", ["lower", "upper", "type"]) IntervalWithCount = namedtuple("IntervalWithCount", ["lower", "upper", "type", "count"]) IntervalWithTypeCounts = namedtuple("IntervalWithTypeCounts", ["lower", "upper", "counts"]) + +AnyInterval = Interval | IntervalWithCount | IntervalWithTypeCounts +GeneralizedInterval = None | AnyInterval + +TInterval = TypeVar('TInterval', bound = AnyInterval) diff --git a/execution_engine/task/process/rectangle_cython.pyx b/execution_engine/task/process/rectangle_cython.pyx index 688b7fd3..294155a7 100644 --- a/execution_engine/task/process/rectangle_cython.pyx +++ b/execution_engine/task/process/rectangle_cython.pyx @@ -1,14 +1,13 @@ import copy -from functools import reduce import datetime -from collections import namedtuple +from functools import reduce cimport numpy as np import numpy as np from sortedcontainers import SortedDict -from execution_engine.task.process import Interval, IntervalWithCount, IntervalWithTypeCounts +from execution_engine.task.process import Interval, IntervalWithCount, IntervalWithTypeCounts, AnyInterval from execution_engine.util.interval import IntervalType DEF SCHAR_MIN = -128 @@ -16,10 +15,11 @@ DEF SCHAR_MAX = 127 MODULE_IMPLEMENTATION = "cython" + def intervals_to_events( - intervals: list[Interval], + intervals: list[AnyInterval], closing_offset: int = 1, -) -> list[tuple[int, bool, IntervalType]]: +) -> list[tuple[int, bool, AnyInterval]]: """ Converts the intervals to a list of events. @@ -28,71 +28,9 @@ def intervals_to_events( :param intervals: The intervals. :return: The events. """ - events = [(i.lower, True, i.type) for i in intervals] + [ - (i.upper + closing_offset, False, i.type) for i in intervals - ] - return sorted( - events, - key=lambda i: (i[0]), - ) - -def intervals_with_count_to_events( - intervals: list[IntervalWithCount], -) -> list[tuple[int, bool, IntervalType, int]]: - """ - Converts the intervals to a list of events. - - The events are a sorted list of the opening/closing points of all rectangles. - - :param intervals: The intervals. - :return: The events. - """ - events = [(i.lower, True, i.type, i.count) for i in intervals] + [ - (i.upper + 1, False, i.type, i.count) for i in intervals - ] - return sorted( - events, - key=lambda i: (i[0]), - ) - - -def intersect_rects(list[Interval] intervals) -> list[Interval]: - cdef double x_min = -np.inf - cdef signed char y_min = SCHAR_MAX - cdef double end_point = np.inf - - if not len(intervals): - return [] - - order = IntervalType.intersection_priority() - events = intervals_to_events(intervals) - - for cur_x, start_point, y_max_type in events: - - y_max = order.index(y_max_type) - - if start_point: - if end_point < cur_x: - # we already hit an endpoint and here starts a new one, so the intersection is empty - return [] - - if y_max < y_min: - y_min = y_max - - if cur_x > x_min: - # this point is further than the previously found one, so reset the intersection's start point - x_min = cur_x - - else: - # we found and endpoint - if cur_x > end_point: - # this endpoint lies behind another endpoint, we know we can stop - if x_min > end_point - 1: - return [] - return [Interval(lower=x_min, upper=end_point - 1, type=order[y_min])] - end_point = cur_x - - return [Interval(lower=x_min, upper=end_point - 1, type=order[y_min])] + events = [ (i.lower, True, i) for i in intervals ] \ + + [ (i.upper + closing_offset, False, i) for i in intervals ] + return sorted(events,key=lambda i: i[0]) def union_rects(list[Interval] intervals) -> list[Interval]: @@ -116,7 +54,8 @@ def union_rects(list[Interval] intervals) -> list[Interval]: union = [] - for x_min, start_point, y_max_type in events: + for x_min, start_point, interval in events: + y_max_type = interval.type y_max = order.index(y_max_type) if x_min > cur_x: @@ -194,7 +133,7 @@ def union_rects_with_count(list[IntervalWithCount] intervals) -> list[IntervalWi order = IntervalType.union_priority()[::-1] - events = intervals_with_count_to_events(intervals) + events = intervals_to_events(intervals) union = [] @@ -212,7 +151,8 @@ def union_rects_with_count(list[IntervalWithCount] intervals) -> list[IntervalWi break return max_key - for x, start_point, y_type, count_event in events: + for x, start_point, interval in events: + y_type, count_event = interval.type, interval.count y = order.index(y_type) if start_point: y_max = get_y_max() @@ -319,6 +259,46 @@ def merge_adjacent_intervals(intervals: list[IntervalWithCount]) -> list[Interva return merged_intervals + +def intersect_rects(list[Interval] intervals) -> list[Interval]: + cdef double x_min = -np.inf + cdef signed char y_min = SCHAR_MAX + cdef double end_point = np.inf + + if not len(intervals): + return [] + + order = IntervalType.intersection_priority() + events = intervals_to_events(intervals) + + for cur_x, start_point, interval in events: + y_max_type = interval.type + y_max = order.index(y_max_type) + + if start_point: + if end_point < cur_x: + # we already hit an endpoint and here starts a new one, so the intersection is empty + return [] + + if y_max < y_min: + y_min = y_max + + if cur_x > x_min: + # this point is further than the previously found one, so reset the intersection's start point + x_min = cur_x + + else: + # we found and endpoint + if cur_x > end_point: + # this endpoint lies behind another endpoint, we know we can stop + if x_min > end_point - 1: + return [] + return [Interval(lower=x_min, upper=end_point - 1, type=order[y_min])] + end_point = cur_x + + return [Interval(lower=x_min, upper=end_point - 1, type=order[y_min])] + + def intersect_interval_lists( left: list[Interval], right: list[Interval] ) -> list[Interval]: @@ -438,15 +418,15 @@ def find_overlapping_windows( active_intervals = 0 for window_event, interval_event in interleaved_events(): if window_event: - time, open_, interval_type = window_event - window_open(time, interval_type) if open_ else window_close(time) + time, open_, interval = window_event + window_open(time, interval.type) if open_ else window_close(time) else: - time, open_, type_ = interval_event + time, open_, interval = interval_event active_intervals += (1 if open_ else -1) if active_intervals == 0: # 1 -> 0 transition interval_unsatisfied(time) elif active_intervals == 1: # 0 -> 1 transition - interval_satisfied(time, type_) + interval_satisfied(time, interval.type) # Return the list of unique intersecting windows return intersecting_windows @@ -492,13 +472,14 @@ def find_rectangles_with_count(all_intervals: list[list[Interval]]) -> list[Inte # counts change. counts = dict() previous_time = events[0][0] - for (time, open_, interval_type) in events: + for (time, open_, interval) in events: if previous_time is None: previous_time = time elif not previous_time == time: add_segment(previous_time, time, copy.copy(counts)) previous_time = time + interval_type = interval.type old_count = counts.get(interval_type, 0) counts[interval_type] = old_count + (1 if open_ else -1) diff --git a/execution_engine/task/process/rectangle_python.py b/execution_engine/task/process/rectangle_python.py index b87db0a2..cd15da22 100644 --- a/execution_engine/task/process/rectangle_python.py +++ b/execution_engine/task/process/rectangle_python.py @@ -5,12 +5,28 @@ import numpy as np from sortedcontainers import SortedDict, SortedList -from execution_engine.task.process import Interval, IntervalWithCount, IntervalWithTypeCounts +from execution_engine.task.process import Interval, IntervalWithCount, IntervalWithTypeCounts, AnyInterval from execution_engine.util.interval import IntervalType MODULE_IMPLEMENTATION = "python" +def intervals_to_events( + intervals: list[AnyInterval], closing_offset: int = 1 +) -> list[tuple[int, bool, AnyInterval]]: + """ + Converts the intervals to a list of events. + + The events are a sorted list of the opening/closing points of all rectangles. + + :param intervals: The intervals. + :return: The events. + """ + events = [ (i.lower, True, i) for i in intervals ] \ + + [ (i.upper + closing_offset, False, i) for i in intervals ] + return sorted(events,key=lambda i: i[0]) + + def union_rects(intervals: list[Interval]) -> list[Interval]: """ Unions the intervals. @@ -30,7 +46,8 @@ def union_rects(intervals: list[Interval]) -> list[Interval]: cur_x = -np.inf open_y = SortedList() - for x_min, start_point, y_max in events: + for x_min, start_point, interval in events: + y_max = interval.type if x_min > cur_x: # previously unvisited x cur_x = x_min @@ -99,7 +116,7 @@ def union_rects_with_count( return [] with IntervalType.union_order(): - events = intervals_with_count_to_events(intervals) + events = intervals_to_events(intervals) union = [] @@ -119,7 +136,8 @@ def get_y_max() -> IntervalType | None: break return max_key - for x, start_point, y, count_event in events: + for x, start_point, interval in events: + y, count_event = interval.type, interval.count if start_point: y_max = get_y_max() @@ -249,7 +267,8 @@ def intersect_rects(intervals: list[Interval]) -> list[Interval]: y_min = np.inf end_point = np.inf - for cur_x, start_point, y_max in events: + for cur_x, start_point, interval in events: + y_max = interval.type if start_point: if end_point < cur_x: # we already hit an endpoint and here starts a new one, so the intersection is empty @@ -274,46 +293,6 @@ def intersect_rects(intervals: list[Interval]) -> list[Interval]: return [Interval(lower=x_min, upper=end_point - 1, type=y_min)] -def intervals_to_events( - intervals: list[Interval], closing_offset: int = 1 -) -> list[tuple[int, bool, IntervalType]]: - """ - Converts the intervals to a list of events. - - The events are a sorted list of the opening/closing points of all rectangles. - - :param intervals: The intervals. - :return: The events. - """ - events = [(i.lower, True, i.type) for i in intervals] + [ - (i.upper + closing_offset, False, i.type) for i in intervals - ] - return sorted( - events, - key=lambda i: (i[0]), - ) - - -def intervals_with_count_to_events( - intervals: list[IntervalWithCount], -) -> list[tuple[int, bool, IntervalType, int]]: - """ - Converts the intervals to a list of events. - - The events are a sorted list of the opening/closing points of all rectangles. - - :param intervals: The intervals. - :return: The events. - """ - events = [(i.lower, True, i.type, i.count) for i in intervals] + [ - (i.upper + 1, False, i.type, i.count) for i in intervals - ] - return sorted( - events, - key=lambda i: (i[0]), - ) - - def intersect_interval_lists( left: list[Interval], right: list[Interval] ) -> list[Interval]: @@ -433,15 +412,15 @@ def interleaved_events(): active_intervals = 0 for window_event, interval_event in interleaved_events(): if window_event: - time, open_, interval_type = window_event - window_open(time, interval_type) if open_ else window_close(time) + time, open_, interval = window_event + window_open(time, interval.type) if open_ else window_close(time) else: - time, open_, type_ = interval_event + time, open_, interval = interval_event active_intervals += (1 if open_ else -1) if active_intervals == 0: # 1 -> 0 transition interval_unsatisfied(time) elif active_intervals == 1: # 0 -> 1 transition - interval_satisfied(time, type_) + interval_satisfied(time, interval.type) # Return the list of unique intersecting windows return intersecting_windows @@ -487,13 +466,14 @@ def add_segment(start, end, type_counts): # counts change. counts = dict() previous_time = events[0][0] - for (time, open_, interval_type) in events: + for (time, open_, interval) in events: if previous_time is None: previous_time = time elif not previous_time == time: add_segment(previous_time, time, copy.copy(counts)) previous_time = time + interval_type = interval.type old_count = counts.get(interval_type, 0) counts[interval_type] = old_count + (1 if open_ else -1) From 84d0f6146018d6d9ac84b9a24cc4b160e4d05b1c Mon Sep 17 00:00:00 2001 From: Jan Moringen Date: Wed, 12 Mar 2025 15:40:35 +0100 Subject: [PATCH 04/43] feat: add function find_rectangles in process module --- execution_engine/task/process/rectangle.py | 13 ++ .../task/process/rectangle_cython.pyx | 124 +++++++++++++++++- .../task/process/rectangle_python.py | 124 +++++++++++++++++- 3 files changed, 259 insertions(+), 2 deletions(-) diff --git a/execution_engine/task/process/rectangle.py b/execution_engine/task/process/rectangle.py index 71464542..c2565d91 100644 --- a/execution_engine/task/process/rectangle.py +++ b/execution_engine/task/process/rectangle.py @@ -799,9 +799,22 @@ def find_overlapping_personal_windows( return result def find_rectangles_with_count(data: list[PersonIntervals]) -> PersonIntervals: + # TODO(jmoringe): can this use _process_interval? if len(data) == 0: return {} else: keys = data[0].keys() return {key: _impl.find_rectangles_with_count([ intervals[key] for intervals in data ]) for key in keys} + +def find_rectangles(data: list[PersonIntervals], interval_constructor: Callable) -> PersonIntervals: + # TODO(jmoringe): can this use _process_interval? + if len(data) == 0: + return {} + else: + keys = set() + for track in data: + keys |= track.keys() + return {key: _impl.find_rectangles([ intervals.get(key, []) for intervals in data ], + interval_constructor) + for key in keys} diff --git a/execution_engine/task/process/rectangle_cython.pyx b/execution_engine/task/process/rectangle_cython.pyx index 294155a7..f85b645e 100644 --- a/execution_engine/task/process/rectangle_cython.pyx +++ b/execution_engine/task/process/rectangle_cython.pyx @@ -1,6 +1,8 @@ import copy import datetime -from functools import reduce +import typing +from functools import reduce, cmp_to_key +from typing import Callable cimport numpy as np @@ -484,3 +486,123 @@ def find_rectangles_with_count(all_intervals: list[list[Interval]]) -> list[Inte counts[interval_type] = old_count + (1 if open_ else -1) return result + +IntervalConstructor = Callable[[int, int, typing.List[AnyInterval]], AnyInterval] + +def find_rectangles(all_intervals: list[list[AnyInterval]], + interval_constructor: IntervalConstructor) \ + -> list[AnyInterval]: + """For multiple parallel "tracks" of intervals, identify temporal + intervals in which no change occurs on any "track". For each such + interval, call interval_constructor to determine how the interval + should be represented in the overall result. To this end, + interval_constructor receives a list "active" intervals the + elements of which are either None or an interval from + all_intervals and returns either None or an interval. The returned + None values and intervals are further processed into the overall + return value by merging adjacent intervals without "payload" + change. + + :param all_intervals: A list of intervals that are checked for overlap with the windows. + :param interval_constructor: A callable that accepts a start time, + an end time and a list of "active" + intervals and returns None or an + interval. The list of active + intervals has the same length as + all_intervals and each element is + either None or an element from the + corresponding list in all_intervals. + :return: A list of intervals computed by interval_constructor such + that adjacent intervals (i.e. without gaps between them) + have different "payloads". + """ + # Convert all intervals into a single list of events sorted by + # time. Multiple events at the same point in time can be problem + # here: If an interval open event and an interval close event on + # the same track happen at the same time (which happens for + # adjacent intervals on that track), we must order the close event + # before the open event, otherwise our tracking of active + # intervals would get confused. + track_count = len(all_intervals) + events = [ + (time, event, interval, j) + for j, intervals in enumerate(all_intervals) + for (time, event, interval) in intervals_to_events(intervals, closing_offset=0) + ] + def compare_events(event1, event2): + if event1[0] < event2[0]: # event1 is earlier + return -1 + elif event2[0] < event1[0]: # event2 is earlier + return 1 + elif (event1[3] == event2[3] # at the same time and on same track, + and event1[1] == False): # sort close events before open events + return -1 + else: # at the same time, but different tracks => any order is fine + return 1 + events.sort(key = cmp_to_key(compare_events)) + + # The result will be a list of intervals produced by + # interval_constructor. + result = [] + previous_end = None + def add_segment(start, end, original_intervals): + nonlocal previous_end + if previous_end == start and len(result) > 0: + result[-1] = result[-1]._replace(upper=previous_end - 1) + interval = interval_constructor(start, end, original_intervals) + if interval is not None: # interval type negative is implicit + result.append(interval) + previous_end = end + + def no_gap_between_points_in_time(end_time, start_time): + # Since points in time for intervals are quantized to whole + # seconds and intervals are closed (inclusive) for both start + # and end points, two adjacent intervals like + # [START_TIME1, 10:59:59] [11:00:00, END_TIME2] + # have no gap between them and can be considered a single + # continuous interval [START_TIME1, END_TIME2]. + return (end_time == start_time) or (end_time == start_time - 1) + + def is_same_result(active_intervals1, active_intervals2): + # When we have to decide whether to extend a result interval + # or start a new one, we compare the state for the existing + # result interval with the new state. The states are derived + # from the respective lists of active intervals by calling + # interval_constructor (with fake points in time) . + return (interval_constructor(0,0,active_intervals1) + == interval_constructor(0,0,active_intervals2)) + + active_intervals = [None] * track_count + def process_event_for_point_in_time(index, point_time): + nonlocal active_intervals + high_time = point_time + any_open = False + for i in range(index, len(events)): + time, open_, interval, track = events[i] + if no_gap_between_points_in_time(point_time, time): + high_time = max(high_time, time) + any_open |= open_ + else: + return i, time, active_intervals.copy(), high_time if any_open else high_time + 1 + active_intervals[track] = interval if open_ else None + return None, None, None, None + + if not len(events) == 0: + # Step through event "clusters" with a common point in time and + # emit result intervals with unchanged interval "payload". + index, time = 0, events[0][0] + interval_start_time = time + index, time, interval_start_state, high_time = process_event_for_point_in_time(index, time) + while index: + new_index, new_time, maybe_end_state, high_time = process_event_for_point_in_time(index, time) + # Diagram for this program point: + # |___potential_result_interval___|| | + # index new_index + # interval_start_time time new_time + # interval_start_state maybe_end_state + # high_time + if (maybe_end_state is None) or (not is_same_result(interval_start_state, maybe_end_state)): + add_segment(interval_start_time, time, interval_start_state) + interval_start_time, interval_start_state = high_time, maybe_end_state + index, time = new_index, new_time + return result diff --git a/execution_engine/task/process/rectangle_python.py b/execution_engine/task/process/rectangle_python.py index cd15da22..d6dc1ae4 100644 --- a/execution_engine/task/process/rectangle_python.py +++ b/execution_engine/task/process/rectangle_python.py @@ -1,6 +1,8 @@ import copy import datetime -from functools import reduce +import typing +from functools import reduce, cmp_to_key +from typing import Callable import numpy as np from sortedcontainers import SortedDict, SortedList @@ -478,3 +480,123 @@ def add_segment(start, end, type_counts): counts[interval_type] = old_count + (1 if open_ else -1) return result + +IntervalConstructor = Callable[[int, int, typing.List[AnyInterval]], AnyInterval] + +def find_rectangles(all_intervals: list[list[AnyInterval]], + interval_constructor: IntervalConstructor) \ + -> list[AnyInterval]: + """For multiple parallel "tracks" of intervals, identify temporal + intervals in which no change occurs on any "track". For each such + interval, call interval_constructor to determine how the interval + should be represented in the overall result. To this end, + interval_constructor receives a list "active" intervals the + elements of which are either None or an interval from + all_intervals and returns either None or an interval. The returned + None values and intervals are further processed into the overall + return value by merging adjacent intervals without "payload" + change. + + :param all_intervals: A list of intervals that are checked for overlap with the windows. + :param interval_constructor: A callable that accepts a start time, + an end time and a list of "active" + intervals and returns None or an + interval. The list of active + intervals has the same length as + all_intervals and each element is + either None or an element from the + corresponding list in all_intervals. + :return: A list of intervals computed by interval_constructor such + that adjacent intervals (i.e. without gaps between them) + have different "payloads". + """ + # Convert all intervals into a single list of events sorted by + # time. Multiple events at the same point in time can be problem + # here: If an interval open event and an interval close event on + # the same track happen at the same time (which happens for + # adjacent intervals on that track), we must order the close event + # before the open event, otherwise our tracking of active + # intervals would get confused. + track_count = len(all_intervals) + events = [ + (time, event, interval, j) + for j, intervals in enumerate(all_intervals) + for (time, event, interval) in intervals_to_events(intervals, closing_offset=0) + ] + def compare_events(event1, event2): + if event1[0] < event2[0]: # event1 is earlier + return -1 + elif event2[0] < event1[0]: # event2 is earlier + return 1 + elif (event1[3] == event2[3] # at the same time and on same track, + and event1[1] == False): # sort close events before open events + return -1 + else: # at the same time, but different tracks => any order is fine + return 1 + events.sort(key = cmp_to_key(compare_events)) + + # The result will be a list of intervals produced by + # interval_constructor. + result = [] + previous_end = None + def add_segment(start, end, original_intervals): + nonlocal previous_end + if previous_end == start and len(result) > 0: + result[-1] = result[-1]._replace(upper=previous_end - 1) + interval = interval_constructor(start, end, original_intervals) + if interval is not None: # interval type negative is implicit + result.append(interval) + previous_end = end + + def no_gap_between_points_in_time(end_time, start_time): + # Since points in time for intervals are quantized to whole + # seconds and intervals are closed (inclusive) for both start + # and end points, two adjacent intervals like + # [START_TIME1, 10:59:59] [11:00:00, END_TIME2] + # have no gap between them and can be considered a single + # continuous interval [START_TIME1, END_TIME2]. + return (end_time == start_time) or (end_time == start_time - 1) + + def is_same_result(active_intervals1, active_intervals2): + # When we have to decide whether to extend a result interval + # or start a new one, we compare the state for the existing + # result interval with the new state. The states are derived + # from the respective lists of active intervals by calling + # interval_constructor (with fake points in time) . + return (interval_constructor(0,0,active_intervals1) + == interval_constructor(0,0,active_intervals2)) + + active_intervals = [None] * track_count + def process_event_for_point_in_time(index, point_time): + nonlocal active_intervals + high_time = point_time + any_open = False + for i in range(index, len(events)): + time, open_, interval, track = events[i] + if no_gap_between_points_in_time(point_time, time): + high_time = max(high_time, time) + any_open |= open_ + else: + return i, time, active_intervals.copy(), high_time if any_open else high_time + 1 + active_intervals[track] = interval if open_ else None + return None, None, None, None + + if not len(events) == 0: + # Step through event "clusters" with a common point in time and + # emit result intervals with unchanged interval "payload". + index, time = 0, events[0][0] + interval_start_time = time + index, time, interval_start_state, high_time = process_event_for_point_in_time(index, time) + while index: + new_index, new_time, maybe_end_state, high_time = process_event_for_point_in_time(index, time) + # Diagram for this program point: + # |___potential_result_interval___|| | + # index new_index + # interval_start_time time new_time + # interval_start_state maybe_end_state + # high_time + if (maybe_end_state is None) or (not is_same_result(interval_start_state, maybe_end_state)): + add_segment(interval_start_time, time, interval_start_state) + interval_start_time, interval_start_state = high_time, maybe_end_state + index, time = new_index, new_time + return result From 9fecc6744ee6f76373f924db37c22b8d84f97701 Mon Sep 17 00:00:00 2001 From: Jan Moringen Date: Thu, 6 Mar 2025 14:07:34 +0100 Subject: [PATCH 05/43] feat: allow saving additional interval attributes in result_interval table --- execution_engine/omop/db/celida/tables.py | 4 +++- execution_engine/omop/db/celida/views.py | 1 + execution_engine/task/process/__init__.py | 20 +++++++++++++++++++- execution_engine/task/process/rectangle.py | 20 +++++++------------- execution_engine/task/task.py | 16 +++++++++++++--- execution_engine/util/types.py | 3 ++- 6 files changed, 45 insertions(+), 19 deletions(-) diff --git a/execution_engine/omop/db/celida/tables.py b/execution_engine/omop/db/celida/tables.py index 9f85b407..b434d038 100644 --- a/execution_engine/omop/db/celida/tables.py +++ b/execution_engine/omop/db/celida/tables.py @@ -169,7 +169,9 @@ class ResultInterval(Base): # noqa: D101 interval_start: Mapped[datetime] interval_end: Mapped[datetime] interval_type = mapped_column(IntervalTypeEnum) - + interval_ratio: Mapped[float] = mapped_column( + nullable=True + ) execution_run: Mapped["ExecutionRun"] = relationship( primaryjoin="ResultInterval.run_id == ExecutionRun.run_id", ) diff --git a/execution_engine/omop/db/celida/views.py b/execution_engine/omop/db/celida/views.py index 19f547da..ffafdfa8 100644 --- a/execution_engine/omop/db/celida/views.py +++ b/execution_engine/omop/db/celida/views.py @@ -199,6 +199,7 @@ def interval_result_view() -> Select: rri.c.interval_type, rri.c.interval_start, rri.c.interval_end, + rri.c.interval_ratio, ) .select_from(rri) .outerjoin(pip, (rri.c.pi_pair_id == pip.c.pi_pair_id)) diff --git a/execution_engine/task/process/__init__.py b/execution_engine/task/process/__init__.py index 5f4bf06e..b1ace933 100644 --- a/execution_engine/task/process/__init__.py +++ b/execution_engine/task/process/__init__.py @@ -1,3 +1,4 @@ +import copy import importlib import os import sys @@ -44,4 +45,21 @@ def get_processing_module( AnyInterval = Interval | IntervalWithCount | IntervalWithTypeCounts GeneralizedInterval = None | AnyInterval -TInterval = TypeVar('TInterval', bound = AnyInterval) +TInterval = TypeVar("TInterval", bound=AnyInterval) + +def interval_like(interval: TInterval, start: int, end: int) -> TInterval: + """ + Return a copy of the given interval with its lower and upper bounds replaced. + + Args: + interval (I): The interval to copy. Must be one of Interval, IntervalWithCount, or IntervalWithTypeCounts. + start (datetime): The new lower bound. + end (datetime): The new upper bound. + + Returns: + I: A copy of the interval with updated lower and upper bounds. + """ + + return copy.copy(interval)._replace( + lower=start, upper=end + ) # type: ignore[return-value]m diff --git a/execution_engine/task/process/rectangle.py b/execution_engine/task/process/rectangle.py index c2565d91..d3077a4c 100644 --- a/execution_engine/task/process/rectangle.py +++ b/execution_engine/task/process/rectangle.py @@ -2,7 +2,7 @@ import importlib import logging import os -from typing import Callable, cast +from typing import Callable, cast, List import numpy as np import pendulum @@ -12,7 +12,7 @@ from execution_engine.util.interval import IntervalType, interval_datetime from execution_engine.util.types import TimeRange -from . import Interval, IntervalWithCount +from . import Interval, IntervalWithCount, AnyInterval, GeneralizedInterval, interval_like PROCESS_RECTANGLE_VERSION = os.getenv("PROCESS_RECTANGLE_VERSION", "auto") @@ -566,17 +566,11 @@ def mask_intervals( for person_id, intervals in mask.items() } - result = {} - for person_id in data: - # intersect every interval in data with every interval in mask - person_result = _impl.intersect_interval_lists( - data[person_id], person_mask[person_id] - ) - if not person_result: - continue - - result[person_id] = person_result - + def intersection_interval(start: int, end: int, intervals: List[GeneralizedInterval]) -> GeneralizedInterval: + left_interval, right_interval = intervals + if left_interval is not None and right_interval is not None: + return interval_like(right_interval, start, end) + result = find_rectangles([person_mask, data], intersection_interval) return result diff --git a/execution_engine/task/task.py b/execution_engine/task/task.py index a9b62cd7..552e0955 100644 --- a/execution_engine/task/task.py +++ b/execution_engine/task/task.py @@ -578,6 +578,18 @@ def store_result_in_db( run_id=bind_params["run_id"], cohort_category=self.category, ) + def interval_data(interval): + data = dict( + interval_start=interval.lower, + interval_end=interval.upper, + interval_type=interval.type, + ) + if isinstance(interval, Interval): + data["interval_ratio"] = None + else: + assert isinstance(interval, IntervalWithCount) + data["interval_ratio"] = interval.count + return data try: with get_engine().begin() as conn: @@ -586,9 +598,7 @@ def store_result_in_db( [ { "person_id": person_id, - "interval_start": normalized_interval.lower, - "interval_end": normalized_interval.upper, - "interval_type": normalized_interval.type, + **interval_data(normalized_interval), **params, } for person_id, intervals in result.items() diff --git a/execution_engine/util/types.py b/execution_engine/util/types.py index cf099700..c085ea79 100644 --- a/execution_engine/util/types.py +++ b/execution_engine/util/types.py @@ -13,8 +13,9 @@ ) from execution_engine.util.value import ValueNumber, ValueNumeric from execution_engine.util.value.time import ValueCount, ValueDuration, ValuePeriod +from execution_engine.task.process import AnyInterval -PersonIntervals = dict[int, Any] +PersonIntervals = dict[int, AnyInterval] class TimeRange(BaseModel): From e71ea223907f9f33c077e9b55e6a423d986b9e22 Mon Sep 17 00:00:00 2001 From: Jan Moringen Date: Tue, 11 Mar 2025 19:47:49 +0100 Subject: [PATCH 06/43] feat: maybe unnecessary change in mask_intervals --- execution_engine/task/process/rectangle.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/execution_engine/task/process/rectangle.py b/execution_engine/task/process/rectangle.py index d3077a4c..92cbacda 100644 --- a/execution_engine/task/process/rectangle.py +++ b/execution_engine/task/process/rectangle.py @@ -563,7 +563,7 @@ def mask_intervals( ) for interval in intervals ] - for person_id, intervals in mask.items() + for person_id, intervals in mask.items() if person_id in data } def intersection_interval(start: int, end: int, intervals: List[GeneralizedInterval]) -> GeneralizedInterval: From 593e624c74c9a7d0da1884a3dce067bd3f21ede6 Mon Sep 17 00:00:00 2001 From: Jan Moringen Date: Thu, 13 Mar 2025 14:28:56 +0100 Subject: [PATCH 07/43] feat: in CappedMinCount, produce intervals with "ratio" information Also rewrite CappedMinCount to use find_rectangles. Co-authored-by: Gregor Lichtner --- execution_engine/task/process/__init__.py | 5 +- execution_engine/task/process/rectangle.py | 9 - .../task/process/rectangle_cython.pyx | 59 +---- .../task/process/rectangle_python.py | 59 +---- execution_engine/task/task.py | 51 ++-- .../combination/test_temporal_combination.py | 240 ++++++++++++++++++ 6 files changed, 273 insertions(+), 150 deletions(-) diff --git a/execution_engine/task/process/__init__.py b/execution_engine/task/process/__init__.py index b1ace933..9d997c41 100644 --- a/execution_engine/task/process/__init__.py +++ b/execution_engine/task/process/__init__.py @@ -40,9 +40,8 @@ def get_processing_module( Interval = namedtuple("Interval", ["lower", "upper", "type"]) IntervalWithCount = namedtuple("IntervalWithCount", ["lower", "upper", "type", "count"]) -IntervalWithTypeCounts = namedtuple("IntervalWithTypeCounts", ["lower", "upper", "counts"]) -AnyInterval = Interval | IntervalWithCount | IntervalWithTypeCounts +AnyInterval = Interval | IntervalWithCount GeneralizedInterval = None | AnyInterval TInterval = TypeVar("TInterval", bound=AnyInterval) @@ -52,7 +51,7 @@ def interval_like(interval: TInterval, start: int, end: int) -> TInterval: Return a copy of the given interval with its lower and upper bounds replaced. Args: - interval (I): The interval to copy. Must be one of Interval, IntervalWithCount, or IntervalWithTypeCounts. + interval (I): The interval to copy. Must be one of Interval or IntervalWithCount. start (datetime): The new lower bound. end (datetime): The new upper bound. diff --git a/execution_engine/task/process/rectangle.py b/execution_engine/task/process/rectangle.py index 92cbacda..4b1da00e 100644 --- a/execution_engine/task/process/rectangle.py +++ b/execution_engine/task/process/rectangle.py @@ -792,15 +792,6 @@ def find_overlapping_personal_windows( return result -def find_rectangles_with_count(data: list[PersonIntervals]) -> PersonIntervals: - # TODO(jmoringe): can this use _process_interval? - if len(data) == 0: - return {} - else: - keys = data[0].keys() - return {key: _impl.find_rectangles_with_count([ intervals[key] for intervals in data ]) - for key in keys} - def find_rectangles(data: list[PersonIntervals], interval_constructor: Callable) -> PersonIntervals: # TODO(jmoringe): can this use _process_interval? if len(data) == 0: diff --git a/execution_engine/task/process/rectangle_cython.pyx b/execution_engine/task/process/rectangle_cython.pyx index f85b645e..ed7dd47a 100644 --- a/execution_engine/task/process/rectangle_cython.pyx +++ b/execution_engine/task/process/rectangle_cython.pyx @@ -1,7 +1,5 @@ -import copy -import datetime import typing -from functools import reduce, cmp_to_key +from functools import cmp_to_key from typing import Callable cimport numpy as np @@ -9,7 +7,7 @@ cimport numpy as np import numpy as np from sortedcontainers import SortedDict -from execution_engine.task.process import Interval, IntervalWithCount, IntervalWithTypeCounts, AnyInterval +from execution_engine.task.process import Interval, IntervalWithCount, AnyInterval from execution_engine.util.interval import IntervalType DEF SCHAR_MIN = -128 @@ -433,59 +431,6 @@ def find_overlapping_windows( # Return the list of unique intersecting windows return intersecting_windows -def find_rectangles_with_count(all_intervals: list[list[Interval]]) -> list[IntervalWithTypeCounts]: - """ - For multiple parallel "tracks" of intervals, identify temporal - intervals in which no change occurs on any "track". For each such - interval, report the number of active intervals grouped by - interval type across all "tracks". When there is no interval on a - track for a given temporal interval, act as if a negative interval - was present there. - - :param all_intervals: A list of intervals that are checked for overlap with the windows. - :return: A list of windows that have any overlap with the intervals. - """ - # Convert all intervals into a list of events sorted by - # time. Multiple events at the same point in time are not a - # problem here: since we simply count the number of "active" - # intervals the result does not depend on the order in which we - # process the events. - track_count = len(all_intervals) - events = reduce(lambda acc, intervals: acc + intervals_to_events(intervals, closing_offset=0), - all_intervals, - []) - events.sort(key=lambda i: i[0]) - - # The result will be a list of intervals - result = [] - def add_segment(start, end, type_counts): - # We consider the period between 23:59:59 of a day and - # 00:00:00 of the following day to be empty. - if not (start == end - 1 and datetime.datetime.fromtimestamp(end).time() == datetime.time(0, 0, 0)): - # Assume implicit negative intervals: increase the count - # for the negative type as needed so that the overall - # count is equal to track_count. - missing = track_count - sum(type_counts.values()) - if missing > 0: - type_counts[IntervalType.NEGATIVE] = type_counts.get(IntervalType.NEGATIVE, 0) + missing - result.append(IntervalWithTypeCounts(start, end, type_counts)) - - # Step through events and emit result intervals whenever the - # counts change. - counts = dict() - previous_time = events[0][0] - for (time, open_, interval) in events: - if previous_time is None: - previous_time = time - elif not previous_time == time: - add_segment(previous_time, time, copy.copy(counts)) - previous_time = time - - interval_type = interval.type - old_count = counts.get(interval_type, 0) - counts[interval_type] = old_count + (1 if open_ else -1) - - return result IntervalConstructor = Callable[[int, int, typing.List[AnyInterval]], AnyInterval] diff --git a/execution_engine/task/process/rectangle_python.py b/execution_engine/task/process/rectangle_python.py index d6dc1ae4..0a2fa187 100644 --- a/execution_engine/task/process/rectangle_python.py +++ b/execution_engine/task/process/rectangle_python.py @@ -1,13 +1,11 @@ -import copy -import datetime import typing -from functools import reduce, cmp_to_key +from functools import cmp_to_key from typing import Callable import numpy as np from sortedcontainers import SortedDict, SortedList -from execution_engine.task.process import Interval, IntervalWithCount, IntervalWithTypeCounts, AnyInterval +from execution_engine.task.process import Interval, IntervalWithCount, AnyInterval from execution_engine.util.interval import IntervalType MODULE_IMPLEMENTATION = "python" @@ -427,59 +425,6 @@ def interleaved_events(): # Return the list of unique intersecting windows return intersecting_windows -def find_rectangles_with_count(all_intervals: list[list[Interval]]) -> list[IntervalWithTypeCounts]: - """ - For multiple parallel "tracks" of intervals, identify temporal - intervals in which no change occurs on any "track". For each such - interval, report the number of active intervals grouped by - interval type across all "tracks". When there is no interval on a - track for a given temporal interval, act as if a negative interval - was present there. - - :param all_intervals: A list of intervals that are checked for overlap with the windows. - :return: A list of windows that have any overlap with the intervals. - """ - # Convert all intervals into a list of events sorted by - # time. Multiple events at the same point in time are not a - # problem here: since we simply count the number of "active" - # intervals the result does not depend on the order in which we - # process the events. - track_count = len(all_intervals) - events = reduce(lambda acc, intervals: acc + intervals_to_events(intervals, closing_offset=0), - all_intervals, - []) - events.sort(key=lambda i: i[0]) - - # The result will be a list of intervals - result = [] - def add_segment(start, end, type_counts): - # We consider the period between 23:59:59 of a day and - # 00:00:00 of the following day to be empty. - if not (start == end - 1 and datetime.datetime.fromtimestamp(end).time() == datetime.time(0, 0, 0)): - # Assume implicit negative intervals: increase the count - # for the negative type as needed so that the overall - # count is equal to track_count. - missing = track_count - sum(type_counts.values()) - if missing > 0: - type_counts[IntervalType.NEGATIVE] = type_counts.get(IntervalType.NEGATIVE, 0) + missing - result.append(IntervalWithTypeCounts(start, end, type_counts)) - - # Step through events and emit result intervals whenever the - # counts change. - counts = dict() - previous_time = events[0][0] - for (time, open_, interval) in events: - if previous_time is None: - previous_time = time - elif not previous_time == time: - add_segment(previous_time, time, copy.copy(counts)) - previous_time = time - - interval_type = interval.type - old_count = counts.get(interval_type, 0) - counts[interval_type] = old_count + (1 if open_ else -1) - - return result IntervalConstructor = Callable[[int, int, typing.List[AnyInterval]], AnyInterval] diff --git a/execution_engine/task/task.py b/execution_engine/task/task.py index 552e0955..601f65c4 100644 --- a/execution_engine/task/task.py +++ b/execution_engine/task/task.py @@ -3,6 +3,7 @@ import json import logging from enum import Enum, auto +from typing import List from sqlalchemy.exc import DBAPIError, IntegrityError, ProgrammingError, SQLAlchemyError @@ -13,7 +14,14 @@ from execution_engine.omop.db.celida.tables import ResultInterval from execution_engine.omop.sqlclient import OMOPSQLClient from execution_engine.settings import get_config -from execution_engine.task.process import Interval, get_processing_module +from execution_engine.task.process import ( + Interval, + IntervalWithCount, + AnyInterval, + GeneralizedInterval, + get_processing_module, + interval_like, +) from execution_engine.util.interval import IntervalType from execution_engine.util.types import PersonIntervals, TimeRange @@ -294,31 +302,26 @@ def handle_binary_logical_operator( max_count=self.expr.count_max, ) elif isinstance(self.expr, logic.CappedCount): - intervals_with_count = process.find_rectangles_with_count(data) - result = dict() - for key, intervals in intervals_with_count.items(): - - key_result = [] - + def interval_counts(start: int, end: int, intervals: List[AnyInterval]) -> GeneralizedInterval: + positive_count = 0 + not_applicable_count = 0 for interval in intervals: - counts = interval.counts - not_applicable_count = counts.get(IntervalType.NOT_APPLICABLE, 0) + if interval is None or interval.type == IntervalType.NEGATIVE: + pass + elif interval.type == IntervalType.POSITIVE: + positive_count += 1 + elif interval.type == IntervalType.NOT_APPLICABLE: + not_applicable_count += 1 + # we require at least one positive interval to be present in any case (hence the max(1, ...)) + effective_count_min = max(1, self.expr.count_min - not_applicable_count) + if positive_count >= effective_count_min: + effective_type = IntervalType.POSITIVE + else: + effective_type = IntervalType.NEGATIVE + ratio = positive_count / effective_count_min + return IntervalWithCount(start, end, effective_type, ratio) + return process.find_rectangles(data, interval_counts) - # we require at least one positive interval to be present in any case (hence the max(1, ...)) - effective_count_min = max( - 1, self.expr.count_min - not_applicable_count - ) - positive_count = counts.get(IntervalType.POSITIVE, 0) - effective_type = ( - IntervalType.POSITIVE - if positive_count >= effective_count_min - else IntervalType.NEGATIVE - ) - key_result.append( - Interval(interval.lower, interval.upper, effective_type) - ) - result[key] = key_result - return result elif isinstance(self.expr, logic.AllOrNone): raise NotImplementedError("AllOrNone is not implemented yet.") else: diff --git a/tests/execution_engine/omop/criterion/combination/test_temporal_combination.py b/tests/execution_engine/omop/criterion/combination/test_temporal_combination.py index 7501c6fa..280fb64a 100644 --- a/tests/execution_engine/omop/criterion/combination/test_temporal_combination.py +++ b/tests/execution_engine/omop/criterion/combination/test_temporal_combination.py @@ -1902,3 +1902,243 @@ def test_at_least_combination_on_database_no_measurements( ) ) assert result_tuples == expected[person.person_id] + + +class TestIntervalRatio(TestCriterionCombinationDatabase): + """This test ensures that counting criteria with minimum count + thresholds adapt to the temporal interval of the population + criterion. + + As a concrete test case, this class applies an intervention + criterion that requires a procedure to be performed in at least + two of three shifts for each day. However, if the hospital stay + ends during a given day, shifts on that day but outside the + hospital stay should not count towards the threshold of the + criterion. + """ + + @pytest.fixture + def observation_window(self) -> TimeRange: + return TimeRange( + name="observation", start="2025-02-18 13:55:00Z", end="2025-02-23 11:00:00Z" + ) + + def patient_events(self, db_session, visit_occurrence): + c1 = create_condition( + vo=visit_occurrence, + condition_concept_id=concept_covid19.concept_id, + condition_start_datetime=pendulum.parse("2025-02-19 08:00:00+01:00"), + condition_end_datetime=pendulum.parse("2025-02-23 02:00:00+01:00"), + ) + # One screen on the 19th + e1_night = create_procedure( + vo=visit_occurrence, + procedure_concept_id=concept_delir_screening.concept_id, + start_datetime=pendulum.parse("2025-02-19 23:00:00+01:00"), + end_datetime=pendulum.parse("2025-02-19 23:01:00+01:00"), + ) + # Two screen on the 20th + e2_morn = create_procedure( + vo=visit_occurrence, + procedure_concept_id=concept_delir_screening.concept_id, + start_datetime=pendulum.parse("2025-02-20 08:00:00+01:00"), + end_datetime=pendulum.parse("2025-02-20 08:01:00+01:00"), + ) + e2_late = create_procedure( + vo=visit_occurrence, + procedure_concept_id=concept_delir_screening.concept_id, + start_datetime=pendulum.parse("2025-02-20 15:00:00+01:00"), + end_datetime=pendulum.parse("2025-02-20 15:01:00+01:00"), + ) + # Three screenings on the 21st + e3_morn = create_procedure( + vo=visit_occurrence, + procedure_concept_id=concept_delir_screening.concept_id, + start_datetime=pendulum.parse("2025-02-21 08:00:00+01:00"), + end_datetime=pendulum.parse("2025-02-21 08:01:00+01:00"), + ) + e3_late = create_procedure( + vo=visit_occurrence, + procedure_concept_id=concept_delir_screening.concept_id, + start_datetime=pendulum.parse("2025-02-21 15:00:00+01:00"), + end_datetime=pendulum.parse("2025-02-21 15:01:00+01:00"), + ) + e3_night = create_procedure( + vo=visit_occurrence, + procedure_concept_id=concept_delir_screening.concept_id, + start_datetime=pendulum.parse("2025-02-21 23:00:00+01:00"), + end_datetime=pendulum.parse("2025-02-21 23:01:00+01:00"), + ) + # Four screenings on the 22st + e4_night_pre = create_procedure( + vo=visit_occurrence, + procedure_concept_id=concept_delir_screening.concept_id, + start_datetime=pendulum.parse("2025-02-22 01:00:00+01:00"), + end_datetime=pendulum.parse("2025-02-22 01:01:00+01:00"), + ) + e4_morn = create_procedure( + vo=visit_occurrence, + procedure_concept_id=concept_delir_screening.concept_id, + start_datetime=pendulum.parse("2025-02-22 08:00:00+01:00"), + end_datetime=pendulum.parse("2025-02-22 08:01:00+01:00"), + ) + e4_late = create_procedure( + vo=visit_occurrence, + procedure_concept_id=concept_delir_screening.concept_id, + start_datetime=pendulum.parse("2025-02-22 15:00:00+01:00"), + end_datetime=pendulum.parse("2025-02-22 15:01:00+01:00"), + ) + e4_night = create_procedure( + vo=visit_occurrence, + procedure_concept_id=concept_delir_screening.concept_id, + start_datetime=pendulum.parse("2025-02-22 23:00:00+01:00"), + end_datetime=pendulum.parse("2025-02-22 23:01:00+01:00"), + ) + db_session.add_all( + [ + c1, + e1_night, + e2_morn, + e2_late, + e3_morn, + e3_late, + e3_night, + e4_night_pre, + e4_morn, + e4_late, + e4_night, + ] + ) + db_session.commit() + + @pytest.mark.parametrize( + "population,intervention,expected", + [ + ( + LogicalCriterionCombination.And(c2), # population + LogicalCriterionCombination.CappedAtLeast( + *[ + FixedWindowTemporalIndicatorCombination.Day( + criterion=shift_class(criterion=delir_screening), + ) + for shift_class in [ + FixedWindowTemporalIndicatorCombination.NightShiftAfterMidnight, + FixedWindowTemporalIndicatorCombination.MorningShift, + FixedWindowTemporalIndicatorCombination.AfternoonShift, + FixedWindowTemporalIndicatorCombination.NightShiftBeforeMidnight, + ] + ], + threshold=4, + ), + { + 1: [ + # The criterion should be fulfilled on the day + # before the discharge and on the day of the + # discharge even though the actual number of + # screenings on the latter day is just 1. + ( + IntervalType.NOT_APPLICABLE, + "nan", # workaround, is actually really a nan value + pendulum.parse("2025-02-18 16:55:00Z"), + pendulum.parse("2025-02-19 07:59:59+01:00"), + ), + ( + IntervalType.NEGATIVE, + 1 / 3, # one this day, only 3 shifts are possible, + pendulum.parse("2025-02-19 08:00:00+01:00"), + pendulum.parse("2025-02-19 23:59:59+01:00"), + ), + ( + IntervalType.NEGATIVE, + 0.5, + pendulum.parse("2025-02-20 00:00:00+01:00"), + pendulum.parse("2025-02-20 23:59:59+01:00"), + ), + ( + IntervalType.NEGATIVE, + 0.75, + pendulum.parse("2025-02-21 00:00:00+01:00"), + pendulum.parse("2025-02-21 23:59:59+01:00"), + ), + ( + IntervalType.POSITIVE, + 1.0, + pendulum.parse("2025-02-22 00:00:00+01:00"), + pendulum.parse("2025-02-22 23:59:59+01:00"), + ), + ( + IntervalType.NEGATIVE, + 0, + pendulum.parse("2025-02-23 00:00:00+01:00"), + pendulum.parse("2025-02-23 02:00:00+01:00"), + ), + ( + IntervalType.NOT_APPLICABLE, + "nan", # workaround, is actually really a nan value + pendulum.parse("2025-02-23 02:00:01+01:00"), + pendulum.parse("2025-02-23 04:30:00Z"), + ), + ] + }, + ), + ], + ) + def test_interval_ratio_on_database( + self, + person, + db_session, + population, + intervention, + base_criterion, + expected, + observation_window, + criteria, + ): + persons = [person[0]] # only one person + vos = [ + create_visit( + person_id=person.person_id, + visit_start_datetime=observation_window.start + + datetime.timedelta(hours=3), + visit_end_datetime=observation_window.end + - datetime.timedelta(hours=6.5), + visit_concept_id=concepts.INTENSIVE_CARE, + ) + for person in persons + ] + + self.patient_events(db_session, vos[0]) + + db_session.add_all(vos) + db_session.commit() + + self.insert_criterion_combination( + db_session, population, intervention, base_criterion, observation_window + ) + + df = self.fetch_interval_result( + db_session, + pi_pair_id=self.pi_pair_id, + criterion_id=None, + category=CohortCategory.POPULATION_INTERVENTION, + ) + + for person in persons: + result = df.query(f"person_id=={person.person_id}") + result_tuples = list( + result[ + [ + "interval_type", + "interval_ratio", + "interval_start", + "interval_end", + ] + ] + .fillna("nan") + .itertuples(index=False, name=None) + ) + + for result_tuple, expected_tuple in zip( + result_tuples, expected[person.person_id] + ): + assert result_tuple == expected_tuple From d811f875304aad051a213cf0c9993724b2055525 Mon Sep 17 00:00:00 2001 From: Jan Moringen Date: Wed, 12 Mar 2025 18:33:17 +0100 Subject: [PATCH 08/43] feat: in NoDataPreservingAnd, preserve type and attributes of data intervals Also rewrite NoDataPreservingAnd to use find_rectangles --- execution_engine/task/task.py | 27 ++++++++++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/execution_engine/task/task.py b/execution_engine/task/task.py index 601f65c4..cd0a9dba 100644 --- a/execution_engine/task/task.py +++ b/execution_engine/task/task.py @@ -354,8 +354,19 @@ def handle_no_data_preserving_operator( ), "Dependency is not a NoDataPreservingAnd / NoDataPreservingOr expression." if isinstance(self.expr, logic.NoDataPreservingAnd): - result = process.intersect_intervals(data) - elif isinstance(self.expr, logic.NoDataPreservingOr): + def intersection_interval(start: int, end: int, intervals: List[GeneralizedInterval]) -> GeneralizedInterval: + with IntervalType.intersection_order(): + min_interval = min(intervals, key = lambda i: IntervalType.NEGATIVE if i is None else i.type) + if min_interval is not None: + return interval_like(min_interval, start, end) + else: + # Explicit representation of negative intervals is + # required here because the database views do not + # understand the implicit representation. + return Interval(start, end, IntervalType.NEGATIVE) + result = process.find_rectangles(data, intersection_interval) + else: + assert isinstance(self.expr, logic.NoDataPreservingOr) result = process.union_intervals(data) # todo: the only difference between this function and handle_binary_logical_operator is the following lines @@ -367,7 +378,17 @@ def handle_no_data_preserving_operator( interval_type=IntervalType.NEGATIVE, ) - result = process.concat_intervals([result, result_negative]) + def union_interval(start: int, end: int, intervals: List[GeneralizedInterval]) -> GeneralizedInterval: + with IntervalType.union_order(): + max_interval = max(intervals, key=lambda i: IntervalType.NEGATIVE if i is None else i.type) + if max_interval is not None: + return interval_like(max_interval, start, end) + else: + # Explicit representation of negative intervals is + # required here because the database views do not + # understand the implicit representation. + return Interval(start, end, IntervalType.NEGATIVE) + result = process.find_rectangles([result, result_negative], union_interval) return result From b659b090ddbbd8943891d535e255d06caff4cdf7 Mon Sep 17 00:00:00 2001 From: Jan Moringen Date: Wed, 12 Mar 2025 18:35:17 +0100 Subject: [PATCH 09/43] feat: in LeftDependentToggle, preserve type and attributes of data intervals Also rewrite LeftDependentToggle to use find_rectangles --- execution_engine/task/task.py | 50 +++++++++++++++-------------------- 1 file changed, 21 insertions(+), 29 deletions(-) diff --git a/execution_engine/task/task.py b/execution_engine/task/task.py index cd0a9dba..f605d519 100644 --- a/execution_engine/task/task.py +++ b/execution_engine/task/task.py @@ -456,36 +456,28 @@ def handle_left_dependent_toggle( # data[0] is the left dependency (i.e. P) # data[1] is the right dependency (i.e. I) - - data_p = process.select_type(left, IntervalType.POSITIVE) - + # observation_window_intervals extends the result to the + # correct temporal range; Its type is not important. + observation_window_intervals = {key: [Interval(observation_window.start.timestamp(), + observation_window.end.timestamp(), + IntervalType.POSITIVE)] + for key in left.keys()} if isinstance(self.expr, logic.LeftDependentToggle): - interval_type = IntervalType.NOT_APPLICABLE - elif isinstance(self.expr, logic.ConditionalFilter): - interval_type = IntervalType.NEGATIVE - - result_not_p = process.complementary_intervals( - data_p, - reference=base_data, - observation_window=observation_window, - interval_type=interval_type, - ) - - result_p_and_i = process.intersect_intervals([data_p, right]) - - result = process.concat_intervals([result_not_p, result_p_and_i]) - - # fill remaining time with NEGATIVE - result_no_data = process.complementary_intervals( - result, - reference=base_data, - observation_window=observation_window, - interval_type=IntervalType.NEGATIVE, - ) - - result = process.concat_intervals([result, result_no_data]) - - return result + fill_type = IntervalType.NOT_APPLICABLE + else: + assert isinstance(self.expr, logic.ConditionalFilter) + fill_type = IntervalType.NEGATIVE + + def new_interval(start: int, end: int, intervals: List[GeneralizedInterval]) -> GeneralizedInterval: + left_interval, right_interval, observation_window_ = intervals + if (left_interval is None) or not left_interval.type == IntervalType.POSITIVE : + # no left_interval or not positive -> use fill type + return Interval(start, end, fill_type) + elif right_interval is not None: + return interval_like(right_interval, start, end) + else: # left_interval but not right_interval -> implicit negative + return None + return process.find_rectangles([left, right, observation_window_intervals], new_interval) def handle_temporal_operator( self, data: list[PersonIntervals], observation_window: TimeRange From 91e37c52a2dd4d044b814f37b91305575720644d Mon Sep 17 00:00:00 2001 From: Jan Moringen Date: Thu, 13 Mar 2025 14:15:00 +0100 Subject: [PATCH 10/43] fix: change behavior of TemporalCount for partially applicable indicator windows Before this change, an indicator window which was partially within the population interval and did not satisfy the temporal count condition produced a "not applicable" result. With this change, such indicator windows produce a negative result. Also implement TemporalCount via find_rectangles and remove find_overlapping_windows. --- execution_engine/task/process/rectangle.py | 17 -- .../task/process/rectangle_cython.pyx | 105 -------- .../task/process/rectangle_python.py | 105 -------- execution_engine/task/task.py | 44 +++- .../task/process/test_rectangle.py | 232 ------------------ 5 files changed, 43 insertions(+), 460 deletions(-) diff --git a/execution_engine/task/process/rectangle.py b/execution_engine/task/process/rectangle.py index 4b1da00e..ffc1d655 100644 --- a/execution_engine/task/process/rectangle.py +++ b/execution_engine/task/process/rectangle.py @@ -741,23 +741,6 @@ def add_interval(interval_start, interval_end, interval_type): return intervals -def find_overlapping_windows( - windows: list[Interval], data: PersonIntervals -) -> PersonIntervals: - """ - Returns a list of windows that overlap with any interval in the intervals list. A window is included in the - result if it overlaps in any part with any of the given intervals, not just where they intersect. The entire - window is returned, not just the overlapping segment. - - Note that a single, common list of windows is used for all persons. - - :param windows: A list of windows, where each window is defined as an interval. - :param data: The dict with intervals that are checked for overlap with the windows. - :return: A list of windows that have any overlap with the intervals. - """ - return {key: _impl.find_overlapping_windows(windows, data[key]) for key in data} - - def find_overlapping_personal_windows( windows: PersonIntervals, data: PersonIntervals ) -> PersonIntervals: diff --git a/execution_engine/task/process/rectangle_cython.pyx b/execution_engine/task/process/rectangle_cython.pyx index ed7dd47a..4e497449 100644 --- a/execution_engine/task/process/rectangle_cython.pyx +++ b/execution_engine/task/process/rectangle_cython.pyx @@ -327,111 +327,6 @@ def union_interval_lists( return union_rects(left + right) -def find_overlapping_windows( - windows: list[Interval], intervals: list[Interval] -) -> list[Interval]: - """ - Returns a list of windows that overlap with any interval in the intervals list. A window is included in the - result if it overlaps in any part with any of the given intervals, not just where they intersect. The entire - window is returned, not just the overlapping segment. - - :param windows: A list of windows, where each window is defined as an interval. - :param intervals: A list of intervals that are checked for overlap with the windows. - :return: A list of windows that have any overlap with the intervals. - """ - # Convert all intervals and windows into events - window_events = intervals_to_events(windows, closing_offset=0) - interval_events = intervals_to_events(intervals, closing_offset=0) - - # Here we collect interval for the intersecting windows - intersecting_windows = [] - def add_segment(start, end, interval_type): - intersecting_windows.append(Interval(start, end, interval_type)) - - # State and "event handler" functions for state transitions: - # inside/not inside window, inside/not inside at least one - # interval. - previous_event = None - window_state = False - any_satisfied_in_window = False - satisfied_interval_type = None - def window_open(event_time, interval_type): - nonlocal previous_event, window_state, any_satisfied_in_window - assert not window_state - window_state = interval_type - if satisfied_interval_type is not None: - any_satisfied_in_window = satisfied_interval_type - previous_event = event_time - def window_close(event_time): - nonlocal previous_event, window_state, any_satisfied_in_window - assert window_state - if window_state == IntervalType.NOT_APPLICABLE: - interval_type = IntervalType.NOT_APPLICABLE - add_segment(previous_event, event_time, interval_type) - elif any_satisfied_in_window == False: - pass - else: - interval_type = any_satisfied_in_window - add_segment(previous_event, event_time, interval_type) - window_state = False - any_satisfied_in_window = False - previous_event = event_time - def interval_satisfied(event_time, interval_type): - nonlocal satisfied_interval_type, any_satisfied_in_window - # If we are inside a window, remember that we saw an interval - # of type interval_type. Do not overwrite previously seen - # higher priority types with lower priority types - if not (satisfied_interval_type in [IntervalType.POSITIVE, IntervalType.NEGATIVE]): - satisfied_interval_type = interval_type - if window_state: - # Priorities: POSITIVE > NEGATIVE > NO_DATA or NOT_APPLICABLE > no value - if any_satisfied_in_window == IntervalType.POSITIVE: - pass - elif any_satisfied_in_window == IntervalType.NEGATIVE: - if interval_type == IntervalType.POSITIVE: - any_satisfied_in_window = interval_type - else: - any_satisfied_in_window = interval_type - def interval_unsatisfied(event_time): - nonlocal satisfied_interval_type - satisfied_interval_type = None - - # Use two indices to traverse the two sorted lists of events in - # parallel. Call event handler functions for state transitions. - def interleaved_events(): - w_idx, i_idx = 0, 0 - while True: - window_event = window_events[w_idx] if w_idx < len(window_events) else None - interval_event = interval_events[i_idx] if i_idx < len(interval_events) else None - # When tied in terms of event time, use the following - # priority for reporting events: - # window open > interval open > interval close > window close - if window_event and (not interval_event or window_event[0] < interval_event[0] or ( - window_event[0] == interval_event[0] and window_event[1])): - w_idx += 1 - yield window_event, None - elif interval_event: - i_idx += 1 - yield None, interval_event - else: - break - active_intervals = 0 - for window_event, interval_event in interleaved_events(): - if window_event: - time, open_, interval = window_event - window_open(time, interval.type) if open_ else window_close(time) - else: - time, open_, interval = interval_event - active_intervals += (1 if open_ else -1) - if active_intervals == 0: # 1 -> 0 transition - interval_unsatisfied(time) - elif active_intervals == 1: # 0 -> 1 transition - interval_satisfied(time, interval.type) - - # Return the list of unique intersecting windows - return intersecting_windows - - IntervalConstructor = Callable[[int, int, typing.List[AnyInterval]], AnyInterval] def find_rectangles(all_intervals: list[list[AnyInterval]], diff --git a/execution_engine/task/process/rectangle_python.py b/execution_engine/task/process/rectangle_python.py index 0a2fa187..b4d39af6 100644 --- a/execution_engine/task/process/rectangle_python.py +++ b/execution_engine/task/process/rectangle_python.py @@ -321,111 +321,6 @@ def union_interval_lists(left: list[Interval], right: list[Interval]) -> list[In return union_rects(left + right) -def find_overlapping_windows( - windows: list[Interval], intervals: list[Interval] -) -> list[Interval]: - """ - Returns a list of windows that overlap with any interval in the intervals list. A window is included in the - result if it overlaps in any part with any of the given intervals, not just where they intersect. The entire - window is returned, not just the overlapping segment. - - :param windows: A list of windows, where each window is defined as an interval. - :param intervals: A list of intervals that are checked for overlap with the windows. - :return: A list of windows that have any overlap with the intervals. - """ - # Convert all intervals and windows into events - window_events = intervals_to_events(windows, closing_offset=0) - interval_events = intervals_to_events(intervals, closing_offset=0) - - # Here we collect interval for the intersecting windows - intersecting_windows = [] - def add_segment(start, end, interval_type): - intersecting_windows.append(Interval(start, end, interval_type)) - - # State and "event handler" functions for state transitions: - # inside/not inside window, inside/not inside at least one - # interval. - previous_event = None - window_state = False - any_satisfied_in_window = False - satisfied_interval_type = None - def window_open(event_time, interval_type): - nonlocal previous_event, window_state, any_satisfied_in_window - assert not window_state - window_state = interval_type - if satisfied_interval_type is not None: - any_satisfied_in_window = satisfied_interval_type - previous_event = event_time - def window_close(event_time): - nonlocal previous_event, window_state, any_satisfied_in_window - assert window_state - if window_state == IntervalType.NOT_APPLICABLE: - interval_type = IntervalType.NOT_APPLICABLE - add_segment(previous_event, event_time, interval_type) - elif any_satisfied_in_window == False: - pass - else: - interval_type = any_satisfied_in_window - add_segment(previous_event, event_time, interval_type) - window_state = False - any_satisfied_in_window = False - previous_event = event_time - def interval_satisfied(event_time, interval_type): - nonlocal satisfied_interval_type, any_satisfied_in_window - # If we are inside a window, remember that we saw an interval - # of type interval_type. Do not overwrite previously seen - # higher priority types with lower priority types - if not (satisfied_interval_type in [IntervalType.POSITIVE, IntervalType.NEGATIVE]): - satisfied_interval_type = interval_type - if window_state: - # Priorities: POSITIVE > NEGATIVE > NO_DATA or NOT_APPLICABLE > no value - if any_satisfied_in_window == IntervalType.POSITIVE: - pass - elif any_satisfied_in_window == IntervalType.NEGATIVE: - if interval_type == IntervalType.POSITIVE: - any_satisfied_in_window = interval_type - else: - any_satisfied_in_window = interval_type - def interval_unsatisfied(event_time): - nonlocal satisfied_interval_type - satisfied_interval_type = None - - # Use two indices to traverse the two sorted lists of events in - # parallel. Call event handler functions for state transitions. - def interleaved_events(): - w_idx, i_idx = 0, 0 - while True: - window_event = window_events[w_idx] if w_idx < len(window_events) else None - interval_event = interval_events[i_idx] if i_idx < len(interval_events) else None - # When tied in terms of event time, use the following - # priority for reporting events: - # window open > interval open > interval close > window close - if window_event and (not interval_event or window_event[0] < interval_event[0] or ( - window_event[0] == interval_event[0] and window_event[1])): - w_idx += 1 - yield window_event, None - elif interval_event: - i_idx += 1 - yield None, interval_event - else: - break - active_intervals = 0 - for window_event, interval_event in interleaved_events(): - if window_event: - time, open_, interval = window_event - window_open(time, interval.type) if open_ else window_close(time) - else: - time, open_, interval = interval_event - active_intervals += (1 if open_ else -1) - if active_intervals == 0: # 1 -> 0 transition - interval_unsatisfied(time) - elif active_intervals == 1: # 0 -> 1 transition - interval_satisfied(time, interval.type) - - # Return the list of unique intersecting windows - return intersecting_windows - - IntervalConstructor = Callable[[int, int, typing.List[AnyInterval]], AnyInterval] def find_rectangles(all_intervals: list[list[AnyInterval]], diff --git a/execution_engine/task/task.py b/execution_engine/task/task.py index f605d519..3339809f 100644 --- a/execution_engine/task/task.py +++ b/execution_engine/task/task.py @@ -551,7 +551,49 @@ def get_start_end_from_interval_type( timezone=get_config().timezone, ) - result = process.find_overlapping_windows(indicator_windows, data_p) + # Create a "temporary window interval" for each window + # interval. Associate with each temporary window interval + # all data intervals that overlap it. The association + # works by assigning a unique id to each temporary window + # interval. + ids = dict() # window_interval -> unique id + infos = dict() # unique id -> list of overlapping data intervals + def temporary_window_interval(start: int, end: int, intervals: List[AnyInterval]): + window_interval, data_interval = intervals + if window_interval is None or window_interval.type == IntervalType.NOT_APPLICABLE: + return Interval(start, end, IntervalType.NOT_APPLICABLE) + else: + window_id = ids.get(window_interval, len(ids)) + ids[window_interval] = window_id + info = infos.get(window_id, set()) + infos[window_id] = info + data_interval_type = data_interval.type if data_interval is not None else IntervalType.NEGATIVE + info.add(data_interval_type) + return IntervalWithCount(start, end, IntervalType.POSITIVE, window_id) + person_indicator_windows = { key: indicator_windows for key in data_p.keys() } + result = process.find_rectangles([ person_indicator_windows, data_p], temporary_window_interval) + # Turn the temporary window intervals into the final + # intervals by computing the interval types based on the + # respective overlapping data intervals. + def finalize_interval(interval): + if isinstance(interval, IntervalWithCount): + window_id = interval.count + data_intervals = infos[window_id] + # TODO(jmoringe): there should be a way to implement this with max(data_intervals) + if IntervalType.POSITIVE in data_intervals: + interval_type = IntervalType.POSITIVE + elif IntervalType.NEGATIVE in data_intervals: + interval_type = IntervalType.NEGATIVE + elif IntervalType.NOT_APPLICABLE in data_intervals: + interval_type = IntervalType.NOT_APPLICABLE + else: + assert IntervalType.NO_DATA in data_intervals + interval_type = IntervalType.NO_DATA + return Interval(interval.lower, interval.upper, interval_type) + else: + return interval + result = { key: [ finalize_interval(i) for i in intervals ] + for key, intervals in result.items() } return result diff --git a/tests/execution_engine/task/process/test_rectangle.py b/tests/execution_engine/task/process/test_rectangle.py index 97fd2c32..6e8526d4 100644 --- a/tests/execution_engine/task/process/test_rectangle.py +++ b/tests/execution_engine/task/process/test_rectangle.py @@ -3287,235 +3287,3 @@ def test_spanning_midnight_naive(self, timezone): intervals[1].upper == pendulum.parse("2023-07-03 04:00:00", tz=timezone).timestamp() ) - - -class TestFindOverlappingWindows(ProcessTest): - def test_no_intersection(self): - windows = [Interval(10, 20, T.POSITIVE)] - intervals = {1: [Interval(1, 5, T.POSITIVE)]} - assert self.process.find_overlapping_windows(windows, intervals)[1] == [] - - def test_full_overlap(self): - windows = [Interval(10, 20, T.POSITIVE)] - intervals = {1: [Interval(10, 20, T.POSITIVE)]} - assert self.process.find_overlapping_windows(windows, intervals)[1] == [ - Interval(10, 20, T.POSITIVE) - ] - - def test_partial_overlap_start(self): - windows = [Interval(10, 20, T.POSITIVE)] - intervals = {1: [Interval(5, 15, T.POSITIVE)]} - assert self.process.find_overlapping_windows(windows, intervals)[1] == [ - Interval(10, 20, T.POSITIVE) - ] - - def test_partial_overlap_end(self): - windows = [Interval(10, 20, T.POSITIVE)] - intervals = {1: [Interval(15, 25, T.POSITIVE)]} - assert self.process.find_overlapping_windows(windows, intervals)[1] == [ - Interval(10, 20, T.POSITIVE) - ] - - def test_multiple_overlapping_intervals(self): - windows = [Interval(10, 20, T.POSITIVE)] - intervals = {1: [Interval(5, 15, T.POSITIVE), Interval(15, 25, T.POSITIVE)]} - assert self.process.find_overlapping_windows(windows, intervals)[1] == [ - Interval(10, 20, T.POSITIVE) - ] - - def test_interval_exactly_on_window_boundaries(self): - windows = [Interval(10, 20, T.POSITIVE)] - intervals = {1: [Interval(20, 30, T.POSITIVE)]} - assert self.process.find_overlapping_windows(windows, intervals)[1] == [ - Interval(10, 20, T.POSITIVE) - ] - - def test_point_interval(self): - windows = [Interval(10, 20, T.POSITIVE)] - - intervals = {1: [Interval(9, 9, T.POSITIVE)]} - assert self.process.find_overlapping_windows(windows, intervals)[1] == [] - - intervals = {1: [Interval(10, 10, T.POSITIVE)]} - assert self.process.find_overlapping_windows(windows, intervals)[1] == [ - Interval(10, 20, T.POSITIVE) - ] - - intervals = {1: [Interval(10, 10, T.POSITIVE)] * 10} - assert self.process.find_overlapping_windows(windows, intervals)[1] == [ - Interval(10, 20, T.POSITIVE) - ] - - intervals = {1: [Interval(11, 11, T.POSITIVE)] * 10} - assert self.process.find_overlapping_windows(windows, intervals)[1] == [ - Interval(10, 20, T.POSITIVE) - ] - - intervals = {1: [Interval(20, 20, T.POSITIVE)]} - assert self.process.find_overlapping_windows(windows, intervals)[1] == [ - Interval(10, 20, T.POSITIVE) - ] - - intervals = {1: [Interval(20, 20, T.POSITIVE)] * 10} - assert self.process.find_overlapping_windows(windows, intervals)[1] == [ - Interval(10, 20, T.POSITIVE) - ] - - intervals = {1: [Interval(21, 21, T.POSITIVE)]} - assert self.process.find_overlapping_windows(windows, intervals)[1] == [] - - def test_multiple_opening_closing(self): - windows = [Interval(10, 20, T.POSITIVE), Interval(25, 35, T.POSITIVE)] - - intervals = {1: [Interval(10, 10 + i, T.POSITIVE) for i in range(5)]} - assert self.process.find_overlapping_windows(windows, intervals)[1] == [ - Interval(10, 20, T.POSITIVE) - ] - - intervals = {1: [Interval(20, 20 + i, T.POSITIVE) for i in range(4)]} - assert self.process.find_overlapping_windows(windows, intervals)[1] == [ - Interval(10, 20, T.POSITIVE) - ] - - intervals = {1: [Interval(10 - i, 10, T.POSITIVE) for i in range(5)]} - assert self.process.find_overlapping_windows(windows, intervals)[1] == [ - Interval(10, 20, T.POSITIVE) - ] - - intervals = {1: [Interval(20 - i, 20, T.POSITIVE) for i in range(5)]} - assert self.process.find_overlapping_windows(windows, intervals)[1] == [ - Interval(10, 20, T.POSITIVE) - ] - - intervals = { - 1: [Interval(20 - i, 20, T.POSITIVE) for i in range(5)] - + [Interval(21, 24, T.POSITIVE)] - } - assert self.process.find_overlapping_windows(windows, intervals)[1] == [ - Interval(10, 20, T.POSITIVE) - ] - - intervals = { - 1: [Interval(20 - i, 20 + i, T.POSITIVE) for i in range(5)] - + [Interval(21, 25, T.POSITIVE)] - } - assert self.process.find_overlapping_windows(windows, intervals)[1] == [ - Interval(10, 20, T.POSITIVE), - Interval(25, 35, T.POSITIVE), - ] - - @pytest.mark.parametrize( - "windows,intervals,expected", - [ - ([], [], []), - ([], [(10, 20)], []), - ([(10, 20)], [], []), - ([(10, 20)], [(5, 10)], [(10, 20)]), - ([(10, 20)], [(5, 15)], [(10, 20)]), - ([(10, 20)], [(5, 25)], [(10, 20)]), - ([(10, 20)], [(15, 25)], [(10, 20)]), - ([(10, 20)], [(15, 25), (5, 10)], [(10, 20)]), - ([(10, 20)], [(15, 25), (5, 10), (20, 30)], [(10, 20)]), - ([(10, 20), (30, 40)], [(15, 25), (5, 10), (20, 30)], [(10, 20), (30, 40)]), - ( - [(10, 20), (30, 40)], - [(15, 25), (5, 10), (20, 30), (35, 45)], - [(10, 20), (30, 40)], - ), - ([(10, 20), (21, 30)], [(5, 30)], [(10, 20), (21, 30)]), - ([(10, 20), (21, 30), (31, 40), (41, 50)], [(5, 30)], [(10, 20), (21, 30)]), - ], - ) - def test_intervals(self, windows, intervals, expected): - - windows = [Interval(w[0], w[1], T.POSITIVE) for w in windows] - intervals = {1: [Interval(w[0], w[1], T.POSITIVE) for w in intervals]} - expected = [Interval(w[0], w[1], T.POSITIVE) for w in expected] - - assert self.process.find_overlapping_windows(windows, intervals)[1] == expected - - @pytest.mark.parametrize("seed", [None, 42, 99, 123]) - def test_multiple_windows(self, seed): - - windows = [Interval(i + 5, i + 8, "positive") for i in range(0, 150, 10)] - intervals = [ - Interval(2, 20, "positive"), - Interval(26, 30, "positive"), - Interval(32, 37, "positive"), - Interval(46, 47, "positive"), - Interval(65, 69, "positive"), - Interval(75, 80, "positive"), - Interval(85, 86, "positive"), - Interval(95, 95, "positive"), - Interval(106, 106, "positive"), - Interval(116, 118, "positive"), - Interval(122, 128, "positive"), - Interval(138, 138, "positive"), - ] - - expected = [ - Interval(lower=5, upper=8, type="positive"), - Interval(lower=15, upper=18, type="positive"), - Interval(lower=25, upper=28, type="positive"), - Interval(lower=35, upper=38, type="positive"), - Interval(lower=45, upper=48, type="positive"), - Interval(lower=65, upper=68, type="positive"), - Interval(lower=75, upper=78, type="positive"), - Interval(lower=85, upper=88, type="positive"), - Interval(lower=95, upper=98, type="positive"), - Interval(lower=105, upper=108, type="positive"), - Interval(lower=115, upper=118, type="positive"), - Interval(lower=125, upper=128, type="positive"), - Interval(lower=135, upper=138, type="positive"), - ] - - # Shuffle the windows and intervals list - if seed is not None: - random.seed(seed) # Seed the random number generator for reproducibility - random.shuffle(windows) - random.shuffle(intervals) - - assert ( - self.process.find_overlapping_windows(windows, {1: intervals})[1] - == expected - ) - - @pytest.mark.parametrize("seed", [None, 42, 99, 123]) - def test_multiple_windows_with_multiple_patients(self, seed): - # Define windows - windows = [Interval(i + 5, i + 8, T.POSITIVE) for i in range(0, 150, 10)] - - # Define intervals for multiple patients - patient_intervals = { - 1: [Interval(2, 20, T.POSITIVE), Interval(46, 47, T.POSITIVE)], - 2: [Interval(75, 80, T.POSITIVE), Interval(85, 86, T.POSITIVE)], - 3: [Interval(95, 95, T.POSITIVE), Interval(106, 106, T.POSITIVE)], - } - - # Expected results for each patient - expected = { - 1: [ - Interval(5, 8, T.POSITIVE), - Interval(15, 18, T.POSITIVE), - Interval(45, 48, T.POSITIVE), - ], - 2: [Interval(75, 78, T.POSITIVE), Interval(85, 88, T.POSITIVE)], - 3: [Interval(95, 98, T.POSITIVE), Interval(105, 108, T.POSITIVE)], - } - - # Optionally shuffle the windows and intervals list - if seed is not None: - random.seed(seed) - random.shuffle(windows) - for intervals in patient_intervals.values(): - random.shuffle(intervals) - - # Run the test - result = self.process.find_overlapping_windows(windows, patient_intervals) - - # Sort results for comparison - for pid in expected: - result[pid].sort(key=lambda x: x.lower) - expected[pid].sort(key=lambda x: x.lower) - - assert result == expected From 4b4ae70c8077e033e9c9a0dc3800037d40f8e504 Mon Sep 17 00:00:00 2001 From: Jan Moringen Date: Tue, 18 Mar 2025 14:16:22 +0100 Subject: [PATCH 11/43] refactor: avoid intervals_to_events in find_rectangles This avoids sorting twice and some other work --- execution_engine/task/process/rectangle_cython.pyx | 3 ++- execution_engine/task/process/rectangle_python.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/execution_engine/task/process/rectangle_cython.pyx b/execution_engine/task/process/rectangle_cython.pyx index 4e497449..71fad0db 100644 --- a/execution_engine/task/process/rectangle_cython.pyx +++ b/execution_engine/task/process/rectangle_cython.pyx @@ -367,7 +367,8 @@ def find_rectangles(all_intervals: list[list[AnyInterval]], events = [ (time, event, interval, j) for j, intervals in enumerate(all_intervals) - for (time, event, interval) in intervals_to_events(intervals, closing_offset=0) + for interval in intervals # intervals_to_events(intervals, closing_offset=0) + for (time,event) in [(interval.lower, True), (interval.upper, False)] ] def compare_events(event1, event2): if event1[0] < event2[0]: # event1 is earlier diff --git a/execution_engine/task/process/rectangle_python.py b/execution_engine/task/process/rectangle_python.py index b4d39af6..082b181c 100644 --- a/execution_engine/task/process/rectangle_python.py +++ b/execution_engine/task/process/rectangle_python.py @@ -361,7 +361,8 @@ def find_rectangles(all_intervals: list[list[AnyInterval]], events = [ (time, event, interval, j) for j, intervals in enumerate(all_intervals) - for (time, event, interval) in intervals_to_events(intervals, closing_offset=0) + for interval in intervals # intervals_to_events(intervals, closing_offset=0) + for (time,event) in [(interval.lower, True), (interval.upper, False)] ] def compare_events(event1, event2): if event1[0] < event2[0]: # event1 is earlier From 364497347d679c55a774fecc23e664c887cc3c31 Mon Sep 17 00:00:00 2001 From: Jan Moringen Date: Tue, 18 Mar 2025 14:18:52 +0100 Subject: [PATCH 12/43] refactor: micro-optimizations in find_rectangles --- .../task/process/rectangle_cython.pyx | 32 +++++++++---------- .../task/process/rectangle_python.py | 32 +++++++++---------- 2 files changed, 30 insertions(+), 34 deletions(-) diff --git a/execution_engine/task/process/rectangle_cython.pyx b/execution_engine/task/process/rectangle_cython.pyx index 71fad0db..f1b02d5b 100644 --- a/execution_engine/task/process/rectangle_cython.pyx +++ b/execution_engine/task/process/rectangle_cython.pyx @@ -381,6 +381,7 @@ def find_rectangles(all_intervals: list[list[AnyInterval]], else: # at the same time, but different tracks => any order is fine return 1 events.sort(key = cmp_to_key(compare_events)) + event_count = len(events) # The result will be a list of intervals produced by # interval_constructor. @@ -395,15 +396,6 @@ def find_rectangles(all_intervals: list[list[AnyInterval]], result.append(interval) previous_end = end - def no_gap_between_points_in_time(end_time, start_time): - # Since points in time for intervals are quantized to whole - # seconds and intervals are closed (inclusive) for both start - # and end points, two adjacent intervals like - # [START_TIME1, 10:59:59] [11:00:00, END_TIME2] - # have no gap between them and can be considered a single - # continuous interval [START_TIME1, END_TIME2]. - return (end_time == start_time) or (end_time == start_time - 1) - def is_same_result(active_intervals1, active_intervals2): # When we have to decide whether to extend a result interval # or start a new one, we compare the state for the existing @@ -414,28 +406,34 @@ def find_rectangles(all_intervals: list[list[AnyInterval]], == interval_constructor(0,0,active_intervals2)) active_intervals = [None] * track_count - def process_event_for_point_in_time(index, point_time): - nonlocal active_intervals + def process_events_for_point_in_time(index, point_time): high_time = point_time any_open = False - for i in range(index, len(events)): + for i in range(index, event_count): time, open_, interval, track = events[i] - if no_gap_between_points_in_time(point_time, time): - high_time = max(high_time, time) + # Since points in time for intervals are quantized to whole + # seconds and intervals are closed (inclusive) for both start + # and end points, two adjacent intervals like + # [START_TIME1, 10:59:59] [11:00:00, END_TIME2] + # have no gap between them and can be considered a single + # continuous interval [START_TIME1, END_TIME2]. + if (point_time == time) or (point_time == time - 1): + if time > high_time: + high_time = time any_open |= open_ else: return i, time, active_intervals.copy(), high_time if any_open else high_time + 1 active_intervals[track] = interval if open_ else None return None, None, None, None - if not len(events) == 0: + if not event_count == 0: # Step through event "clusters" with a common point in time and # emit result intervals with unchanged interval "payload". index, time = 0, events[0][0] interval_start_time = time - index, time, interval_start_state, high_time = process_event_for_point_in_time(index, time) + index, time, interval_start_state, high_time = process_events_for_point_in_time(index, time) while index: - new_index, new_time, maybe_end_state, high_time = process_event_for_point_in_time(index, time) + new_index, new_time, maybe_end_state, high_time = process_events_for_point_in_time(index, time) # Diagram for this program point: # |___potential_result_interval___|| | # index new_index diff --git a/execution_engine/task/process/rectangle_python.py b/execution_engine/task/process/rectangle_python.py index 082b181c..b82f9039 100644 --- a/execution_engine/task/process/rectangle_python.py +++ b/execution_engine/task/process/rectangle_python.py @@ -375,6 +375,7 @@ def compare_events(event1, event2): else: # at the same time, but different tracks => any order is fine return 1 events.sort(key = cmp_to_key(compare_events)) + event_count = len(events) # The result will be a list of intervals produced by # interval_constructor. @@ -389,15 +390,6 @@ def add_segment(start, end, original_intervals): result.append(interval) previous_end = end - def no_gap_between_points_in_time(end_time, start_time): - # Since points in time for intervals are quantized to whole - # seconds and intervals are closed (inclusive) for both start - # and end points, two adjacent intervals like - # [START_TIME1, 10:59:59] [11:00:00, END_TIME2] - # have no gap between them and can be considered a single - # continuous interval [START_TIME1, END_TIME2]. - return (end_time == start_time) or (end_time == start_time - 1) - def is_same_result(active_intervals1, active_intervals2): # When we have to decide whether to extend a result interval # or start a new one, we compare the state for the existing @@ -408,28 +400,34 @@ def is_same_result(active_intervals1, active_intervals2): == interval_constructor(0,0,active_intervals2)) active_intervals = [None] * track_count - def process_event_for_point_in_time(index, point_time): - nonlocal active_intervals + def process_events_for_point_in_time(index, point_time): high_time = point_time any_open = False - for i in range(index, len(events)): + for i in range(index, event_count): time, open_, interval, track = events[i] - if no_gap_between_points_in_time(point_time, time): - high_time = max(high_time, time) + # Since points in time for intervals are quantized to whole + # seconds and intervals are closed (inclusive) for both start + # and end points, two adjacent intervals like + # [START_TIME1, 10:59:59] [11:00:00, END_TIME2] + # have no gap between them and can be considered a single + # continuous interval [START_TIME1, END_TIME2]. + if (point_time == time) or (point_time == time - 1): + if time > high_time: + high_time = time any_open |= open_ else: return i, time, active_intervals.copy(), high_time if any_open else high_time + 1 active_intervals[track] = interval if open_ else None return None, None, None, None - if not len(events) == 0: + if not event_count == 0: # Step through event "clusters" with a common point in time and # emit result intervals with unchanged interval "payload". index, time = 0, events[0][0] interval_start_time = time - index, time, interval_start_state, high_time = process_event_for_point_in_time(index, time) + index, time, interval_start_state, high_time = process_events_for_point_in_time(index, time) while index: - new_index, new_time, maybe_end_state, high_time = process_event_for_point_in_time(index, time) + new_index, new_time, maybe_end_state, high_time = process_events_for_point_in_time(index, time) # Diagram for this program point: # |___potential_result_interval___|| | # index new_index From 405f3de7782d318f3dda17cc3b62b5cf57882503 Mon Sep 17 00:00:00 2001 From: Jan Moringen Date: Tue, 18 Mar 2025 14:21:26 +0100 Subject: [PATCH 13/43] refactor: avoid copy operations in find_rectangles --- execution_engine/task/process/rectangle_cython.pyx | 6 ++++-- execution_engine/task/process/rectangle_python.py | 6 ++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/execution_engine/task/process/rectangle_cython.pyx b/execution_engine/task/process/rectangle_cython.pyx index f1b02d5b..769e7731 100644 --- a/execution_engine/task/process/rectangle_cython.pyx +++ b/execution_engine/task/process/rectangle_cython.pyx @@ -422,7 +422,7 @@ def find_rectangles(all_intervals: list[list[AnyInterval]], high_time = time any_open |= open_ else: - return i, time, active_intervals.copy(), high_time if any_open else high_time + 1 + return i, time, active_intervals, high_time if any_open else high_time + 1 active_intervals[track] = interval if open_ else None return None, None, None, None @@ -432,6 +432,7 @@ def find_rectangles(all_intervals: list[list[AnyInterval]], index, time = 0, events[0][0] interval_start_time = time index, time, interval_start_state, high_time = process_events_for_point_in_time(index, time) + interval_start_state = interval_start_state.copy() if interval_start_state is not None else None while index: new_index, new_time, maybe_end_state, high_time = process_events_for_point_in_time(index, time) # Diagram for this program point: @@ -442,6 +443,7 @@ def find_rectangles(all_intervals: list[list[AnyInterval]], # high_time if (maybe_end_state is None) or (not is_same_result(interval_start_state, maybe_end_state)): add_segment(interval_start_time, time, interval_start_state) - interval_start_time, interval_start_state = high_time, maybe_end_state + interval_start_time = high_time + interval_start_state = maybe_end_state.copy() if maybe_end_state is not None else None index, time = new_index, new_time return result diff --git a/execution_engine/task/process/rectangle_python.py b/execution_engine/task/process/rectangle_python.py index b82f9039..8a2376cf 100644 --- a/execution_engine/task/process/rectangle_python.py +++ b/execution_engine/task/process/rectangle_python.py @@ -416,7 +416,7 @@ def process_events_for_point_in_time(index, point_time): high_time = time any_open |= open_ else: - return i, time, active_intervals.copy(), high_time if any_open else high_time + 1 + return i, time, active_intervals, high_time if any_open else high_time + 1 active_intervals[track] = interval if open_ else None return None, None, None, None @@ -426,6 +426,7 @@ def process_events_for_point_in_time(index, point_time): index, time = 0, events[0][0] interval_start_time = time index, time, interval_start_state, high_time = process_events_for_point_in_time(index, time) + interval_start_state = interval_start_state.copy() if interval_start_state is not None else None while index: new_index, new_time, maybe_end_state, high_time = process_events_for_point_in_time(index, time) # Diagram for this program point: @@ -436,6 +437,7 @@ def process_events_for_point_in_time(index, point_time): # high_time if (maybe_end_state is None) or (not is_same_result(interval_start_state, maybe_end_state)): add_segment(interval_start_time, time, interval_start_state) - interval_start_time, interval_start_state = high_time, maybe_end_state + interval_start_time = high_time + interval_start_state = maybe_end_state.copy() if maybe_end_state is not None else None index, time = new_index, new_time return result From b4e60f8e8685fbd922b8f5eabac84b15ebf38c8d Mon Sep 17 00:00:00 2001 From: Jan Moringen Date: Tue, 18 Mar 2025 15:01:35 +0100 Subject: [PATCH 14/43] feat: add is_same_result parameter to find_rectangles --- execution_engine/task/process/rectangle.py | 4 +-- .../task/process/rectangle_cython.pyx | 26 ++++++++++++------- .../task/process/rectangle_python.py | 26 ++++++++++++------- 3 files changed, 34 insertions(+), 22 deletions(-) diff --git a/execution_engine/task/process/rectangle.py b/execution_engine/task/process/rectangle.py index ffc1d655..380cff5a 100644 --- a/execution_engine/task/process/rectangle.py +++ b/execution_engine/task/process/rectangle.py @@ -775,7 +775,7 @@ def find_overlapping_personal_windows( return result -def find_rectangles(data: list[PersonIntervals], interval_constructor: Callable) -> PersonIntervals: +def find_rectangles(data: list[PersonIntervals], interval_constructor: Callable, is_same_result = None) -> PersonIntervals: # TODO(jmoringe): can this use _process_interval? if len(data) == 0: return {} @@ -784,5 +784,5 @@ def find_rectangles(data: list[PersonIntervals], interval_constructor: Callable) for track in data: keys |= track.keys() return {key: _impl.find_rectangles([ intervals.get(key, []) for intervals in data ], - interval_constructor) + interval_constructor, is_same_result=is_same_result) for key in keys} diff --git a/execution_engine/task/process/rectangle_cython.pyx b/execution_engine/task/process/rectangle_cython.pyx index 769e7731..1865dedc 100644 --- a/execution_engine/task/process/rectangle_cython.pyx +++ b/execution_engine/task/process/rectangle_cython.pyx @@ -329,8 +329,20 @@ def union_interval_lists( IntervalConstructor = Callable[[int, int, typing.List[AnyInterval]], AnyInterval] +def default_is_same_result(interval_constructor): + def is_same_result(active_intervals1, active_intervals2): + # When we have to decide whether to extend a result interval + # or start a new one, we compare the state for the existing + # result interval with the new state. The states are derived + # from the respective lists of active intervals by calling + # interval_constructor (with fake points in time) . + return (interval_constructor(0, 0, active_intervals1) + == interval_constructor(0, 0, active_intervals2)) + return is_same_result + def find_rectangles(all_intervals: list[list[AnyInterval]], - interval_constructor: IntervalConstructor) \ + interval_constructor: IntervalConstructor, + is_same_result = None) \ -> list[AnyInterval]: """For multiple parallel "tracks" of intervals, identify temporal intervals in which no change occurs on any "track". For each such @@ -356,6 +368,9 @@ def find_rectangles(all_intervals: list[list[AnyInterval]], that adjacent intervals (i.e. without gaps between them) have different "payloads". """ + if is_same_result is None: + is_same_result = default_is_same_result(interval_constructor) + # Convert all intervals into a single list of events sorted by # time. Multiple events at the same point in time can be problem # here: If an interval open event and an interval close event on @@ -396,15 +411,6 @@ def find_rectangles(all_intervals: list[list[AnyInterval]], result.append(interval) previous_end = end - def is_same_result(active_intervals1, active_intervals2): - # When we have to decide whether to extend a result interval - # or start a new one, we compare the state for the existing - # result interval with the new state. The states are derived - # from the respective lists of active intervals by calling - # interval_constructor (with fake points in time) . - return (interval_constructor(0,0,active_intervals1) - == interval_constructor(0,0,active_intervals2)) - active_intervals = [None] * track_count def process_events_for_point_in_time(index, point_time): high_time = point_time diff --git a/execution_engine/task/process/rectangle_python.py b/execution_engine/task/process/rectangle_python.py index 8a2376cf..05f85683 100644 --- a/execution_engine/task/process/rectangle_python.py +++ b/execution_engine/task/process/rectangle_python.py @@ -323,8 +323,20 @@ def union_interval_lists(left: list[Interval], right: list[Interval]) -> list[In IntervalConstructor = Callable[[int, int, typing.List[AnyInterval]], AnyInterval] +def default_is_same_result(interval_constructor): + def is_same_result(active_intervals1, active_intervals2): + # When we have to decide whether to extend a result interval + # or start a new one, we compare the state for the existing + # result interval with the new state. The states are derived + # from the respective lists of active intervals by calling + # interval_constructor (with fake points in time) . + return (interval_constructor(0, 0, active_intervals1) + == interval_constructor(0, 0, active_intervals2)) + return is_same_result + def find_rectangles(all_intervals: list[list[AnyInterval]], - interval_constructor: IntervalConstructor) \ + interval_constructor: IntervalConstructor, + is_same_result = None) \ -> list[AnyInterval]: """For multiple parallel "tracks" of intervals, identify temporal intervals in which no change occurs on any "track". For each such @@ -350,6 +362,9 @@ def find_rectangles(all_intervals: list[list[AnyInterval]], that adjacent intervals (i.e. without gaps between them) have different "payloads". """ + if is_same_result is None: + is_same_result = default_is_same_result(interval_constructor) + # Convert all intervals into a single list of events sorted by # time. Multiple events at the same point in time can be problem # here: If an interval open event and an interval close event on @@ -390,15 +405,6 @@ def add_segment(start, end, original_intervals): result.append(interval) previous_end = end - def is_same_result(active_intervals1, active_intervals2): - # When we have to decide whether to extend a result interval - # or start a new one, we compare the state for the existing - # result interval with the new state. The states are derived - # from the respective lists of active intervals by calling - # interval_constructor (with fake points in time) . - return (interval_constructor(0,0,active_intervals1) - == interval_constructor(0,0,active_intervals2)) - active_intervals = [None] * track_count def process_events_for_point_in_time(index, point_time): high_time = point_time From fcf98d09725544640e803bcccb4064670c0fa4f1 Mon Sep 17 00:00:00 2001 From: Jan Moringen Date: Tue, 18 Mar 2025 17:57:03 +0100 Subject: [PATCH 15/43] refactor: maybe faster result interval construction in find_rectangles --- .../task/process/rectangle_cython.pyx | 29 +++++++++---------- .../task/process/rectangle_python.py | 29 +++++++++---------- 2 files changed, 26 insertions(+), 32 deletions(-) diff --git a/execution_engine/task/process/rectangle_cython.pyx b/execution_engine/task/process/rectangle_cython.pyx index 1865dedc..71f34ed1 100644 --- a/execution_engine/task/process/rectangle_cython.pyx +++ b/execution_engine/task/process/rectangle_cython.pyx @@ -397,20 +397,6 @@ def find_rectangles(all_intervals: list[list[AnyInterval]], return 1 events.sort(key = cmp_to_key(compare_events)) event_count = len(events) - - # The result will be a list of intervals produced by - # interval_constructor. - result = [] - previous_end = None - def add_segment(start, end, original_intervals): - nonlocal previous_end - if previous_end == start and len(result) > 0: - result[-1] = result[-1]._replace(upper=previous_end - 1) - interval = interval_constructor(start, end, original_intervals) - if interval is not None: # interval type negative is implicit - result.append(interval) - previous_end = end - active_intervals = [None] * track_count def process_events_for_point_in_time(index, point_time): high_time = point_time @@ -431,7 +417,7 @@ def find_rectangles(all_intervals: list[list[AnyInterval]], return i, time, active_intervals, high_time if any_open else high_time + 1 active_intervals[track] = interval if open_ else None return None, None, None, None - + result_intervals = [] if not event_count == 0: # Step through event "clusters" with a common point in time and # emit result intervals with unchanged interval "payload". @@ -448,8 +434,19 @@ def find_rectangles(all_intervals: list[list[AnyInterval]], # interval_start_state maybe_end_state # high_time if (maybe_end_state is None) or (not is_same_result(interval_start_state, maybe_end_state)): - add_segment(interval_start_time, time, interval_start_state) + # Add info for one result interval. + if len(result_intervals) > 0: + previous_result = result_intervals[-1] + if previous_result[1] == interval_start_time: + result_intervals[-1] = (previous_result[0], previous_result[1] - 1, previous_result[2]) + result_intervals.append((interval_start_time, time, interval_start_state)) + # Update interval start info. interval_start_time = high_time interval_start_state = maybe_end_state.copy() if maybe_end_state is not None else None index, time = new_index, new_time + result = [] + for (start, end, intervals) in result_intervals: + interval = interval_constructor(start, end, intervals) + if interval is not None: + result.append(interval) return result diff --git a/execution_engine/task/process/rectangle_python.py b/execution_engine/task/process/rectangle_python.py index 05f85683..247aa1c3 100644 --- a/execution_engine/task/process/rectangle_python.py +++ b/execution_engine/task/process/rectangle_python.py @@ -391,20 +391,6 @@ def compare_events(event1, event2): return 1 events.sort(key = cmp_to_key(compare_events)) event_count = len(events) - - # The result will be a list of intervals produced by - # interval_constructor. - result = [] - previous_end = None - def add_segment(start, end, original_intervals): - nonlocal previous_end - if previous_end == start and len(result) > 0: - result[-1] = result[-1]._replace(upper=previous_end - 1) - interval = interval_constructor(start, end, original_intervals) - if interval is not None: # interval type negative is implicit - result.append(interval) - previous_end = end - active_intervals = [None] * track_count def process_events_for_point_in_time(index, point_time): high_time = point_time @@ -425,7 +411,7 @@ def process_events_for_point_in_time(index, point_time): return i, time, active_intervals, high_time if any_open else high_time + 1 active_intervals[track] = interval if open_ else None return None, None, None, None - + result_intervals = [] if not event_count == 0: # Step through event "clusters" with a common point in time and # emit result intervals with unchanged interval "payload". @@ -442,8 +428,19 @@ def process_events_for_point_in_time(index, point_time): # interval_start_state maybe_end_state # high_time if (maybe_end_state is None) or (not is_same_result(interval_start_state, maybe_end_state)): - add_segment(interval_start_time, time, interval_start_state) + # Add info for one result interval. + if len(result_intervals) > 0: + previous_result = result_intervals[-1] + if previous_result[1] == interval_start_time: + result_intervals[-1] = (previous_result[0], previous_result[1] - 1, previous_result[2]) + result_intervals.append((interval_start_time, time, interval_start_state)) + # Update interval start info. interval_start_time = high_time interval_start_state = maybe_end_state.copy() if maybe_end_state is not None else None index, time = new_index, new_time + result = [] + for (start, end, intervals) in result_intervals: + interval = interval_constructor(start, end, intervals) + if interval is not None: + result.append(interval) return result From 102cd544df0940384cb057f4859c7523070b9749 Mon Sep 17 00:00:00 2001 From: Jan Moringen Date: Tue, 18 Mar 2025 13:26:52 +0100 Subject: [PATCH 16/43] refactor: optimize implementation of TemporalCount --- execution_engine/task/task.py | 92 ++++++++++++++++++++--------------- 1 file changed, 53 insertions(+), 39 deletions(-) diff --git a/execution_engine/task/task.py b/execution_engine/task/task.py index 3339809f..5831d8a2 100644 --- a/execution_engine/task/task.py +++ b/execution_engine/task/task.py @@ -3,7 +3,7 @@ import json import logging from enum import Enum, auto -from typing import List +from typing import List, Any from sqlalchemy.exc import DBAPIError, IntegrityError, ProgrammingError, SQLAlchemyError @@ -551,49 +551,63 @@ def get_start_end_from_interval_type( timezone=get_config().timezone, ) - # Create a "temporary window interval" for each window - # interval. Associate with each temporary window interval - # all data intervals that overlap it. The association - # works by assigning a unique id to each temporary window + # Incrementally compute the interval type for each window # interval. - ids = dict() # window_interval -> unique id - infos = dict() # unique id -> list of overlapping data intervals - def temporary_window_interval(start: int, end: int, intervals: List[AnyInterval]): + window_types: dict[AnyInterval, Any] = dict() # window interval -> interval type + def update_window_type(window_interval, data_interval): + window_type = window_types.get(window_interval.lower, None) + if data_interval is None or data_interval.type is IntervalType.NEGATIVE: + if window_type is not IntervalType.POSITIVE: + window_type = IntervalType.NEGATIVE + elif data_interval.type is IntervalType.POSITIVE: + window_type = IntervalType.POSITIVE + elif data_interval.type is IntervalType.NOT_APPLICABLE: + if window_type is None: + window_type = IntervalType.NOT_APPLICABLE + else: + assert data_interval.type is IntervalType.NO_DATA + if window_type is None: + window_type = IntervalType.NO_DATA + window_types[window_interval.lower] = window_type + return window_type + # The boundaries of the result intervals are identical to + # those of the window intervals. In addition, update the + # result interval window types based on the data + # intervals. + def is_same_interval(left_intervals, right_intervals): + left_window_interval, left_data_interval = left_intervals + right_window_interval, right_data_interval = right_intervals + if right_window_interval is None: + if left_window_interval is None: + return True + else: + update_window_type(left_window_interval, left_data_interval) + return False + else: + update_window_type(right_window_interval, right_data_interval) + if left_window_interval is None: + return False + else: + if left_window_interval is right_window_interval: + return True + else: + update_window_type(left_window_interval, left_data_interval) + return False + # Create result intervals based on the computed interval + # types. + def result_interval(start: int, end: int, intervals: List[AnyInterval]): window_interval, data_interval = intervals - if window_interval is None or window_interval.type == IntervalType.NOT_APPLICABLE: + if window_interval is None or window_interval.type is IntervalType.NOT_APPLICABLE: return Interval(start, end, IntervalType.NOT_APPLICABLE) else: - window_id = ids.get(window_interval, len(ids)) - ids[window_interval] = window_id - info = infos.get(window_id, set()) - infos[window_id] = info - data_interval_type = data_interval.type if data_interval is not None else IntervalType.NEGATIVE - info.add(data_interval_type) - return IntervalWithCount(start, end, IntervalType.POSITIVE, window_id) + window_type = window_types.get(window_interval.lower, None) + if window_type is None: + window_type = update_window_type(window_interval, data_interval) + return Interval(start, end, window_type) person_indicator_windows = { key: indicator_windows for key in data_p.keys() } - result = process.find_rectangles([ person_indicator_windows, data_p], temporary_window_interval) - # Turn the temporary window intervals into the final - # intervals by computing the interval types based on the - # respective overlapping data intervals. - def finalize_interval(interval): - if isinstance(interval, IntervalWithCount): - window_id = interval.count - data_intervals = infos[window_id] - # TODO(jmoringe): there should be a way to implement this with max(data_intervals) - if IntervalType.POSITIVE in data_intervals: - interval_type = IntervalType.POSITIVE - elif IntervalType.NEGATIVE in data_intervals: - interval_type = IntervalType.NEGATIVE - elif IntervalType.NOT_APPLICABLE in data_intervals: - interval_type = IntervalType.NOT_APPLICABLE - else: - assert IntervalType.NO_DATA in data_intervals - interval_type = IntervalType.NO_DATA - return Interval(interval.lower, interval.upper, interval_type) - else: - return interval - result = { key: [ finalize_interval(i) for i in intervals ] - for key, intervals in result.items() } + result = process.find_rectangles([ person_indicator_windows, data_p], + result_interval, + is_same_result=is_same_interval) return result From c321690d3f92db0774f89fcf7c41f4a473e1cf36 Mon Sep 17 00:00:00 2001 From: Jan Moringen Date: Tue, 18 Mar 2025 17:29:41 +0100 Subject: [PATCH 17/43] refactor: do not copy in interval_like The _replace method already returns a new instance. --- execution_engine/task/process/__init__.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/execution_engine/task/process/__init__.py b/execution_engine/task/process/__init__.py index 9d997c41..4f56b5d9 100644 --- a/execution_engine/task/process/__init__.py +++ b/execution_engine/task/process/__init__.py @@ -1,4 +1,3 @@ -import copy import importlib import os import sys @@ -59,6 +58,4 @@ def interval_like(interval: TInterval, start: int, end: int) -> TInterval: I: A copy of the interval with updated lower and upper bounds. """ - return copy.copy(interval)._replace( - lower=start, upper=end - ) # type: ignore[return-value]m + return interval._replace(lower=start, upper=end) # type: ignore[return-value] From fb6b6350c0707dd6fc1571f92b8278d8b063272d Mon Sep 17 00:00:00 2001 From: Gregor Lichtner Date: Sat, 22 Mar 2025 13:01:01 +0100 Subject: [PATCH 18/43] tests: adapt test to new ee structure --- .../omop/criterion/combination/test_temporal_combination.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/execution_engine/omop/criterion/combination/test_temporal_combination.py b/tests/execution_engine/omop/criterion/combination/test_temporal_combination.py index 6f4acb26..72f8c660 100644 --- a/tests/execution_engine/omop/criterion/combination/test_temporal_combination.py +++ b/tests/execution_engine/omop/criterion/combination/test_temporal_combination.py @@ -2008,7 +2008,7 @@ def test_interval_ratio_on_database( db_session.add_all(vos) db_session.commit() - self.insert_criterion_combination( + self.insert_expression( db_session, population, intervention, base_criterion, observation_window ) From b400698335d93c7ef5dd0bbb259276f8e5f740ee Mon Sep 17 00:00:00 2001 From: Jan Moringen Date: Mon, 24 Mar 2025 12:41:01 +0100 Subject: [PATCH 19/43] refactor: improvements in task.py --- execution_engine/task/task.py | 23 +++++++++-------------- 1 file changed, 9 insertions(+), 14 deletions(-) diff --git a/execution_engine/task/task.py b/execution_engine/task/task.py index a512f754..40063de1 100644 --- a/execution_engine/task/task.py +++ b/execution_engine/task/task.py @@ -335,7 +335,7 @@ def interval_counts( ratio = positive_count / effective_count_min return IntervalWithCount(start, end, effective_type, ratio) - return process.find_rectangles(data, interval_counts) + result = process.find_rectangles(data, interval_counts) elif isinstance(self.expr, logic.AllOrNone): raise NotImplementedError("AllOrNone is not implemented yet.") @@ -408,18 +408,13 @@ def handle_left_dependent_toggle( # data[0] is the left dependency (i.e. P) # data[1] is the right dependency (i.e. I) - # observation_window_intervals extends the result to the - # correct temporal range; Its type is not important. - observation_window_intervals = { - key: [ - Interval( - observation_window.start.timestamp(), - observation_window.end.timestamp(), - IntervalType.POSITIVE, - ) - ] - for key in left.keys() - } + # window_intervals extends the result to the correct temporal + # range; Its type is not important. + windows = [ Interval(observation_window.start.timestamp(), + observation_window.end.timestamp(), + IntervalType.POSITIVE) ] + window_intervals = { key: windows for key in left.keys() } + if isinstance(self.expr, logic.LeftDependentToggle): fill_type = IntervalType.NOT_APPLICABLE else: @@ -441,7 +436,7 @@ def new_interval( return None return process.find_rectangles( - [left, right, observation_window_intervals], new_interval + [left, right, window_intervals], new_interval ) def handle_temporal_operator( From a6d8796ecf09bde57a066535b2a8ac5e101b305d Mon Sep 17 00:00:00 2001 From: Jan Moringen Date: Mon, 24 Mar 2025 12:41:33 +0100 Subject: [PATCH 20/43] refactor: simplify Task.insert_negative_intervals Add the negative intervals in a single find_rectangles call instead of computing negative intervals and then merging. --- execution_engine/task/task.py | 35 +++++++++++++++-------------------- 1 file changed, 15 insertions(+), 20 deletions(-) diff --git a/execution_engine/task/task.py b/execution_engine/task/task.py index 40063de1..0c097613 100644 --- a/execution_engine/task/task.py +++ b/execution_engine/task/task.py @@ -611,33 +611,28 @@ def insert_negative_intervals( :param observation_window: The observation window. :return: A DataFrame with the merged intervals. """ + # window_intervals extends the result to the correct temporal + # range and forces results to be computed for patients that + # are not represented in data; The interval types in + # window_intervals are not important. + windows = [ Interval(observation_window.start.timestamp(), + observation_window.end.timestamp(), + IntervalType.POSITIVE) ] + all_keys = data.keys() | base_data.keys() + window_intervals = { key: windows for key in all_keys } - data_negative = process.complementary_intervals( - data, - reference=base_data, - observation_window=observation_window, - interval_type=IntervalType.NEGATIVE, - ) - - def union_interval( - start: int, end: int, intervals: List[GeneralizedInterval] + def create_interval( + start: int, end: int, intervals: List[GeneralizedInterval] ) -> GeneralizedInterval: - with IntervalType.union_order(): - max_interval = max( - intervals, - key=lambda i: IntervalType.NEGATIVE if i is None else i.type, - ) - if max_interval is not None: - return interval_like(max_interval, start, end) + interval, window_interval = intervals + if interval is not None: + return interval_like(interval, start, end) else: # Explicit representation of negative intervals is # required here because the database views do not # understand the implicit representation. return Interval(start, end, IntervalType.NEGATIVE) - - result = process.find_rectangles([data, data_negative], union_interval) - - return result + return process.find_rectangles([data, window_intervals], create_interval) def store_result_in_db( self, From 297bbcb12dde557f5ac2667946af6e1536244d7c Mon Sep 17 00:00:00 2001 From: Jan Moringen Date: Mon, 24 Mar 2025 12:43:19 +0100 Subject: [PATCH 21/43] feat: compute "interval ratio" in logical count operators Also implement the operators via find_rectangles and remove from the process module the functions count_intervals and filter_count_intervals. --- execution_engine/task/process/rectangle.py | 104 ------------------ .../task/process/rectangle_python.py | 84 -------------- execution_engine/task/task.py | 50 ++++++++- 3 files changed, 44 insertions(+), 194 deletions(-) diff --git a/execution_engine/task/process/rectangle.py b/execution_engine/task/process/rectangle.py index 0063e27d..11f31992 100644 --- a/execution_engine/task/process/rectangle.py +++ b/execution_engine/task/process/rectangle.py @@ -428,110 +428,6 @@ def union_intervals(data: list[PersonIntervals]) -> PersonIntervals: return _process_intervals(data, _impl.union_interval_lists) -def interval_to_interval_with_count(interval: Interval) -> IntervalWithCount: - """ - Converts an Interval to an IntervalWithCount. - """ - return IntervalWithCount(interval.lower, interval.upper, interval.type, 1) - - -def intervals_to_intervals_with_count( - intervals: list[Interval], -) -> list[IntervalWithCount]: - """ - Converts a list of Intervals to a list of IntervalWithCount. - """ - return [interval_to_interval_with_count(interval) for interval in intervals] - - -def count_intervals(data: list[PersonIntervals]) -> PersonIntervalsWithCount: - """ - Counts the intervals per dict key in the list. - - :param data: A list of dict of intervals. - :return: A dict with the unioned intervals. - """ - if not len(data): - return dict() - - # assert dfs is a list of dataframes - assert isinstance(data, list) and all( - isinstance(arr, dict) for arr in data - ), "data must be a list of dicts" - - result = {} - - for arr in data: - if not len(arr): - # if the operation is union, an empty dataframe can be ignored - continue - - for group_keys, intervals in arr.items(): - intervals_with_count = intervals_to_intervals_with_count(intervals) - intervals_with_count = _impl.union_rects_with_count(intervals_with_count) - if group_keys not in result: - result[group_keys] = intervals_with_count - else: - result[group_keys] = _impl.union_rects_with_count( - result[group_keys] + intervals_with_count - ) - - return result - - -def filter_count_intervals( - data: PersonIntervalsWithCount, - min_count: int | None, - max_count: int | None, - keep_no_data: bool = True, - keep_not_applicable: bool = True, -) -> PersonIntervals: - """ - Filters the intervals per dict key in the list by count. - - :param data: A list of dict of intervals. - :param min_count: The minimum count of the intervals. - :param max_count: The maximum count of the intervals. - :param keep_no_data: Whether to keep NO_DATA intervals (irrespective of the count). - :param keep_not_applicable: Whether to keep NOT_APPLICABLE intervals (irrespective of the count). - :return: A dict with the unioned intervals. - """ - - result: PersonIntervals = {} - - interval_filter = [] - - if keep_no_data: - interval_filter.append(IntervalType.NO_DATA) - if keep_not_applicable: - interval_filter.append(IntervalType.NOT_APPLICABLE) - - if min_count is None and max_count is None: - raise ValueError("min_count and max_count cannot both be None") - elif min_count is not None and max_count is not None: - for person_id in data: - result[person_id] = [ - Interval(interval.lower, interval.upper, interval.type) - for interval in data[person_id] - if min_count <= interval.count <= max_count - or interval.type in interval_filter - ] - elif min_count is not None: - for person_id in data: - result[person_id] = [ - Interval(interval.lower, interval.upper, interval.type) - for interval in data[person_id] - if min_count <= interval.count or interval.type in interval_filter - ] - elif max_count is not None: - for person_id in data: - result[person_id] = [ - Interval(interval.lower, interval.upper, interval.type) - for interval in data[person_id] - if interval.count <= max_count or interval.type in interval_filter - ] - - return result def intersect_intervals(data: list[PersonIntervals]) -> PersonIntervals: diff --git a/execution_engine/task/process/rectangle_python.py b/execution_engine/task/process/rectangle_python.py index 247aa1c3..929701ef 100644 --- a/execution_engine/task/process/rectangle_python.py +++ b/execution_engine/task/process/rectangle_python.py @@ -105,90 +105,6 @@ def union_rects(intervals: list[Interval]) -> list[Interval]: return union -def union_rects_with_count( - intervals: list[IntervalWithCount], -) -> list[IntervalWithCount]: - """ - Unions the intervals while keeping track of the count of overlapping intervals of the same type. - """ - - if not len(intervals): - return [] - - with IntervalType.union_order(): - events = intervals_to_events(intervals) - - union = [] - - last_x_start = -np.inf # holds the x_min of the currently open output rectangle - last_x_end = events[0][ - 0 - ] # x variable of the last closed interval (we start with the first x, so we - # don't close the first rectangle at the first x) - previous_x_visited = -np.inf - open_y = SortedDict() - - def get_y_max() -> IntervalType | None: - max_key = None - for key in reversed(open_y): - if open_y[key] > 0: - max_key = key - break - return max_key - - for x, start_point, interval in events: - y, count_event = interval.type, interval.count - if start_point: - y_max = get_y_max() - - if x > previous_x_visited and y_max is None: - # no currently open rectangles - last_x_start = x # start new output rectangle - elif y >= y_max: - if x == last_x_end or x == last_x_start: - # we already closed a rectangle at this x, so we don't need to start a new one - open_y[y] = open_y.get(y, 0) + count_event - continue - - union.append( - IntervalWithCount( - lower=last_x_start, - upper=x - 1, - type=y_max, - count=open_y[y_max], - ) - ) - last_x_end = x - last_x_start = x - - open_y[y] = open_y.get(y, 0) + count_event - - else: - open_y[y] = max(open_y.get(y, 0) - count_event, 0) - - y_max = get_y_max() - - if (y_max is None or (open_y and y_max <= y)) and x > last_x_end: - if y_max is None or y_max < y: - # the closing rectangle has a higher y_max than the currently open ones - count = count_event - else: - # the closing rectangle has the same y_max as the currently open ones - count = open_y[y] + count_event - - union.append( - IntervalWithCount( - lower=last_x_start, upper=x - 1, type=y, count=count - ) - ) # close the previous rectangle at y_max - last_x_end = x - last_x_start = x # start new output rectangle - - previous_x_visited = x - - return merge_adjacent_intervals(union) - - def merge_adjacent_intervals( intervals: list[IntervalWithCount], ) -> list[IntervalWithCount]: diff --git a/execution_engine/task/task.py b/execution_engine/task/task.py index 0c097613..a3c33400 100644 --- a/execution_engine/task/task.py +++ b/execution_engine/task/task.py @@ -306,12 +306,50 @@ def handle_binary_logical_operator( elif isinstance(self.expr, (logic.Or, logic.NonSimplifiableOr)): result = process.union_intervals(data) elif isinstance(self.expr, logic.Count): - result = process.count_intervals(data) - result = process.filter_count_intervals( - result, - min_count=self.expr.count_min, - max_count=self.expr.count_max, - ) + # result = process.count_intervals(data) + # result = process.filter_count_intervals( + # result, + # min_count=self.expr.count_min, + # max_count=self.expr.count_max, + # ) + count_min = self.expr.count_min + count_max = self.expr.count_max + if count_min is None and count_max is None: + raise ValueError("count_min and count_max cannot both be None") + def interval_counts( + start: int, end: int, intervals: List[AnyInterval] + ) -> GeneralizedInterval: + positive_count, negative_count, not_applicable_count, no_data_count = 0, 0, 0, 0 + for interval in intervals: + if interval is None or interval.type is IntervalType.NEGATIVE: + negative_count += 1 + elif interval.type is IntervalType.POSITIVE: + positive_count += 1 + elif interval.type is IntervalType.NOT_APPLICABLE: + not_applicable_count += 1 + elif interval.type is IntervalType.NO_DATA: + no_data_count += 1 + # + if positive_count > 0: + if count_min is None: + interval_type = IntervalType.POSITIVE if ( + positive_count <= count_max) else IntervalType.NEGATIVE + return Interval(start, end, interval_type) + else: + min_good = count_min <= positive_count + max_good = (count_max is None) or (positive_count <= count_max) + interval_type = IntervalType.POSITIVE if (min_good and max_good) else IntervalType.NEGATIVE + ratio = positive_count / count_min + return IntervalWithCount(start, end, interval_type, ratio) + if no_data_count > 0: + return Interval(start, end, IntervalType.NO_DATA) + if not_applicable_count > 0: + return Interval(start, end, IntervalType.NOT_APPLICABLE) + if negative_count > 0: + return Interval(start, end, IntervalType.NEGATIVE) + + result = process.find_rectangles(data, interval_counts) + elif isinstance(self.expr, logic.CappedCount): def interval_counts( From 920801d43810c1a88c3255a7c24c8c210997ce8e Mon Sep 17 00:00:00 2001 From: Gregor Lichtner Date: Mon, 24 Mar 2025 21:02:20 +0100 Subject: [PATCH 22/43] fix: temporal operator handling --- execution_engine/task/process/rectangle.py | 30 ++++-- execution_engine/task/task.py | 115 ++++++++++++++------- 2 files changed, 97 insertions(+), 48 deletions(-) diff --git a/execution_engine/task/process/rectangle.py b/execution_engine/task/process/rectangle.py index 11f31992..7b07402a 100644 --- a/execution_engine/task/process/rectangle.py +++ b/execution_engine/task/process/rectangle.py @@ -428,8 +428,6 @@ def union_intervals(data: list[PersonIntervals]) -> PersonIntervals: return _process_intervals(data, _impl.union_interval_lists) - - def intersect_intervals(data: list[PersonIntervals]) -> PersonIntervals: """ Intersects the intervals per dict key in the list. @@ -519,7 +517,7 @@ def create_time_intervals( end_time: datetime.time, interval_type: IntervalType, timezone: pytz.tzinfo.DstTzInfo | str, -) -> list[Interval]: +) -> tuple[Interval, ...]: """ Constructs a list of time intervals within a specified date range, each defined by daily start and end times. @@ -574,7 +572,7 @@ def create_time_intervals( end_datetime = end_datetime.in_timezone(timezone) # Prepare to collect intervals - intervals = [] + intervals: list[Interval] = [] previous_end = None def add_interval( @@ -656,7 +654,8 @@ def add_interval( # Move to the next day current_date += datetime.timedelta(days=1) - return intervals + # use a tuple for windows to make sure it is immutable (and can be shared by all persons) + return tuple(intervals) def find_overlapping_personal_windows( @@ -698,6 +697,7 @@ def find_rectangles( data: list[PersonIntervals], interval_constructor: Callable, is_same_result: Callable | None = None, + reset: Callable | None = None, ) -> PersonIntervals: """ Iterates over intervals for each person across all items in `data` and constructs new intervals @@ -717,11 +717,21 @@ def find_rectangles( keys: Set[int] = set() for track in data: keys |= track.keys() - return { - key: _impl.find_rectangles( - [intervals.get(key, []) for intervals in data], + result = {} + + for key in keys: + + if reset: + reset() + + intervals_for_person: list[list[Interval]] = [ + intervals.get(key, []) for intervals in data + ] + intervals = _impl.find_rectangles( + intervals_for_person, interval_constructor, is_same_result=is_same_result, ) - for key in keys - } + result[key] = intervals + + return result diff --git a/execution_engine/task/task.py b/execution_engine/task/task.py index a3c33400..04445f08 100644 --- a/execution_engine/task/task.py +++ b/execution_engine/task/task.py @@ -2,6 +2,7 @@ import datetime import json import logging +from collections import Counter from enum import Enum, auto from typing import List @@ -101,19 +102,16 @@ def find_base_task(task: Task) -> Task: return find_base_task(self) - def select_predecessor_result( - self, expr: logic.BaseExpr, data: list[PersonIntervals] - ) -> PersonIntervals: + def get_predecessor_data_index(self, expr: logic.BaseExpr) -> int: """ - Select the result results of the predecessor task from the given expression. + Get the index of the predecessor data from the given expression. This is required in expressions where order is important, e.g. in BinaryNonCommutativeOperator. As the nx.DiGraph (and by inheritance, ExecutionGraph) does not store the order of the predecessors, we need to find the predecessor task by its expression and select the result from the data. :param expr: The expression of the predecessor task. - :param data: The input data. - :return: The result of the predecessor task. + :return: The index in of expr in the data of predecessor results """ if len(self.dependencies) == 0: raise ValueError("Task has no dependencies.") @@ -125,7 +123,23 @@ def select_predecessor_result( f"Task with expression '{str(expr)}' not found in dependencies." ) - return data[idx] + return idx + + def select_predecessor_result( + self, expr: logic.BaseExpr, data: list[PersonIntervals] + ) -> PersonIntervals: + """ + Select the result results of the predecessor task from the given expression. + + This is required in expressions where order is important, e.g. in BinaryNonCommutativeOperator. + As the nx.DiGraph (and by inheritance, ExecutionGraph) does not store the order of the predecessors, + we need to find the predecessor task by its expression and select the result from the data. + + :param expr: The expression of the predecessor task. + :param data: The input data. + :return: The result of the predecessor task. + """ + return data[self.get_predecessor_data_index(expr)] def run( self, @@ -275,6 +289,7 @@ def handle_unary_logical_operator( :return: A DataFrame with the inverted intervals. """ assert self.expr.is_Not, "Dependency is not a Not expression." + assert len(data) == 1, "Unary operators require only one input" result = process.invert_intervals( data[0], @@ -316,29 +331,37 @@ def handle_binary_logical_operator( count_max = self.expr.count_max if count_min is None and count_max is None: raise ValueError("count_min and count_max cannot both be None") + def interval_counts( - start: int, end: int, intervals: List[AnyInterval] + start: int, end: int, intervals: List[AnyInterval] ) -> GeneralizedInterval: - positive_count, negative_count, not_applicable_count, no_data_count = 0, 0, 0, 0 - for interval in intervals: - if interval is None or interval.type is IntervalType.NEGATIVE: - negative_count += 1 - elif interval.type is IntervalType.POSITIVE: - positive_count += 1 - elif interval.type is IntervalType.NOT_APPLICABLE: - not_applicable_count += 1 - elif interval.type is IntervalType.NO_DATA: - no_data_count += 1 - # + + counts = Counter( + (interval.type if interval else IntervalType.NEGATIVE) + for interval in intervals + ) + + positive_count = counts[IntervalType.POSITIVE] + negative_count = counts[IntervalType.NEGATIVE] + not_applicable_count = counts[IntervalType.NOT_APPLICABLE] + no_data_count = counts[IntervalType.NO_DATA] + if positive_count > 0: if count_min is None: - interval_type = IntervalType.POSITIVE if ( - positive_count <= count_max) else IntervalType.NEGATIVE + interval_type = ( + IntervalType.POSITIVE + if (positive_count <= count_max) # type:ignore [operator] + else IntervalType.NEGATIVE + ) return Interval(start, end, interval_type) else: min_good = count_min <= positive_count max_good = (count_max is None) or (positive_count <= count_max) - interval_type = IntervalType.POSITIVE if (min_good and max_good) else IntervalType.NEGATIVE + interval_type = ( + IntervalType.POSITIVE + if (min_good and max_good) + else IntervalType.NEGATIVE + ) ratio = positive_count / count_min return IntervalWithCount(start, end, interval_type, ratio) if no_data_count > 0: @@ -348,6 +371,8 @@ def interval_counts( if negative_count > 0: return Interval(start, end, IntervalType.NEGATIVE) + raise ValueError("No intervals of any kind found") + result = process.find_rectangles(data, interval_counts) elif isinstance(self.expr, logic.CappedCount): @@ -444,14 +469,17 @@ def handle_left_dependent_toggle( self.expr, (logic.LeftDependentToggle, logic.ConditionalFilter) ), "Dependency is not a LeftDependentToggle or ConditionalFilter expression." - # data[0] is the left dependency (i.e. P) - # data[1] is the right dependency (i.e. I) # window_intervals extends the result to the correct temporal # range; Its type is not important. - windows = [ Interval(observation_window.start.timestamp(), - observation_window.end.timestamp(), - IntervalType.POSITIVE) ] - window_intervals = { key: windows for key in left.keys() } + # use a tuple for windows to make sure it is immutable (and can be shared by all persons) + windows = ( + Interval( + observation_window.start.timestamp(), + observation_window.end.timestamp(), + IntervalType.POSITIVE, + ), + ) + window_intervals = {key: windows for key in left.keys()} if isinstance(self.expr, logic.LeftDependentToggle): fill_type = IntervalType.NOT_APPLICABLE @@ -473,9 +501,7 @@ def new_interval( else: # left_interval but not right_interval -> implicit negative return None - return process.find_rectangles( - [left, right, window_intervals], new_interval - ) + return process.find_rectangles([left, right, window_intervals], new_interval) def handle_temporal_operator( self, data: list[PersonIntervals], observation_window: TimeRange @@ -491,7 +517,7 @@ def handle_temporal_operator( :return: A DataFrame with the merged intervals. """ - data_p = data[0] + data_p = self.select_predecessor_result(self.expr.args[0], data) # data_p = process.select_type(data[0], IntervalType.POSITIVE) # data_p = {key: val for key, val in data_p.items() if val} @@ -517,7 +543,10 @@ def get_start_end_from_interval_type( assert ( len(data) >= 2 ), "TemporalCount with indicator criterion requires at least two inputs" - data, indicator_personal_windows = data[:-1], data[-1] + + indicator_personal_windows = data.pop( + self.get_predecessor_data_index(self.expr.interval_criterion) + ) result = process.find_overlapping_personal_windows( indicator_personal_windows, data_p @@ -559,6 +588,9 @@ def get_start_end_from_interval_type( dict() ) # window interval -> interval type + def reset_window_types() -> None: + window_types.clear() + def update_window_type( window_interval: AnyInterval, data_interval: AnyInterval ) -> IntervalType: @@ -628,6 +660,7 @@ def result_interval( [person_indicator_windows, data_p], result_interval, is_same_result=is_same_interval, + reset=reset_window_types, ) return result @@ -653,14 +686,19 @@ def insert_negative_intervals( # range and forces results to be computed for patients that # are not represented in data; The interval types in # window_intervals are not important. - windows = [ Interval(observation_window.start.timestamp(), - observation_window.end.timestamp(), - IntervalType.POSITIVE) ] + # use a tuple for windows to make sure it is immutable (and can be shared by all persons) + windows = ( + Interval( + observation_window.start.timestamp(), + observation_window.end.timestamp(), + IntervalType.POSITIVE, + ), + ) all_keys = data.keys() | base_data.keys() - window_intervals = { key: windows for key in all_keys } + window_intervals = {key: windows for key in all_keys} def create_interval( - start: int, end: int, intervals: List[GeneralizedInterval] + start: int, end: int, intervals: List[GeneralizedInterval] ) -> GeneralizedInterval: interval, window_interval = intervals if interval is not None: @@ -670,6 +708,7 @@ def create_interval( # required here because the database views do not # understand the implicit representation. return Interval(start, end, IntervalType.NEGATIVE) + return process.find_rectangles([data, window_intervals], create_interval) def store_result_in_db( From c4968ad42f9ab6d69915176f203587bba36eeb44 Mon Sep 17 00:00:00 2001 From: Gregor Lichtner Date: Mon, 24 Mar 2025 21:36:57 +0100 Subject: [PATCH 23/43] feat: add sum(count) in logic.Or handling --- execution_engine/task/task.py | 89 +++++++++++++++++++++++++++++++---- 1 file changed, 80 insertions(+), 9 deletions(-) diff --git a/execution_engine/task/task.py b/execution_engine/task/task.py index 04445f08..832c8ef6 100644 --- a/execution_engine/task/task.py +++ b/execution_engine/task/task.py @@ -4,7 +4,7 @@ import logging from collections import Counter from enum import Enum, auto -from typing import List +from typing import Callable, List, Type from sqlalchemy.exc import DBAPIError, IntegrityError, ProgrammingError, SQLAlchemyError @@ -28,6 +28,12 @@ process = get_processing_module() +COUNT_TYPES = ( + logic.MinCount, + logic.ExactCount, + logic.CappedMinCount, +) + def get_engine() -> OMOPSQLClient: """ @@ -141,6 +147,29 @@ def select_predecessor_result( """ return data[self.get_predecessor_data_index(expr)] + def receives_only_count_inputs(self) -> bool: + """ + Indicates whether this tasks only receives inputs from expression that perform counting and thus return + IntervalWithCount. + """ + # all arguments are count types + if all(isinstance(parent, COUNT_TYPES) for parent in self.expr.args): + return True + + # all arguments are logic.BinaryNonCommutativeOperator, and all of their "right" children are count types + if all( + isinstance(parent, logic.BinaryNonCommutativeOperator) + and isinstance(parent.right, logic.Expr) + for parent in self.expr.args + ) and all( + isinstance(grandparent, COUNT_TYPES) + for parent in self.expr.args + for grandparent in parent.right.args + ): + return True + + return False + def run( self, data: list[PersonIntervals], @@ -319,7 +348,35 @@ def handle_binary_logical_operator( if isinstance(self.expr, (logic.And, logic.NonSimplifiableAnd)): result = process.intersect_intervals(data) elif isinstance(self.expr, (logic.Or, logic.NonSimplifiableOr)): - result = process.union_intervals(data) + # result = process.union_intervals(data) + if self.receives_only_count_inputs(): + with IntervalType.custom_union_priority_order( + IntervalType.union_priority() + ): + with IntervalType.union_order(): + + def interval_union_with_count( + start: int, end: int, intervals: List[IntervalWithCount] + ) -> IntervalWithCount: + # todo: @moringenj can we improve the "interval is not None" checks here? + interval_type = max( + interval.type + for interval in intervals + if interval is not None + ) + count = sum( + interval.count + for interval in intervals + if interval is not None + and interval.type == interval_type + and interval.count is not None + ) + return IntervalWithCount(start, end, interval_type, count) + + return process.find_rectangles(data, interval_union_with_count) + else: + result = process.union_intervals(data) + elif isinstance(self.expr, logic.Count): # result = process.count_intervals(data) # result = process.filter_count_intervals( @@ -365,11 +422,11 @@ def interval_counts( ratio = positive_count / count_min return IntervalWithCount(start, end, interval_type, ratio) if no_data_count > 0: - return Interval(start, end, IntervalType.NO_DATA) + return IntervalWithCount(start, end, IntervalType.NO_DATA, 0) if not_applicable_count > 0: - return Interval(start, end, IntervalType.NOT_APPLICABLE) + return IntervalWithCount(start, end, IntervalType.NOT_APPLICABLE, 0) if negative_count > 0: - return Interval(start, end, IntervalType.NEGATIVE) + return IntervalWithCount(start, end, IntervalType.NEGATIVE, 0) raise ValueError("No intervals of any kind found") @@ -487,15 +544,28 @@ def handle_left_dependent_toggle( assert isinstance(self.expr, logic.ConditionalFilter) fill_type = IntervalType.NEGATIVE + interval_type: ( + Callable[[int, int, IntervalType], IntervalWithCount] | Type[Interval] + ) + + # if all incoming data of "right" are count types, we will create a count type as well, to + # allow summing in the next layer + if isinstance(self.expr.right, logic.Expr) and all( + isinstance(parent, COUNT_TYPES) for parent in self.expr.right.args + ): + interval_type = lambda start, end, fill_type: IntervalWithCount( + start, end, fill_type, None + ) + else: + interval_type = Interval + def new_interval( start: int, end: int, intervals: List[GeneralizedInterval] ) -> GeneralizedInterval: left_interval, right_interval, observation_window_ = intervals - if ( - left_interval is None - ) or not left_interval.type == IntervalType.POSITIVE: + if (left_interval is None) or left_interval.type != IntervalType.POSITIVE: # no left_interval or not positive -> use fill type - return Interval(start, end, fill_type) + return interval_type(start, end, fill_type) elif right_interval is not None: return interval_like(right_interval, start, end) else: # left_interval but not right_interval -> implicit negative @@ -588,6 +658,7 @@ def get_start_end_from_interval_type( dict() ) # window interval -> interval type + # todo: @moringenj - is this additional function really a good solution? def reset_window_types() -> None: window_types.clear() From e9ce71e3609035cbc23b0a3b2bdd2c4541766725 Mon Sep 17 00:00:00 2001 From: Gregor Lichtner Date: Mon, 24 Mar 2025 21:37:22 +0100 Subject: [PATCH 24/43] feat: IntervalType.union_order changed (NEGATIVE > N/A) --- .../util/interval/typed_interval.py | 7 ++++++- tests/execution_engine/util/test_interval.py | 17 +++++++++-------- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/execution_engine/util/interval/typed_interval.py b/execution_engine/util/interval/typed_interval.py index aa35696f..d9d89cb0 100644 --- a/execution_engine/util/interval/typed_interval.py +++ b/execution_engine/util/interval/typed_interval.py @@ -82,9 +82,14 @@ class IntervalType(StrEnum): criterion/combination of criteria is/are not satisfied. """ - __union_priority_order: list[str] = [POSITIVE, NO_DATA, NOT_APPLICABLE, NEGATIVE] + # UNION PRIORITY ORDER + # from Mar 24, 2025: changed order of NEGATIVE and NOT_APPLICABLE, because when OR combining two population/inter- + # vention expressions, and one has NOT_APPLICABLE whereas the other has NEGATIVE, the overall result should be + # NEGATIVE, not NOT_APPLICABLE. + __union_priority_order: list[str] = [POSITIVE, NO_DATA, NEGATIVE, NOT_APPLICABLE] """Union priority order starting with the highest priority.""" + # INTERSECTION PRIORITY ORDER # until Jan 13, 2024: # POSITIVE has higher priority than NO_DATA, as in measurements we return NO_DATA intervals for all intervals # inbetween measurements (and outside), and these are &-ed with the POSITIVE intervals for e.g. conditions. diff --git a/tests/execution_engine/util/test_interval.py b/tests/execution_engine/util/test_interval.py index 4a80ef15..945b2ab9 100644 --- a/tests/execution_engine/util/test_interval.py +++ b/tests/execution_engine/util/test_interval.py @@ -198,14 +198,17 @@ def test_union_interval_same_type(self, type_): [ (T.POSITIVE, T.NEGATIVE, T.POSITIVE), (T.NO_DATA, T.NEGATIVE, T.NO_DATA), - (T.NOT_APPLICABLE, T.NEGATIVE, T.NOT_APPLICABLE), + (T.NEGATIVE, T.NOT_APPLICABLE, T.NEGATIVE), (T.POSITIVE, T.NO_DATA, T.POSITIVE), (T.POSITIVE, T.NOT_APPLICABLE, T.POSITIVE), (T.NO_DATA, T.NOT_APPLICABLE, T.NO_DATA), ], ) def test_union_interval_different_type(self, type1, type2, overlapping_type): - # type1 is the higher priority type + + with IntervalType.union_order(): + # type1 is the higher priority type + assert type1 > type2 assert interval_int(1, 2, type1) | interval_int(2, 3, type2) == IntInterval( interval_int(1, 2, type1), @@ -472,8 +475,8 @@ def test_union_priority(self): assert IntervalType.union_priority() == [ IntervalType.POSITIVE, IntervalType.NO_DATA, - IntervalType.NOT_APPLICABLE, IntervalType.NEGATIVE, + IntervalType.NOT_APPLICABLE, ] def test_intersection_priority(self): @@ -505,8 +508,8 @@ def test_custom_union_priority_order(self): assert IntervalType.union_priority() == [ IntervalType.POSITIVE, IntervalType.NO_DATA, - IntervalType.NOT_APPLICABLE, IntervalType.NEGATIVE, + IntervalType.NOT_APPLICABLE, ] def test_custom_intersection_priority_order(self): @@ -596,12 +599,10 @@ def test_or(self): # NOT_APPLICABLE | others assert ( - IntervalType.NOT_APPLICABLE | IntervalType.NEGATIVE - == IntervalType.NOT_APPLICABLE + IntervalType.NOT_APPLICABLE | IntervalType.NEGATIVE == IntervalType.NEGATIVE ) assert ( - IntervalType.NEGATIVE | IntervalType.NOT_APPLICABLE - == IntervalType.NOT_APPLICABLE + IntervalType.NEGATIVE | IntervalType.NOT_APPLICABLE == IntervalType.NEGATIVE ) # same types From a9d17b8e1cefe2c24527c5b8b5c511a08890d41d Mon Sep 17 00:00:00 2001 From: Gregor Lichtner Date: Tue, 25 Mar 2025 14:45:10 +0100 Subject: [PATCH 25/43] fix: threshold reduction in CappedMinCount --- execution_engine/task/task.py | 4 +++- execution_engine/util/logic.py | 9 ++++++++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/execution_engine/task/task.py b/execution_engine/task/task.py index 832c8ef6..6052fd37 100644 --- a/execution_engine/task/task.py +++ b/execution_engine/task/task.py @@ -447,7 +447,9 @@ def interval_counts( elif interval.type == IntervalType.NOT_APPLICABLE: not_applicable_count += 1 # we require at least one positive interval to be present in any case (hence the max(1, ...)) - effective_count_min = max(1, self.expr.count_min - not_applicable_count) # type: ignore[attr-defined] + effective_count_min = min( + self.expr.count_min, max(1, len(intervals) - not_applicable_count) # type: ignore[attr-defined] + ) if positive_count >= effective_count_min: effective_type = IntervalType.POSITIVE else: diff --git a/execution_engine/util/logic.py b/execution_engine/util/logic.py index 0d11f587..84c82696 100644 --- a/execution_engine/util/logic.py +++ b/execution_engine/util/logic.py @@ -1011,6 +1011,14 @@ class LeftDependentToggle(BinaryNonCommutativeOperator): """ A LeftDependentToggle object represents a logical AND operation if the left operand is positive, otherwise it returns NOT_APPLICABLE. + + | left | right | Result | + |----------|----------|----------| + | NEGATIVE | * | NOT_APPLICABLE | + | NO_DATA | * | NOT_APPLICABLE | + | POSITIVE | POSITIVE | POSITIVE | + | POSITIVE | NEGATIVE | NEGATIVE | + | POSITIVE | NO_DATA | NO_DATA | """ @@ -1019,7 +1027,6 @@ class ConditionalFilter(BinaryNonCommutativeOperator): A ConditionalFilter object returns the right operand if the left operand is POSITIVE, and NEGATIVE otherwise - A conditional filter returns `right` iff `left` is POSITIVE, otherwise NEGATIVE. | left | right | Result | From 89efdbde684f90575603c5dfb5839e57ac88f607 Mon Sep 17 00:00:00 2001 From: Jan Moringen Date: Tue, 25 Mar 2025 11:08:26 +0100 Subject: [PATCH 26/43] refactor: fix window_types problem without reset callback And add test. --- execution_engine/task/process/rectangle.py | 26 +--- execution_engine/task/task.py | 28 ++-- .../combination/test_temporal_combination.py | 143 ++++++++++++++++++ 3 files changed, 164 insertions(+), 33 deletions(-) diff --git a/execution_engine/task/process/rectangle.py b/execution_engine/task/process/rectangle.py index 7b07402a..09036bda 100644 --- a/execution_engine/task/process/rectangle.py +++ b/execution_engine/task/process/rectangle.py @@ -517,7 +517,7 @@ def create_time_intervals( end_time: datetime.time, interval_type: IntervalType, timezone: pytz.tzinfo.DstTzInfo | str, -) -> tuple[Interval, ...]: +) -> list[Interval]: """ Constructs a list of time intervals within a specified date range, each defined by daily start and end times. @@ -654,8 +654,7 @@ def add_interval( # Move to the next day current_date += datetime.timedelta(days=1) - # use a tuple for windows to make sure it is immutable (and can be shared by all persons) - return tuple(intervals) + return intervals def find_overlapping_personal_windows( @@ -697,7 +696,6 @@ def find_rectangles( data: list[PersonIntervals], interval_constructor: Callable, is_same_result: Callable | None = None, - reset: Callable | None = None, ) -> PersonIntervals: """ Iterates over intervals for each person across all items in `data` and constructs new intervals @@ -717,21 +715,11 @@ def find_rectangles( keys: Set[int] = set() for track in data: keys |= track.keys() - result = {} - - for key in keys: - - if reset: - reset() - - intervals_for_person: list[list[Interval]] = [ - intervals.get(key, []) for intervals in data - ] - intervals = _impl.find_rectangles( - intervals_for_person, + return { + key: _impl.find_rectangles( + [intervals.get(key, []) for intervals in data], interval_constructor, is_same_result=is_same_result, ) - result[key] = intervals - - return result + for key in keys + } diff --git a/execution_engine/task/task.py b/execution_engine/task/task.py index 6052fd37..b168b3ce 100644 --- a/execution_engine/task/task.py +++ b/execution_engine/task/task.py @@ -1,4 +1,5 @@ import base64 +import copy import datetime import json import logging @@ -655,19 +656,13 @@ def get_start_end_from_interval_type( ) # Incrementally compute the interval type for each window - # interval. - window_types: dict[AnyInterval, IntervalType] = ( - dict() - ) # window interval -> interval type - - # todo: @moringenj - is this additional function really a good solution? - def reset_window_types() -> None: - window_types.clear() + # interval. Maps id of window interval -> interval type + window_types: dict[int, IntervalType] = dict() def update_window_type( window_interval: AnyInterval, data_interval: AnyInterval ) -> IntervalType: - window_type = window_types.get(window_interval.lower, None) + window_type = window_types.get(id(window_interval), None) if data_interval is None or data_interval.type is IntervalType.NEGATIVE: if window_type is not IntervalType.POSITIVE: @@ -681,7 +676,7 @@ def update_window_type( assert data_interval.type is IntervalType.NO_DATA if window_type is None: window_type = IntervalType.NO_DATA - window_types[window_interval.lower] = window_type + window_types[id(window_interval)] = window_type return window_type @@ -690,7 +685,7 @@ def update_window_type( # result interval window types based on the data # intervals. def is_same_interval( - left_intervals: tuple[AnyInterval], right_intervals: tuple[AnyInterval] + left_intervals: List[AnyInterval], right_intervals: List[AnyInterval] ) -> bool: left_window_interval, left_data_interval = left_intervals right_window_interval, right_data_interval = right_intervals @@ -723,17 +718,22 @@ def result_interval( ): return Interval(start, end, IntervalType.NOT_APPLICABLE) else: - window_type = window_types.get(window_interval.lower, None) + window_type = window_types.get(id(window_interval), None) if window_type is None: window_type = update_window_type(window_interval, data_interval) return Interval(start, end, window_type) - person_indicator_windows = {key: indicator_windows for key in data_p.keys()} + # Make separate copies of the intervals for each person so + # that the object identity of each interval is unique and + # can be used as a dictionary key. + person_indicator_windows = { + key: [ copy.copy(window) for window in indicator_windows ] + for key in data_p.keys() + } result = process.find_rectangles( [person_indicator_windows, data_p], result_interval, is_same_result=is_same_interval, - reset=reset_window_types, ) return result diff --git a/tests/execution_engine/omop/criterion/combination/test_temporal_combination.py b/tests/execution_engine/omop/criterion/combination/test_temporal_combination.py index 72f8c660..527201ff 100644 --- a/tests/execution_engine/omop/criterion/combination/test_temporal_combination.py +++ b/tests/execution_engine/omop/criterion/combination/test_temporal_combination.py @@ -2038,3 +2038,146 @@ def test_interval_ratio_on_database( result_tuples, expected[person.person_id] ): assert result_tuple == expected_tuple + +class TestIndicatorWindowsMulitplePatients(TestCriterionCombinationDatabase): + """ + This test ensures that the data TemporalCount operator works + independently between persons within a PersonIntervals data set. + + This is mostly a regression test since at one point the exact + problem of cross-talk between the data structures for different + persons caused the operator to return incorrect results. + """ + + @pytest.fixture + def observation_window(self) -> TimeRange: + return TimeRange( + name="observation", start="2025-02-18 14:55:00+01:00", end="2025-02-22 12:00:00+01:00" + ) + + def patient_events(self, db_session, visit_occurrence): + person_id = visit_occurrence.person_id + events = [] + c1 = create_condition( + vo=visit_occurrence, + condition_concept_id=concept_covid19.concept_id, + condition_start_datetime=pendulum.parse("2025-02-19 08:00:00+01:00"), + condition_end_datetime=pendulum.parse("2025-02-21 02:00:00+01:00"), + ) + events.append(c1) + if person_id == 1: + e1 = create_procedure( + vo=visit_occurrence, + procedure_concept_id=concept_delir_screening.concept_id, + start_datetime=pendulum.parse("2025-02-19 18:00:00+01:00"), + end_datetime=pendulum.parse("2025-02-19 18:01:00+01:00"), + ) + events.append(e1) + db_session.add_all(events) + db_session.commit() + + @pytest.mark.parametrize( + "population,intervention,expected", + [ + ( + logic.And(c2), # population + temporal_logic_util.Day(criterion=delir_screening), + { + 1: [ + ( + IntervalType.NOT_APPLICABLE, + pendulum.parse("2025-02-18 17:55:00+01:00"), + pendulum.parse("2025-02-19 07:59:59+01:00"), + ), + ( + IntervalType.POSITIVE, + pendulum.parse("2025-02-19 08:00:00+01:00"), + pendulum.parse("2025-02-19 23:59:59+01:00"), + ), + ( + IntervalType.NEGATIVE, + pendulum.parse("2025-02-20 00:00:00+01:00"), + pendulum.parse("2025-02-21 02:00:00+01:00"), + ), + ( + IntervalType.NOT_APPLICABLE, + pendulum.parse("2025-02-21 02:00:01+01:00"), + pendulum.parse("2025-02-22 05:30:00+01:00"), + ), + ], + 2: [ + ( + IntervalType.NOT_APPLICABLE, + pendulum.parse("2025-02-18 17:55:00+01:00"), + pendulum.parse("2025-02-19 07:59:59+01:00"), + ), + # If cross-talk between the data structures + # for different persons occurs, parts of the + # following interval may turn positive because + # of the results for the first person. + ( + IntervalType.NEGATIVE, + pendulum.parse("2025-02-19 08:00:00+01:00"), + pendulum.parse("2025-02-21 02:00:00+01:00"), + ), + ( + IntervalType.NOT_APPLICABLE, + pendulum.parse("2025-02-21 02:00:01+01:00"), + pendulum.parse("2025-02-22 05:30:00+01:00"), + ), + ], + }, + ), + ], + ) + def test_multiple_patients_on_database( + self, + person, + db_session, + population, + intervention, + base_criterion, + expected, + observation_window, + criteria, + ): + persons = person[:2] + vos = [] + for person in persons: + visit = create_visit( + person_id=person.person_id, + visit_start_datetime=observation_window.start + + datetime.timedelta(hours=3), + visit_end_datetime=observation_window.end + - datetime.timedelta(hours=6.5), + visit_concept_id=concepts.INTENSIVE_CARE, + ) + vos.append(visit) + self.patient_events(db_session, visit) + + db_session.add_all(vos) + db_session.commit() + + self.insert_expression( + db_session, population, intervention, base_criterion, observation_window + ) + + df = self.fetch_interval_result( + db_session, + pi_pair_id=self.pi_pair_id, + criterion_id=None, + category=CohortCategory.POPULATION_INTERVENTION, + ) + + for person in persons: + result = df.query(f"person_id=={person.person_id}") + result_tuples = list( + result[ [ "interval_type", "interval_start", "interval_end" ] ] + .fillna("nan") + .itertuples(index=False, name=None) + ) + + for result_tuple, expected_tuple in zip( + result_tuples, expected[person.person_id] + ): + assert result_tuple == expected_tuple From 08eeecf4372d9342d35fa86525c2c2ed9f7dfcc0 Mon Sep 17 00:00:00 2001 From: Jan Moringen Date: Mon, 24 Mar 2025 18:11:42 +0100 Subject: [PATCH 27/43] refactor: remove obsolete tests The tested functions were removed in commit 297bbcb12dde557f5ac2667946af6e1536244d7c. --- .../task/process/test_rectangle.py | 1515 ----------------- 1 file changed, 1515 deletions(-) diff --git a/tests/execution_engine/task/process/test_rectangle.py b/tests/execution_engine/task/process/test_rectangle.py index 6e8526d4..cd94c7b2 100644 --- a/tests/execution_engine/task/process/test_rectangle.py +++ b/tests/execution_engine/task/process/test_rectangle.py @@ -1262,406 +1262,6 @@ def test_union_rect_timezones(self): ), "Failed: Mixed intervals not handled correctly" -class TestUnionRectWithCount(ProcessTest): - def test_union_rect_with_count_negative_duration(self): - intervals = [ - IntervalWithCount(lower=5, upper=3, type=T.POSITIVE, count=1), - ] - expected_intervals = [] - - result = self.process._impl.union_rects_with_count(intervals) - - assert result == expected_intervals - - def test_union_special_cases(self): - intervals = [ - IntervalWithCount(lower=180, upper=190, type=T.POSITIVE, count=1), - IntervalWithCount(lower=180, upper=180, type=T.POSITIVE, count=2), - IntervalWithCount(lower=190, upper=200, type=T.POSITIVE, count=1), - IntervalWithCount(lower=180, upper=189, type=T.POSITIVE, count=1), - IntervalWithCount(lower=190, upper=190, type=T.POSITIVE, count=2), - IntervalWithCount(lower=191, upper=200, type=T.POSITIVE, count=1), - ] - - expected_intervals = [ - IntervalWithCount(lower=180, upper=180, type=T.POSITIVE, count=4), - IntervalWithCount(lower=181, upper=189, type=T.POSITIVE, count=2), - IntervalWithCount(lower=190, upper=190, type=T.POSITIVE, count=4), - IntervalWithCount(lower=191, upper=200, type=T.POSITIVE, count=2), - ] - - result = self.process._impl.union_rects_with_count(intervals) - assert result == expected_intervals - - intervals = [ - IntervalWithCount(lower=1, upper=2, type=T.POSITIVE, count=1), - IntervalWithCount(lower=4, upper=5, type=T.POSITIVE, count=1), - IntervalWithCount(lower=4, upper=5, type=T.POSITIVE, count=1), - ] - - expected_intervals = [ - IntervalWithCount(lower=1, upper=2, type=T.POSITIVE, count=1), - IntervalWithCount(lower=4, upper=5, type=T.POSITIVE, count=2), - ] - - result = self.process._impl.union_rects_with_count(intervals) - - assert result == expected_intervals - - @pytest.mark.parametrize("factor", [1, 2, 3]) - def test_union_rect_with_count_adjacent(self, factor): - intervals = [ - IntervalWithCount(lower=1, upper=2, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=3, upper=4, type=T.POSITIVE, count=1 * factor), - ] - - expected_intervals = [ - IntervalWithCount(lower=1, upper=4, type=T.POSITIVE, count=1 * factor), - ] - - result = self.process._impl.union_rects_with_count(intervals) - - assert result == expected_intervals - - @pytest.mark.parametrize("factor", [1, 2, 3]) - def test_union_rect_with_count_one(self, factor): - intervals = [ - IntervalWithCount(lower=1, upper=2, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=3, upper=4, type=T.NEGATIVE, count=1 * factor), - ] - - expected_intervals = [ - IntervalWithCount(lower=1, upper=2, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=3, upper=4, type=T.NEGATIVE, count=1 * factor), - ] - - result = self.process._impl.union_rects_with_count(intervals) - - assert result == expected_intervals - - intervals = [ - IntervalWithCount(lower=1, upper=3, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=3, upper=4, type=T.NEGATIVE, count=1 * factor), - ] - - expected_intervals = [ - IntervalWithCount(lower=1, upper=3, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=4, upper=4, type=T.NEGATIVE, count=1 * factor), - ] - - result = self.process._impl.union_rects_with_count(intervals) - - assert result == expected_intervals - - intervals = [ - IntervalWithCount(lower=1, upper=4, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=3, upper=4, type=T.NEGATIVE, count=1 * factor), - ] - - expected_intervals = [ - IntervalWithCount(lower=1, upper=4, type=T.POSITIVE, count=1 * factor), - ] - - result = self.process._impl.union_rects_with_count(intervals) - - assert result == expected_intervals - - intervals = [ - IntervalWithCount(lower=1, upper=4, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=3, upper=4, type=T.POSITIVE, count=1 * factor), - ] - - expected_intervals = [ - IntervalWithCount(lower=1, upper=2, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=3, upper=4, type=T.POSITIVE, count=2 * factor), - ] - - result = self.process._impl.union_rects_with_count(intervals) - - assert result == expected_intervals - - intervals = [ - IntervalWithCount(lower=1, upper=4, type=T.NEGATIVE, count=1 * factor), - IntervalWithCount(lower=3, upper=4, type=T.POSITIVE, count=1 * factor), - ] - - expected_intervals = [ - IntervalWithCount(lower=1, upper=2, type=T.NEGATIVE, count=1 * factor), - IntervalWithCount(lower=3, upper=4, type=T.POSITIVE, count=1 * factor), - ] - - result = self.process._impl.union_rects_with_count(intervals) - - assert result == expected_intervals - - intervals = [ - IntervalWithCount(lower=1, upper=2, type=T.NEGATIVE, count=1 * factor), - IntervalWithCount(lower=3, upper=4, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=5, upper=6, type=T.NEGATIVE, count=1 * factor), - ] - - expected_intervals = [ - IntervalWithCount(lower=1, upper=2, type=T.NEGATIVE, count=1 * factor), - IntervalWithCount(lower=3, upper=4, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=5, upper=6, type=T.NEGATIVE, count=1 * factor), - ] - - result = self.process._impl.union_rects_with_count(intervals) - - assert result == expected_intervals - - intervals = [ - IntervalWithCount(lower=1, upper=2, type=T.NEGATIVE, count=1 * factor), - IntervalWithCount(lower=3, upper=4, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=1, upper=6, type=T.NEGATIVE, count=1 * factor), - ] - - expected_intervals = [ - IntervalWithCount(lower=1, upper=2, type=T.NEGATIVE, count=2 * factor), - IntervalWithCount(lower=3, upper=4, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=5, upper=6, type=T.NEGATIVE, count=1 * factor), - ] - - result = self.process._impl.union_rects_with_count(intervals) - - assert result == expected_intervals - - intervals = [ - IntervalWithCount(lower=1, upper=2, type=T.NEGATIVE, count=1 * factor), - IntervalWithCount(lower=1, upper=2, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=1, upper=2, type=T.NEGATIVE, count=1 * factor), - ] - - expected_intervals = [ - IntervalWithCount(lower=1, upper=2, type=T.POSITIVE, count=1 * factor), - ] - - result = self.process._impl.union_rects_with_count(intervals) - - assert result == expected_intervals - - intervals = [ - IntervalWithCount(lower=1, upper=2, type=T.NEGATIVE, count=1 * factor), - IntervalWithCount(lower=1, upper=2, type=T.POSITIVE, count=1 * factor), - ] - - expected_intervals = [ - IntervalWithCount(lower=1, upper=2, type=T.POSITIVE, count=1 * factor), - ] - - result = self.process._impl.union_rects_with_count(intervals) - - assert result == expected_intervals - - intervals = [ - IntervalWithCount(lower=1, upper=2, type=T.NEGATIVE, count=1 * factor), - IntervalWithCount(lower=1, upper=2, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=1, upper=2, type=T.NEGATIVE, count=1 * factor), - IntervalWithCount(lower=1, upper=2, type=T.NEGATIVE, count=1 * factor), - ] - - expected_intervals = [ - IntervalWithCount(lower=1, upper=2, type=T.POSITIVE, count=1 * factor), - ] - - result = self.process._impl.union_rects_with_count(intervals) - - assert result == expected_intervals - - intervals = [ - IntervalWithCount(lower=1, upper=2, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=1, upper=2, type=T.NEGATIVE, count=1 * factor), - IntervalWithCount(lower=1, upper=2, type=T.NEGATIVE, count=1 * factor), - IntervalWithCount(lower=1, upper=2, type=T.NEGATIVE, count=1 * factor), - ] - - expected_intervals = [ - IntervalWithCount(lower=1, upper=2, type=T.POSITIVE, count=1 * factor), - ] - - result = self.process._impl.union_rects_with_count(intervals) - - assert result == expected_intervals - - intervals = [ - IntervalWithCount(lower=1, upper=2, type=T.NEGATIVE, count=1 * factor), - IntervalWithCount(lower=1, upper=2, type=T.NEGATIVE, count=1 * factor), - IntervalWithCount(lower=1, upper=2, type=T.NEGATIVE, count=1 * factor), - IntervalWithCount(lower=1, upper=2, type=T.POSITIVE, count=1 * factor), - ] - - expected_intervals = [ - IntervalWithCount(lower=1, upper=2, type=T.POSITIVE, count=1 * factor), - ] - - result = self.process._impl.union_rects_with_count(intervals) - - assert result == expected_intervals - - # multiple intervals with the same start and end - intervals = [ - IntervalWithCount(lower=1, upper=2, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=1, upper=2, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=1, upper=2, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=1, upper=2, type=T.POSITIVE, count=1 * factor), - ] - - expected_intervals = [ - IntervalWithCount(lower=1, upper=2, type=T.POSITIVE, count=4 * factor), - ] - - result = self.process._impl.union_rects_with_count(intervals) - - assert result == expected_intervals - - # multiple intervals with the same start, different end - intervals = [ - IntervalWithCount(lower=1, upper=2, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=1, upper=3, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=1, upper=4, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=1, upper=5, type=T.POSITIVE, count=1 * factor), - ] - - expected_intervals = [ - IntervalWithCount(lower=1, upper=2, type=T.POSITIVE, count=4 * factor), - IntervalWithCount(lower=3, upper=3, type=T.POSITIVE, count=3 * factor), - IntervalWithCount(lower=4, upper=4, type=T.POSITIVE, count=2 * factor), - IntervalWithCount(lower=5, upper=5, type=T.POSITIVE, count=1 * factor), - ] - - result = self.process._impl.union_rects_with_count(intervals) - - assert result == expected_intervals - - # multiple intervals with the same end, different start, with other types inbetween - intervals = [ - IntervalWithCount(lower=1, upper=2, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=2, upper=8, type=T.NEGATIVE, count=1 * factor), - IntervalWithCount(lower=1, upper=3, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=1, upper=4, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=2, upper=4, type=T.NEGATIVE, count=1 * factor), - IntervalWithCount(lower=1, upper=5, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=2, upper=4, type=T.NEGATIVE, count=1 * factor), - IntervalWithCount(lower=5, upper=6, type=T.NEGATIVE, count=1 * factor), - ] - - expected_intervals = [ - IntervalWithCount(lower=1, upper=2, type=T.POSITIVE, count=4 * factor), - IntervalWithCount(lower=3, upper=3, type=T.POSITIVE, count=3 * factor), - IntervalWithCount(lower=4, upper=4, type=T.POSITIVE, count=2 * factor), - IntervalWithCount(lower=5, upper=5, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=6, upper=6, type=T.NEGATIVE, count=2 * factor), - IntervalWithCount(lower=7, upper=8, type=T.NEGATIVE, count=1 * factor), - ] - - result = self.process._impl.union_rects_with_count(intervals) - - assert result == expected_intervals - - # multiple intervals with the same end, different start - intervals = [ - IntervalWithCount(lower=4, upper=6, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=3, upper=6, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=2, upper=6, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=1, upper=6, type=T.POSITIVE, count=1 * factor), - ] - - expected_intervals = [ - IntervalWithCount(lower=1, upper=1, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=2, upper=2, type=T.POSITIVE, count=2 * factor), - IntervalWithCount(lower=3, upper=3, type=T.POSITIVE, count=3 * factor), - IntervalWithCount(lower=4, upper=6, type=T.POSITIVE, count=4 * factor), - ] - - result = self.process._impl.union_rects_with_count(intervals) - - assert result == expected_intervals - - # multiple starts and end at the same position inbetween - intervals = [ - IntervalWithCount(lower=1, upper=4, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=2, upper=4, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=3, upper=4, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=4, upper=5, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=4, upper=6, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=4, upper=7, type=T.POSITIVE, count=1 * factor), - ] - - expected_intervals = [ - IntervalWithCount(lower=1, upper=1, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=2, upper=2, type=T.POSITIVE, count=2 * factor), - IntervalWithCount(lower=3, upper=3, type=T.POSITIVE, count=3 * factor), - IntervalWithCount(lower=4, upper=4, type=T.POSITIVE, count=6 * factor), - IntervalWithCount(lower=5, upper=5, type=T.POSITIVE, count=3 * factor), - IntervalWithCount(lower=6, upper=6, type=T.POSITIVE, count=2 * factor), - IntervalWithCount(lower=7, upper=7, type=T.POSITIVE, count=1 * factor), - ] - - result = self.process._impl.union_rects_with_count(intervals) - - assert result == expected_intervals - - # multiple starts and end at the same position inbetween, other types inbetween - intervals = [ - IntervalWithCount(lower=1, upper=4, type=T.NEGATIVE, count=1 * factor), - IntervalWithCount(lower=2, upper=4, type=T.NEGATIVE, count=1 * factor), - IntervalWithCount(lower=3, upper=4, type=T.NEGATIVE, count=1 * factor), - IntervalWithCount(lower=4, upper=5, type=T.NEGATIVE, count=1 * factor), - IntervalWithCount(lower=4, upper=6, type=T.NEGATIVE, count=1 * factor), - IntervalWithCount(lower=4, upper=7, type=T.NEGATIVE, count=1 * factor), - IntervalWithCount(lower=1, upper=4, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=2, upper=4, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=3, upper=4, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=4, upper=5, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=4, upper=6, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=4, upper=7, type=T.POSITIVE, count=1 * factor), - ] - - expected_intervals = [ - IntervalWithCount(lower=1, upper=1, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=2, upper=2, type=T.POSITIVE, count=2 * factor), - IntervalWithCount(lower=3, upper=3, type=T.POSITIVE, count=3 * factor), - IntervalWithCount(lower=4, upper=4, type=T.POSITIVE, count=6 * factor), - IntervalWithCount(lower=5, upper=5, type=T.POSITIVE, count=3 * factor), - IntervalWithCount(lower=6, upper=6, type=T.POSITIVE, count=2 * factor), - IntervalWithCount(lower=7, upper=7, type=T.POSITIVE, count=1 * factor), - ] - - result = self.process._impl.union_rects_with_count(intervals) - - assert result == expected_intervals - - # multiple starts and end at the same position inbetween, other types inbetween (random order) - intervals = [ - IntervalWithCount(lower=3, upper=4, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=4, upper=6, type=T.NEGATIVE, count=1 * factor), - IntervalWithCount(lower=3, upper=4, type=T.NEGATIVE, count=1 * factor), - IntervalWithCount(lower=1, upper=4, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=4, upper=5, type=T.NEGATIVE, count=1 * factor), - IntervalWithCount(lower=1, upper=4, type=T.NEGATIVE, count=1 * factor), - IntervalWithCount(lower=2, upper=4, type=T.NEGATIVE, count=1 * factor), - IntervalWithCount(lower=4, upper=7, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=4, upper=5, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=4, upper=7, type=T.NEGATIVE, count=1 * factor), - IntervalWithCount(lower=4, upper=6, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=2, upper=4, type=T.POSITIVE, count=1 * factor), - ] - - expected_intervals = [ - IntervalWithCount(lower=1, upper=1, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=2, upper=2, type=T.POSITIVE, count=2 * factor), - IntervalWithCount(lower=3, upper=3, type=T.POSITIVE, count=3 * factor), - IntervalWithCount(lower=4, upper=4, type=T.POSITIVE, count=6 * factor), - IntervalWithCount(lower=5, upper=5, type=T.POSITIVE, count=3 * factor), - IntervalWithCount(lower=6, upper=6, type=T.POSITIVE, count=2 * factor), - IntervalWithCount(lower=7, upper=7, type=T.POSITIVE, count=1 * factor), - ] - - result = self.process._impl.union_rects_with_count(intervals) - - assert result == expected_intervals - - class TestMergeAdjacentIntervals(ProcessTest): def test_empty_list(self): """Test that an empty list returns an empty list.""" @@ -2172,1118 +1772,3 @@ def test_union_intervals_no_data_negative(self): result = self.intervals_to_df(result, ["person_id"]) pd.testing.assert_frame_equal(result, expected_df) - - -class TestCountIntervals(ProcessTest): - def test_count_intervals_empty_dataframe_list(self): - result = self.process.count_intervals([]) - assert ( - not result - ), "Failed: Empty list of DataFrames should return an empty DataFrame" - - def test_count_intervals_single_dataframe(self): - df = pd.DataFrame( - { - "person_id": ["A", "A"], - "interval_start": pd.to_datetime( - ["2020-01-01 12:00:00+00:00", "2020-01-02 12:00:00+00:00"] - ), - "interval_end": pd.to_datetime( - ["2020-01-02 18:00:00+00:00", "2020-01-03 18:00:00+00:00"] - ), - "interval_type": [T.POSITIVE, T.POSITIVE], - } - ) - expected_df = pd.DataFrame( - { - "person_id": ["A", "A", "A"], - "interval_start": pd.to_datetime( - [ - "2020-01-01 12:00:00+00:00", - "2020-01-02 12:00:00+00:00", - "2020-01-02 18:00:01+00:00", - ] - ), - "interval_end": pd.to_datetime( - [ - "2020-01-02 11:59:59+00:00", - "2020-01-02 18:00:00+00:00", - "2020-01-03 18:00:00+00:00", - ] - ), - "interval_type": [T.POSITIVE, T.POSITIVE, T.POSITIVE], - "interval_count": [1, 2, 1], - } - ) - - result = self.process.count_intervals([df_to_person_interval_tuple(df)]) - result = self.intervals_to_df(result, ["person_id"]) - - pd.testing.assert_frame_equal(result, expected_df) - - def test_count_intervals_overlapping_intervals(self): - df1 = pd.DataFrame( - { - "person_id": ["A"], - "interval_start": pd.to_datetime(["2020-01-01 12:00:00+00:00"]), - "interval_end": pd.to_datetime(["2020-01-05 18:00:00+00:00"]), - "interval_type": [T.POSITIVE], - } - ) - df2 = pd.DataFrame( - { - "person_id": ["A"], - "interval_start": pd.to_datetime(["2020-01-04 12:00:00+00:00"]), - "interval_end": pd.to_datetime(["2020-01-06 12:00:00+00:00"]), - "interval_type": [T.POSITIVE], - } - ) - expected_df = pd.DataFrame( - { - "person_id": ["A", "A", "A"], - "interval_start": pd.to_datetime( - [ - "2020-01-01 12:00:00+00:00", - "2020-01-04 12:00:00+00:00", - "2020-01-05 18:00:01+00:00", - ] - ), - "interval_end": pd.to_datetime( - [ - "2020-01-04 11:59:59+00:00", - "2020-01-05 18:00:00+00:00", - "2020-01-06 12:00:00+00:00", - ] - ), - "interval_type": [T.POSITIVE, T.POSITIVE, T.POSITIVE], - "interval_count": [1, 2, 1], - } - ) - - result = self.process.count_intervals( - [ - df_to_person_interval_tuple(df1, by=["person_id"]), - df_to_person_interval_tuple(df2, by=["person_id"]), - ] - ) - result = self.intervals_to_df(result, ["person_id"]) - - pd.testing.assert_frame_equal(result, expected_df) - - def test_count_intervals_non_overlapping_intervals(self): - df1 = pd.DataFrame( - { - "person_id": ["A"], - "interval_start": pd.to_datetime(["2020-01-01 12:00:00+00:00"]), - "interval_end": pd.to_datetime(["2020-01-03 13:30:00+00:00"]), - "interval_type": [T.POSITIVE], - } - ) - df2 = pd.DataFrame( - { - "person_id": ["A"], - "interval_start": pd.to_datetime(["2020-01-03 13:31:00+00:00"]), - "interval_end": pd.to_datetime(["2020-01-04 18:00:00+00:00"]), - "interval_type": [T.POSITIVE], - } - ) - expected_df = ( - pd.concat([df1, df2]).reset_index(drop=True).assign(interval_count=1) - ) - - result = self.process.count_intervals( - [ - df_to_person_interval_tuple(df1, by=["person_id"]), - df_to_person_interval_tuple(df2, by=["person_id"]), - ] - ) - result = self.intervals_to_df(result, ["person_id"]) - - pd.testing.assert_frame_equal(result, expected_df) - - def test_count_intervals_group_by_multiple_columns(self): - df1 = pd.DataFrame( - { - "group1": ["A"], - "group2": ["B"], - "interval_start": pd.to_datetime(["2020-01-01 12:00:00+00:00"]), - "interval_end": pd.to_datetime(["2020-01-02 12:00:00+00:00"]), - "interval_type": [T.POSITIVE], - } - ) - df2 = pd.DataFrame( - { - "group1": ["A"], - "group2": ["B"], - "interval_start": pd.to_datetime(["2020-01-02 12:00:00+00:00"]), - "interval_end": pd.to_datetime(["2020-01-03 12:00:00+00:00"]), - "interval_type": [T.POSITIVE], - } - ) - expected_df = pd.DataFrame( - { - "group1": ["A", "A", "A"], - "group2": ["B", "B", "B"], - "interval_start": pd.to_datetime( - [ - "2020-01-01 12:00:00+00:00", - "2020-01-02 12:00:00+00:00", - "2020-01-02 12:00:01+00:00", - ] - ), - "interval_end": pd.to_datetime( - [ - "2020-01-02 11:59:59+00:00", - "2020-01-02 12:00:00+00:00", - "2020-01-03 12:00:00+00:00", - ] - ), - "interval_type": [T.POSITIVE, T.POSITIVE, T.POSITIVE], - "interval_count": [1, 2, 1], - } - ) - - result = self.process.count_intervals( - [ - df_to_person_interval_tuple(df1, by=["group1", "group2"]), - df_to_person_interval_tuple(df2, by=["group1", "group2"]), - ] - ) - result = self.intervals_to_df(result, ["group1", "group2"]) - - pd.testing.assert_frame_equal(result, expected_df) - - def test_count_intervals_adjacent_intervals(self): - df1 = pd.DataFrame( - { - "person_id": ["A"], - "interval_start": pd.to_datetime(["2020-01-01 12:00:00+00:00"]), - "interval_end": pd.to_datetime(["2020-01-03 13:30:59+00:00"]), - "interval_type": [T.POSITIVE], - } - ) - df2 = pd.DataFrame( - { - "person_id": ["A"], - "interval_start": pd.to_datetime(["2020-01-03 13:31:00+00:00"]), - "interval_end": pd.to_datetime(["2020-01-04 18:00:00+00:00"]), - "interval_type": [T.POSITIVE], - } - ) - expected_df = pd.DataFrame( - { - "person_id": ["A"], - "interval_start": pd.to_datetime(["2020-01-01 12:00:00+00:00"]), - "interval_end": pd.to_datetime(["2020-01-04 18:00:00+00:00"]), - "interval_type": [T.POSITIVE], - "interval_count": [1], - } - ) - - result = self.process.count_intervals( - [ - df_to_person_interval_tuple(df1, by=["person_id"]), - df_to_person_interval_tuple(df2, by=["person_id"]), - ] - ) - result = self.intervals_to_df(result, ["person_id"]) - - pd.testing.assert_frame_equal(result, expected_df) - - def test_count_intervals_with_timezone(self): - data1 = { - "person_id": [1, 1], - "concept_id": ["A", "A"], - "interval_start": pd.to_datetime( - ["2023-01-01T00:00:00Z", "2023-01-02T00:00:00Z"], utc=True - ), - "interval_end": pd.to_datetime( - ["2023-01-01T12:00:00Z", "2023-01-02T12:00:00Z"], utc=True - ), - "interval_type": [T.POSITIVE, T.POSITIVE], - } - data2 = { - "person_id": [1, 1], - "concept_id": ["A", "B"], - "interval_start": pd.to_datetime( - ["2023-01-01T06:00:00Z", "2023-01-03T00:00:00Z"], utc=True - ), - "interval_end": pd.to_datetime( - ["2023-01-01T18:00:00Z", "2023-01-03T12:00:00Z"], utc=True - ), - "interval_type": [T.POSITIVE, T.POSITIVE], - } - df1 = pd.DataFrame(data1) - df2 = pd.DataFrame(data2) - - expected_data = { - "person_id": [1, 1, 1, 1, 1], - "concept_id": ["A", "A", "A", "A", "B"], - "interval_start": pd.to_datetime( - [ - "2023-01-01T00:00:00Z", - "2023-01-01T06:00:00Z", - "2023-01-01T12:00:01Z", - "2023-01-02T00:00:00Z", - "2023-01-03T00:00:00Z", - ], - utc=True, - ), - "interval_end": pd.to_datetime( - [ - "2023-01-01T05:59:59Z", - "2023-01-01T12:00:00Z", - "2023-01-01T18:00:00Z", - "2023-01-02T12:00:00Z", - "2023-01-03T12:00:00Z", - ], - utc=True, - ), - "interval_type": [ - T.POSITIVE, - T.POSITIVE, - T.POSITIVE, - T.POSITIVE, - T.POSITIVE, - ], - "interval_count": [1, 2, 1, 1, 1], - } - expected_df = pd.DataFrame(expected_data) - - result_df = self.process.count_intervals( - [ - df_to_person_interval_tuple(df1, by=["person_id", "concept_id"]), - df_to_person_interval_tuple(df2, by=["person_id", "concept_id"]), - ] - ) - result_df = self.intervals_to_df(result_df, ["person_id", "concept_id"]) - - pd.testing.assert_frame_equal(result_df, expected_df) - - def test_count_intervals_group_by_multiple_columns_complex_data2(self): - data1 = """ - group1 group2 interval_start interval_end interval_type - B 1 2024-01-01 13:00:00 2024-01-01 14:00:00 POSITIVE - B 1 2024-01-01 15:00:00 2024-01-01 16:00:00 POSITIVE - B 1 2024-01-01 17:00:00 2024-01-01 18:00:00 POSITIVE - B 1 2024-01-01 19:00:00 2024-01-01 20:00:00 POSITIVE - """ - df1 = df_from_str(data1) - - data2 = """ - group1 group2 interval_start interval_end interval_type - B 1 2024-01-01 18:00:00 2024-01-01 19:00:00 POSITIVE - B 1 2024-01-01 13:00:00 2024-01-01 14:00:00 POSITIVE - B 1 2024-01-01 15:00:00 2024-01-01 16:00:00 POSITIVE - """ - df2 = df_from_str(data2) - - expected_data = """ - group1 group2 interval_start interval_end interval_type interval_count - B 1 2024-01-01 13:00:00 2024-01-01 14:00:00 POSITIVE 2 - B 1 2024-01-01 15:00:00 2024-01-01 16:00:00 POSITIVE 2 - B 1 2024-01-01 17:00:00 2024-01-01 17:59:59 POSITIVE 1 - B 1 2024-01-01 18:00:00 2024-01-01 18:00:00 POSITIVE 2 - B 1 2024-01-01 18:00:01 2024-01-01 18:59:59 POSITIVE 1 - B 1 2024-01-01 19:00:00 2024-01-01 19:00:00 POSITIVE 2 - B 1 2024-01-01 19:00:01 2024-01-01 20:00:00 POSITIVE 1 - """ - expected_df = ( - df_from_str(expected_data) - .sort_values(by=["group1", "group2", "interval_start", "interval_end"]) - .reset_index(drop=True) - ) - - result = self.process.count_intervals( - [ - df_to_person_interval_tuple(df1, by=["group1", "group2"]), - df_to_person_interval_tuple(df2, by=["group1", "group2"]), - ] - ) - result = ( - self.intervals_to_df(result, ["group1", "group2"]) - .sort_values(by=["group1", "group2", "interval_start", "interval_end"]) - .reset_index(drop=True) - ) - - pd.testing.assert_frame_equal(result, expected_df) - - def test_count_intervals_group_by_multiple_columns_complex_data(self): - data1 = """ - group1 group2 interval_start interval_end interval_type - A 1 2023-01-01 12:00:00 2023-01-02 12:00:00 POSITIVE - A 1 2023-01-02 05:00:00 2023-01-03 05:00:00 POSITIVE - A 1 2023-01-03 06:00:00 2023-01-03 12:00:00 POSITIVE - A 1 2023-01-03 13:00:00 2023-01-04 12:00:00 POSITIVE - A 1 2023-01-04 18:00:00 2023-01-04 20:00:00 POSITIVE - A 1 2023-01-05 06:00:00 2023-01-05 23:59:00 POSITIVE - A 2 2023-02-01 12:59:00 2023-02-01 12:59:00 POSITIVE - A 2 2023-02-01 12:59:01 2023-02-01 12:59:01 POSITIVE - B 1 2024-01-01 13:00:00 2024-01-01 14:00:00 POSITIVE - B 1 2024-01-01 15:00:00 2024-01-01 16:00:00 POSITIVE - B 1 2024-01-01 17:00:00 2024-01-01 18:00:00 POSITIVE - B 1 2024-01-01 19:00:00 2024-01-01 20:00:00 POSITIVE - """ - df1 = df_from_str(data1) - - data2 = """ - group1 group2 interval_start interval_end interval_type - A 1 2023-01-03 12:00:01 2023-01-03 12:59:59 POSITIVE - B 2 2024-02-01 12:00:00 2024-02-01 13:00:00 POSITIVE - B 2 2024-02-01 13:00:00 2024-02-01 14:00:00 POSITIVE - B 2 2024-02-01 15:00:00 2024-02-01 16:00:00 POSITIVE - B 1 2024-01-01 18:00:00 2024-01-01 19:00:00 POSITIVE - B 1 2024-01-01 13:00:00 2024-01-01 14:00:00 POSITIVE - B 1 2024-01-01 15:00:00 2024-01-01 16:00:00 POSITIVE - """ - df2 = df_from_str(data2) - - data3 = """ - group1 group2 interval_start interval_end interval_type - A 2 2023-02-01 06:00:00 2023-02-01 06:00:00 POSITIVE - A 2 2023-02-01 06:00:02 2023-02-01 12:58:58 POSITIVE - A 1 2023-01-04 22:00:00 2023-01-05 02:00:00 POSITIVE - B 3 2023-03-04 22:00:00 2023-03-05 02:00:00 POSITIVE - B 2 2024-02-01 15:00:00 2024-02-01 16:00:01 POSITIVE - """ - df3 = df_from_str(data3) - - expected_data = """ - group1 group2 interval_start interval_end interval_type interval_count - A 1 2023-01-01 12:00:00 2023-01-02 04:59:59 POSITIVE 1 - A 1 2023-01-02 05:00:00 2023-01-02 12:00:00 POSITIVE 2 - A 1 2023-01-02 12:00:01 2023-01-03 05:00:00 POSITIVE 1 - A 1 2023-01-03 06:00:00 2023-01-04 12:00:00 POSITIVE 1 - A 1 2023-01-04 18:00:00 2023-01-04 20:00:00 POSITIVE 1 - A 1 2023-01-04 22:00:00 2023-01-05 02:00:00 POSITIVE 1 - A 1 2023-01-05 06:00:00 2023-01-05 23:59:00 POSITIVE 1 - A 2 2023-02-01 12:59:00 2023-02-01 12:59:01 POSITIVE 1 - A 2 2023-02-01 06:00:00 2023-02-01 06:00:00 POSITIVE 1 - A 2 2023-02-01 06:00:02 2023-02-01 12:58:58 POSITIVE 1 - B 1 2024-01-01 13:00:00 2024-01-01 14:00:00 POSITIVE 2 - B 1 2024-01-01 15:00:00 2024-01-01 16:00:00 POSITIVE 2 - B 1 2024-01-01 17:00:00 2024-01-01 17:59:59 POSITIVE 1 - B 1 2024-01-01 18:00:00 2024-01-01 18:00:00 POSITIVE 2 - B 1 2024-01-01 18:00:01 2024-01-01 18:59:59 POSITIVE 1 - B 1 2024-01-01 19:00:00 2024-01-01 19:00:00 POSITIVE 2 - B 1 2024-01-01 19:00:01 2024-01-01 20:00:00 POSITIVE 1 - B 2 2024-02-01 12:00:00 2024-02-01 12:59:59 POSITIVE 1 - B 2 2024-02-01 13:00:00 2024-02-01 13:00:00 POSITIVE 2 - B 2 2024-02-01 13:00:01 2024-02-01 14:00:00 POSITIVE 1 - B 2 2024-02-01 15:00:00 2024-02-01 16:00:00 POSITIVE 2 - B 2 2024-02-01 16:00:01 2024-02-01 16:00:01 POSITIVE 1 - B 3 2023-03-04 22:00:00 2023-03-05 02:00:00 POSITIVE 1 - """ - expected_df = ( - df_from_str(expected_data) - .sort_values(by=["group1", "group2", "interval_start", "interval_end"]) - .reset_index(drop=True) - ) - - result = self.process.count_intervals( - [ - df_to_person_interval_tuple(df1, by=["group1", "group2"]), - df_to_person_interval_tuple(df2, by=["group1", "group2"]), - df_to_person_interval_tuple(df3, by=["group1", "group2"]), - ] - ) - result = ( - self.intervals_to_df(result, ["group1", "group2"]) - .sort_values(by=["group1", "group2", "interval_start", "interval_end"]) - .reset_index(drop=True) - ) - - pd.testing.assert_frame_equal(result, expected_df) - - def test_count_intervals_edge_case(self): - data1 = """ - person_id interval_start interval_end interval_type - 30833 2023-03-02 13:00:01+00:00 2023-03-02 14:00:00+00:00 POSITIVE - 30833 2023-03-02 14:00:01+00:00 2023-03-02 19:00:00+00:00 NEGATIVE - """ - df1 = df_from_str(data1) - - data2 = """ - person_id interval_start interval_end interval_type - 30833 2023-03-02 14:00:01+00:00 2023-03-02 15:00:00+00:00 POSITIVE - """ - df2 = df_from_str(data2) - - expected_data = """ - person_id interval_start interval_end interval_type interval_count - 30833 2023-03-02 13:00:01+00:00 2023-03-02 15:00:00+00:00 POSITIVE 1 - 30833 2023-03-02 15:00:01+00:00 2023-03-02 19:00:00+00:00 NEGATIVE 1 - """ - expected_df = df_from_str(expected_data) - - result = self.process.count_intervals( - [ - df_to_person_interval_tuple(df1, by=["person_id"]), - df_to_person_interval_tuple(df2, by=["person_id"]), - ] - ) - result = self.intervals_to_df(result, ["person_id"]) - - pd.testing.assert_frame_equal(result, expected_df) - - data1 = """ - person_id interval_start interval_end interval_type - 30833 2023-03-02 13:00:01+00:00 2023-03-02 19:00:00+00:00 NEGATIVE - """ - df1 = df_from_str(data1) - - data2 = """ - person_id interval_start interval_end interval_type - 30833 2023-03-02 13:00:01+00:00 2023-03-02 14:00:00+00:00 POSITIVE - 30833 2023-03-02 14:00:01+00:00 2023-03-02 19:00:00+00:00 NEGATIVE - """ - df2 = df_from_str(data2) - - expected_data = """ - person_id interval_start interval_end interval_type interval_count - 30833 2023-03-02 13:00:01+00:00 2023-03-02 14:00:00+00:00 POSITIVE 1 - 30833 2023-03-02 14:00:01+00:00 2023-03-02 19:00:00+00:00 NEGATIVE 2 - """ - expected_df = df_from_str(expected_data) - - result = self.process.count_intervals( - [ - df_to_person_interval_tuple(df1, by=["person_id"]), - df_to_person_interval_tuple(df2, by=["person_id"]), - ] - ) - result = self.intervals_to_df(result, ["person_id"]) - - pd.testing.assert_frame_equal(result, expected_df) - - def test_count_intervals_edge_case_int(self): - intervals1 = { - 1: [ - Interval(lower=1, upper=4, type=T.NEGATIVE), - ] - } - - intervals2 = { - 1: [ - Interval(lower=1, upper=2, type=T.POSITIVE), - Interval(lower=3, upper=4, type=T.NEGATIVE), - ] - } - - expected_intervals = { - 1: [ - IntervalWithCount(lower=1, upper=2, type=T.POSITIVE, count=1), - IntervalWithCount(lower=3, upper=4, type=T.NEGATIVE, count=2), - ] - } - - result = self.process.count_intervals([intervals1, intervals2]) - - assert len(result) == len(expected_intervals) - assert result == expected_intervals - - def test_count_intervals_no_data_negative_int(self): - intervals1 = [ - Interval(lower=1, upper=2, type=T.NO_DATA), - Interval(lower=3, upper=4, type=T.POSITIVE), - Interval(lower=5, upper=6, type=T.NO_DATA), - ] - - intervals2 = [ - Interval(lower=1, upper=2, type=T.NO_DATA), - Interval(lower=3, upper=4, type=T.NEGATIVE), - Interval(lower=5, upper=6, type=T.NO_DATA), - ] - - expected_intervals = { - 1: [ - IntervalWithCount(lower=1, upper=2, type=T.NO_DATA, count=2), - IntervalWithCount(lower=3, upper=4, type=T.POSITIVE, count=1), - IntervalWithCount(lower=5, upper=6, type=T.NO_DATA, count=2), - ] - } - - result = self.process.count_intervals([{1: intervals1}, {1: intervals2}]) - - assert len(result) == len(expected_intervals) - assert result == expected_intervals - - def test_count_intervals_no_data_negative(self): - data1 = """ - person_id interval_start interval_end interval_type - 30748 2023-02-26 07:00:00+00:00 2023-03-02 12:59:59+00:00 NO_DATA - 30748 2023-03-02 13:00:00+00:00 2023-03-02 14:00:00+00:00 POSITIVE - 30748 2023-03-02 14:00:01+00:00 2023-04-03 23:00:00+00:00 NO_DATA - """ - df1 = df_from_str(data1) - - data2 = """ - person_id interval_start interval_end interval_type - 30748 2023-02-26 07:00:00+00:00 2023-03-02 12:59:59+00:00 NO_DATA - 30748 2023-03-02 13:00:00+00:00 2023-03-02 14:00:00+00:00 NEGATIVE - 30748 2023-03-02 14:00:01+00:00 2023-04-03 23:00:00+00:00 NO_DATA - """ - - df2 = df_from_str(data2) - - expected_data = """ - person_id interval_start interval_end interval_type interval_count - 30748 2023-02-26 07:00:00+00:00 2023-03-02 12:59:59+00:00 NO_DATA 2 - 30748 2023-03-02 13:00:00+00:00 2023-03-02 14:00:00+00:00 POSITIVE 1 - 30748 2023-03-02 14:00:01+00:00 2023-04-03 23:00:00+00:00 NO_DATA 2 - """ - expected_df = df_from_str(expected_data) - - result = self.process.count_intervals( - [ - df_to_person_interval_tuple(df1, by=["person_id"]), - df_to_person_interval_tuple(df2, by=["person_id"]), - ] - ) - result = self.intervals_to_df(result, ["person_id"]) - - pd.testing.assert_frame_equal(result, expected_df) - - -class TestIntersectIntervals(ProcessTest): - def test_intersect_intervals_empty_dataframe_list(self): - result = self.process.intersect_intervals([]) - assert ( - not result - ), "Failed: Empty list of DataFrames should return an empty DataFrame" - - def test_intersect_intervals_single_dataframe(self): - df = pd.DataFrame( - { - "person_id": ["A", "A"], - "interval_start": pd.to_datetime( - ["2020-01-01 08:00:00+00:00", "2020-01-02 09:00:00+00:00"] - ), - "interval_end": pd.to_datetime( - ["2020-01-01 10:00:00+00:00", "2020-01-02 11:00:00+00:00"] - ), - "interval_type": [T.POSITIVE, T.POSITIVE], - } - ) - by = ["person_id"] - result = self.process.intersect_intervals( - [df_to_person_interval_tuple(df, by=by)] - ) - result = self.intervals_to_df(result, by=by) - expected_df = df.copy() - pd.testing.assert_frame_equal(result, expected_df) - - def test_intersect_intervals_intersecting_intervals(self): - df1 = pd.DataFrame( - { - "person_id": ["A"], - "interval_start": pd.to_datetime(["2020-01-01 08:00:00+00:00"]), - "interval_end": pd.to_datetime(["2020-01-01 10:00:00+00:00"]), - "interval_type": [T.POSITIVE], - } - ) - df2 = pd.DataFrame( - { - "person_id": ["A"], - "interval_start": pd.to_datetime(["2020-01-01 09:00:00+00:00"]), - "interval_end": pd.to_datetime(["2020-01-01 11:00:00+00:00"]), - "interval_type": [T.POSITIVE], - } - ) - expected_df = pd.DataFrame( - { - "person_id": ["A"], - "interval_start": pd.to_datetime(["2020-01-01 09:00:00+00:00"]), - "interval_end": pd.to_datetime(["2020-01-01 10:00:00+00:00"]), - "interval_type": [T.POSITIVE], - } - ) - - by = ["person_id"] - result = self.process.intersect_intervals( - [ - df_to_person_interval_tuple(df1, by=by), - df_to_person_interval_tuple(df2, by=by), - ] - ) - result = self.intervals_to_df(result, by=by) - - pd.testing.assert_frame_equal(result, expected_df) - - def test_intersect_intervals_no_intersecting_intervals(self): - df1 = pd.DataFrame( - { - "person_id": ["A"], - "interval_start": pd.to_datetime(["2020-01-01 08:00:00+00:00"]), - "interval_end": pd.to_datetime(["2020-01-01 09:00:00+00:00"]), - "interval_type": [T.POSITIVE], - } - ) - df2 = pd.DataFrame( - { - "person_id": ["A"], - "interval_start": pd.to_datetime(["2020-01-01 10:00:00+00:00"]), - "interval_end": pd.to_datetime(["2020-01-01 11:00:00+00:00"]), - "interval_type": [T.POSITIVE], - } - ) - expected_df = pd.DataFrame( - columns=["person_id", "interval_start", "interval_end", "interval_type"] - ) - - by = ["person_id"] - result = self.process.intersect_intervals( - [ - df_to_person_interval_tuple(df1, by=by), - df_to_person_interval_tuple(df2, by=by), - ] - ) - result = self.intervals_to_df(result, by=by) - - pd.testing.assert_frame_equal(result, expected_df) - - def test_intersect_intervals_group_by_multiple_columns(self): - df1 = pd.DataFrame( - { - "group1": ["A"], - "group2": ["B"], - "interval_start": pd.to_datetime(["2020-01-01 08:00:00+00:00"]), - "interval_end": pd.to_datetime(["2020-01-01 09:00:00+00:00"]), - "interval_type": [T.POSITIVE], - } - ) - df2 = pd.DataFrame( - { - "group1": ["A"], - "group2": ["B"], - "interval_start": pd.to_datetime(["2020-01-01 08:30:00+00:00"]), - "interval_end": pd.to_datetime(["2020-01-01 09:30:00+00:00"]), - "interval_type": [T.POSITIVE], - } - ) - expected_df = pd.DataFrame( - { - "group1": ["A"], - "group2": ["B"], - "interval_start": pd.to_datetime(["2020-01-01 08:30:00+00:00"]), - "interval_end": pd.to_datetime(["2020-01-01 09:00:00+00:00"]), - "interval_type": [T.POSITIVE], - } - ) - - by = ["group1", "group2"] - result = self.process.intersect_intervals( - [ - df_to_person_interval_tuple(df1, by=by), - df_to_person_interval_tuple(df2, by=by), - ] - ) - result = self.intervals_to_df(result, by=by) - - pd.testing.assert_frame_equal(result, expected_df) - - def test_intersect_interval_with_timezone(self): - # Prepare test data with datetime64[ns, UTC] dtype - data1 = { - "person_id": [1, 1], - "concept_id": ["A", "A"], - "interval_start": pd.to_datetime( - ["2023-01-01T00:00:00Z", "2023-01-02T00:00:00Z"], utc=True - ), - "interval_end": pd.to_datetime( - ["2023-01-01T12:00:00Z", "2023-01-02T12:00:00Z"], utc=True - ), - "interval_type": [T.POSITIVE, T.POSITIVE], - } - data2 = { - "person_id": [1, 1], - "concept_id": ["A", "B"], - "interval_start": pd.to_datetime( - ["2023-01-01T06:00:00Z", "2023-01-03T00:00:00Z"], utc=True - ), - "interval_end": pd.to_datetime( - ["2023-01-01T18:00:00Z", "2023-01-03T12:00:00Z"], utc=True - ), - "interval_type": [T.POSITIVE, T.POSITIVE], - } - df1 = pd.DataFrame(data1) - df2 = pd.DataFrame(data2) - - # Call the function - by = ["person_id", "concept_id"] - result = self.process.intersect_intervals( - [ - df_to_person_interval_tuple(df1, by=by), - df_to_person_interval_tuple(df2, by=by), - ] - ) - result = self.intervals_to_df(result, by=by) - - # Define expected output - expected_data = { - "person_id": [1], - "concept_id": ["A"], - "interval_start": pd.to_datetime(["2023-01-01T06:00:00Z"], utc=True), - "interval_end": pd.to_datetime(["2023-01-01T12:00:00Z"], utc=True), - "interval_type": [T.POSITIVE], - } - expected_df = pd.DataFrame(expected_data) - - # Assert - pd.testing.assert_frame_equal(result, expected_df) - - def test_intersect_intervals_group_by_multiple_columns_complex_data(self): - data1 = """ - group1 group2 interval_start interval_end interval_type - A 1 2023-01-01 12:00:00 2023-01-02 12:00:00 POSITIVE - A 1 2023-01-02 05:00:00 2023-01-03 05:00:00 POSITIVE - A 1 2023-01-03 06:00:00 2023-01-03 12:30:00 POSITIVE - A 1 2023-01-03 13:00:00 2023-01-04 12:00:00 POSITIVE - A 1 2023-01-04 18:00:00 2023-01-04 20:00:00 POSITIVE - A 1 2023-01-05 06:00:00 2023-01-05 23:59:00 POSITIVE - A 2 2023-02-01 12:59:00 2023-02-01 12:59:00 POSITIVE - A 2 2023-02-01 12:59:01 2023-02-01 12:59:01 POSITIVE - B 1 2024-01-01 13:00:00 2024-01-01 14:00:00 POSITIVE - B 1 2024-01-01 15:00:00 2024-01-01 16:00:00 POSITIVE - B 2 2023-02-01 12:00:00 2023-02-01 12:59:00 POSITIVE - B 3 2024-03-01 16:00:00 2024-03-01 20:00:00 POSITIVE - B 4 2024-04-01 12:00:00 2024-04-02 12:00:00 POSITIVE - """ - df1 = df_from_str(data1) - - data2 = """ - group1 group2 interval_start interval_end interval_type - A 1 2023-01-03 12:00:01 2023-01-03 12:59:59 POSITIVE - A 2 2023-02-01 06:00:00 2023-02-02 12:00:00 POSITIVE - B 2 2024-02-01 13:00:00 2024-02-01 14:00:00 POSITIVE - B 2 2024-02-01 15:00:00 2024-02-01 16:00:00 POSITIVE - B 1 2024-01-01 18:00:00 2024-01-01 19:00:00 POSITIVE - B 1 2024-01-01 13:10:00 2024-01-01 13:20:00 POSITIVE - B 1 2024-01-01 13:30:00 2024-01-01 13:40:00 POSITIVE - B 1 2024-01-01 13:50:00 2024-01-01 14:00:00 POSITIVE - B 2 2023-02-01 12:59:00 2023-02-01 14:00:00 POSITIVE - B 3 2024-03-01 16:00:00 2024-03-01 18:00:00 POSITIVE - B 4 2024-04-01 12:00:00 2024-04-02 12:00:00 POSITIVE - B 5 2024-05-01 16:00:00 2024-05-01 18:00:00 POSITIVE - """ - df2 = df_from_str(data2) - - data3 = """ - group1 group2 interval_start interval_end interval_type - A 1 2023-01-03 11:00:01 2023-01-03 12:59:59 POSITIVE - A 2 2023-02-01 06:00:02 2023-02-01 12:59:00 POSITIVE - A 1 2023-01-04 22:00:00 2023-01-05 02:00:00 POSITIVE - B 3 2023-03-04 22:00:00 2023-03-05 02:00:00 POSITIVE - B 1 2024-01-01 00:00:00 2024-01-01 13:25:00 POSITIVE - B 1 2024-01-01 13:45:00 2024-01-01 13:55:00 POSITIVE - B 1 2024-01-01 13:51:00 2024-01-01 13:53:00 POSITIVE - B 2 2023-02-01 12:00:00 2023-02-01 14:00:00 POSITIVE - B 3 2024-03-01 18:00:00 2024-03-01 20:00:00 POSITIVE - B 5 2024-05-01 18:00:00 2024-05-01 20:00:00 POSITIVE - """ - df3 = df_from_str(data3) - - expected_data = """ - group1 group2 interval_start interval_end interval_type - A 1 2023-01-03 12:00:01 2023-01-03 12:30:00 POSITIVE - A 2 2023-02-01 12:59:00 2023-02-01 12:59:00 POSITIVE - B 1 2024-01-01 13:10:00 2024-01-01 13:20:00 POSITIVE - B 1 2024-01-01 13:50:00 2024-01-01 13:55:00 POSITIVE - B 2 2023-02-01 12:59:00 2023-02-01 12:59:00 POSITIVE - B 3 2024-03-01 18:00:00 2024-03-01 18:00:00 POSITIVE - """ - expected_df = df_from_str(expected_data) - - by = ["group1", "group2"] - result = self.process.intersect_intervals( - [ - df_to_person_interval_tuple(df1, by=by), - df_to_person_interval_tuple(df2, by=by), - df_to_person_interval_tuple(df3, by=by), - ] - ) - result = ( - self.intervals_to_df(result, by=by) - .sort_values(by=by) - .reset_index(drop=True) - ) - - pd.testing.assert_frame_equal(result, expected_df) - - -class TestIntervalFilling(ProcessTest): - def assert_equal(self, data, expected, observation_window=None): - def to_df(data): - df = pd.DataFrame( - data, - columns=[ - "person_id", - "interval_start", - "interval_end", - "interval_type", - ], - ) - df["interval_start"] = pd.to_datetime(df["interval_start"]) - df["interval_end"] = pd.to_datetime(df["interval_end"], utc=True) - - return df - - result = self.process.forward_fill( - df_to_person_interval_tuple(to_df(data), by=["person_id"]), - observation_window, - ) - df_result = self.intervals_to_df(result, ["person_id"]) - df_expected = to_df(expected) - - pd.testing.assert_frame_equal(df_result, df_expected, check_dtype=False) - - def test_single_row(self): - data = [ - (1, "2023-03-01 08:00:00+00:00", "2023-03-01 08:00:00+00:00", "POSITIVE"), - ] - - expected = [ - (1, "2023-03-01 08:00:00+00:00", "2023-03-01 08:00:00+00:00", "POSITIVE"), - ] - - self.assert_equal(data, expected) - - def test_empty(self): - data = [] - expected = [] - - self.assert_equal(data, expected) - - def test_single_type_per_person(self): - data = [ - (1, "2023-03-01 08:00:00+00:00", "2023-03-01 08:00:00+00:00", "POSITIVE"), - (1, "2023-03-01 09:00:00+00:00", "2023-03-01 09:00:00+00:00", "POSITIVE"), - (1, "2023-03-01 09:10:00+00:00", "2023-03-01 10:00:00+00:00", "POSITIVE"), - (2, "2023-03-01 11:00:00+00:00", "2023-03-01 11:00:00+00:00", "POSITIVE"), - (2, "2023-03-01 12:00:00+00:00", "2023-03-01 13:00:00+00:00", "POSITIVE"), - (2, "2023-03-01 14:00:00+00:00", "2023-03-01 15:00:00+00:00", "POSITIVE"), - ] - - expected = [ - (1, "2023-03-01 08:00:00+00:00", "2023-03-01 10:00:00+00:00", "POSITIVE"), - (2, "2023-03-01 11:00:00+00:00", "2023-03-01 15:00:00+00:00", "POSITIVE"), - ] - - self.assert_equal(data, expected) - - def test_last_row_different(self): - data = [ - (1, "2023-03-01 08:00:00+00:00", "2023-03-01 08:00:00+00:00", "POSITIVE"), - (1, "2023-03-01 09:00:00+00:00", "2023-03-01 09:00:00+00:00", "POSITIVE"), - (1, "2023-03-01 09:10:00+00:00", "2023-03-01 10:00:00+00:00", "NEGATIVE"), - (2, "2023-03-01 11:00:00+00:00", "2023-03-01 11:00:00+00:00", "POSITIVE"), - (2, "2023-03-01 12:00:00+00:00", "2023-03-01 13:00:00+00:00", "POSITIVE"), - (2, "2023-03-01 14:00:00+00:00", "2023-03-01 15:00:00+00:00", "NEGATIVE"), - ] - - expected = [ - (1, "2023-03-01 08:00:00+00:00", "2023-03-01 09:09:59+00:00", "POSITIVE"), - (1, "2023-03-01 09:10:00+00:00", "2023-03-01 10:00:00+00:00", "NEGATIVE"), - (2, "2023-03-01 11:00:00+00:00", "2023-03-01 13:59:59+00:00", "POSITIVE"), - (2, "2023-03-01 14:00:00+00:00", "2023-03-01 15:00:00+00:00", "NEGATIVE"), - ] - - self.assert_equal(data, expected) - - def test_forward_fill(self): - data = [ - (1, "2023-03-01 08:00:00+00:00", "2023-03-01 08:00:00+00:00", "POSITIVE"), - (1, "2023-03-01 09:00:00+00:00", "2023-03-01 09:00:00+00:00", "POSITIVE"), - (1, "2023-03-01 09:10:00+00:00", "2023-03-01 10:00:00+00:00", "POSITIVE"), - (1, "2023-03-01 11:00:00+00:00", "2023-03-01 11:00:00+00:00", "NEGATIVE"), - (1, "2023-03-01 12:00:00+00:00", "2023-03-01 13:00:00+00:00", "POSITIVE"), - (1, "2023-03-01 14:00:00+00:00", "2023-03-01 15:00:00+00:00", "POSITIVE"), - ] - - expected = [ - (1, "2023-03-01 08:00:00+00:00", "2023-03-01 10:59:59+00:00", "POSITIVE"), - (1, "2023-03-01 11:00:00+00:00", "2023-03-01 11:59:59+00:00", "NEGATIVE"), - (1, "2023-03-01 12:00:00+00:00", "2023-03-01 15:00:00+00:00", "POSITIVE"), - ] - self.assert_equal(data, expected) - - data = [ - (1, "2021-01-01 08:00:00+00:00", "2021-01-01 09:00:00+00:00", "POSITIVE"), - (1, "2021-01-01 09:00:00+00:00", "2021-01-01 10:00:00+00:00", "POSITIVE"), - (2, "2021-01-02 10:00:00+00:00", "2021-01-02 10:15:00+00:00", "NEGATIVE"), - (2, "2021-01-02 10:30:00+00:00", "2021-01-02 11:00:00+00:00", "POSITIVE"), - (2, "2021-01-02 11:30:00+00:00", "2021-01-02 12:00:00+00:00", "NEGATIVE"), - (3, "2021-01-03 12:00:00+00:00", "2021-01-03 12:30:00+00:00", "POSITIVE"), - (3, "2021-01-03 12:45:00+00:00", "2021-01-03 13:00:00+00:00", "NEGATIVE"), - ] - - expected = [ - (1, "2021-01-01 08:00:00+00:00", "2021-01-01 10:00:00+00:00", "POSITIVE"), - (2, "2021-01-02 10:00:00+00:00", "2021-01-02 10:29:59+00:00", "NEGATIVE"), - (2, "2021-01-02 10:30:00+00:00", "2021-01-02 11:29:59+00:00", "POSITIVE"), - (2, "2021-01-02 11:30:00+00:00", "2021-01-02 12:00:00+00:00", "NEGATIVE"), - (3, "2021-01-03 12:00:00+00:00", "2021-01-03 12:44:59+00:00", "POSITIVE"), - (3, "2021-01-03 12:45:00+00:00", "2021-01-03 13:00:00+00:00", "NEGATIVE"), - ] - - self.assert_equal(data, expected) - - def test_forward_fill_with_observation_window(self): - observation_window = TimeRange.from_tuple( - ("2023-03-01 08:00:00+00:00", "2023-03-15 15:00:00+00:00") - ) - data = [ - (1, "2023-03-01 08:00:00+00:00", "2023-03-01 08:00:00+00:00", "POSITIVE"), - (1, "2023-03-01 09:00:00+00:00", "2023-03-01 09:00:00+00:00", "POSITIVE"), - (1, "2023-03-01 09:10:00+00:00", "2023-03-01 10:00:00+00:00", "POSITIVE"), - (1, "2023-03-01 11:00:00+00:00", "2023-03-01 11:00:00+00:00", "NEGATIVE"), - (1, "2023-03-01 12:00:00+00:00", "2023-03-01 13:00:00+00:00", "POSITIVE"), - (1, "2023-03-01 14:00:00+00:00", "2023-03-01 15:00:00+00:00", "POSITIVE"), - ] - - expected = [ - (1, "2023-03-01 08:00:00+00:00", "2023-03-01 10:59:59+00:00", "POSITIVE"), - (1, "2023-03-01 11:00:00+00:00", "2023-03-01 11:59:59+00:00", "NEGATIVE"), - (1, "2023-03-01 12:00:00+00:00", "2023-03-15 15:00:00+00:00", "POSITIVE"), - ] - self.assert_equal(data, expected, observation_window) - - # with timezone - observation_window = TimeRange.from_tuple( - ("2023-03-01 08:00:00+01:00", "2023-04-15 15:00:00+02:00") - ) - data = [ - (1, "2023-03-01 08:00:00+01:00", "2023-03-01 08:00:00+01:00", "POSITIVE"), - (1, "2023-03-01 09:00:00+01:00", "2023-03-01 09:00:00+01:00", "POSITIVE"), - (1, "2023-03-01 09:10:00+01:00", "2023-03-01 10:00:00+01:00", "POSITIVE"), - (1, "2023-03-01 11:00:00+01:00", "2023-03-01 11:00:00+01:00", "NEGATIVE"), - (1, "2023-03-01 12:00:00+01:00", "2023-03-01 13:00:00+01:00", "POSITIVE"), - (1, "2023-03-01 14:00:00+01:00", "2023-03-01 15:00:00+01:00", "POSITIVE"), - ] - - expected = [ - (1, "2023-03-01 08:00:00+01:00", "2023-03-01 10:59:59+01:00", "POSITIVE"), - (1, "2023-03-01 11:00:00+01:00", "2023-03-01 11:59:59+01:00", "NEGATIVE"), - (1, "2023-03-01 12:00:00+01:00", "2023-04-15 15:00:00+02:00", "POSITIVE"), - ] - self.assert_equal(data, expected, observation_window) - - -class TestCreateTimeIntervals(ProcessTest): - # Helper to create timezone-aware datetime objects using pendulum - def tz_aware_datetime(self, date_str, timezone): - return pendulum.parse(date_str, tz=timezone) - - @pytest.mark.parametrize("timezone", ["America/New_York", "Europe/Berlin", "UTC"]) - def test_naive_datetimes(self, timezone): - start_datetime = pendulum.parse("2023-07-01 12:00:00").naive() - end_datetime = pendulum.parse("2023-07-03 12:00:00").naive() - start_time = time(9, 0) - end_time = time(17, 0) - intervals = self.process.create_time_intervals( - start_datetime, - end_datetime, - start_time, - end_time, - interval_type=T.POSITIVE, - timezone=timezone, - ) - # Ignore intervals of type NOT_APPLICABLE at the boundary of the period - intervals = [i for i in intervals if i.type == T.POSITIVE] - assert len(intervals) == 3 # Expecting intervals for July 1st and 2nd - assert ( - intervals[0].lower - == pendulum.parse("2023-07-01 12:00:00", tz=timezone).timestamp() - ) - assert ( - intervals[0].upper - == pendulum.parse("2023-07-01 17:00:00", tz=timezone).timestamp() - ) - assert ( - intervals[1].lower - == pendulum.parse("2023-07-02 09:00:00", tz=timezone).timestamp() - ) - assert ( - intervals[1].upper - == pendulum.parse("2023-07-02 17:00:00", tz=timezone).timestamp() - ) - assert ( - intervals[2].lower - == pendulum.parse("2023-07-03 09:00:00", tz=timezone).timestamp() - ) - assert ( - intervals[2].upper - == pendulum.parse("2023-07-03 12:00:00", tz=timezone).timestamp() - ) - - def test_aware_datetimes(self): - tz = "America/New_York" - start_datetime = self.tz_aware_datetime("2023-07-01 12:00:00", tz) - end_datetime = self.tz_aware_datetime("2023-07-03 12:00:00", tz) - start_time = time(22, 0) - end_time = time(6, 0) - intervals = self.process.create_time_intervals( - start_datetime, - end_datetime, - start_time, - end_time, - interval_type=T.POSITIVE, - timezone=tz, - ) - # Ignore intervals of type NOT_APPLICABLE at the boundary of the period - intervals = [ i for i in intervals if i.type == T.POSITIVE ] - assert ( - len(intervals) == 2 - ) # Expecting intervals for the nights of July 1st and 2nd - assert ( - intervals[0].lower - == self.tz_aware_datetime("2023-07-01 22:00:00", tz).timestamp() - ) - assert ( - intervals[0].upper - == self.tz_aware_datetime("2023-07-02 06:00:00", tz).timestamp() - ) - assert ( - intervals[1].lower - == self.tz_aware_datetime("2023-07-02 22:00:00", tz).timestamp() - ) - assert ( - intervals[1].upper - == self.tz_aware_datetime("2023-07-03 06:00:00", tz).timestamp() - ) - - @pytest.mark.parametrize("timezone", ["America/New_York", "Europe/Berlin", "UTC"]) - def test_spanning_midnight_naive(self, timezone): - start_datetime = pendulum.parse("2023-07-01 12:00:00", tz=timezone) - end_datetime = pendulum.parse("2023-07-03 04:00:00", tz=timezone) - start_time = time(22, 0) - end_time = time(6, 0) - intervals = self.process.create_time_intervals( - start_datetime, - end_datetime, - start_time, - end_time, - interval_type=T.POSITIVE, - timezone=timezone, - ) - # Ignore intervals of type NOT_APPLICABLE at the boundary of the period - intervals = [i for i in intervals if i.type == T.POSITIVE] - assert ( - len(intervals) == 2 - ) # Expecting intervals for the nights of July 1st and 2nd - assert ( - intervals[0].lower - == pendulum.parse("2023-07-01 22:00:00", tz=timezone).timestamp() - ) - assert ( - intervals[0].upper - == pendulum.parse("2023-07-02 06:00:00", tz=timezone).timestamp() - ) - assert ( - intervals[1].lower - == pendulum.parse("2023-07-02 22:00:00", tz=timezone).timestamp() - ) - assert ( - intervals[1].upper - == pendulum.parse("2023-07-03 04:00:00", tz=timezone).timestamp() - ) From e8f1ab2d466226b71323ecbe7bfce386a476e002 Mon Sep 17 00:00:00 2001 From: Jan Moringen Date: Mon, 24 Mar 2025 18:11:49 +0100 Subject: [PATCH 28/43] refactor: improvements for logic.Count implementation in Task * Cleanup and comments * Avoid some unnecessary work * Add two test cases --- execution_engine/task/task.py | 32 +- tests/_fixtures/concept.py | 20 +- .../combination/test_temporal_combination.py | 293 +++++++++++++++++- 3 files changed, 304 insertions(+), 41 deletions(-) diff --git a/execution_engine/task/task.py b/execution_engine/task/task.py index b168b3ce..b74bf544 100644 --- a/execution_engine/task/task.py +++ b/execution_engine/task/task.py @@ -379,39 +379,31 @@ def interval_union_with_count( result = process.union_intervals(data) elif isinstance(self.expr, logic.Count): - # result = process.count_intervals(data) - # result = process.filter_count_intervals( - # result, - # min_count=self.expr.count_min, - # max_count=self.expr.count_max, - # ) count_min = self.expr.count_min count_max = self.expr.count_max if count_min is None and count_max is None: raise ValueError("count_min and count_max cannot both be None") def interval_counts( - start: int, end: int, intervals: List[AnyInterval] + start: int, end: int, intervals: List[GeneralizedInterval] ) -> GeneralizedInterval: + # Count the different interval types. None represents + # implicit negative intervals and is counted as such. counts = Counter( (interval.type if interval else IntervalType.NEGATIVE) for interval in intervals ) + # The interval type with the highest "union priority" + # determines the result. positive_count = counts[IntervalType.POSITIVE] - negative_count = counts[IntervalType.NEGATIVE] - not_applicable_count = counts[IntervalType.NOT_APPLICABLE] - no_data_count = counts[IntervalType.NO_DATA] - if positive_count > 0: if count_min is None: - interval_type = ( - IntervalType.POSITIVE - if (positive_count <= count_max) # type:ignore [operator] - else IntervalType.NEGATIVE - ) - return Interval(start, end, interval_type) + if positive_count <= count_max: + return Interval(start, end, IntervalType.POSITIVE) + else: + return None # Implicit negative interval else: min_good = count_min <= positive_count max_good = (count_max is None) or (positive_count <= count_max) @@ -422,11 +414,11 @@ def interval_counts( ) ratio = positive_count / count_min return IntervalWithCount(start, end, interval_type, ratio) - if no_data_count > 0: + if counts[IntervalType.NO_DATA] > 0: return IntervalWithCount(start, end, IntervalType.NO_DATA, 0) - if not_applicable_count > 0: + if counts[IntervalType.NOT_APPLICABLE] > 0: return IntervalWithCount(start, end, IntervalType.NOT_APPLICABLE, 0) - if negative_count > 0: + if counts[IntervalType.NEGATIVE] > 0: return IntervalWithCount(start, end, IntervalType.NEGATIVE, 0) raise ValueError("No intervals of any kind found") diff --git a/tests/_fixtures/concept.py b/tests/_fixtures/concept.py index fd14e338..b5ef855d 100644 --- a/tests/_fixtures/concept.py +++ b/tests/_fixtures/concept.py @@ -51,14 +51,14 @@ def unit_concept(): invalid_reason=None, ) -concept_delir_screening = Concept( # TODO(jmoringe): copied from above (concept_artificial_respiration) - concept_id=4230167, +concept_delir_screening = Concept( + concept_id=4196006, # TODO(jmoringe): made-up id and code concept_name="Delir Screening", - domain_id="Procedure", + domain_id="Measurement", vocabulary_id="SNOMED", concept_class_id="Procedure", standard_concept="S", - concept_code="40617009", + concept_code="431182000", invalid_reason=None, ) @@ -283,6 +283,18 @@ def unit_concept(): invalid_reason=None, ) +concept_body_height: Concept = Concept( + concept_id=3036277, + concept_name="Body height", + domain_id="Measurement", + vocabulary_id="LOINC", + concept_class_id="Clinical Observation", + standard_concept="S", + concept_code="8302-2", + invalid_reason=None, +) + + """ The following list of concepts are heparin drugs and all of them directly map to heparin as ingredient (via ancestor, not relationship !) diff --git a/tests/execution_engine/omop/criterion/combination/test_temporal_combination.py b/tests/execution_engine/omop/criterion/combination/test_temporal_combination.py index 527201ff..3944993d 100644 --- a/tests/execution_engine/omop/criterion/combination/test_temporal_combination.py +++ b/tests/execution_engine/omop/criterion/combination/test_temporal_combination.py @@ -29,6 +29,9 @@ concept_surgical_procedure, concept_unit_kg, concept_unit_mg, + concept_body_height, + concept_unit_cm, + concept_tidal_volume, ) from tests._testdata import concepts from tests.execution_engine.omop.criterion.test_criterion import TestCriterion @@ -200,7 +203,7 @@ def test_expr_contains_criteria(self, mock_criteria): concept=concept_covid19, ) -c3 = ProcedureOccurrence( +artificial_respiration = ProcedureOccurrence( concept=concept_artificial_respiration, ) @@ -208,6 +211,10 @@ def test_expr_contains_criteria(self, mock_criteria): concept=concept_surgical_procedure, ) +delir_screening = ProcedureOccurrence( + concept=concept_delir_screening, +) + bodyweight_measurement_without_forward_fill = Measurement( concept=concept_body_weight, value=ValueNumber.parse("<=110", unit=concept_unit_kg), @@ -221,8 +228,30 @@ def test_expr_contains_criteria(self, mock_criteria): static=False, ) -delir_screening = ProcedureOccurrence( - concept=concept_delir_screening, +body_height_measurement_without_forward_fill = Measurement( + concept=concept_body_height, + value=ValueNumber.parse("<=110", unit=concept_unit_cm), + static=False, + forward_fill=False, +) + +body_height_measurement_with_forward_fill = Measurement( + concept=concept_body_height, + value=ValueNumber.parse("<=110", unit=concept_unit_cm), + static=False, +) + +tidal_volume_measurement_without_forward_fill = Measurement( + concept=concept_tidal_volume, + value=ValueNumber.parse("<=110", unit=concept_unit_cm), # TODO(jmoringe): copied; does not make sense + static=False, + forward_fill=False +) + +tidal_volume_measurement_with_forward_fill = Measurement( + concept=concept_tidal_volume, + value=ValueNumber.parse("<=110", unit=concept_unit_cm), # TODO(jmoringe): copied; does not make sense + static=False, ) @@ -242,9 +271,14 @@ def criteria(self, db_session): criteria = [ c1, c2, - c3, + artificial_respiration, + c4, bodyweight_measurement_without_forward_fill, bodyweight_measurement_with_forward_fill, + body_height_measurement_without_forward_fill, + body_height_measurement_with_forward_fill, + tidal_volume_measurement_without_forward_fill, + tidal_volume_measurement_with_forward_fill, delir_screening, ] for i, c in enumerate(criteria): @@ -384,7 +418,7 @@ def patient_events(self, db_session, person_visit): ), ( temporal_logic_util.Day( - c3, + artificial_respiration, ), { 1: { @@ -461,7 +495,7 @@ def patient_events(self, db_session, person_visit): ), ( temporal_logic_util.Presence( - c3, + artificial_respiration, start_time=datetime.time(8, 30), end_time=datetime.time(16, 59), ), @@ -469,7 +503,7 @@ def patient_events(self, db_session, person_visit): ), ( temporal_logic_util.Presence( - c3, + artificial_respiration, start_time=datetime.time(17, 30), end_time=datetime.time(22, 00), ), @@ -523,7 +557,7 @@ def patient_events(self, db_session, person_visit): ), ( temporal_logic_util.MorningShift( - c3, + artificial_respiration, ), {1: set(), 2: set(), 3: set()}, ), @@ -570,7 +604,7 @@ def patient_events(self, db_session, person_visit): ), ( temporal_logic_util.AfternoonShift( - c3, + artificial_respiration, ), { 1: { @@ -622,7 +656,7 @@ def patient_events(self, db_session, person_visit): ), ( temporal_logic_util.NightShift( - c3, + artificial_respiration, ), {1: set(), 2: set(), 3: set()}, ), @@ -665,7 +699,7 @@ def patient_events(self, db_session, person_visit): ), ( temporal_logic_util.NightShiftBeforeMidnight( - c3, + artificial_respiration, ), {1: set(), 2: set(), 3: set()}, ), @@ -704,7 +738,7 @@ def patient_events(self, db_session, person_visit): ), ( temporal_logic_util.NightShiftAfterMidnight( - c3, + artificial_respiration, ), {1: set(), 2: set(), 3: set()}, ), @@ -817,7 +851,7 @@ def patient_events(self, db_session, visit_occurrence): ), ( temporal_logic_util.Day( - c3, + artificial_respiration, ), { 1: { @@ -909,7 +943,7 @@ def patient_events(self, db_session, visit_occurrence): ), ( temporal_logic_util.Presence( - c3, + artificial_respiration, start_time=datetime.time(17, 30), end_time=datetime.time(22, 00), ), @@ -1011,7 +1045,7 @@ def patient_events(self, db_session, visit_occurrence): ), ( temporal_logic_util.MorningShift( - c3, + artificial_respiration, ), { 1: { @@ -1115,7 +1149,7 @@ def patient_events(self, db_session, visit_occurrence): ), ( temporal_logic_util.AfternoonShift( - c3, + artificial_respiration, ), { 1: { @@ -1219,7 +1253,7 @@ def patient_events(self, db_session, visit_occurrence): ), ( temporal_logic_util.NightShift( - c3, + artificial_respiration, ), { 1: { @@ -2181,3 +2215,228 @@ def test_multiple_patients_on_database( result_tuples, expected[person.person_id] ): assert result_tuple == expected_tuple + + +class TestCountOnIndicatorWindows(TestCriterionCombinationDatabase): + """ + This test checks the behavior of the logical Count operator for + different thresholds and different kinds of inputs. Of particular + interest is the computed count attribute of the result intervals + and the behavior for edge cases regarding the count thresholds. + """ + + @pytest.fixture + def observation_window(self) -> TimeRange: + return TimeRange( + name="observation", start="2025-02-18 14:55:00+01:00", end="2025-02-22 12:00:00+01:00" + ) + + def patient_events(self, db_session, visit_occurrence): + person_id = visit_occurrence.person_id + events = [create_condition( + vo=visit_occurrence, + condition_concept_id=concept_covid19.concept_id, + condition_start_datetime=pendulum.parse("2025-02-19 08:00:00+01:00"), + condition_end_datetime=pendulum.parse("2025-02-21 02:00:00+01:00"), + )] + if person_id == 1: + events.append(create_measurement( + vo=visit_occurrence, + measurement_concept_id=concept_body_weight.concept_id, + measurement_datetime=pendulum.parse("2025-02-19 18:00:00+01:00"), + value_as_number=90, + unit_concept_id=concept_unit_kg.concept_id, + )) + events.append(create_measurement( + vo=visit_occurrence, + measurement_concept_id=concept_body_height.concept_id, + measurement_datetime=pendulum.parse("2025-02-20 07:00:00+01:00"), + value_as_number=90, + unit_concept_id=concept_unit_cm.concept_id, + )) + events.append(create_measurement( + vo=visit_occurrence, + measurement_concept_id=concept_tidal_volume.concept_id, + measurement_datetime=pendulum.parse("2025-02-20 18:00:00+01:00"), + value_as_number=90, + unit_concept_id=concept_unit_cm.concept_id, + )) + db_session.add_all(events) + db_session.commit() + + @pytest.mark.parametrize( + "population,intervention,expected", + [ + ( + logic.And(c2), # population + logic.MinCount( + temporal_logic_util.AnyTime(bodyweight_measurement_without_forward_fill), + temporal_logic_util.AnyTime(body_height_measurement_without_forward_fill), + temporal_logic_util.AnyTime(tidal_volume_measurement_without_forward_fill), + threshold=1, + ), + { + 1: [ + ( + IntervalType.NOT_APPLICABLE, + 'nan', + pendulum.parse("2025-02-18 17:55:00+01:00"), + pendulum.parse("2025-02-19 07:59:59+01:00"), + ), + ( + IntervalType.POSITIVE, + 3, + pendulum.parse("2025-02-19 08:00:00+01:00"), + pendulum.parse("2025-02-21 02:00:00+01:00"), + ), + ( + IntervalType.NOT_APPLICABLE, + 'nan', + pendulum.parse("2025-02-21 02:00:01+01:00"), + pendulum.parse("2025-02-22 05:30:00+01:00"), + ), + ], + 2: [ + ( + IntervalType.NOT_APPLICABLE, + 'nan', + pendulum.parse("2025-02-18 17:55:00+01:00"), + pendulum.parse("2025-02-19 07:59:59+01:00"), + ), + ( + IntervalType.NO_DATA, + 0, + pendulum.parse("2025-02-19 08:00:00+01:00"), + pendulum.parse("2025-02-21 02:00:00+01:00"), + ), + ( + IntervalType.NOT_APPLICABLE, + 'nan', + pendulum.parse("2025-02-21 02:00:01+01:00"), + pendulum.parse("2025-02-22 05:30:00+01:00"), + ), + ], + }, + ), + ( + logic.And(c2), # population + logic.MinCount( + bodyweight_measurement_with_forward_fill, + body_height_measurement_with_forward_fill, + tidal_volume_measurement_with_forward_fill, + threshold=2, + ), + { + 1: [ + ( + IntervalType.NOT_APPLICABLE, + 'nan', + pendulum.parse("2025-02-18 17:55:00+01:00"), + pendulum.parse("2025-02-19 07:59:59+01:00"), + ), + ( + IntervalType.NO_DATA, + 0.0, + pendulum.parse("2025-02-19 08:00:00+01:00"), + pendulum.parse("2025-02-19 17:59:59+01:00"), + ), + ( + IntervalType.NEGATIVE, + 0.5, + pendulum.parse("2025-02-19 18:00:00+01:00"), + pendulum.parse("2025-02-20 06:59:59+01:00"), + ), + ( + IntervalType.POSITIVE, + 1, + pendulum.parse("2025-02-20 07:00:00+01:00"), + pendulum.parse("2025-02-20 17:59:59+01:00"), + ), + ( + IntervalType.POSITIVE, + 1.5, + pendulum.parse("2025-02-20 18:00:00+01:00"), + pendulum.parse("2025-02-21 02:00:00+01:00"), + ), + ( + IntervalType.NOT_APPLICABLE, + 'nan', + pendulum.parse("2025-02-21 02:00:01+01:00"), + pendulum.parse("2025-02-22 05:30:00+01:00"), + ), + ], + 2: [ + ( + IntervalType.NOT_APPLICABLE, + 'nan', + pendulum.parse("2025-02-18 17:55:00+01:00"), + pendulum.parse("2025-02-19 07:59:59+01:00"), + ), + ( + IntervalType.NO_DATA, + 0, + pendulum.parse("2025-02-19 08:00:00+01:00"), + pendulum.parse("2025-02-21 02:00:00+01:00"), + ), + ( + IntervalType.NOT_APPLICABLE, + 'nan', + pendulum.parse("2025-02-21 02:00:01+01:00"), + pendulum.parse("2025-02-22 05:30:00+01:00"), + ), + ], + }, + ), + ], + ) + def test_combination_on_database( + self, + person, + db_session, + population, + intervention, + base_criterion, + expected, + observation_window, + criteria, + ): + persons = person[:2] + vos = [] + for person in persons: + visit = create_visit( + person_id=person.person_id, + visit_start_datetime=observation_window.start + + datetime.timedelta(hours=3), + visit_end_datetime=observation_window.end + - datetime.timedelta(hours=6.5), + visit_concept_id=concepts.INTENSIVE_CARE, + ) + vos.append(visit) + self.patient_events(db_session, visit) + + db_session.add_all(vos) + db_session.commit() + + self.insert_expression( + db_session, population, intervention, base_criterion, observation_window + ) + + df = self.fetch_interval_result( + db_session, + pi_pair_id=self.pi_pair_id, + criterion_id=None, + category=CohortCategory.POPULATION_INTERVENTION, + ) + + for person in persons: + result = df.query(f"person_id=={person.person_id}") + result_tuples = list( + result[ [ "interval_type", "interval_ratio", "interval_start", "interval_end" ] ] + .fillna("nan") + .itertuples(index=False, name=None) + ) + + for result_tuple, expected_tuple in zip( + result_tuples, expected[person.person_id] + ): + assert result_tuple == expected_tuple From 6f69a39c6300fa5b7443de53c3f84f38360c845e Mon Sep 17 00:00:00 2001 From: Jan Moringen Date: Tue, 25 Mar 2025 15:33:46 +0100 Subject: [PATCH 29/43] refactor: rewrite implementation of logic.Or in Task * Use explicit priorities for interval types --- execution_engine/task/task.py | 50 +++++++++++++++++------------------ 1 file changed, 25 insertions(+), 25 deletions(-) diff --git a/execution_engine/task/task.py b/execution_engine/task/task.py index b74bf544..52792965 100644 --- a/execution_engine/task/task.py +++ b/execution_engine/task/task.py @@ -349,32 +349,32 @@ def handle_binary_logical_operator( if isinstance(self.expr, (logic.And, logic.NonSimplifiableAnd)): result = process.intersect_intervals(data) elif isinstance(self.expr, (logic.Or, logic.NonSimplifiableOr)): - # result = process.union_intervals(data) if self.receives_only_count_inputs(): - with IntervalType.custom_union_priority_order( - IntervalType.union_priority() - ): - with IntervalType.union_order(): - - def interval_union_with_count( - start: int, end: int, intervals: List[IntervalWithCount] - ) -> IntervalWithCount: - # todo: @moringenj can we improve the "interval is not None" checks here? - interval_type = max( - interval.type - for interval in intervals - if interval is not None - ) - count = sum( - interval.count - for interval in intervals - if interval is not None - and interval.type == interval_type - and interval.count is not None - ) - return IntervalWithCount(start, end, interval_type, count) - - return process.find_rectangles(data, interval_union_with_count) + def interval_union_with_count( + start: int, end: int, intervals: List[GeneralizedInterval] + ) -> IntervalWithCount: + result_type = None + result_count = 0 + for interval in intervals: + if interval is None: + interval_type, interval_count = IntervalType.NEGATIVE, 0 + else: + interval_type, interval_count = interval.type, interval.count + if ((interval_type is IntervalType.POSITIVE + and not result_type is IntervalType.POSITIVE) + or (interval_type is IntervalType.NO_DATA + and not result_type is IntervalType.POSITIVE + and not result_type is IntervalType.NO_DATA) + or (interval_type is IntervalType.NEGATIVE + and (result_type is IntervalType.NOT_APPLICABLE + or result_type is None)) + or (interval_type is IntervalType.NOT_APPLICABLE + and result_type is None)): + result_type = interval_type + result_count = 0 + result_count += interval_count + return IntervalWithCount(start, end, result_type, result_count) + result = process.find_rectangles(data, interval_union_with_count) else: result = process.union_intervals(data) From a962449b3cc4d405874e3fd25a33ce19f013c346 Mon Sep 17 00:00:00 2001 From: Jan Moringen Date: Tue, 25 Mar 2025 17:07:06 +0100 Subject: [PATCH 30/43] fix: improve logic.Count for count_min = 0 Do not require a positive number of intervals of type POSITIVE if count_min is 0. Add a test case. --- execution_engine/task/task.py | 11 ++-- .../combination/test_temporal_combination.py | 57 +++++++++++++++++++ 2 files changed, 64 insertions(+), 4 deletions(-) diff --git a/execution_engine/task/task.py b/execution_engine/task/task.py index 52792965..e5c1ce7f 100644 --- a/execution_engine/task/task.py +++ b/execution_engine/task/task.py @@ -383,6 +383,8 @@ def interval_union_with_count( count_max = self.expr.count_max if count_min is None and count_max is None: raise ValueError("count_min and count_max cannot both be None") + if count_min is None: + count_min = 0 def interval_counts( start: int, end: int, intervals: List[GeneralizedInterval] @@ -395,11 +397,12 @@ def interval_counts( for interval in intervals ) - # The interval type with the highest "union priority" - # determines the result. + # Either the count constraints or the interval type + # with the highest "union priority" determines the + # result. positive_count = counts[IntervalType.POSITIVE] - if positive_count > 0: - if count_min is None: + if positive_count > 0 or count_min == 0: + if count_min == 0: if positive_count <= count_max: return Interval(start, end, IntervalType.POSITIVE) else: diff --git a/tests/execution_engine/omop/criterion/combination/test_temporal_combination.py b/tests/execution_engine/omop/criterion/combination/test_temporal_combination.py index 3944993d..bd248a1d 100644 --- a/tests/execution_engine/omop/criterion/combination/test_temporal_combination.py +++ b/tests/execution_engine/omop/criterion/combination/test_temporal_combination.py @@ -2387,6 +2387,63 @@ def patient_events(self, db_session, visit_occurrence): ], }, ), + ( + logic.And(c2), # population + logic.MaxCount( + bodyweight_measurement_with_forward_fill, + body_height_measurement_with_forward_fill, + tidal_volume_measurement_with_forward_fill, + threshold=2, + ), + { + 1: [ + ( + IntervalType.NOT_APPLICABLE, + 'nan', + pendulum.parse("2025-02-18 17:55:00+01:00"), + pendulum.parse("2025-02-19 07:59:59+01:00"), + ), + ( + IntervalType.POSITIVE, + 'nan', + pendulum.parse("2025-02-19 08:00:00+01:00"), + pendulum.parse("2025-02-20 17:59:59+01:00"), + ), + ( + IntervalType.NEGATIVE, + 'nan', + pendulum.parse("2025-02-20 18:00:00+01:00"), + pendulum.parse("2025-02-21 02:00:00+01:00"), + ), + ( + IntervalType.NOT_APPLICABLE, + 'nan', + pendulum.parse("2025-02-21 02:00:01+01:00"), + pendulum.parse("2025-02-22 05:30:00+01:00"), + ), + ], + 2: [ + ( + IntervalType.NOT_APPLICABLE, + 'nan', + pendulum.parse("2025-02-18 17:55:00+01:00"), + pendulum.parse("2025-02-19 07:59:59+01:00"), + ), + ( + IntervalType.POSITIVE, + 'nan', + pendulum.parse("2025-02-19 08:00:00+01:00"), + pendulum.parse("2025-02-21 02:00:00+01:00"), + ), + ( + IntervalType.NOT_APPLICABLE, + 'nan', + pendulum.parse("2025-02-21 02:00:01+01:00"), + pendulum.parse("2025-02-22 05:30:00+01:00"), + ), + ], + }, + ), ], ) def test_combination_on_database( From 06597470a6710416f9cc5150f6d83229a1213c48 Mon Sep 17 00:00:00 2001 From: Gregor Lichtner Date: Tue, 25 Mar 2025 20:09:16 +0100 Subject: [PATCH 31/43] feat: allow custom counting functions in logic.Or --- execution_engine/task/task.py | 87 ++++++++++++++++++++--------------- 1 file changed, 51 insertions(+), 36 deletions(-) diff --git a/execution_engine/task/task.py b/execution_engine/task/task.py index e5c1ce7f..3bc9fee6 100644 --- a/execution_engine/task/task.py +++ b/execution_engine/task/task.py @@ -5,7 +5,7 @@ import logging from collections import Counter from enum import Enum, auto -from typing import Callable, List, Type +from typing import Callable, List, Type, cast from sqlalchemy.exc import DBAPIError, IntegrityError, ProgrammingError, SQLAlchemyError @@ -36,6 +36,41 @@ ) +def default_interval_union_with_count( + start: int, end: int, intervals: List[GeneralizedInterval] +) -> IntervalWithCount: + """ + Default interval counting function to be used in logic.Or + """ + result_type = None + result_count = 0 + for interval in intervals: + if interval is None: + interval_type, interval_count = IntervalType.NEGATIVE, 0 + else: + interval_type, interval_count = interval.type, cast(int, interval.count) + if ( + ( + interval_type is IntervalType.POSITIVE + and result_type is not IntervalType.POSITIVE + ) + or ( + interval_type is IntervalType.NO_DATA + and result_type is not IntervalType.POSITIVE + and result_type is not IntervalType.NO_DATA + ) + or ( + interval_type is IntervalType.NEGATIVE + and (result_type is IntervalType.NOT_APPLICABLE or result_type is None) + ) + or (interval_type is IntervalType.NOT_APPLICABLE and result_type is None) + ): + result_type = interval_type + result_count = 0 + result_count += interval_count + return IntervalWithCount(start, end, result_type, result_count) + + def get_engine() -> OMOPSQLClient: """ Returns a OMOPSQLClient object. @@ -160,12 +195,8 @@ def receives_only_count_inputs(self) -> bool: # all arguments are logic.BinaryNonCommutativeOperator, and all of their "right" children are count types if all( isinstance(parent, logic.BinaryNonCommutativeOperator) - and isinstance(parent.right, logic.Expr) + and isinstance(parent.right, COUNT_TYPES) for parent in self.expr.args - ) and all( - isinstance(grandparent, COUNT_TYPES) - for parent in self.expr.args - for grandparent in parent.right.args ): return True @@ -350,31 +381,14 @@ def handle_binary_logical_operator( result = process.intersect_intervals(data) elif isinstance(self.expr, (logic.Or, logic.NonSimplifiableOr)): if self.receives_only_count_inputs(): - def interval_union_with_count( - start: int, end: int, intervals: List[GeneralizedInterval] - ) -> IntervalWithCount: - result_type = None - result_count = 0 - for interval in intervals: - if interval is None: - interval_type, interval_count = IntervalType.NEGATIVE, 0 - else: - interval_type, interval_count = interval.type, interval.count - if ((interval_type is IntervalType.POSITIVE - and not result_type is IntervalType.POSITIVE) - or (interval_type is IntervalType.NO_DATA - and not result_type is IntervalType.POSITIVE - and not result_type is IntervalType.NO_DATA) - or (interval_type is IntervalType.NEGATIVE - and (result_type is IntervalType.NOT_APPLICABLE - or result_type is None)) - or (interval_type is IntervalType.NOT_APPLICABLE - and result_type is None)): - result_type = interval_type - result_count = 0 - result_count += interval_count - return IntervalWithCount(start, end, result_type, result_count) - result = process.find_rectangles(data, interval_union_with_count) + # we check if there are custom data preparatory and interval counting functions and use these + prepare_func = getattr(self.expr, "prepare_data", None) + if prepare_func: + data = prepare_func(self, data) + func = getattr( + self.expr, "count_intervals", default_interval_union_with_count + ) + result = process.find_rectangles(data, func) else: result = process.union_intervals(data) @@ -403,10 +417,10 @@ def interval_counts( positive_count = counts[IntervalType.POSITIVE] if positive_count > 0 or count_min == 0: if count_min == 0: - if positive_count <= count_max: + if positive_count <= count_max: # type: ignore[operator] return Interval(start, end, IntervalType.POSITIVE) else: - return None # Implicit negative interval + return None # Implicit negative interval else: min_good = count_min <= positive_count max_good = (count_max is None) or (positive_count <= count_max) @@ -655,7 +669,7 @@ def get_start_end_from_interval_type( window_types: dict[int, IntervalType] = dict() def update_window_type( - window_interval: AnyInterval, data_interval: AnyInterval + window_interval: AnyInterval | None, data_interval: AnyInterval | None ) -> IntervalType: window_type = window_types.get(id(window_interval), None) @@ -680,7 +694,8 @@ def update_window_type( # result interval window types based on the data # intervals. def is_same_interval( - left_intervals: List[AnyInterval], right_intervals: List[AnyInterval] + left_intervals: List[AnyInterval | None], + right_intervals: List[AnyInterval | None], ) -> bool: left_window_interval, left_data_interval = left_intervals right_window_interval, right_data_interval = right_intervals @@ -722,7 +737,7 @@ def result_interval( # that the object identity of each interval is unique and # can be used as a dictionary key. person_indicator_windows = { - key: [ copy.copy(window) for window in indicator_windows ] + key: [copy.copy(window) for window in indicator_windows] for key in data_p.keys() } result = process.find_rectangles( From b23b1c2720da1c45446d13c006eceeb18d88df7d Mon Sep 17 00:00:00 2001 From: Jan Moringen Date: Wed, 26 Mar 2025 10:38:40 +0100 Subject: [PATCH 32/43] refactor: simplify a few type annotations --- execution_engine/task/process/rectangle.py | 2 +- execution_engine/task/task.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/execution_engine/task/process/rectangle.py b/execution_engine/task/process/rectangle.py index 09036bda..c0e16414 100644 --- a/execution_engine/task/process/rectangle.py +++ b/execution_engine/task/process/rectangle.py @@ -473,7 +473,7 @@ def mask_intervals( def intersection_interval( start: int, end: int, intervals: List[GeneralizedInterval] - ) -> GeneralizedInterval | None: + ) -> GeneralizedInterval: left_interval, right_interval = intervals diff --git a/execution_engine/task/task.py b/execution_engine/task/task.py index 3bc9fee6..4db6e97c 100644 --- a/execution_engine/task/task.py +++ b/execution_engine/task/task.py @@ -669,7 +669,7 @@ def get_start_end_from_interval_type( window_types: dict[int, IntervalType] = dict() def update_window_type( - window_interval: AnyInterval | None, data_interval: AnyInterval | None + window_interval: GeneralizedInterval, data_interval: GeneralizedInterval ) -> IntervalType: window_type = window_types.get(id(window_interval), None) @@ -694,8 +694,8 @@ def update_window_type( # result interval window types based on the data # intervals. def is_same_interval( - left_intervals: List[AnyInterval | None], - right_intervals: List[AnyInterval | None], + left_intervals: List[GeneralizedInterval], + right_intervals: List[GeneralizedInterval], ) -> bool: left_window_interval, left_data_interval = left_intervals right_window_interval, right_data_interval = right_intervals From 160a47b900af5d6cb4b20525168d304f2248aee2 Mon Sep 17 00:00:00 2001 From: Jan Moringen Date: Wed, 26 Mar 2025 10:39:07 +0100 Subject: [PATCH 33/43] fix: in find_rectangles, fix event sorting and adjacent interval merging The comparison of events with the same timestamp and the same "track" was too simplistic: for an open-and-close event pair for single (point-shaped) interval, the open event should go before the close event. however, for the close event from one interval and the open event from a different, adjacent interval, the close event should go before the open interval. With the new sorting in place, the processing of event clusters can restrict the "1 second lookahead" to open events (instead of all events). The mechanism is intended to detect adjacent intervals by possibly "pulling in" the open event of a following interval if that open event happens in the next second after the close event of the previous interval. --- execution_engine/task/process/rectangle_cython.pyx | 10 ++++++---- execution_engine/task/process/rectangle_python.py | 10 ++++++---- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/execution_engine/task/process/rectangle_cython.pyx b/execution_engine/task/process/rectangle_cython.pyx index 71f34ed1..5edba776 100644 --- a/execution_engine/task/process/rectangle_cython.pyx +++ b/execution_engine/task/process/rectangle_cython.pyx @@ -390,9 +390,11 @@ def find_rectangles(all_intervals: list[list[AnyInterval]], return -1 elif event2[0] < event1[0]: # event2 is earlier return 1 - elif (event1[3] == event2[3] # at the same time and on same track, - and event1[1] == False): # sort close events before open events - return -1 + elif event1[3] == event2[3]: # at the same time and on same track, + if event1[2] is event2[2]: # same interval + return -1 if (event1[1] is True) else 1 # sort open events before open events + else: # different intervals + return -1 if (event1[1] is False) else 1 # sort close events before open events else: # at the same time, but different tracks => any order is fine return 1 events.sort(key = cmp_to_key(compare_events)) @@ -409,7 +411,7 @@ def find_rectangles(all_intervals: list[list[AnyInterval]], # [START_TIME1, 10:59:59] [11:00:00, END_TIME2] # have no gap between them and can be considered a single # continuous interval [START_TIME1, END_TIME2]. - if (point_time == time) or (point_time == time - 1): + if (point_time == time) or (open_ and (point_time == time - 1)): if time > high_time: high_time = time any_open |= open_ diff --git a/execution_engine/task/process/rectangle_python.py b/execution_engine/task/process/rectangle_python.py index 929701ef..2917398c 100644 --- a/execution_engine/task/process/rectangle_python.py +++ b/execution_engine/task/process/rectangle_python.py @@ -300,9 +300,11 @@ def compare_events(event1, event2): return -1 elif event2[0] < event1[0]: # event2 is earlier return 1 - elif (event1[3] == event2[3] # at the same time and on same track, - and event1[1] == False): # sort close events before open events - return -1 + elif event1[3] == event2[3]: # at the same time and on same track, + if event1[2] is event2[2]: # same interval + return -1 if (event1[1] is True) else 1 # sort open events before open events + else: # different intervals + return -1 if (event1[1] is False) else 1 # sort close events before open events else: # at the same time, but different tracks => any order is fine return 1 events.sort(key = cmp_to_key(compare_events)) @@ -319,7 +321,7 @@ def process_events_for_point_in_time(index, point_time): # [START_TIME1, 10:59:59] [11:00:00, END_TIME2] # have no gap between them and can be considered a single # continuous interval [START_TIME1, END_TIME2]. - if (point_time == time) or (point_time == time - 1): + if (point_time == time) or (open_ and (point_time == time - 1)): if time > high_time: high_time = time any_open |= open_ From ddde0991582036beac16bce34cdb732181e618c0 Mon Sep 17 00:00:00 2001 From: Jan Moringen Date: Wed, 26 Mar 2025 11:46:48 +0100 Subject: [PATCH 34/43] refactor: omit empty results in rectangle.find_rectangles And handle empty inputs slightly better in rectangle_[pc]ython.find_rectangles. --- execution_engine/task/process/rectangle.py | 10 ++-- .../task/process/rectangle_cython.pyx | 55 ++++++++++--------- .../task/process/rectangle_python.py | 55 ++++++++++--------- 3 files changed, 62 insertions(+), 58 deletions(-) diff --git a/execution_engine/task/process/rectangle.py b/execution_engine/task/process/rectangle.py index c0e16414..20c27081 100644 --- a/execution_engine/task/process/rectangle.py +++ b/execution_engine/task/process/rectangle.py @@ -713,13 +713,15 @@ def find_rectangles( return {} else: keys: Set[int] = set() + result: Dict[int, List[GeneralizedInterval]] = dict() for track in data: keys |= track.keys() - return { - key: _impl.find_rectangles( + for key in keys: + key_result = _impl.find_rectangles( [intervals.get(key, []) for intervals in data], interval_constructor, is_same_result=is_same_result, ) - for key in keys - } + if len(key_result) > 0: + result[key] = key_result + return result diff --git a/execution_engine/task/process/rectangle_cython.pyx b/execution_engine/task/process/rectangle_cython.pyx index 5edba776..73c1898e 100644 --- a/execution_engine/task/process/rectangle_cython.pyx +++ b/execution_engine/task/process/rectangle_cython.pyx @@ -385,6 +385,9 @@ def find_rectangles(all_intervals: list[list[AnyInterval]], for interval in intervals # intervals_to_events(intervals, closing_offset=0) for (time,event) in [(interval.lower, True), (interval.upper, False)] ] + event_count = len(events) + if event_count == 0: + return [] def compare_events(event1, event2): if event1[0] < event2[0]: # event1 is earlier return -1 @@ -398,7 +401,6 @@ def find_rectangles(all_intervals: list[list[AnyInterval]], else: # at the same time, but different tracks => any order is fine return 1 events.sort(key = cmp_to_key(compare_events)) - event_count = len(events) active_intervals = [None] * track_count def process_events_for_point_in_time(index, point_time): high_time = point_time @@ -420,32 +422,31 @@ def find_rectangles(all_intervals: list[list[AnyInterval]], active_intervals[track] = interval if open_ else None return None, None, None, None result_intervals = [] - if not event_count == 0: - # Step through event "clusters" with a common point in time and - # emit result intervals with unchanged interval "payload". - index, time = 0, events[0][0] - interval_start_time = time - index, time, interval_start_state, high_time = process_events_for_point_in_time(index, time) - interval_start_state = interval_start_state.copy() if interval_start_state is not None else None - while index: - new_index, new_time, maybe_end_state, high_time = process_events_for_point_in_time(index, time) - # Diagram for this program point: - # |___potential_result_interval___|| | - # index new_index - # interval_start_time time new_time - # interval_start_state maybe_end_state - # high_time - if (maybe_end_state is None) or (not is_same_result(interval_start_state, maybe_end_state)): - # Add info for one result interval. - if len(result_intervals) > 0: - previous_result = result_intervals[-1] - if previous_result[1] == interval_start_time: - result_intervals[-1] = (previous_result[0], previous_result[1] - 1, previous_result[2]) - result_intervals.append((interval_start_time, time, interval_start_state)) - # Update interval start info. - interval_start_time = high_time - interval_start_state = maybe_end_state.copy() if maybe_end_state is not None else None - index, time = new_index, new_time + # Step through event "clusters" with a common point in time and + # emit result intervals with unchanged interval "payload". + index, time = 0, events[0][0] + interval_start_time = time + index, time, interval_start_state, high_time = process_events_for_point_in_time(index, time) + interval_start_state = interval_start_state.copy() if interval_start_state is not None else None + while index: + new_index, new_time, maybe_end_state, high_time = process_events_for_point_in_time(index, time) + # Diagram for this program point: + # |___potential_result_interval___|| | + # index new_index + # interval_start_time time new_time + # interval_start_state maybe_end_state + # high_time + if (maybe_end_state is None) or (not is_same_result(interval_start_state, maybe_end_state)): + # Add info for one result interval. + if len(result_intervals) > 0: + previous_result = result_intervals[-1] + if previous_result[1] == interval_start_time: + result_intervals[-1] = (previous_result[0], previous_result[1] - 1, previous_result[2]) + result_intervals.append((interval_start_time, time, interval_start_state)) + # Update interval start info. + interval_start_time = high_time + interval_start_state = maybe_end_state.copy() if maybe_end_state is not None else None + index, time = new_index, new_time result = [] for (start, end, intervals) in result_intervals: interval = interval_constructor(start, end, intervals) diff --git a/execution_engine/task/process/rectangle_python.py b/execution_engine/task/process/rectangle_python.py index 2917398c..9ab329fb 100644 --- a/execution_engine/task/process/rectangle_python.py +++ b/execution_engine/task/process/rectangle_python.py @@ -295,6 +295,9 @@ def find_rectangles(all_intervals: list[list[AnyInterval]], for interval in intervals # intervals_to_events(intervals, closing_offset=0) for (time,event) in [(interval.lower, True), (interval.upper, False)] ] + event_count = len(events) + if event_count == 0: + return [] def compare_events(event1, event2): if event1[0] < event2[0]: # event1 is earlier return -1 @@ -308,7 +311,6 @@ def compare_events(event1, event2): else: # at the same time, but different tracks => any order is fine return 1 events.sort(key = cmp_to_key(compare_events)) - event_count = len(events) active_intervals = [None] * track_count def process_events_for_point_in_time(index, point_time): high_time = point_time @@ -330,32 +332,31 @@ def process_events_for_point_in_time(index, point_time): active_intervals[track] = interval if open_ else None return None, None, None, None result_intervals = [] - if not event_count == 0: - # Step through event "clusters" with a common point in time and - # emit result intervals with unchanged interval "payload". - index, time = 0, events[0][0] - interval_start_time = time - index, time, interval_start_state, high_time = process_events_for_point_in_time(index, time) - interval_start_state = interval_start_state.copy() if interval_start_state is not None else None - while index: - new_index, new_time, maybe_end_state, high_time = process_events_for_point_in_time(index, time) - # Diagram for this program point: - # |___potential_result_interval___|| | - # index new_index - # interval_start_time time new_time - # interval_start_state maybe_end_state - # high_time - if (maybe_end_state is None) or (not is_same_result(interval_start_state, maybe_end_state)): - # Add info for one result interval. - if len(result_intervals) > 0: - previous_result = result_intervals[-1] - if previous_result[1] == interval_start_time: - result_intervals[-1] = (previous_result[0], previous_result[1] - 1, previous_result[2]) - result_intervals.append((interval_start_time, time, interval_start_state)) - # Update interval start info. - interval_start_time = high_time - interval_start_state = maybe_end_state.copy() if maybe_end_state is not None else None - index, time = new_index, new_time + # Step through event "clusters" with a common point in time and + # emit result intervals with unchanged interval "payload". + index, time = 0, events[0][0] + interval_start_time = time + index, time, interval_start_state, high_time = process_events_for_point_in_time(index, time) + interval_start_state = interval_start_state.copy() if interval_start_state is not None else None + while index: + new_index, new_time, maybe_end_state, high_time = process_events_for_point_in_time(index, time) + # Diagram for this program point: + # |___potential_result_interval___|| | + # index new_index + # interval_start_time time new_time + # interval_start_state maybe_end_state + # high_time + if (maybe_end_state is None) or (not is_same_result(interval_start_state, maybe_end_state)): + # Add info for one result interval. + if len(result_intervals) > 0: + previous_result = result_intervals[-1] + if previous_result[1] == interval_start_time: + result_intervals[-1] = (previous_result[0], previous_result[1] - 1, previous_result[2]) + result_intervals.append((interval_start_time, time, interval_start_state)) + # Update interval start info. + interval_start_time = high_time + interval_start_state = maybe_end_state.copy() if maybe_end_state is not None else None + index, time = new_index, new_time result = [] for (start, end, intervals) in result_intervals: interval = interval_constructor(start, end, intervals) From 1d30db7024cab288408d49b752736a359e3a6c42 Mon Sep 17 00:00:00 2001 From: Jan Moringen Date: Wed, 26 Mar 2025 13:55:39 +0100 Subject: [PATCH 35/43] feat: add TemporalCount.result_for_not_applicable --- execution_engine/task/task.py | 14 +++----------- execution_engine/util/logic.py | 5 +++++ .../combination/test_temporal_combination.py | 9 ++++++++- 3 files changed, 16 insertions(+), 12 deletions(-) diff --git a/execution_engine/task/task.py b/execution_engine/task/task.py index 4db6e97c..2c055f08 100644 --- a/execution_engine/task/task.py +++ b/execution_engine/task/task.py @@ -668,25 +668,17 @@ def get_start_end_from_interval_type( # interval. Maps id of window interval -> interval type window_types: dict[int, IntervalType] = dict() + result_for_not_applicable = cast(logic.TemporalCount, self.expr).result_for_not_applicable def update_window_type( window_interval: GeneralizedInterval, data_interval: GeneralizedInterval ) -> IntervalType: - window_type = window_types.get(id(window_interval), None) - + window_type = window_types.get(id(window_interval), result_for_not_applicable) if data_interval is None or data_interval.type is IntervalType.NEGATIVE: - if window_type is not IntervalType.POSITIVE: + if window_type is IntervalType.NOT_APPLICABLE: window_type = IntervalType.NEGATIVE elif data_interval.type is IntervalType.POSITIVE: window_type = IntervalType.POSITIVE - elif data_interval.type is IntervalType.NOT_APPLICABLE: - if window_type is None: - window_type = IntervalType.NOT_APPLICABLE - else: - assert data_interval.type is IntervalType.NO_DATA - if window_type is None: - window_type = IntervalType.NO_DATA window_types[id(window_interval)] = window_type - return window_type # The boundaries of the result intervals are identical to diff --git a/execution_engine/util/logic.py b/execution_engine/util/logic.py index 84c82696..5dc83378 100644 --- a/execution_engine/util/logic.py +++ b/execution_engine/util/logic.py @@ -2,6 +2,7 @@ from typing import Any, Callable, Dict, Iterator, Self, cast from execution_engine.util.enum import TimeIntervalType +from execution_engine.util.interval import IntervalType from execution_engine.util.serializable import Serializable, SerializableABC @@ -616,6 +617,7 @@ class TemporalCount(CountOperator, SerializableABC): end_time: time | None = None interval_type: TimeIntervalType | None = None interval_criterion: BaseExpr | None = None + result_for_not_applicable: IntervalType = IntervalType.NOT_APPLICABLE def __new__( cls, @@ -626,6 +628,7 @@ def __new__( end_time: time | None = None, interval_type: TimeIntervalType | None = None, interval_criterion: BaseExpr | None = None, + result_for_not_applicable: IntervalType = IntervalType.NOT_APPLICABLE, **kwargs: Any, ) -> Self: """ @@ -660,6 +663,7 @@ def __new__( ) self.interval_type = interval_type self.interval_criterion = interval_criterion + self.result_for_not_applicable = result_for_not_applicable return self @@ -763,6 +767,7 @@ def dict(self, include_id: bool = False) -> dict: if self.interval_criterion else None ), + "result_for_not_applicable": self.result_for_not_applicable, } ) return data diff --git a/tests/execution_engine/omop/criterion/combination/test_temporal_combination.py b/tests/execution_engine/omop/criterion/combination/test_temporal_combination.py index bd248a1d..d09794d1 100644 --- a/tests/execution_engine/omop/criterion/combination/test_temporal_combination.py +++ b/tests/execution_engine/omop/criterion/combination/test_temporal_combination.py @@ -83,6 +83,7 @@ def test_criterion_combination_dict(self, mock_criteria): "end_time": "16:00:00", "interval_type": None, "interval_criterion": None, + "result_for_not_applicable": IntervalType.NOT_APPLICABLE, "args": [criterion.dict() for criterion in mock_criteria], }, } @@ -103,11 +104,13 @@ def test_criterion_combination_from_dict(self, mock_criteria): expr = logic.Expr.from_dict(expr_dict) + assert isinstance(expr, logic.TemporalMinCount) assert len(expr.args) == len(mock_criteria) assert expr.start_time == datetime.time(8, 0) assert expr.end_time == datetime.time(16, 0) assert expr.interval_type is None assert expr.interval_criterion is None + assert expr.result_for_not_applicable is IntervalType.NOT_APPLICABLE for idx, criterion in enumerate(expr.args): assert str(criterion) == str(mock_criteria[idx]) @@ -126,11 +129,13 @@ def test_criterion_combination_from_dict(self, mock_criteria): expr = logic.Expr.from_dict(expr_dict) + assert isinstance(expr, logic.TemporalMinCount) assert len(expr.args) == len(mock_criteria) assert expr.start_time is None assert expr.end_time is None assert expr.interval_type == TimeIntervalType.MORNING_SHIFT assert expr.interval_criterion is None + assert expr.result_for_not_applicable is IntervalType.NOT_APPLICABLE for idx, criterion in enumerate(expr.args): assert str(criterion) == str(mock_criteria[idx]) @@ -150,6 +155,7 @@ def test_repr(self, mock_criteria): " end_time=None,\n" " interval_type=TimeIntervalType.MORNING_SHIFT,\n" " interval_criterion=None,\n" + " result_for_not_applicable=NOT_APPLICABLE,\n" " threshold=1\n" ")" ) @@ -170,6 +176,7 @@ def test_repr(self, mock_criteria): " end_time='16:00:00',\n" " interval_type=None,\n" " interval_criterion=None,\n" + " result_for_not_applicable=NOT_APPLICABLE,\n" " threshold=1\n" ")" ) @@ -2304,7 +2311,7 @@ def patient_events(self, db_session, visit_occurrence): pendulum.parse("2025-02-19 07:59:59+01:00"), ), ( - IntervalType.NO_DATA, + IntervalType.NOT_APPLICABLE, 0, pendulum.parse("2025-02-19 08:00:00+01:00"), pendulum.parse("2025-02-21 02:00:00+01:00"), From 701c2b31cc1b86d8826c69e2368593f99b27dc61 Mon Sep 17 00:00:00 2001 From: Gregor Lichtner Date: Wed, 26 Mar 2025 16:44:30 +0100 Subject: [PATCH 36/43] feat: make Presence use result_for_not_applicable=NEGATIVE --- execution_engine/util/logic.py | 6 + execution_engine/util/temporal_logic_util.py | 2 + .../combination/test_temporal_combination.py | 387 ++++++++++-------- 3 files changed, 215 insertions(+), 180 deletions(-) diff --git a/execution_engine/util/logic.py b/execution_engine/util/logic.py index 5dc83378..715dfa23 100644 --- a/execution_engine/util/logic.py +++ b/execution_engine/util/logic.py @@ -802,6 +802,7 @@ def __new__( end_time: time | None = None, interval_type: TimeIntervalType | None = None, interval_criterion: BaseExpr | None = None, + result_for_not_applicable: IntervalType = IntervalType.NOT_APPLICABLE, **kwargs: Any, ) -> "TemporalMinCount": """ @@ -818,6 +819,7 @@ def __new__( end_time=end_time, interval_type=interval_type, interval_criterion=interval_criterion, + result_for_not_applicable=result_for_not_applicable, ), ) return self @@ -844,6 +846,7 @@ def __new__( end_time: time | None = None, interval_type: TimeIntervalType | None = None, interval_criterion: BaseExpr | None = None, + result_for_not_applicable: IntervalType = IntervalType.NOT_APPLICABLE, **kwargs: Any, ) -> "TemporalMaxCount": """ @@ -860,6 +863,7 @@ def __new__( end_time=end_time, interval_type=interval_type, interval_criterion=interval_criterion, + result_for_not_applicable=result_for_not_applicable, ), ) return self @@ -886,6 +890,7 @@ def __new__( end_time: time | None = None, interval_type: TimeIntervalType | None = None, interval_criterion: BaseExpr | None = None, + result_for_not_applicable: IntervalType = IntervalType.NOT_APPLICABLE, **kwargs: Any, ) -> "TemporalExactCount": """ @@ -902,6 +907,7 @@ def __new__( end_time=end_time, interval_type=interval_type, interval_criterion=interval_criterion, + result_for_not_applicable=result_for_not_applicable, ), ) return self diff --git a/execution_engine/util/temporal_logic_util.py b/execution_engine/util/temporal_logic_util.py index 42c06e51..d583a681 100644 --- a/execution_engine/util/temporal_logic_util.py +++ b/execution_engine/util/temporal_logic_util.py @@ -2,6 +2,7 @@ from execution_engine.util import logic from execution_engine.util.enum import TimeIntervalType +from execution_engine.util.interval import IntervalType def Presence( @@ -22,6 +23,7 @@ def Presence( start_time=start_time, end_time=end_time, interval_criterion=interval_criterion, + result_for_not_applicable=IntervalType.NEGATIVE, ) diff --git a/tests/execution_engine/omop/criterion/combination/test_temporal_combination.py b/tests/execution_engine/omop/criterion/combination/test_temporal_combination.py index d09794d1..1af0cddb 100644 --- a/tests/execution_engine/omop/criterion/combination/test_temporal_combination.py +++ b/tests/execution_engine/omop/criterion/combination/test_temporal_combination.py @@ -22,16 +22,16 @@ from execution_engine.util.value import ValueNumber from tests._fixtures.concept import ( concept_artificial_respiration, + concept_body_height, concept_body_weight, concept_covid19, concept_delir_screening, concept_heparin_ingredient, concept_surgical_procedure, + concept_tidal_volume, + concept_unit_cm, concept_unit_kg, concept_unit_mg, - concept_body_height, - concept_unit_cm, - concept_tidal_volume, ) from tests._testdata import concepts from tests.execution_engine.omop.criterion.test_criterion import TestCriterion @@ -140,9 +140,6 @@ def test_criterion_combination_from_dict(self, mock_criteria): for idx, criterion in enumerate(expr.args): assert str(criterion) == str(mock_criteria[idx]) - # @pytest.mark.skip( - # reason="the repr does not return arguments in a consistent manner" - # ) def test_repr(self, mock_criteria): expr = temporal_logic_util.MorningShift(mock_criteria[0]) @@ -155,7 +152,7 @@ def test_repr(self, mock_criteria): " end_time=None,\n" " interval_type=TimeIntervalType.MORNING_SHIFT,\n" " interval_criterion=None,\n" - " result_for_not_applicable=NOT_APPLICABLE,\n" + " result_for_not_applicable=NEGATIVE,\n" " threshold=1\n" ")" ) @@ -250,14 +247,18 @@ def test_expr_contains_criteria(self, mock_criteria): tidal_volume_measurement_without_forward_fill = Measurement( concept=concept_tidal_volume, - value=ValueNumber.parse("<=110", unit=concept_unit_cm), # TODO(jmoringe): copied; does not make sense + value=ValueNumber.parse( + "<=110", unit=concept_unit_cm + ), # TODO(jmoringe): copied; does not make sense static=False, - forward_fill=False + forward_fill=False, ) tidal_volume_measurement_with_forward_fill = Measurement( concept=concept_tidal_volume, - value=ValueNumber.parse("<=110", unit=concept_unit_cm), # TODO(jmoringe): copied; does not make sense + value=ValueNumber.parse( + "<=110", unit=concept_unit_cm + ), # TODO(jmoringe): copied; does not make sense static=False, ) @@ -2080,6 +2081,7 @@ def test_interval_ratio_on_database( ): assert result_tuple == expected_tuple + class TestIndicatorWindowsMulitplePatients(TestCriterionCombinationDatabase): """ This test ensures that the data TemporalCount operator works @@ -2093,7 +2095,9 @@ class TestIndicatorWindowsMulitplePatients(TestCriterionCombinationDatabase): @pytest.fixture def observation_window(self) -> TimeRange: return TimeRange( - name="observation", start="2025-02-18 14:55:00+01:00", end="2025-02-22 12:00:00+01:00" + name="observation", + start="2025-02-18 14:55:00+01:00", + end="2025-02-22 12:00:00+01:00", ) def patient_events(self, db_session, visit_occurrence): @@ -2213,7 +2217,7 @@ def test_multiple_patients_on_database( for person in persons: result = df.query(f"person_id=={person.person_id}") result_tuples = list( - result[ [ "interval_type", "interval_start", "interval_end" ] ] + result[["interval_type", "interval_start", "interval_end"]] .fillna("nan") .itertuples(index=False, name=None) ) @@ -2235,39 +2239,49 @@ class TestCountOnIndicatorWindows(TestCriterionCombinationDatabase): @pytest.fixture def observation_window(self) -> TimeRange: return TimeRange( - name="observation", start="2025-02-18 14:55:00+01:00", end="2025-02-22 12:00:00+01:00" + name="observation", + start="2025-02-18 14:55:00+01:00", + end="2025-02-22 12:00:00+01:00", ) def patient_events(self, db_session, visit_occurrence): person_id = visit_occurrence.person_id - events = [create_condition( - vo=visit_occurrence, - condition_concept_id=concept_covid19.concept_id, - condition_start_datetime=pendulum.parse("2025-02-19 08:00:00+01:00"), - condition_end_datetime=pendulum.parse("2025-02-21 02:00:00+01:00"), - )] - if person_id == 1: - events.append(create_measurement( - vo=visit_occurrence, - measurement_concept_id=concept_body_weight.concept_id, - measurement_datetime=pendulum.parse("2025-02-19 18:00:00+01:00"), - value_as_number=90, - unit_concept_id=concept_unit_kg.concept_id, - )) - events.append(create_measurement( + events = [ + create_condition( vo=visit_occurrence, - measurement_concept_id=concept_body_height.concept_id, - measurement_datetime=pendulum.parse("2025-02-20 07:00:00+01:00"), - value_as_number=90, - unit_concept_id=concept_unit_cm.concept_id, - )) - events.append(create_measurement( - vo=visit_occurrence, - measurement_concept_id=concept_tidal_volume.concept_id, - measurement_datetime=pendulum.parse("2025-02-20 18:00:00+01:00"), - value_as_number=90, - unit_concept_id=concept_unit_cm.concept_id, - )) + condition_concept_id=concept_covid19.concept_id, + condition_start_datetime=pendulum.parse("2025-02-19 08:00:00+01:00"), + condition_end_datetime=pendulum.parse("2025-02-21 02:00:00+01:00"), + ) + ] + if person_id == 1: + events.append( + create_measurement( + vo=visit_occurrence, + measurement_concept_id=concept_body_weight.concept_id, + measurement_datetime=pendulum.parse("2025-02-19 18:00:00+01:00"), + value_as_number=90, + unit_concept_id=concept_unit_kg.concept_id, + ) + ) + events.append( + create_measurement( + vo=visit_occurrence, + measurement_concept_id=concept_body_height.concept_id, + measurement_datetime=pendulum.parse("2025-02-20 07:00:00+01:00"), + value_as_number=90, + unit_concept_id=concept_unit_cm.concept_id, + ) + ) + events.append( + create_measurement( + vo=visit_occurrence, + measurement_concept_id=concept_tidal_volume.concept_id, + measurement_datetime=pendulum.parse("2025-02-20 18:00:00+01:00"), + value_as_number=90, + unit_concept_id=concept_unit_cm.concept_id, + ) + ) db_session.add_all(events) db_session.commit() @@ -2277,16 +2291,22 @@ def patient_events(self, db_session, visit_occurrence): ( logic.And(c2), # population logic.MinCount( - temporal_logic_util.AnyTime(bodyweight_measurement_without_forward_fill), - temporal_logic_util.AnyTime(body_height_measurement_without_forward_fill), - temporal_logic_util.AnyTime(tidal_volume_measurement_without_forward_fill), + temporal_logic_util.AnyTime( + bodyweight_measurement_without_forward_fill + ), + temporal_logic_util.AnyTime( + body_height_measurement_without_forward_fill + ), + temporal_logic_util.AnyTime( + tidal_volume_measurement_without_forward_fill + ), threshold=1, ), { 1: [ ( IntervalType.NOT_APPLICABLE, - 'nan', + "nan", pendulum.parse("2025-02-18 17:55:00+01:00"), pendulum.parse("2025-02-19 07:59:59+01:00"), ), @@ -2298,158 +2318,158 @@ def patient_events(self, db_session, visit_occurrence): ), ( IntervalType.NOT_APPLICABLE, - 'nan', + "nan", pendulum.parse("2025-02-21 02:00:01+01:00"), pendulum.parse("2025-02-22 05:30:00+01:00"), ), ], 2: [ ( - IntervalType.NOT_APPLICABLE, - 'nan', - pendulum.parse("2025-02-18 17:55:00+01:00"), - pendulum.parse("2025-02-19 07:59:59+01:00"), + IntervalType.NOT_APPLICABLE, + "nan", + pendulum.parse("2025-02-18 17:55:00+01:00"), + pendulum.parse("2025-02-19 07:59:59+01:00"), ), ( - IntervalType.NOT_APPLICABLE, - 0, - pendulum.parse("2025-02-19 08:00:00+01:00"), - pendulum.parse("2025-02-21 02:00:00+01:00"), + IntervalType.NEGATIVE, + 0, + pendulum.parse("2025-02-19 08:00:00+01:00"), + pendulum.parse("2025-02-21 02:00:00+01:00"), ), ( - IntervalType.NOT_APPLICABLE, - 'nan', - pendulum.parse("2025-02-21 02:00:01+01:00"), - pendulum.parse("2025-02-22 05:30:00+01:00"), + IntervalType.NOT_APPLICABLE, + "nan", + pendulum.parse("2025-02-21 02:00:01+01:00"), + pendulum.parse("2025-02-22 05:30:00+01:00"), ), ], }, ), ( - logic.And(c2), # population - logic.MinCount( - bodyweight_measurement_with_forward_fill, - body_height_measurement_with_forward_fill, - tidal_volume_measurement_with_forward_fill, - threshold=2, - ), - { - 1: [ - ( - IntervalType.NOT_APPLICABLE, - 'nan', - pendulum.parse("2025-02-18 17:55:00+01:00"), - pendulum.parse("2025-02-19 07:59:59+01:00"), - ), - ( - IntervalType.NO_DATA, - 0.0, - pendulum.parse("2025-02-19 08:00:00+01:00"), - pendulum.parse("2025-02-19 17:59:59+01:00"), - ), - ( - IntervalType.NEGATIVE, - 0.5, - pendulum.parse("2025-02-19 18:00:00+01:00"), - pendulum.parse("2025-02-20 06:59:59+01:00"), - ), - ( - IntervalType.POSITIVE, - 1, - pendulum.parse("2025-02-20 07:00:00+01:00"), - pendulum.parse("2025-02-20 17:59:59+01:00"), - ), - ( - IntervalType.POSITIVE, - 1.5, - pendulum.parse("2025-02-20 18:00:00+01:00"), - pendulum.parse("2025-02-21 02:00:00+01:00"), - ), - ( - IntervalType.NOT_APPLICABLE, - 'nan', - pendulum.parse("2025-02-21 02:00:01+01:00"), - pendulum.parse("2025-02-22 05:30:00+01:00"), - ), - ], - 2: [ - ( - IntervalType.NOT_APPLICABLE, - 'nan', - pendulum.parse("2025-02-18 17:55:00+01:00"), - pendulum.parse("2025-02-19 07:59:59+01:00"), - ), - ( - IntervalType.NO_DATA, - 0, - pendulum.parse("2025-02-19 08:00:00+01:00"), - pendulum.parse("2025-02-21 02:00:00+01:00"), - ), - ( - IntervalType.NOT_APPLICABLE, - 'nan', - pendulum.parse("2025-02-21 02:00:01+01:00"), - pendulum.parse("2025-02-22 05:30:00+01:00"), - ), - ], - }, + logic.And(c2), # population + logic.MinCount( + bodyweight_measurement_with_forward_fill, + body_height_measurement_with_forward_fill, + tidal_volume_measurement_with_forward_fill, + threshold=2, + ), + { + 1: [ + ( + IntervalType.NOT_APPLICABLE, + "nan", + pendulum.parse("2025-02-18 17:55:00+01:00"), + pendulum.parse("2025-02-19 07:59:59+01:00"), + ), + ( + IntervalType.NO_DATA, + 0.0, + pendulum.parse("2025-02-19 08:00:00+01:00"), + pendulum.parse("2025-02-19 17:59:59+01:00"), + ), + ( + IntervalType.NEGATIVE, + 0.5, + pendulum.parse("2025-02-19 18:00:00+01:00"), + pendulum.parse("2025-02-20 06:59:59+01:00"), + ), + ( + IntervalType.POSITIVE, + 1, + pendulum.parse("2025-02-20 07:00:00+01:00"), + pendulum.parse("2025-02-20 17:59:59+01:00"), + ), + ( + IntervalType.POSITIVE, + 1.5, + pendulum.parse("2025-02-20 18:00:00+01:00"), + pendulum.parse("2025-02-21 02:00:00+01:00"), + ), + ( + IntervalType.NOT_APPLICABLE, + "nan", + pendulum.parse("2025-02-21 02:00:01+01:00"), + pendulum.parse("2025-02-22 05:30:00+01:00"), + ), + ], + 2: [ + ( + IntervalType.NOT_APPLICABLE, + "nan", + pendulum.parse("2025-02-18 17:55:00+01:00"), + pendulum.parse("2025-02-19 07:59:59+01:00"), + ), + ( + IntervalType.NO_DATA, + 0, + pendulum.parse("2025-02-19 08:00:00+01:00"), + pendulum.parse("2025-02-21 02:00:00+01:00"), + ), + ( + IntervalType.NOT_APPLICABLE, + "nan", + pendulum.parse("2025-02-21 02:00:01+01:00"), + pendulum.parse("2025-02-22 05:30:00+01:00"), + ), + ], + }, ), ( - logic.And(c2), # population - logic.MaxCount( - bodyweight_measurement_with_forward_fill, - body_height_measurement_with_forward_fill, - tidal_volume_measurement_with_forward_fill, - threshold=2, - ), - { - 1: [ - ( - IntervalType.NOT_APPLICABLE, - 'nan', - pendulum.parse("2025-02-18 17:55:00+01:00"), - pendulum.parse("2025-02-19 07:59:59+01:00"), - ), - ( - IntervalType.POSITIVE, - 'nan', - pendulum.parse("2025-02-19 08:00:00+01:00"), - pendulum.parse("2025-02-20 17:59:59+01:00"), - ), - ( - IntervalType.NEGATIVE, - 'nan', - pendulum.parse("2025-02-20 18:00:00+01:00"), - pendulum.parse("2025-02-21 02:00:00+01:00"), - ), - ( - IntervalType.NOT_APPLICABLE, - 'nan', - pendulum.parse("2025-02-21 02:00:01+01:00"), - pendulum.parse("2025-02-22 05:30:00+01:00"), - ), - ], - 2: [ - ( - IntervalType.NOT_APPLICABLE, - 'nan', - pendulum.parse("2025-02-18 17:55:00+01:00"), - pendulum.parse("2025-02-19 07:59:59+01:00"), - ), - ( - IntervalType.POSITIVE, - 'nan', - pendulum.parse("2025-02-19 08:00:00+01:00"), - pendulum.parse("2025-02-21 02:00:00+01:00"), - ), - ( - IntervalType.NOT_APPLICABLE, - 'nan', - pendulum.parse("2025-02-21 02:00:01+01:00"), - pendulum.parse("2025-02-22 05:30:00+01:00"), - ), - ], - }, + logic.And(c2), # population + logic.MaxCount( + bodyweight_measurement_with_forward_fill, + body_height_measurement_with_forward_fill, + tidal_volume_measurement_with_forward_fill, + threshold=2, + ), + { + 1: [ + ( + IntervalType.NOT_APPLICABLE, + "nan", + pendulum.parse("2025-02-18 17:55:00+01:00"), + pendulum.parse("2025-02-19 07:59:59+01:00"), + ), + ( + IntervalType.POSITIVE, + "nan", + pendulum.parse("2025-02-19 08:00:00+01:00"), + pendulum.parse("2025-02-20 17:59:59+01:00"), + ), + ( + IntervalType.NEGATIVE, + "nan", + pendulum.parse("2025-02-20 18:00:00+01:00"), + pendulum.parse("2025-02-21 02:00:00+01:00"), + ), + ( + IntervalType.NOT_APPLICABLE, + "nan", + pendulum.parse("2025-02-21 02:00:01+01:00"), + pendulum.parse("2025-02-22 05:30:00+01:00"), + ), + ], + 2: [ + ( + IntervalType.NOT_APPLICABLE, + "nan", + pendulum.parse("2025-02-18 17:55:00+01:00"), + pendulum.parse("2025-02-19 07:59:59+01:00"), + ), + ( + IntervalType.POSITIVE, + "nan", + pendulum.parse("2025-02-19 08:00:00+01:00"), + pendulum.parse("2025-02-21 02:00:00+01:00"), + ), + ( + IntervalType.NOT_APPLICABLE, + "nan", + pendulum.parse("2025-02-21 02:00:01+01:00"), + pendulum.parse("2025-02-22 05:30:00+01:00"), + ), + ], + }, ), ], ) @@ -2495,7 +2515,14 @@ def test_combination_on_database( for person in persons: result = df.query(f"person_id=={person.person_id}") result_tuples = list( - result[ [ "interval_type", "interval_ratio", "interval_start", "interval_end" ] ] + result[ + [ + "interval_type", + "interval_ratio", + "interval_start", + "interval_end", + ] + ] .fillna("nan") .itertuples(index=False, name=None) ) From 29da72cafcfc2ac78dbac5185d7732906a458725 Mon Sep 17 00:00:00 2001 From: Gregor Lichtner Date: Thu, 27 Mar 2025 17:00:09 +0100 Subject: [PATCH 37/43] fix: re-set base node category --- execution_engine/execution_graph/graph.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/execution_engine/execution_graph/graph.py b/execution_engine/execution_graph/graph.py index 792a034c..b4acdc85 100644 --- a/execution_engine/execution_graph/graph.py +++ b/execution_engine/execution_graph/graph.py @@ -104,6 +104,10 @@ def traverse( traverse(expr, category=category) + # make sure base node has still the correct category - it might have changed duing traversal if the base node + # is used in an interval_criterion of a TemporalCount operator + graph.nodes[base_node].update({"category": CohortCategory.BASE}) + if hash(expr) != expr_hash: raise ValueError("Expression has been modified during traversal") From b93e49ee3de8586e8fc24d4d78129a448851db7f Mon Sep 17 00:00:00 2001 From: Gregor Lichtner Date: Thu, 27 Mar 2025 17:39:38 +0100 Subject: [PATCH 38/43] fix: temporal count with interval criterion --- execution_engine/task/task.py | 42 ++++++++++++++--------- execution_engine/util/value/value.py | 2 +- tests/execution_engine/util/test_value.py | 5 +-- 3 files changed, 27 insertions(+), 22 deletions(-) diff --git a/execution_engine/task/task.py b/execution_engine/task/task.py index 2c055f08..7d70f1d6 100644 --- a/execution_engine/task/task.py +++ b/execution_engine/task/task.py @@ -598,10 +598,21 @@ def handle_temporal_operator( :param observation_window: The observation window. :return: A DataFrame with the merged intervals. """ + assert isinstance(self.expr, logic.TemporalCount) - data_p = self.select_predecessor_result(self.expr.args[0], data) - # data_p = process.select_type(data[0], IntervalType.POSITIVE) - # data_p = {key: val for key, val in data_p.items() if val} + if self.expr.interval_criterion is not None: + assert ( + len(data) == 2 + ), f"TemporalCount with indicator criterion requires exactly two input streams, got {len(data)}" + indicator_personal_windows = data.pop( + self.get_predecessor_data_index(self.expr.interval_criterion) + ) + + assert ( + len(data) == 1 + ), f"TemporalCount requires exactly one input streams, got {len(data)}" + + data_arg = data[0] def get_start_end_from_interval_type( type_: TimeIntervalType, @@ -620,18 +631,10 @@ def get_start_end_from_interval_type( assert self.expr.count_max is None if self.expr.interval_criterion is not None: - - # last element is the indicator windows - assert ( - len(data) >= 2 - ), "TemporalCount with indicator criterion requires at least two inputs" - - indicator_personal_windows = data.pop( - self.get_predecessor_data_index(self.expr.interval_criterion) - ) + data_positive = process.select_type(data_arg, IntervalType.POSITIVE) result = process.find_overlapping_personal_windows( - indicator_personal_windows, data_p + indicator_personal_windows, data_positive ) else: @@ -668,11 +671,16 @@ def get_start_end_from_interval_type( # interval. Maps id of window interval -> interval type window_types: dict[int, IntervalType] = dict() - result_for_not_applicable = cast(logic.TemporalCount, self.expr).result_for_not_applicable + result_for_not_applicable = cast( + logic.TemporalCount, self.expr + ).result_for_not_applicable + def update_window_type( window_interval: GeneralizedInterval, data_interval: GeneralizedInterval ) -> IntervalType: - window_type = window_types.get(id(window_interval), result_for_not_applicable) + window_type = window_types.get( + id(window_interval), result_for_not_applicable + ) if data_interval is None or data_interval.type is IntervalType.NEGATIVE: if window_type is IntervalType.NOT_APPLICABLE: window_type = IntervalType.NEGATIVE @@ -730,10 +738,10 @@ def result_interval( # can be used as a dictionary key. person_indicator_windows = { key: [copy.copy(window) for window in indicator_windows] - for key in data_p.keys() + for key in data_arg.keys() } result = process.find_rectangles( - [person_indicator_windows, data_p], + [person_indicator_windows, data_arg], result_interval, is_same_result=is_same_interval, ) diff --git a/execution_engine/util/value/value.py b/execution_engine/util/value/value.py index fcda3dc6..5b6c8fb6 100644 --- a/execution_engine/util/value/value.py +++ b/execution_engine/util/value/value.py @@ -362,4 +362,4 @@ def __str__(self) -> str: """ Get the string representation of the value. """ - return f"Value == {repr(self.value)}" + return f"Value == {str(self.value)}" diff --git a/tests/execution_engine/util/test_value.py b/tests/execution_engine/util/test_value.py index e339be60..5a101f10 100644 --- a/tests/execution_engine/util/test_value.py +++ b/tests/execution_engine/util/test_value.py @@ -307,10 +307,7 @@ def test_to_sql(self, test_concept, test_table): def test_str(self, test_concept): value_concept = ValueConcept(value=test_concept) - assert ( - str(value_concept) - == "Value == Concept(concept_id=1, concept_name='Test Concept', concept_code='unit', domain_id='units', vocabulary_id='test', concept_class_id='test', standard_concept=None, invalid_reason=None)" - ) + assert str(value_concept) == "Value == Test Concept" def test_repr(self, test_concept): value_concept = ValueConcept(value=test_concept) From 63f8d5f1df074f2924f23679e8735db2f8e76fdf Mon Sep 17 00:00:00 2001 From: Gregor Lichtner Date: Fri, 28 Mar 2025 08:59:03 +0100 Subject: [PATCH 39/43] revert: remove result_for_not_applicable flag the intended behavior - which is :TemporalMinCount should return NEGATIVE (instead of NOT_APPLICABLE) if all inputs are NO_DATA) - can be achieved much simpler by treating NO_DATA in the data_intervals (in update_window_type) the same as NEGATIVE intervals. --- execution_engine/task/process/rectangle.py | 10 +++++++--- execution_engine/task/task.py | 16 ++++++++++------ execution_engine/util/logic.py | 11 ----------- execution_engine/util/temporal_logic_util.py | 2 -- .../combination/test_temporal_combination.py | 5 ----- 5 files changed, 17 insertions(+), 27 deletions(-) diff --git a/execution_engine/task/process/rectangle.py b/execution_engine/task/process/rectangle.py index 20c27081..58999a44 100644 --- a/execution_engine/task/process/rectangle.py +++ b/execution_engine/task/process/rectangle.py @@ -2,7 +2,7 @@ import importlib import logging import os -from typing import Callable, List, Set, cast +from typing import Callable, Dict, List, Set, cast import numpy as np import pendulum @@ -446,7 +446,7 @@ def intersect_intervals(data: list[PersonIntervals]) -> PersonIntervals: def mask_intervals( data: PersonIntervals, mask: PersonIntervals, -) -> PersonIntervals: +) -> Dict[int, List[GeneralizedInterval]]: """ Masks the intervals in the dict per key. @@ -483,6 +483,7 @@ def intersection_interval( return None result = find_rectangles([person_mask, data], intersection_interval) + return result @@ -696,7 +697,7 @@ def find_rectangles( data: list[PersonIntervals], interval_constructor: Callable, is_same_result: Callable | None = None, -) -> PersonIntervals: +) -> Dict[int, List[GeneralizedInterval]]: """ Iterates over intervals for each person across all items in `data` and constructs new intervals ("rectangles") by applying `interval_constructor` to the overlapping intervals in each time range. @@ -714,8 +715,10 @@ def find_rectangles( else: keys: Set[int] = set() result: Dict[int, List[GeneralizedInterval]] = dict() + for track in data: keys |= track.keys() + for key in keys: key_result = _impl.find_rectangles( [intervals.get(key, []) for intervals in data], @@ -724,4 +727,5 @@ def find_rectangles( ) if len(key_result) > 0: result[key] = key_result + return result diff --git a/execution_engine/task/task.py b/execution_engine/task/task.py index 7d70f1d6..6459c03f 100644 --- a/execution_engine/task/task.py +++ b/execution_engine/task/task.py @@ -671,22 +671,26 @@ def get_start_end_from_interval_type( # interval. Maps id of window interval -> interval type window_types: dict[int, IntervalType] = dict() - result_for_not_applicable = cast( - logic.TemporalCount, self.expr - ).result_for_not_applicable - def update_window_type( window_interval: GeneralizedInterval, data_interval: GeneralizedInterval ) -> IntervalType: + window_type = window_types.get( - id(window_interval), result_for_not_applicable + id(window_interval), IntervalType.NOT_APPLICABLE ) - if data_interval is None or data_interval.type is IntervalType.NEGATIVE: + + if ( + data_interval is None + or data_interval.type is IntervalType.NO_DATA + or data_interval.type is IntervalType.NEGATIVE + ): if window_type is IntervalType.NOT_APPLICABLE: window_type = IntervalType.NEGATIVE elif data_interval.type is IntervalType.POSITIVE: window_type = IntervalType.POSITIVE + window_types[id(window_interval)] = window_type + return window_type # The boundaries of the result intervals are identical to diff --git a/execution_engine/util/logic.py b/execution_engine/util/logic.py index 715dfa23..84c82696 100644 --- a/execution_engine/util/logic.py +++ b/execution_engine/util/logic.py @@ -2,7 +2,6 @@ from typing import Any, Callable, Dict, Iterator, Self, cast from execution_engine.util.enum import TimeIntervalType -from execution_engine.util.interval import IntervalType from execution_engine.util.serializable import Serializable, SerializableABC @@ -617,7 +616,6 @@ class TemporalCount(CountOperator, SerializableABC): end_time: time | None = None interval_type: TimeIntervalType | None = None interval_criterion: BaseExpr | None = None - result_for_not_applicable: IntervalType = IntervalType.NOT_APPLICABLE def __new__( cls, @@ -628,7 +626,6 @@ def __new__( end_time: time | None = None, interval_type: TimeIntervalType | None = None, interval_criterion: BaseExpr | None = None, - result_for_not_applicable: IntervalType = IntervalType.NOT_APPLICABLE, **kwargs: Any, ) -> Self: """ @@ -663,7 +660,6 @@ def __new__( ) self.interval_type = interval_type self.interval_criterion = interval_criterion - self.result_for_not_applicable = result_for_not_applicable return self @@ -767,7 +763,6 @@ def dict(self, include_id: bool = False) -> dict: if self.interval_criterion else None ), - "result_for_not_applicable": self.result_for_not_applicable, } ) return data @@ -802,7 +797,6 @@ def __new__( end_time: time | None = None, interval_type: TimeIntervalType | None = None, interval_criterion: BaseExpr | None = None, - result_for_not_applicable: IntervalType = IntervalType.NOT_APPLICABLE, **kwargs: Any, ) -> "TemporalMinCount": """ @@ -819,7 +813,6 @@ def __new__( end_time=end_time, interval_type=interval_type, interval_criterion=interval_criterion, - result_for_not_applicable=result_for_not_applicable, ), ) return self @@ -846,7 +839,6 @@ def __new__( end_time: time | None = None, interval_type: TimeIntervalType | None = None, interval_criterion: BaseExpr | None = None, - result_for_not_applicable: IntervalType = IntervalType.NOT_APPLICABLE, **kwargs: Any, ) -> "TemporalMaxCount": """ @@ -863,7 +855,6 @@ def __new__( end_time=end_time, interval_type=interval_type, interval_criterion=interval_criterion, - result_for_not_applicable=result_for_not_applicable, ), ) return self @@ -890,7 +881,6 @@ def __new__( end_time: time | None = None, interval_type: TimeIntervalType | None = None, interval_criterion: BaseExpr | None = None, - result_for_not_applicable: IntervalType = IntervalType.NOT_APPLICABLE, **kwargs: Any, ) -> "TemporalExactCount": """ @@ -907,7 +897,6 @@ def __new__( end_time=end_time, interval_type=interval_type, interval_criterion=interval_criterion, - result_for_not_applicable=result_for_not_applicable, ), ) return self diff --git a/execution_engine/util/temporal_logic_util.py b/execution_engine/util/temporal_logic_util.py index d583a681..42c06e51 100644 --- a/execution_engine/util/temporal_logic_util.py +++ b/execution_engine/util/temporal_logic_util.py @@ -2,7 +2,6 @@ from execution_engine.util import logic from execution_engine.util.enum import TimeIntervalType -from execution_engine.util.interval import IntervalType def Presence( @@ -23,7 +22,6 @@ def Presence( start_time=start_time, end_time=end_time, interval_criterion=interval_criterion, - result_for_not_applicable=IntervalType.NEGATIVE, ) diff --git a/tests/execution_engine/omop/criterion/combination/test_temporal_combination.py b/tests/execution_engine/omop/criterion/combination/test_temporal_combination.py index 1af0cddb..07a39bc1 100644 --- a/tests/execution_engine/omop/criterion/combination/test_temporal_combination.py +++ b/tests/execution_engine/omop/criterion/combination/test_temporal_combination.py @@ -83,7 +83,6 @@ def test_criterion_combination_dict(self, mock_criteria): "end_time": "16:00:00", "interval_type": None, "interval_criterion": None, - "result_for_not_applicable": IntervalType.NOT_APPLICABLE, "args": [criterion.dict() for criterion in mock_criteria], }, } @@ -110,7 +109,6 @@ def test_criterion_combination_from_dict(self, mock_criteria): assert expr.end_time == datetime.time(16, 0) assert expr.interval_type is None assert expr.interval_criterion is None - assert expr.result_for_not_applicable is IntervalType.NOT_APPLICABLE for idx, criterion in enumerate(expr.args): assert str(criterion) == str(mock_criteria[idx]) @@ -135,7 +133,6 @@ def test_criterion_combination_from_dict(self, mock_criteria): assert expr.end_time is None assert expr.interval_type == TimeIntervalType.MORNING_SHIFT assert expr.interval_criterion is None - assert expr.result_for_not_applicable is IntervalType.NOT_APPLICABLE for idx, criterion in enumerate(expr.args): assert str(criterion) == str(mock_criteria[idx]) @@ -152,7 +149,6 @@ def test_repr(self, mock_criteria): " end_time=None,\n" " interval_type=TimeIntervalType.MORNING_SHIFT,\n" " interval_criterion=None,\n" - " result_for_not_applicable=NEGATIVE,\n" " threshold=1\n" ")" ) @@ -173,7 +169,6 @@ def test_repr(self, mock_criteria): " end_time='16:00:00',\n" " interval_type=None,\n" " interval_criterion=None,\n" - " result_for_not_applicable=NOT_APPLICABLE,\n" " threshold=1\n" ")" ) From aaa89c377b4da9dc7bd54cc0acdc8577b7254de1 Mon Sep 17 00:00:00 2001 From: Gregor Lichtner Date: Fri, 28 Mar 2025 10:15:59 +0100 Subject: [PATCH 40/43] refactor: using int instead of float time in intervals docs: add code comments --- execution_engine/omop/criterion/abstract.py | 3 +- .../omop/criterion/point_in_time.py | 3 +- execution_engine/omop/sqlclient.py | 29 ++ execution_engine/task/process/__init__.py | 14 + execution_engine/task/process/rectangle.py | 89 +++--- .../task/process/rectangle_cython.pyx | 298 ++++++++++-------- .../task/process/rectangle_python.py | 264 ++++++++++++---- execution_engine/task/task.py | 148 ++++++--- .../util/{types.py => types/__init__.py} | 83 ----- execution_engine/util/types/timerange.py | 88 ++++++ setup.py | 2 +- tests/_fixtures/omop_fixture.py | 2 +- .../combination/test_logical_combination.py | 18 +- .../combination/test_temporal_combination.py | 24 +- .../omop/criterion/test_criterion.py | 4 +- .../criterion/test_occurrence_criterion.py | 92 +++--- .../criterion/test_procedure_occurrence.py | 3 +- .../omop/db/celida/test_triggers.py | 2 +- .../task/process/test_rectangle.py | 6 +- tests/execution_engine/util/test_types.py | 3 +- .../test_recommendation_base.py | 2 +- .../test_recommendation_base_v2.py | 2 +- 22 files changed, 756 insertions(+), 423 deletions(-) rename execution_engine/util/{types.py => types/__init__.py} (68%) create mode 100644 execution_engine/util/types/timerange.py diff --git a/execution_engine/omop/criterion/abstract.py b/execution_engine/omop/criterion/abstract.py index 8ceaece8..bc7e5397 100644 --- a/execution_engine/omop/criterion/abstract.py +++ b/execution_engine/omop/criterion/abstract.py @@ -27,7 +27,8 @@ from execution_engine.util.interval import IntervalType from execution_engine.util.serializable import SerializableDataClassABC from execution_engine.util.sql import SelectInto, select_into -from execution_engine.util.types import PersonIntervals, TimeRange +from execution_engine.util.types import PersonIntervals +from execution_engine.util.types.timerange import TimeRange __all__ = [ "Criterion", diff --git a/execution_engine/omop/criterion/point_in_time.py b/execution_engine/omop/criterion/point_in_time.py index d79342a0..c493a567 100644 --- a/execution_engine/omop/criterion/point_in_time.py +++ b/execution_engine/omop/criterion/point_in_time.py @@ -10,7 +10,8 @@ from execution_engine.omop.criterion.concept import ConceptCriterion from execution_engine.task.process import get_processing_module from execution_engine.util.interval import IntervalType -from execution_engine.util.types import PersonIntervals, TimeRange, Timing +from execution_engine.util.types import PersonIntervals, Timing +from execution_engine.util.types.timerange import TimeRange from execution_engine.util.value import Value process = get_processing_module() diff --git a/execution_engine/omop/sqlclient.py b/execution_engine/omop/sqlclient.py index b5488029..ee1e8e81 100644 --- a/execution_engine/omop/sqlclient.py +++ b/execution_engine/omop/sqlclient.py @@ -58,6 +58,35 @@ def _enable_database_triggers( cursor.close() +def datetime_cols_to_epoch(stmt: sqlalchemy.Select) -> sqlalchemy.Select: + """ + Given a SQLAlchemy 2.0 Select that has columns labeled 'interval_start' + or 'interval_end', replace those column expressions with + EXTRACT(EPOCH FROM )::BIGINT so they become integer timestamps. + + Returns a new Select object with the replaced columns. + """ + new_columns = [] + + for col in stmt.selected_columns: + label = getattr(col, "name") + + if label in ("interval_start", "interval_end"): + # We'll wrap col in EXTRACT(EPOCH FROM col)::BIGINT, + new_col = ( + sqlalchemy.func.extract("epoch", col) + .cast(sqlalchemy.BigInteger) + .label(label) + ) + new_columns.append(new_col) + else: + new_columns.append(col) + + new_stmt = stmt.with_only_columns(*new_columns, maintain_column_froms=True) + + return new_stmt + + class OMOPSQLClient: """A client for the OMOP SQL database. diff --git a/execution_engine/task/process/__init__.py b/execution_engine/task/process/__init__.py index f0407782..eed88a7a 100644 --- a/execution_engine/task/process/__init__.py +++ b/execution_engine/task/process/__init__.py @@ -5,6 +5,9 @@ from collections import namedtuple from typing import TypeVar +from execution_engine.util.interval import IntervalType +from execution_engine.util.types.timerange import TimeRange + def get_processing_module( name: str = "rectangle", version: str = "auto" @@ -61,3 +64,14 @@ def interval_like(interval: TInterval, start: int, end: int) -> TInterval: """ return interval._replace(lower=start, upper=end) # type: ignore[return-value] + + +def timerange_to_interval(tr: TimeRange, type_: IntervalType) -> Interval: + """ + Converts a timerange to an interval with the supplied type. + """ + return Interval( + lower=int(tr.start.timestamp()), + upper=int(tr.end.timestamp()), + type=type_, + ) diff --git a/execution_engine/task/process/rectangle.py b/execution_engine/task/process/rectangle.py index 58999a44..53236ff1 100644 --- a/execution_engine/task/process/rectangle.py +++ b/execution_engine/task/process/rectangle.py @@ -2,6 +2,7 @@ import importlib import logging import os +from collections import defaultdict from typing import Callable, Dict, List, Set, cast import numpy as np @@ -10,15 +11,21 @@ from sqlalchemy import CursorResult from execution_engine.util.interval import IntervalType, interval_datetime -from execution_engine.util.types import TimeRange +from ...util.types.timerange import TimeRange from . import ( GeneralizedInterval, Interval, IntervalWithCount, interval_like, + timerange_to_interval, ) +IntervalConstructor = Callable[ + [int, int, List[GeneralizedInterval]], GeneralizedInterval +] +SameResult = Callable[[List[GeneralizedInterval], List[GeneralizedInterval]], bool] + PROCESS_RECTANGLE_VERSION = os.getenv("PROCESS_RECTANGLE_VERSION", "auto") @@ -69,7 +76,7 @@ def result_to_intervals(result: CursorResult) -> PersonIntervals: """ Converts the result of the interval operations to a list of intervals. """ - person_interval = {} + person_interval = defaultdict(list) for row in result: if row.interval_end < row.interval_start: @@ -81,15 +88,12 @@ def result_to_intervals(result: CursorResult) -> PersonIntervals: raise ValueError("Interval end is None") interval = Interval( - row.interval_start.timestamp(), - row.interval_end.timestamp(), + row.interval_start, + row.interval_end, row.interval_type, ) - if row.person_id not in person_interval: - person_interval[row.person_id] = [interval] - else: - person_interval[row.person_id].append(interval) + person_interval[row.person_id].append(interval) for person_id in person_interval: person_interval[person_id] = _impl.union_rects(person_interval[person_id]) @@ -219,10 +223,10 @@ def forward_fill( if observation_window is not None: last_interval = result[person_id][-1] - if last_interval.upper < observation_window.end.timestamp(): + if last_interval.upper < int(observation_window.end.timestamp()): result[person_id][-1] = Interval( last_interval.lower, - observation_window.end.timestamp(), + int(observation_window.end.timestamp()), last_interval.type, ) @@ -307,15 +311,11 @@ def complementary_intervals( """ interval_type_missing_persons = interval_type - baseline_interval = Interval( - observation_window.start.timestamp(), - observation_window.end.timestamp(), - interval_type_missing_persons, + baseline_interval = timerange_to_interval( + observation_window, type_=interval_type_missing_persons ) - observation_window_mask = Interval( - observation_window.start.timestamp(), - observation_window.end.timestamp(), - IntervalType.least_intersection_priority(), + observation_window_mask = timerange_to_interval( + observation_window, type_=IntervalType.least_intersection_priority() ) result = {} @@ -595,8 +595,8 @@ def add_interval( assert previous_end < effective_start # type: ignore[unreachable] intervals.append( Interval( - lower=effective_start.timestamp(), - upper=effective_end.timestamp(), + lower=int(effective_start.timestamp()), + upper=int(effective_end.timestamp()), type=interval_type, ) ) @@ -695,12 +695,20 @@ def find_overlapping_personal_windows( def find_rectangles( data: list[PersonIntervals], - interval_constructor: Callable, - is_same_result: Callable | None = None, + interval_constructor: IntervalConstructor, + is_same_result: SameResult | None = None, ) -> Dict[int, List[GeneralizedInterval]]: """ - Iterates over intervals for each person across all items in `data` and constructs new intervals - ("rectangles") by applying `interval_constructor` to the overlapping intervals in each time range. + Constructs new intervals ("time slices") by combining multiple parallel tracks of intervals. + + This iterates over all intervals for each person across the given `data` list. Whenever an + interval starts or ends on any track, that boundary can produce a new interval. For each + interval, we invoke `interval_constructor(start, end, active_intervals)` to decide how + to label it (e.g., POSITIVE, NEGATIVE). + + If `is_same_result` is provided, it’s used to decide whether two adjacent slices have the + same "type" (so we can merge them). If not provided, a default routine is used that merges + slices if they have the same object identity and the same result. :param data: A list of dictionaries, each mapping a person ID to a list of intervals. :param interval_constructor: A callable that takes the time boundaries and the corresponding intervals @@ -712,20 +720,21 @@ def find_rectangles( # TODO(jmoringe): can this use _process_interval? if len(data) == 0: return {} - else: - keys: Set[int] = set() - result: Dict[int, List[GeneralizedInterval]] = dict() - - for track in data: - keys |= track.keys() - - for key in keys: - key_result = _impl.find_rectangles( - [intervals.get(key, []) for intervals in data], - interval_constructor, - is_same_result=is_same_result, - ) - if len(key_result) > 0: - result[key] = key_result - return result + # Collect all person IDs across all tracks + keys: Set[int] = set() + result: Dict[int, List[GeneralizedInterval]] = dict() + + for track in data: + keys |= track.keys() + + for key in keys: + key_result = _impl.find_rectangles( + [intervals.get(key, []) for intervals in data], + interval_constructor, + is_same_result=is_same_result, + ) + if len(key_result) > 0: + result[key] = key_result + + return result diff --git a/execution_engine/task/process/rectangle_cython.pyx b/execution_engine/task/process/rectangle_cython.pyx index 73c1898e..bdeee748 100644 --- a/execution_engine/task/process/rectangle_cython.pyx +++ b/execution_engine/task/process/rectangle_cython.pyx @@ -2,24 +2,32 @@ import typing from functools import cmp_to_key from typing import Callable -cimport numpy as np - -import numpy as np from sortedcontainers import SortedDict -from execution_engine.task.process import Interval, IntervalWithCount, AnyInterval +from execution_engine.task.process import AnyInterval, Interval, IntervalWithCount +from execution_engine.task.process.rectangle import IntervalConstructor, SameResult from execution_engine.util.interval import IntervalType +from libc.limits cimport SCHAR_MIN +from libc.stdint cimport int64_t + + +# A sentinel for “negative infinity” in 64-bit integer space +cdef int64_t NEG_INF = -9223372036854775807 # -2^63 +cdef int64_t POS_INF = 9223372036854775807 # 2^63 - 1 + DEF SCHAR_MIN = -128 DEF SCHAR_MAX = 127 MODULE_IMPLEMENTATION = "cython" +IntervalEvent = typing.Tuple[int, bool, AnyInterval] +IntervalEventWithCount = typing.Tuple[int, bool, AnyInterval, int] def intervals_to_events( intervals: list[AnyInterval], closing_offset: int = 1, -) -> list[tuple[int, bool, AnyInterval]]: +) -> list[IntervalEvent]: """ Converts the intervals to a list of events. @@ -34,10 +42,10 @@ def intervals_to_events( def union_rects(list[Interval] intervals) -> list[Interval]: - cdef double last_x = -np.inf - cdef double last_x_closed = -np.inf - cdef double cur_x = -np.inf - cdef double first_x + cdef int64_t last_x = NEG_INF + cdef int64_t last_x_closed = NEG_INF + cdef int64_t cur_x = NEG_INF + cdef int64_t first_x cdef signed char max_open_y = SCHAR_MIN #cdef signed char open_y[len(intervals)] @@ -121,86 +129,6 @@ def union_rects(list[Interval] intervals) -> list[Interval]: last_x = cur_x # start new output rectangle return union -def union_rects_with_count(list[IntervalWithCount] intervals) -> list[IntervalWithCount]: - cdef double last_x_start = -np.inf - cdef double last_x_end; - cdef double previous_x_visited = -np.inf - cdef double first_x - cdef signed char max_open_y = SCHAR_MIN - - if not intervals: - return [] - - order = IntervalType.union_priority()[::-1] - - events = intervals_to_events(intervals) - - union = [] - - last_x_start = -np.inf # holds the x_min of the currently open output rectangle - last_x_end = events[0][0] # x variable of the last closed interval (we start with the first x, so we - # don't close the first rectangle at the first x) - previous_x_visited = -np.inf - open_y = SortedDict() - - def get_y_max() -> IntervalType | None: - max_key = None - for key in reversed(open_y): - if open_y[key] > 0: - max_key = key - break - return max_key - - for x, start_point, interval in events: - y_type, count_event = interval.type, interval.count - y = order.index(y_type) - if start_point: - y_max = get_y_max() - - if x > previous_x_visited and y_max is None: - # no currently open rectangles - last_x_start = x # start new output rectangle - elif y >= y_max: - if x == last_x_end or x == last_x_start: - # we already closed a rectangle at this x, so we don't need to start a new one - open_y[y] = open_y.get(y, 0) + count_event - continue - - union.append( - IntervalWithCount( - lower=last_x_start, upper=x - 1, type=order[y_max], count=open_y[y_max] - ) - ) - last_x_end = x - last_x_start = x - - open_y[y] = open_y.get(y, 0) + count_event - - else: - open_y[y] = max(open_y.get(y, 0) - count_event, 0) - - y_max = get_y_max() - - if (y_max is None or (open_y and y_max <= y)) and x > last_x_end: - if y_max is None or y_max < y: - # the closing rectangle has a higher y_max than the currently open ones - count = count_event - else: - # the closing rectangle has the same y_max as the currently open ones - count = open_y[y] + count_event - - union.append( - IntervalWithCount( - lower=last_x_start, upper=x - 1, type=order[y], count=count - ) - ) # close the previous rectangle at y_max - last_x_end = x - last_x_start = x # start new output rectangle - - previous_x_visited = x - - return merge_adjacent_intervals(union) - def merge_adjacent_intervals(intervals: list[IntervalWithCount]) -> list[IntervalWithCount]: """ Merges adjacent intervals in a list of IntervalWithCount namedtuples if they have the same 'type' and 'count'. @@ -261,9 +189,9 @@ def merge_adjacent_intervals(intervals: list[IntervalWithCount]) -> list[Interva def intersect_rects(list[Interval] intervals) -> list[Interval]: - cdef double x_min = -np.inf + cdef int64_t x_min = NEG_INF cdef signed char y_min = SCHAR_MAX - cdef double end_point = np.inf + cdef int64_t end_point = POS_INF if not len(intervals): return [] @@ -327,27 +255,57 @@ def union_interval_lists( return union_rects(left + right) -IntervalConstructor = Callable[[int, int, typing.List[AnyInterval]], AnyInterval] - -def default_is_same_result(interval_constructor): - def is_same_result(active_intervals1, active_intervals2): +def default_is_same_result(interval_constructor: IntervalConstructor): + """ + Creates an 'is_same_result' function that determines whether two sets of active intervals + produce the same resulting interval when passed to 'interval_constructor'. + + The returned function calls: + interval_constructor(0, 0, active_intervals1) + and + interval_constructor(0, 0, active_intervals2) + and checks if the results are equal. If they match, we say they represent the “same” result. + + :param interval_constructor: An interval constructor function. + :return: + A function 'is_same_result' that compares the results of two different sets of active + intervals by invoking 'interval_constructor' on each and checking for equality. + """ + def is_same_result( + active_intervals1: List[GeneralizedInterval], + active_intervals2: List[GeneralizedInterval], + ) -> bool: + """ + Compares the resulting intervals for two sets of active intervals. + + :param active_intervals1: + A list of intervals (or None) describing the first track’s active intervals. + :param active_intervals2: + A list of intervals (or None) describing the second track’s active intervals. + :return: + True if 'interval_constructor(0, 0, ...)' produces the same interval for + both sets, otherwise False. + """ # When we have to decide whether to extend a result interval # or start a new one, we compare the state for the existing # result interval with the new state. The states are derived # from the respective lists of active intervals by calling - # interval_constructor (with fake points in time) . + # interval_constructor (with fake points in time). return (interval_constructor(0, 0, active_intervals1) == interval_constructor(0, 0, active_intervals2)) return is_same_result def find_rectangles(all_intervals: list[list[AnyInterval]], interval_constructor: IntervalConstructor, - is_same_result = None) \ + is_same_result: SameResult | None = None) \ -> list[AnyInterval]: - """For multiple parallel "tracks" of intervals, identify temporal - intervals in which no change occurs on any "track". For each such - interval, call interval_constructor to determine how the interval - should be represented in the overall result. To this end, + """ + Low-level engine for interval construction. + + For multiple parallel "tracks" of intervals, identify segments of time + in which no change occurs on any "track". For each such segment, + call `interval_constructor(start, end, active_intervals)` to determine + how to represent the interval in the overall result. To this end, interval_constructor receives a list "active" intervals the elements of which are either None or an interval from all_intervals and returns either None or an interval. The returned @@ -379,16 +337,29 @@ def find_rectangles(all_intervals: list[list[AnyInterval]], # before the open event, otherwise our tracking of active # intervals would get confused. track_count = len(all_intervals) + events = [ (time, event, interval, j) for j, intervals in enumerate(all_intervals) - for interval in intervals # intervals_to_events(intervals, closing_offset=0) + for interval in intervals for (time,event) in [(interval.lower, True), (interval.upper, False)] ] + event_count = len(events) + if event_count == 0: return [] - def compare_events(event1, event2): + + def compare_events( + event1: IntervalEventWithCount, event2: IntervalEventWithCount + ) -> int: + """ + Sorting comparator to ensure we process events in the correct order: + - earlier time first + - if same time and same track, close events before open events + (so we don't incorrectly treat a consecutive interval on the same track + as overlapping). + """ if event1[0] < event2[0]: # event1 is earlier return -1 elif event2[0] < event1[0]: # event2 is earlier @@ -400,11 +371,32 @@ def find_rectangles(all_intervals: list[list[AnyInterval]], return -1 if (event1[1] is False) else 1 # sort close events before open events else: # at the same time, but different tracks => any order is fine return 1 - events.sort(key = cmp_to_key(compare_events)) - active_intervals = [None] * track_count - def process_events_for_point_in_time(index, point_time): + + # Sort events chronologically according to compare_events + events.sort(key=cmp_to_key(compare_events)) + + active_intervals: list[GeneralizedInterval] = [None] * track_count + + def process_events_for_point_in_time( + index: int, point_time: int + ) -> Tuple[int, int, int] | None: + """ + Consumes events that occur at `point_time` (or effectively that boundary), + updating 'active_intervals' for whichever track is opening or closing + intervals at that time. + + Returns: (new_index, new_time, copy_of_active_intervals, high_time) + - new_index: the index of the first event not processed (because it's after point_time) + - new_time: the time of that next event + - copy_of_active_intervals: a snapshot of 'active_intervals' after processing + - high_time: the highest time covered by these events (may be the same as point_time + or point_time + 1 if we consider inclusive boundaries). + + If we run out of events entirely, returns (None, None, None, None). + """ high_time = point_time any_open = False + for i in range(index, event_count): time, open_, interval, track = events[i] # Since points in time for intervals are quantized to whole @@ -418,38 +410,98 @@ def find_rectangles(all_intervals: list[list[AnyInterval]], high_time = time any_open |= open_ else: - return i, time, active_intervals, high_time if any_open else high_time + 1 + # As soon as we find an event that’s clearly beyond the cluster at point_time, + # we break and return + return i, time, high_time if any_open else high_time + 1 + + # Opening => set this track’s active interval to the new interval + # Closing => set it to None active_intervals[track] = interval if open_ else None - return None, None, None, None + return None + result_intervals = [] # Step through event "clusters" with a common point in time and # emit result intervals with unchanged interval "payload". - index, time = 0, events[0][0] - interval_start_time = time - index, time, interval_start_state, high_time = process_events_for_point_in_time(index, time) - interval_start_state = interval_start_state.copy() if interval_start_state is not None else None - while index: - new_index, new_time, maybe_end_state, high_time = process_events_for_point_in_time(index, time) + index: int | None = 0 + time: int | None = events[0][0] + interval_start_time: int = time + result_intervals: list[tuple[int, int, List[GeneralizedInterval]]] = [] + + if time is None: + # No events at all + return [] + + res = process_events_for_point_in_time(index, time) + + if res is None: + return [] + + index, time, high_time = res + + interval_start_state = active_intervals.copy() + + def finalize_interval( + interval_start_time: int, + current_time: int, + interval_start_state: List[GeneralizedInterval], + ) -> None: + """ + Appends a new time slice (interval_start_time -> current_time) to 'result_intervals', + ensuring we don't create duplicate adjacency boundaries if the previous slice ends + exactly where the new one starts. + """ + if len(result_intervals) > 0: + previous_result = result_intervals[-1] + if previous_result[1] == interval_start_time: + # Adjust the previous slice so it doesn't overlap or duplicate + result_intervals[-1] = ( + previous_result[0], + previous_result[1] - 1, + previous_result[2], + ) + + # Now finalize the current slice + result_intervals.append( + (interval_start_time, current_time, interval_start_state) + ) + + # The main loop: step through event clusters + while True: + res = process_events_for_point_in_time(index, time) + if res is None: + # No more events => finalize the last slice and break + finalize_interval(interval_start_time, time, interval_start_state) + break + + new_index, new_time, high_time = res + # Diagram for this program point: # |___potential_result_interval___|| | # index new_index # interval_start_time time new_time # interval_start_state maybe_end_state # high_time - if (maybe_end_state is None) or (not is_same_result(interval_start_state, maybe_end_state)): - # Add info for one result interval. - if len(result_intervals) > 0: - previous_result = result_intervals[-1] - if previous_result[1] == interval_start_time: - result_intervals[-1] = (previous_result[0], previous_result[1] - 1, previous_result[2]) - result_intervals.append((interval_start_time, time, interval_start_state)) + + # We have a region from [interval_start_time, time) or [interval_start_time, time] + # with 'interval_start_state' as the active intervals. + # Decide if we finalize that region or if we can merge with the next region. + if not is_same_result(interval_start_state, active_intervals): + # If the active intervals changed, finalize the old slice + finalize_interval(interval_start_time, time, interval_start_state) + # Update interval start info. interval_start_time = high_time - interval_start_state = maybe_end_state.copy() if maybe_end_state is not None else None + interval_start_state = active_intervals.copy() + index, time = new_index, new_time + result = [] - for (start, end, intervals) in result_intervals: + + # Finally, convert the (start, end, intervals) slices into actual Interval objects + for start, end, intervals in result_intervals: interval = interval_constructor(start, end, intervals) + if interval is not None: result.append(interval) + return result diff --git a/execution_engine/task/process/rectangle_python.py b/execution_engine/task/process/rectangle_python.py index 9ab329fb..50aa2e1f 100644 --- a/execution_engine/task/process/rectangle_python.py +++ b/execution_engine/task/process/rectangle_python.py @@ -1,19 +1,27 @@ -import typing from functools import cmp_to_key -from typing import Callable +from typing import List, Tuple, cast import numpy as np -from sortedcontainers import SortedDict, SortedList - -from execution_engine.task.process import Interval, IntervalWithCount, AnyInterval +from sortedcontainers import SortedList + +from execution_engine.task.process import ( + AnyInterval, + GeneralizedInterval, + Interval, + IntervalWithCount, +) +from execution_engine.task.process.rectangle import IntervalConstructor, SameResult from execution_engine.util.interval import IntervalType MODULE_IMPLEMENTATION = "python" +IntervalEvent = Tuple[int, bool, AnyInterval] +IntervalEventWithCount = Tuple[int, bool, AnyInterval, int] + def intervals_to_events( - intervals: list[AnyInterval], closing_offset: int = 1 -) -> list[tuple[int, bool, AnyInterval]]: + intervals: list[Interval], closing_offset: int = 1 +) -> list[IntervalEvent]: """ Converts the intervals to a list of events. @@ -22,9 +30,10 @@ def intervals_to_events( :param intervals: The intervals. :return: The events. """ - events = [ (i.lower, True, i) for i in intervals ] \ - + [ (i.upper + closing_offset, False, i) for i in intervals ] - return sorted(events,key=lambda i: i[0]) + events = [(i.lower, True, i) for i in intervals] + [ + (i.upper + closing_offset, False, i) for i in intervals + ] + return sorted(events, key=lambda i: i[0]) def union_rects(intervals: list[Interval]) -> list[Interval]: @@ -220,9 +229,7 @@ def intersect_interval_lists( :return: The list of intersections. """ return union_rects( - [item for x in left - for y in right - for item in intersect_rects([x, y])] + [item for x in left for y in right for item in intersect_rects([x, y])] ) @@ -237,27 +244,62 @@ def union_interval_lists(left: list[Interval], right: list[Interval]) -> list[In return union_rects(left + right) -IntervalConstructor = Callable[[int, int, typing.List[AnyInterval]], AnyInterval] +def default_is_same_result(interval_constructor: IntervalConstructor) -> SameResult: + """ + Creates an 'is_same_result' function that determines whether two sets of active intervals + produce the same resulting interval when passed to 'interval_constructor'. + + The returned function calls: + interval_constructor(0, 0, active_intervals1) + and + interval_constructor(0, 0, active_intervals2) + and checks if the results are equal. If they match, we say they represent the “same” result. + + :param interval_constructor: An interval constructor function. + :return: + A function 'is_same_result' that compares the results of two different sets of active + intervals by invoking 'interval_constructor' on each and checking for equality. + """ -def default_is_same_result(interval_constructor): - def is_same_result(active_intervals1, active_intervals2): + def is_same_result( + active_intervals1: List[GeneralizedInterval], + active_intervals2: List[GeneralizedInterval], + ) -> bool: + """ + Compares the resulting intervals for two sets of active intervals. + + :param active_intervals1: + A list of intervals (or None) describing the first track’s active intervals. + :param active_intervals2: + A list of intervals (or None) describing the second track’s active intervals. + :return: + True if 'interval_constructor(0, 0, ...)' produces the same interval for + both sets, otherwise False. + """ # When we have to decide whether to extend a result interval # or start a new one, we compare the state for the existing # result interval with the new state. The states are derived # from the respective lists of active intervals by calling - # interval_constructor (with fake points in time) . - return (interval_constructor(0, 0, active_intervals1) - == interval_constructor(0, 0, active_intervals2)) + # interval_constructor (with fake points in time). + return interval_constructor(0, 0, active_intervals1) == interval_constructor( + 0, 0, active_intervals2 + ) + return is_same_result -def find_rectangles(all_intervals: list[list[AnyInterval]], - interval_constructor: IntervalConstructor, - is_same_result = None) \ - -> list[AnyInterval]: - """For multiple parallel "tracks" of intervals, identify temporal - intervals in which no change occurs on any "track". For each such - interval, call interval_constructor to determine how the interval - should be represented in the overall result. To this end, + +def find_rectangles( + all_intervals: list[list[AnyInterval]], + interval_constructor: IntervalConstructor, + is_same_result: SameResult | None = None, +) -> list[AnyInterval]: + """ + Low-level engine for interval construction. + + For multiple parallel "tracks" of intervals, identify segments of time + in which no change occurs on any "track". For each such segment, + call `interval_constructor(start, end, active_intervals)` to determine + how to represent the interval in the overall result. To this end, interval_constructor receives a list "active" intervals the elements of which are either None or an interval from all_intervals and returns either None or an interval. The returned @@ -289,32 +331,69 @@ def find_rectangles(all_intervals: list[list[AnyInterval]], # before the open event, otherwise our tracking of active # intervals would get confused. track_count = len(all_intervals) - events = [ + + events: list[IntervalEventWithCount] = [ (time, event, interval, j) for j, intervals in enumerate(all_intervals) - for interval in intervals # intervals_to_events(intervals, closing_offset=0) - for (time,event) in [(interval.lower, True), (interval.upper, False)] + for interval in intervals + for (time, event) in [(interval.lower, True), (interval.upper, False)] ] event_count = len(events) + if event_count == 0: return [] - def compare_events(event1, event2): - if event1[0] < event2[0]: # event1 is earlier + + def compare_events( + event1: IntervalEventWithCount, event2: IntervalEventWithCount + ) -> int: + """ + Sorting comparator to ensure we process events in the correct order: + - earlier time first + - if same time and same track, close events before open events + (so we don't incorrectly treat a consecutive interval on the same track + as overlapping). + """ + if event1[0] < event2[0]: # event1 is earlier return -1 - elif event2[0] < event1[0]: # event2 is earlier + elif event2[0] < event1[0]: # event2 is earlier return 1 - elif event1[3] == event2[3]: # at the same time and on same track, - if event1[2] is event2[2]: # same interval - return -1 if (event1[1] is True) else 1 # sort open events before open events - else: # different intervals - return -1 if (event1[1] is False) else 1 # sort close events before open events - else: # at the same time, but different tracks => any order is fine + elif event1[3] == event2[3]: # at the same time and on same track, + if event1[2] is event2[2]: # same interval + return ( + -1 if (event1[1] is True) else 1 + ) # sort open events before open events + else: # different intervals + return ( + -1 if (event1[1] is False) else 1 + ) # sort close events before open events + else: # at the same time, but different tracks => any order is fine return 1 - events.sort(key = cmp_to_key(compare_events)) - active_intervals = [None] * track_count - def process_events_for_point_in_time(index, point_time): + + # Sort events chronologically according to compare_events + events.sort(key=cmp_to_key(compare_events)) + + active_intervals: list[GeneralizedInterval] = [None] * track_count + + def process_events_for_point_in_time( + index: int, point_time: int + ) -> Tuple[int, int, int] | None: + """ + Consumes events that occur at `point_time` (or effectively that boundary), + updating 'active_intervals' for whichever track is opening or closing + intervals at that time. + + Returns: (new_index, new_time, copy_of_active_intervals, high_time) + - new_index: the index of the first event not processed (because it's after point_time) + - new_time: the time of that next event + - copy_of_active_intervals: a snapshot of 'active_intervals' after processing + - high_time: the highest time covered by these events (may be the same as point_time + or point_time + 1 if we consider inclusive boundaries). + + If we run out of events entirely, returns (None, None, None, None). + """ high_time = point_time any_open = False + for i in range(index, event_count): time, open_, interval, track = events[i] # Since points in time for intervals are quantized to whole @@ -328,38 +407,103 @@ def process_events_for_point_in_time(index, point_time): high_time = time any_open |= open_ else: - return i, time, active_intervals, high_time if any_open else high_time + 1 + # As soon as we find an event that’s clearly beyond the cluster at point_time, + # we break and return + return ( + i, + time, + high_time if any_open else high_time + 1, + ) + + # Opening => set this track’s active interval to the new interval + # Closing => set it to None active_intervals[track] = interval if open_ else None - return None, None, None, None - result_intervals = [] + + # If we exit the loop fully, we used all events + return None + # Step through event "clusters" with a common point in time and # emit result intervals with unchanged interval "payload". - index, time = 0, events[0][0] - interval_start_time = time - index, time, interval_start_state, high_time = process_events_for_point_in_time(index, time) - interval_start_state = interval_start_state.copy() if interval_start_state is not None else None - while index: - new_index, new_time, maybe_end_state, high_time = process_events_for_point_in_time(index, time) + index: int | None = 0 + time: int | None = events[0][0] + interval_start_time: int = cast(int, time) + result_intervals: list[tuple[int, int, List[GeneralizedInterval]]] = [] + + if time is None: + # No events at all + return [] + + res = process_events_for_point_in_time(cast(int, index), cast(int, time)) + + if res is None: + return [] + + index, time, high_time = res + + interval_start_state = active_intervals.copy() + + def finalize_interval( + interval_start_time: int, + current_time: int, + interval_start_state: List[GeneralizedInterval], + ) -> None: + """ + Appends a new time slice (interval_start_time -> current_time) to 'result_intervals', + ensuring we don't create duplicate adjacency boundaries if the previous slice ends + exactly where the new one starts. + """ + if len(result_intervals) > 0: + previous_result = result_intervals[-1] + if previous_result[1] == interval_start_time: + # Adjust the previous slice so it doesn't overlap or duplicate + result_intervals[-1] = ( + previous_result[0], + previous_result[1] - 1, + previous_result[2], + ) + + # Now finalize the current slice + result_intervals.append( + (interval_start_time, current_time, interval_start_state) + ) + + # The main loop: step through event clusters + while True: + res = process_events_for_point_in_time(index, time) + if res is None: + # No more events => finalize the last slice and break + finalize_interval(interval_start_time, time, interval_start_state) + break + + new_index, new_time, high_time = res + # Diagram for this program point: # |___potential_result_interval___|| | # index new_index # interval_start_time time new_time # interval_start_state maybe_end_state # high_time - if (maybe_end_state is None) or (not is_same_result(interval_start_state, maybe_end_state)): - # Add info for one result interval. - if len(result_intervals) > 0: - previous_result = result_intervals[-1] - if previous_result[1] == interval_start_time: - result_intervals[-1] = (previous_result[0], previous_result[1] - 1, previous_result[2]) - result_intervals.append((interval_start_time, time, interval_start_state)) + + # We have a region from [interval_start_time, time) or [interval_start_time, time] + # with 'interval_start_state' as the active intervals. + # Decide if we finalize that region or if we can merge with the next region. + if not is_same_result(interval_start_state, active_intervals): + # If the active intervals changed, finalize the old slice + finalize_interval(interval_start_time, time, interval_start_state) + # Update interval start info. interval_start_time = high_time - interval_start_state = maybe_end_state.copy() if maybe_end_state is not None else None + interval_start_state = active_intervals.copy() + index, time = new_index, new_time + result = [] - for (start, end, intervals) in result_intervals: + + # Finally, convert the (start, end, intervals) slices into actual Interval objects + for start, end, intervals in result_intervals: interval = interval_constructor(start, end, intervals) + if interval is not None: result.append(interval) + return result diff --git a/execution_engine/task/task.py b/execution_engine/task/task.py index 6459c03f..082d5813 100644 --- a/execution_engine/task/task.py +++ b/execution_engine/task/task.py @@ -13,7 +13,7 @@ from execution_engine.constants import CohortCategory from execution_engine.omop.criterion.abstract import Criterion from execution_engine.omop.db.celida.tables import ResultInterval -from execution_engine.omop.sqlclient import OMOPSQLClient +from execution_engine.omop.sqlclient import OMOPSQLClient, datetime_cols_to_epoch from execution_engine.settings import get_config from execution_engine.task.process import ( AnyInterval, @@ -22,10 +22,12 @@ IntervalWithCount, get_processing_module, interval_like, + timerange_to_interval, ) from execution_engine.util.enum import TimeIntervalType from execution_engine.util.interval import IntervalType -from execution_engine.util.types import PersonIntervals, TimeRange +from execution_engine.util.types import PersonIntervals +from execution_engine.util.types.timerange import TimeRange process = get_processing_module() @@ -309,8 +311,10 @@ def handle_criterion( :param observation_window: The observation window. :return: A DataFrame with the result of the query. """ + engine = get_engine() query = criterion.create_query() + query = datetime_cols_to_epoch(query) engine.log_query(query, params=bind_params) logging.debug(f"Running query - '{criterion.description()}'") @@ -542,11 +546,7 @@ def handle_left_dependent_toggle( # range; Its type is not important. # use a tuple for windows to make sure it is immutable (and can be shared by all persons) windows = ( - Interval( - observation_window.start.timestamp(), - observation_window.end.timestamp(), - IntervalType.POSITIVE, - ), + timerange_to_interval(observation_window, type_=IntervalType.POSITIVE), ) window_intervals = {key: windows for key in left.keys()} @@ -589,21 +589,39 @@ def handle_temporal_operator( self, data: list[PersonIntervals], observation_window: TimeRange ) -> PersonIntervals: """ - Handles a TemporalCount operator. + Handles a TemporalCount operator, which checks whether a certain condition + (e.g., an event or observation) occurs a specific number of times in each defined test interval. - May be used to aggregate multiple criteria in a temporal manner, e.g. to count the number of times a certain - condition is met within a certain time frame (e.g. morning shift). + Note: Currently, only TemporalMinCount(*, threshold=1) is supported. - :param data: The input data. - :param observation_window: The observation window. - :return: A DataFrame with the merged intervals. + This method can do one of two main things: + 1) If `self.expr.interval_criterion` is set, it will intersect ("find overlapping") + a set of personal indicator windows with your data, returning intervals that + represent when the criterion is met within those windows. + 2) If `interval_criterion` is not set, it will create time slices (intervals) for + the specified `TimeIntervalType` (e.g., morning shift) within `observation_window` + and determine whether the data meets the condition inside each slice. + + For instance, if your expression says 'TemporalCount(ANY_TIME, count_min=1)', then + the entire observation_window is treated as one big interval, and we only need to + check if we see at least one positive data interval in that timeframe. + + :param data: A list of dictionaries mapping person ID -> list of intervals. + Typically, the intervals from your workflow or dataset. + :param observation_window: The overall time range we’re interested in. + :return: A dictionary mapping each person to a list of resulting intervals + (e.g., each labeled with POSITIVE, NEGATIVE, NOT_APPLICABLE, etc.). """ assert isinstance(self.expr, logic.TemporalCount) if self.expr.interval_criterion is not None: + # If we have an interval_criterion, we expect exactly two data streams: + # (1) The "main" data + # (2) The "indicator" data (e.g., personal windows or ICU periods) assert ( len(data) == 2 ), f"TemporalCount with indicator criterion requires exactly two input streams, got {len(data)}" + indicator_personal_windows = data.pop( self.get_predecessor_data_index(self.expr.interval_criterion) ) @@ -618,7 +636,9 @@ def get_start_end_from_interval_type( type_: TimeIntervalType, ) -> tuple[datetime.time, datetime.time]: """ - Returns the start and end time for a given TimeIntervalType, read from the configuration. + Reads the config for the given TimeIntervalType, returning its start and end times. + + For example, 'MORNING_SHIFT' might map to 06:00 - 14:00, if configured that way. """ try: cnf = getattr(get_config().time_intervals, type_.value) @@ -627,26 +647,33 @@ def get_start_end_from_interval_type( return cnf.start, cnf.end assert isinstance(self.expr, logic.TemporalCount), "Invalid expression type" - assert self.expr.count_min == 1 - assert self.expr.count_max is None + + if self.expr.count_min != 1 or self.expr.count_max is not None: + raise NotImplementedError( + "Currently, only TemporalMinCount(*, threshold=1) is supported." + ) if self.expr.interval_criterion is not None: + # Filter out only the POSITIVE intervals from the data data_positive = process.select_type(data_arg, IntervalType.POSITIVE) + # Overlap the personal windows with the data intervals to see if the condition + # is met for each relevant chunk in the personal window. result = process.find_overlapping_personal_windows( indicator_personal_windows, data_positive ) else: - + # If we have no `interval_criterion`, we must create the intervals ourselves + # from the observation_window (e.g. ANY_TIME vs MORNING_SHIFT). if self.expr.interval_type == TimeIntervalType.ANY_TIME: - indicator_windows = [ - Interval( - lower=observation_window.start.timestamp(), - upper=observation_window.end.timestamp(), - type=IntervalType.POSITIVE, - ) - ] + # Just one interval covering the entire observation window + indicator_windows = ( + timerange_to_interval( + observation_window, type_=IntervalType.POSITIVE + ), + ) else: + # If we do have a known interval type, or explicit start/end times, build them: if self.expr.interval_type is not None: start_time, end_time = get_start_end_from_interval_type( self.expr.interval_type @@ -658,6 +685,8 @@ def get_start_end_from_interval_type( else: raise ValueError("Invalid time interval settings") + # Create repeated intervals for each day from observation_window.start + # up to observation_window.end, using e.g. "06:00 - 14:00" if it's morning, etc. indicator_windows = process.create_time_intervals( start_datetime=observation_window.start, end_datetime=observation_window.end, @@ -667,42 +696,65 @@ def get_start_end_from_interval_type( timezone=get_config().timezone, ) - # Incrementally compute the interval type for each window - # interval. Maps id of window interval -> interval type + # We'll track each "window interval" by its object ID, storing the best + # result type found so far (NEGATIVE, POSITIVE, or NOT_APPLICABLE). window_types: dict[int, IntervalType] = dict() def update_window_type( window_interval: GeneralizedInterval, data_interval: GeneralizedInterval ) -> IntervalType: + """ + Called whenever we cross a boundary in time. + + window_interval: The current "indicator" interval (e.g., a morning slice). + data_interval: The data interval from data_arg that overlaps with this moment, + could be POSITIVE, NEGATIVE, NO_DATA, NOT_APPLICABLE (or None, + which is interpreted as NEGATIVE). - window_type = window_types.get( + This function decides how to update the 'window interval type' based on new info. + """ + current_type = window_types.get( id(window_interval), IntervalType.NOT_APPLICABLE ) + # If the data interval is NEGATIVE (or equivalently, None) or NO_DATA, we treat it as NEGATIVE if ( data_interval is None or data_interval.type is IntervalType.NO_DATA or data_interval.type is IntervalType.NEGATIVE ): - if window_type is IntervalType.NOT_APPLICABLE: - window_type = IntervalType.NEGATIVE + # Set current_type to NEGATIVE if it is not already POSITIVE (because a POSITIVE data interval + # was passed earlier) + if current_type is IntervalType.NOT_APPLICABLE: + current_type = IntervalType.NEGATIVE + elif data_interval.type is IntervalType.POSITIVE: - window_type = IntervalType.POSITIVE + # If the data is POSITIVE, set the window to POSITIVE, overriding any negative state. + current_type = IntervalType.POSITIVE - window_types[id(window_interval)] = window_type + window_types[id(window_interval)] = current_type - return window_type + return current_type - # The boundaries of the result intervals are identical to - # those of the window intervals. In addition, update the - # result interval window types based on the data - # intervals. def is_same_interval( left_intervals: List[GeneralizedInterval], right_intervals: List[GeneralizedInterval], ) -> bool: + """ + Helper used by ` process.find_rectangles()`. + + The framework calls this to decide if two adjacent intervals share + the same "payload" and can be merged into a single interval. + + 'left_intervals' and 'right_intervals' are each a pair [window_interval, data_interval], + for the previous block vs. the new block. We update the 'window_types' dict with + whatever is found in the new block, and return True if we can keep merging them. + """ left_window_interval, left_data_interval = left_intervals right_window_interval, right_data_interval = right_intervals + + # If the next block's window interval is None, it means we're off-limits + # or there's no overlap. We update the left block's type and return False. if right_window_interval is None: if left_window_interval is None: return True @@ -710,21 +762,34 @@ def is_same_interval( update_window_type(left_window_interval, left_data_interval) return False else: + # We do have a right_window_interval, so update it with the right_data_interval update_window_type(right_window_interval, right_data_interval) + + # If the left window is None, can't be the same interval if left_window_interval is None: return False else: + # If left_window_interval and right_window_interval are literally + # the same object, we can treat them as the same interval if left_window_interval is right_window_interval: return True else: + # They’re different intervals => finalize left and move on update_window_type(left_window_interval, left_data_interval) return False - # Create result intervals based on the computed interval - # types. def result_interval( start: int, end: int, intervals: List[AnyInterval] ) -> AnyInterval: + """ + Called at the end of building each interval, to produce the final interval + object that will go into the result. + + 'intervals' is [window_interval, data_interval], where either can be None. + + We look up window_interval in window_types to see whether we've determined + it is POSITIVE, NEGATIVE, or NOT_APPLICABLE. + """ window_interval, data_interval = intervals if ( window_interval is None @@ -744,6 +809,7 @@ def result_interval( key: [copy.copy(window) for window in indicator_windows] for key in data_arg.keys() } + result = process.find_rectangles( [person_indicator_windows, data_arg], result_interval, @@ -775,11 +841,7 @@ def insert_negative_intervals( # window_intervals are not important. # use a tuple for windows to make sure it is immutable (and can be shared by all persons) windows = ( - Interval( - observation_window.start.timestamp(), - observation_window.end.timestamp(), - IntervalType.POSITIVE, - ), + timerange_to_interval(observation_window, type_=IntervalType.POSITIVE), ) all_keys = data.keys() | base_data.keys() window_intervals = {key: windows for key in all_keys} diff --git a/execution_engine/util/types.py b/execution_engine/util/types/__init__.py similarity index 68% rename from execution_engine/util/types.py rename to execution_engine/util/types/__init__.py index 817b9407..5fffe8b2 100644 --- a/execution_engine/util/types.py +++ b/execution_engine/util/types/__init__.py @@ -1,99 +1,16 @@ -from datetime import date, datetime, timedelta from typing import Any -import pendulum -import pytz from pydantic import BaseModel, ConfigDict, field_validator, model_validator from execution_engine.task.process import AnyInterval from execution_engine.util import serializable from execution_engine.util.enum import TimeUnit -from execution_engine.util.interval import ( - DateTimeInterval, - IntervalType, - interval_datetime, -) from execution_engine.util.value import ValueNumber, ValueNumeric from execution_engine.util.value.time import ValueCount, ValueDuration, ValuePeriod PersonIntervals = dict[int, AnyInterval] -@serializable.register_class -class TimeRange(BaseModel): - """ - A time range. - """ - - start: datetime - end: datetime - name: str | None = None - - @field_validator("start", "end") - def check_timezone(cls, v: datetime) -> datetime: - """ - Check that the start, end parameters are timezone-aware. - """ - if not v.tzinfo: - raise ValueError("Datetime object must be timezone-aware") - - # workaround to fix pd.testing.assert_frame_equal errors when tzinfo is of type - # pydantic_core._pydantic_core.TzInfo, which somehow triggers an error when comparing a dataframe with a that - # tz in a datetime column vs a pytz.UTC tz column. - if v.tzinfo == pytz.UTC: - return v.astimezone(pytz.utc) - - return v - - @classmethod - def from_tuple( - cls, dt: tuple[datetime | str, datetime | str], name: str | None = None - ) -> "TimeRange": - """ - Create a time range from a tuple of datetimes. - """ - return cls(start=dt[0], end=dt[1], name=name) - - @property - def period(self) -> pendulum.Interval: - """ - Get the period of the time range. - """ - return pendulum.interval(start=self.start.date(), end=self.end.date()) - - def date_range(self) -> set[date]: - """ - Get the date range of the time range. - """ - return set(self.period.range("days")) - - @property - def duration(self) -> timedelta: - """ - Get the duration of the time range. - """ - return self.end - self.start - - def interval(self, type_: IntervalType) -> DateTimeInterval: - """ - Get the interval of the time range. - - :param type_: The type of interval to get. - :return: The interval. - """ - return interval_datetime(self.start, self.end, type_=type_) - - def model_dump(self, *args: Any, **kwargs: Any) -> dict[str, datetime]: - """ - Get the dictionary representation of the time range. - """ - prefix = self.name + "_" if self.name else "" - return { - prefix + "start_datetime": self.start, - prefix + "end_datetime": self.end, - } - - @serializable.register_class class Timing(BaseModel): """ diff --git a/execution_engine/util/types/timerange.py b/execution_engine/util/types/timerange.py new file mode 100644 index 00000000..b1ed30bd --- /dev/null +++ b/execution_engine/util/types/timerange.py @@ -0,0 +1,88 @@ +from datetime import date, datetime, timedelta +from typing import Any + +import pendulum +import pytz +from pydantic import BaseModel, field_validator + +from execution_engine.util import serializable +from execution_engine.util.interval import ( + DateTimeInterval, + IntervalType, + interval_datetime, +) + + +@serializable.register_class +class TimeRange(BaseModel): + """ + A time range. + """ + + start: datetime + end: datetime + name: str | None = None + + @field_validator("start", "end") + def check_timezone(cls, v: datetime) -> datetime: + """ + Check that the start, end parameters are timezone-aware. + """ + if not v.tzinfo: + raise ValueError("Datetime object must be timezone-aware") + + # workaround to fix pd.testing.assert_frame_equal errors when tzinfo is of type + # pydantic_core._pydantic_core.TzInfo, which somehow triggers an error when comparing a dataframe with a that + # tz in a datetime column vs a pytz.UTC tz column. + if v.tzinfo == pytz.UTC: + return v.astimezone(pytz.utc) + + return v + + @classmethod + def from_tuple( + cls, dt: tuple[datetime | str, datetime | str], name: str | None = None + ) -> "TimeRange": + """ + Create a time range from a tuple of datetimes. + """ + return cls(start=dt[0], end=dt[1], name=name) + + @property + def period(self) -> pendulum.Interval: + """ + Get the period of the time range. + """ + return pendulum.interval(start=self.start.date(), end=self.end.date()) + + def date_range(self) -> set[date]: + """ + Get the date range of the time range. + """ + return set(self.period.range("days")) + + @property + def duration(self) -> timedelta: + """ + Get the duration of the time range. + """ + return self.end - self.start + + def interval(self, type_: IntervalType) -> DateTimeInterval: + """ + Get the interval of the time range. + + :param type_: The type of interval to get. + :return: The interval. + """ + return interval_datetime(self.start, self.end, type_=type_) + + def model_dump(self, *args: Any, **kwargs: Any) -> dict[str, datetime]: + """ + Get the dictionary representation of the time range. + """ + prefix = self.name + "_" if self.name else "" + return { + prefix + "start_datetime": self.start, + prefix + "end_datetime": self.end, + } diff --git a/setup.py b/setup.py index 67dac9a2..1d65c572 100644 --- a/setup.py +++ b/setup.py @@ -6,6 +6,6 @@ setup( name="Interval Extension", - ext_modules=cythonize(ext_modules), + ext_modules=cythonize(ext_modules, compiler_directives={"language_level": "3"}), include_dirs=[numpy.get_include()], ) diff --git a/tests/_fixtures/omop_fixture.py b/tests/_fixtures/omop_fixture.py index c682bc55..49707d8d 100644 --- a/tests/_fixtures/omop_fixture.py +++ b/tests/_fixtures/omop_fixture.py @@ -10,7 +10,7 @@ from sqlalchemy import create_engine, event, text from sqlalchemy.orm.session import sessionmaker -from execution_engine.util.types import TimeRange +from execution_engine.util.types.timerange import TimeRange logging.basicConfig() logger = logging.getLogger() diff --git a/tests/execution_engine/omop/criterion/combination/test_logical_combination.py b/tests/execution_engine/omop/criterion/combination/test_logical_combination.py index 6c83d4b2..f57bb20f 100644 --- a/tests/execution_engine/omop/criterion/combination/test_logical_combination.py +++ b/tests/execution_engine/omop/criterion/combination/test_logical_combination.py @@ -13,7 +13,8 @@ from execution_engine.omop.criterion.procedure_occurrence import ProcedureOccurrence from execution_engine.task.process import get_processing_module from execution_engine.util import logic -from execution_engine.util.types import Dosage, TimeRange +from execution_engine.util.types import Dosage +from execution_engine.util.types.timerange import TimeRange from execution_engine.util.value import ValueNumber from tests._fixtures.concept import ( concept_artificial_respiration, @@ -48,6 +49,19 @@ def intervals_to_df(result, by=None): return df +@pytest.fixture(params=["cython", "python"], scope="session") +def process_module(request): + module = get_processing_module("rectangle", version=request.param) + assert module._impl.MODULE_IMPLEMENTATION == request.param + return module + + +class ProcessTest: + @pytest.fixture(autouse=True) + def setup_method(self, process_module): + self.process = process_module + + class TestExpr: """ Test class for testing Expr @@ -127,7 +141,7 @@ def test_expr_contains_criteria(self, mock_criteria): assert expr.args[i] == mock_criteria[i] -class TestCriterionCombinationDatabase(TestCriterion): +class TestCriterionCombinationDatabase(TestCriterion, ProcessTest): """ Test class for testing criterion combinations on the database. """ diff --git a/tests/execution_engine/omop/criterion/combination/test_temporal_combination.py b/tests/execution_engine/omop/criterion/combination/test_temporal_combination.py index 07a39bc1..6d6878ca 100644 --- a/tests/execution_engine/omop/criterion/combination/test_temporal_combination.py +++ b/tests/execution_engine/omop/criterion/combination/test_temporal_combination.py @@ -1,6 +1,5 @@ import datetime -import pandas as pd import pendulum import pytest import sqlalchemy as sa @@ -18,7 +17,8 @@ from execution_engine.util import logic, temporal_logic_util from execution_engine.util.enum import TimeIntervalType from execution_engine.util.interval import IntervalType -from execution_engine.util.types import Dosage, TimeRange +from execution_engine.util.types import Dosage +from execution_engine.util.types.timerange import TimeRange from execution_engine.util.value import ValueNumber from tests._fixtures.concept import ( concept_artificial_respiration, @@ -42,18 +42,20 @@ create_procedure, create_visit, ) -from tests.functions import intervals_to_df as intervals_to_df_orig from tests.mocks.criterion import MockCriterion -process = get_processing_module() +@pytest.fixture(params=["cython", "python"], scope="session") +def process_module(request): + module = get_processing_module("rectangle", version=request.param) + assert module._impl.MODULE_IMPLEMENTATION == request.param + return module -def intervals_to_df(result, by=None): - df = intervals_to_df_orig(result, by, process.normalize_interval) - for col in df.columns: - if isinstance(df[col].dtype, pd.DatetimeTZDtype): - df[col] = df[col].dt.tz_convert("Europe/Berlin") - return df + +class ProcessTest: + @pytest.fixture(autouse=True) + def setup_method(self, process_module): + self.process = process_module class TestFixedWindowTemporalIndicatorCombination: @@ -258,7 +260,7 @@ def test_expr_contains_criteria(self, mock_criteria): ) -class TestCriterionCombinationDatabase(TestCriterion): +class TestCriterionCombinationDatabase(TestCriterion, ProcessTest): """ Test class for testing criterion combinations on the database. """ diff --git a/tests/execution_engine/omop/criterion/test_criterion.py b/tests/execution_engine/omop/criterion/test_criterion.py index 092e2e9c..ee96c222 100644 --- a/tests/execution_engine/omop/criterion/test_criterion.py +++ b/tests/execution_engine/omop/criterion/test_criterion.py @@ -21,6 +21,7 @@ partial_day_coverage, ) from execution_engine.omop.db.omop.tables import Person +from execution_engine.omop.sqlclient import datetime_cols_to_epoch from execution_engine.task import ( # noqa: F401 -- required for the mock.patch below runner, task, @@ -29,7 +30,7 @@ from execution_engine.util import datetime_converter, logic from execution_engine.util.db import add_result_insert from execution_engine.util.interval import IntervalType -from execution_engine.util.types import TimeRange +from execution_engine.util.types.timerange import TimeRange from execution_engine.util.value import ValueConcept, ValueNumber from tests._fixtures.omop_fixture import celida_recommendation from tests._testdata import concepts @@ -281,6 +282,7 @@ def insert_criterion(self, db_session, criterion, observation_window: TimeRange) self.register_criterion(criterion, db_session) query = criterion.create_query() + query = datetime_cols_to_epoch(query) result = db_session.connection().execute( query, parameters=observation_window.model_dump() | {"run_id": self.run_id} diff --git a/tests/execution_engine/omop/criterion/test_occurrence_criterion.py b/tests/execution_engine/omop/criterion/test_occurrence_criterion.py index 029a9ece..07ac99ed 100644 --- a/tests/execution_engine/omop/criterion/test_occurrence_criterion.py +++ b/tests/execution_engine/omop/criterion/test_occurrence_criterion.py @@ -4,7 +4,7 @@ import pytest from execution_engine.util.interval import IntervalType, TypedInterval -from execution_engine.util.types import TimeRange +from execution_engine.util.types.timerange import TimeRange from tests._fixtures.omop_fixture import disable_postgres_trigger from tests.execution_engine.omop.criterion.test_criterion import TestCriterion, date_set @@ -160,55 +160,53 @@ def test_multiple_occurrences_multiple_days( @pytest.mark.parametrize( "test_cases", [ - ( - [ - { - "time_range": [ # non-overlapping - ("2023-03-03 08:00:00Z", "2023-03-03 16:00:00Z"), - ("2023-03-04 09:00:00Z", "2023-03-06 15:00:00Z"), - ("2023-03-08 10:00:00Z", "2023-03-09 18:00:00Z"), - ], - "expected": { - "2023-03-03", - "2023-03-04", - "2023-03-05", - "2023-03-06", - "2023-03-08", - "2023-03-09", - }, + [ + { + "time_range": [ # non-overlapping + ("2023-03-03 08:00:00Z", "2023-03-03 16:00:00Z"), + ("2023-03-04 09:00:00Z", "2023-03-06 15:00:00Z"), + ("2023-03-08 10:00:00Z", "2023-03-09 18:00:00Z"), + ], + "expected": { + "2023-03-03", + "2023-03-04", + "2023-03-05", + "2023-03-06", + "2023-03-08", + "2023-03-09", }, - { - "time_range": [ # exact overlap - ("2023-03-01 08:00:00Z", "2023-03-02 16:00:00Z"), - ("2023-03-02 16:00:00Z", "2023-03-03 23:59:00Z"), - ("2023-03-03 23:59:00Z", "2023-03-04 11:00:00Z"), - ], - "expected": { - "2023-03-01", - "2023-03-02", - "2023-03-03", - "2023-03-04", - }, + }, + { + "time_range": [ # exact overlap + ("2023-03-01 08:00:00Z", "2023-03-02 16:00:00Z"), + ("2023-03-02 16:00:00Z", "2023-03-03 23:59:00Z"), + ("2023-03-03 23:59:00Z", "2023-03-04 11:00:00Z"), + ], + "expected": { + "2023-03-01", + "2023-03-02", + "2023-03-03", + "2023-03-04", }, - { - "time_range": [ # overlap by some margin - ("2023-03-01 08:00:00Z", "2023-03-03 16:00:00Z"), - ("2023-03-03 12:00:00Z", "2023-03-05 20:00:00Z"), - ("2023-03-06 10:00:00Z", "2023-03-08 18:00:00Z"), - ], - "expected": { - "2023-03-01", - "2023-03-02", - "2023-03-03", - "2023-03-04", - "2023-03-05", - "2023-03-06", - "2023-03-07", - "2023-03-08", - }, + }, + { + "time_range": [ # overlap by some margin + ("2023-03-01 08:00:00Z", "2023-03-03 16:00:00Z"), + ("2023-03-03 12:00:00Z", "2023-03-05 20:00:00Z"), + ("2023-03-06 10:00:00Z", "2023-03-08 18:00:00Z"), + ], + "expected": { + "2023-03-01", + "2023-03-02", + "2023-03-03", + "2023-03-04", + "2023-03-05", + "2023-03-06", + "2023-03-07", + "2023-03-08", }, - ] - ) + }, + ] ], ) def test_multiple_persons( diff --git a/tests/execution_engine/omop/criterion/test_procedure_occurrence.py b/tests/execution_engine/omop/criterion/test_procedure_occurrence.py index 0caa9861..aa2a9c05 100644 --- a/tests/execution_engine/omop/criterion/test_procedure_occurrence.py +++ b/tests/execution_engine/omop/criterion/test_procedure_occurrence.py @@ -4,7 +4,8 @@ from execution_engine.omop.concepts import Concept from execution_engine.omop.criterion.procedure_occurrence import ProcedureOccurrence from execution_engine.util.enum import TimeUnit -from execution_engine.util.types import TimeRange, Timing +from execution_engine.util.types import Timing +from execution_engine.util.types.timerange import TimeRange from execution_engine.util.value import ValueNumber from execution_engine.util.value.time import ValueDuration from tests.execution_engine.omop.criterion.test_occurrence_criterion import Occurrence diff --git a/tests/execution_engine/omop/db/celida/test_triggers.py b/tests/execution_engine/omop/db/celida/test_triggers.py index 911f6577..ad75dea2 100644 --- a/tests/execution_engine/omop/db/celida/test_triggers.py +++ b/tests/execution_engine/omop/db/celida/test_triggers.py @@ -8,7 +8,7 @@ from execution_engine.omop.db.omop.schema import SCHEMA_NAME as OMOP_SCHEMA_NAME from execution_engine.omop.db.omop.tables import Person from execution_engine.util.interval import IntervalType as T -from execution_engine.util.types import TimeRange +from execution_engine.util.types.timerange import TimeRange from tests._fixtures.omop_fixture import celida_recommendation diff --git a/tests/execution_engine/task/process/test_rectangle.py b/tests/execution_engine/task/process/test_rectangle.py index cd94c7b2..961afd44 100644 --- a/tests/execution_engine/task/process/test_rectangle.py +++ b/tests/execution_engine/task/process/test_rectangle.py @@ -1,6 +1,3 @@ -import random -from datetime import time - import pandas as pd import pendulum import pytest @@ -12,7 +9,8 @@ get_processing_module, ) from execution_engine.util.interval import IntervalType as T -from execution_engine.util.types import PersonIntervals, TimeRange +from execution_engine.util.types import PersonIntervals +from execution_engine.util.types.timerange import TimeRange from tests.functions import df_from_str from tests.functions import intervals_to_df as intervals_to_df_original from tests.functions import parse_dt diff --git a/tests/execution_engine/util/test_types.py b/tests/execution_engine/util/test_types.py index 350fd66a..bcff9e02 100644 --- a/tests/execution_engine/util/test_types.py +++ b/tests/execution_engine/util/test_types.py @@ -6,7 +6,8 @@ from pydantic import ValidationError from execution_engine.util.enum import TimeUnit -from execution_engine.util.types import Dosage, TimeRange, Timing +from execution_engine.util.types import Dosage, Timing +from execution_engine.util.types.timerange import TimeRange from execution_engine.util.value import ValueNumber from execution_engine.util.value.time import ValueCount, ValueDuration, ValuePeriod from tests._fixtures.concept import concept_unit_mg diff --git a/tests/recommendation/test_recommendation_base.py b/tests/recommendation/test_recommendation_base.py index a78af845..54580628 100644 --- a/tests/recommendation/test_recommendation_base.py +++ b/tests/recommendation/test_recommendation_base.py @@ -18,7 +18,7 @@ ) from execution_engine.omop.db.omop.tables import Person from execution_engine.util.interval import IntervalType -from execution_engine.util.types import TimeRange +from execution_engine.util.types.timerange import TimeRange from tests._testdata import concepts, parameter from tests.functions import ( create_condition, diff --git a/tests/recommendation/test_recommendation_base_v2.py b/tests/recommendation/test_recommendation_base_v2.py index 5bc1a584..ec322d32 100644 --- a/tests/recommendation/test_recommendation_base_v2.py +++ b/tests/recommendation/test_recommendation_base_v2.py @@ -8,7 +8,7 @@ from execution_engine.omop.db.omop.tables import Person from execution_engine.util.interval import IntervalType -from execution_engine.util.types import TimeRange +from execution_engine.util.types.timerange import TimeRange from tests._testdata import concepts from tests._testdata.generator.composite import ( AndGenerator, From f54c5d261058323b49bab73c16d76bfb62bcea69 Mon Sep 17 00:00:00 2001 From: Gregor Lichtner Date: Fri, 28 Mar 2025 18:26:02 +0100 Subject: [PATCH 41/43] fix(datetime): truncate to second precision to avoid rounding issues with PostgreSQL Remove microseconds from start and end datetimes to ensure consistent behavior across Python and PostgreSQL. Python uses floor when casting to int, while PostgreSQL may round, potentially causing off-by-one-second bugs. Added assertions to enforce the truncation. --- execution_engine/execution_engine.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/execution_engine/execution_engine.py b/execution_engine/execution_engine.py index 9bef1e7f..87a9ad30 100644 --- a/execution_engine/execution_engine.py +++ b/execution_engine/execution_engine.py @@ -434,6 +434,23 @@ def execute_recommendation( assert start_datetime.tzinfo, "start_datetime must be timezone-aware" assert end_datetime.tzinfo, "end_datetime must be timezone-aware" + # We cut off microseconds to ensure second-level precision. + # In Python, timestamps are usually floored when cast to int, + # but in PostgreSQL they might be rounded. To avoid subtle bugs + # where a 1-second difference appears due to different rounding + # strategies, we explicitly strip microseconds and assert that + # the timestamp is exactly an integer. + start_datetime = start_datetime.replace(microsecond=0) + end_datetime = end_datetime.replace(microsecond=0) + + assert start_datetime.timestamp() == int( + start_datetime.timestamp() + ), f"start_datetime still contains sub-second precision: {start_datetime}" + + assert end_datetime.timestamp() == int( + end_datetime.timestamp() + ), f"end_datetime still contains sub-second precision: {end_datetime}" + date_format = "%Y-%m-%d %H:%M:%S %z" logging.info( From 9ae95b081c2d1de431c1bb2c016b9324cc719f92 Mon Sep 17 00:00:00 2001 From: Gregor Lichtner Date: Fri, 28 Mar 2025 22:35:45 +0100 Subject: [PATCH 42/43] fix: disappearing 0sec intervals in LeftDependentToggle --- .../task/process/rectangle_cython.pyx | 110 ++-- .../task/process/rectangle_python.py | 80 +-- .../combination/test_logical_combination.py | 487 +++++++++++++++++- .../task/process/test_rectangle.py | 74 +++ 4 files changed, 663 insertions(+), 88 deletions(-) diff --git a/execution_engine/task/process/rectangle_cython.pyx b/execution_engine/task/process/rectangle_cython.pyx index bdeee748..caad697e 100644 --- a/execution_engine/task/process/rectangle_cython.pyx +++ b/execution_engine/task/process/rectangle_cython.pyx @@ -22,7 +22,7 @@ DEF SCHAR_MAX = 127 MODULE_IMPLEMENTATION = "cython" IntervalEvent = typing.Tuple[int, bool, AnyInterval] -IntervalEventWithCount = typing.Tuple[int, bool, AnyInterval, int] +IntervalEventOnTrack = typing.Tuple[int, bool, AnyInterval, int] def intervals_to_events( intervals: list[AnyInterval], @@ -295,10 +295,11 @@ def default_is_same_result(interval_constructor: IntervalConstructor): == interval_constructor(0, 0, active_intervals2)) return is_same_result -def find_rectangles(all_intervals: list[list[AnyInterval]], - interval_constructor: IntervalConstructor, - is_same_result: SameResult | None = None) \ - -> list[AnyInterval]: +def find_rectangles( + all_intervals: list[list[AnyInterval]], + interval_constructor: IntervalConstructor, + is_same_result: SameResult | None = None, +) -> list[AnyInterval]: """ Low-level engine for interval construction. @@ -342,16 +343,15 @@ def find_rectangles(all_intervals: list[list[AnyInterval]], (time, event, interval, j) for j, intervals in enumerate(all_intervals) for interval in intervals - for (time,event) in [(interval.lower, True), (interval.upper, False)] + for (time, event) in [(interval.lower, True), (interval.upper, False)] ] - event_count = len(events) if event_count == 0: return [] def compare_events( - event1: IntervalEventWithCount, event2: IntervalEventWithCount + event1: IntervalEventOnTrack, event2: IntervalEventOnTrack ) -> int: """ Sorting comparator to ensure we process events in the correct order: @@ -359,17 +359,27 @@ def find_rectangles(all_intervals: list[list[AnyInterval]], - if same time and same track, close events before open events (so we don't incorrectly treat a consecutive interval on the same track as overlapping). + + Index of event1 and event2: + - [0]: time of event + - [1]: opening (True) or closing (False) + - [2]: the interval to which the event belongs + - [3]: track index """ - if event1[0] < event2[0]: # event1 is earlier + if event1[0] < event2[0]: # event1 is earlier return -1 - elif event2[0] < event1[0]: # event2 is earlier + elif event2[0] < event1[0]: # event2 is earlier return 1 - elif event1[3] == event2[3]: # at the same time and on same track, - if event1[2] is event2[2]: # same interval - return -1 if (event1[1] is True) else 1 # sort open events before open events - else: # different intervals - return -1 if (event1[1] is False) else 1 # sort close events before open events - else: # at the same time, but different tracks => any order is fine + elif event1[3] == event2[3]: # at the same time and on same track, + if event1[2] == event2[2]: # same interval (we don't check for "is" because they might be different objects, but still represent the same interval) + return ( + -1 if (event1[1] is True) else 1 + ) # sort open events before open events + else: # different intervals + return ( + -1 if (event1[1] is False) else 1 + ) # sort close events before open events + else: # at the same time, but different tracks => any order is fine return 1 # Sort events chronologically according to compare_events @@ -377,6 +387,32 @@ def find_rectangles(all_intervals: list[list[AnyInterval]], active_intervals: list[GeneralizedInterval] = [None] * track_count + def finalize_interval( + interval_start_time: int, + current_time: int, + interval_start_state: List[GeneralizedInterval], + ) -> None: + """ + Appends a new time slice (interval_start_time -> current_time) to 'result_intervals', + ensuring we don't create duplicate adjacency boundaries if the previous slice ends + exactly where the new one starts. + """ + if len(result_intervals) > 0: + previous_result = result_intervals[-1] + if previous_result[1] == interval_start_time: + # Adjust the previous slice so it doesn't overlap or duplicate + result_intervals[-1] = ( + previous_result[0], + previous_result[1] - 1, + previous_result[2], + ) + + # Now finalize the current slice + result_intervals.append( + (interval_start_time, current_time, interval_start_state) + ) + + def process_events_for_point_in_time( index: int, point_time: int ) -> Tuple[int, int, int] | None: @@ -405,25 +441,33 @@ def find_rectangles(all_intervals: list[list[AnyInterval]], # [START_TIME1, 10:59:59] [11:00:00, END_TIME2] # have no gap between them and can be considered a single # continuous interval [START_TIME1, END_TIME2]. - if (point_time == time) or (open_ and (point_time == time - 1)): + + point_interval_closing = any_open and not open_ and interval.lower == interval.upper == point_time + + if ((point_time == time) and not point_interval_closing) or (open_ and (point_time == time - 1)): if time > high_time: high_time = time any_open |= open_ else: # As soon as we find an event that’s clearly beyond the cluster at point_time, # we break and return - return i, time, high_time if any_open else high_time + 1 + return ( + i, + time, + high_time if any_open else high_time + 1, + ) # Opening => set this track’s active interval to the new interval # Closing => set it to None active_intervals[track] = interval if open_ else None + + # If we exit the loop fully, we used all events return None - result_intervals = [] # Step through event "clusters" with a common point in time and # emit result intervals with unchanged interval "payload". index: int | None = 0 - time: int | None = events[0][0] + time: int | None = events[index][0] interval_start_time: int = time result_intervals: list[tuple[int, int, List[GeneralizedInterval]]] = [] @@ -431,6 +475,7 @@ def find_rectangles(all_intervals: list[list[AnyInterval]], # No events at all return [] + # process the event at index 0 at the first timepoint res = process_events_for_point_in_time(index, time) if res is None: @@ -440,31 +485,6 @@ def find_rectangles(all_intervals: list[list[AnyInterval]], interval_start_state = active_intervals.copy() - def finalize_interval( - interval_start_time: int, - current_time: int, - interval_start_state: List[GeneralizedInterval], - ) -> None: - """ - Appends a new time slice (interval_start_time -> current_time) to 'result_intervals', - ensuring we don't create duplicate adjacency boundaries if the previous slice ends - exactly where the new one starts. - """ - if len(result_intervals) > 0: - previous_result = result_intervals[-1] - if previous_result[1] == interval_start_time: - # Adjust the previous slice so it doesn't overlap or duplicate - result_intervals[-1] = ( - previous_result[0], - previous_result[1] - 1, - previous_result[2], - ) - - # Now finalize the current slice - result_intervals.append( - (interval_start_time, current_time, interval_start_state) - ) - # The main loop: step through event clusters while True: res = process_events_for_point_in_time(index, time) diff --git a/execution_engine/task/process/rectangle_python.py b/execution_engine/task/process/rectangle_python.py index 50aa2e1f..8d92c83e 100644 --- a/execution_engine/task/process/rectangle_python.py +++ b/execution_engine/task/process/rectangle_python.py @@ -16,7 +16,7 @@ MODULE_IMPLEMENTATION = "python" IntervalEvent = Tuple[int, bool, AnyInterval] -IntervalEventWithCount = Tuple[int, bool, AnyInterval, int] +IntervalEventOnTrack = Tuple[int, bool, AnyInterval, int] def intervals_to_events( @@ -332,7 +332,7 @@ def find_rectangles( # intervals would get confused. track_count = len(all_intervals) - events: list[IntervalEventWithCount] = [ + events: list[IntervalEventOnTrack] = [ (time, event, interval, j) for j, intervals in enumerate(all_intervals) for interval in intervals @@ -344,7 +344,7 @@ def find_rectangles( return [] def compare_events( - event1: IntervalEventWithCount, event2: IntervalEventWithCount + event1: IntervalEventOnTrack, event2: IntervalEventOnTrack ) -> int: """ Sorting comparator to ensure we process events in the correct order: @@ -352,13 +352,21 @@ def compare_events( - if same time and same track, close events before open events (so we don't incorrectly treat a consecutive interval on the same track as overlapping). + + Index of event1 and event2: + - [0]: time of event + - [1]: opening (True) or closing (False) + - [2]: the interval to which the event belongs + - [3]: track index """ if event1[0] < event2[0]: # event1 is earlier return -1 elif event2[0] < event1[0]: # event2 is earlier return 1 elif event1[3] == event2[3]: # at the same time and on same track, - if event1[2] is event2[2]: # same interval + if ( + event1[2] == event2[2] + ): # same interval (we don't check for "is" because they might be different objects, but still represent the same interval) return ( -1 if (event1[1] is True) else 1 ) # sort open events before open events @@ -374,6 +382,31 @@ def compare_events( active_intervals: list[GeneralizedInterval] = [None] * track_count + def finalize_interval( + interval_start_time: int, + current_time: int, + interval_start_state: List[GeneralizedInterval], + ) -> None: + """ + Appends a new time slice (interval_start_time -> current_time) to 'result_intervals', + ensuring we don't create duplicate adjacency boundaries if the previous slice ends + exactly where the new one starts. + """ + if len(result_intervals) > 0: + previous_result = result_intervals[-1] + if previous_result[1] == interval_start_time: + # Adjust the previous slice so it doesn't overlap or duplicate + result_intervals[-1] = ( + previous_result[0], + previous_result[1] - 1, + previous_result[2], + ) + + # Now finalize the current slice + result_intervals.append( + (interval_start_time, current_time, interval_start_state) + ) + def process_events_for_point_in_time( index: int, point_time: int ) -> Tuple[int, int, int] | None: @@ -402,7 +435,16 @@ def process_events_for_point_in_time( # [START_TIME1, 10:59:59] [11:00:00, END_TIME2] # have no gap between them and can be considered a single # continuous interval [START_TIME1, END_TIME2]. - if (point_time == time) or (open_ and (point_time == time - 1)): + + point_interval_closing = ( + any_open + and not open_ + and interval.lower == interval.upper == point_time + ) + + if ((point_time == time) and not point_interval_closing) or ( + open_ and (point_time == time - 1) + ): if time > high_time: high_time = time any_open |= open_ @@ -425,7 +467,7 @@ def process_events_for_point_in_time( # Step through event "clusters" with a common point in time and # emit result intervals with unchanged interval "payload". index: int | None = 0 - time: int | None = events[0][0] + time: int | None = events[index][0] # type: ignore[index] interval_start_time: int = cast(int, time) result_intervals: list[tuple[int, int, List[GeneralizedInterval]]] = [] @@ -433,6 +475,7 @@ def process_events_for_point_in_time( # No events at all return [] + # process the event at index 0 at the first timepoint res = process_events_for_point_in_time(cast(int, index), cast(int, time)) if res is None: @@ -442,31 +485,6 @@ def process_events_for_point_in_time( interval_start_state = active_intervals.copy() - def finalize_interval( - interval_start_time: int, - current_time: int, - interval_start_state: List[GeneralizedInterval], - ) -> None: - """ - Appends a new time slice (interval_start_time -> current_time) to 'result_intervals', - ensuring we don't create duplicate adjacency boundaries if the previous slice ends - exactly where the new one starts. - """ - if len(result_intervals) > 0: - previous_result = result_intervals[-1] - if previous_result[1] == interval_start_time: - # Adjust the previous slice so it doesn't overlap or duplicate - result_intervals[-1] = ( - previous_result[0], - previous_result[1] - 1, - previous_result[2], - ) - - # Now finalize the current slice - result_intervals.append( - (interval_start_time, current_time, interval_start_state) - ) - # The main loop: step through event clusters while True: res = process_events_for_point_in_time(index, time) diff --git a/tests/execution_engine/omop/criterion/combination/test_logical_combination.py b/tests/execution_engine/omop/criterion/combination/test_logical_combination.py index f57bb20f..0f5d2ec2 100644 --- a/tests/execution_engine/omop/criterion/combination/test_logical_combination.py +++ b/tests/execution_engine/omop/criterion/combination/test_logical_combination.py @@ -11,8 +11,9 @@ from execution_engine.omop.criterion.measurement import Measurement from execution_engine.omop.criterion.noop import NoopCriterion from execution_engine.omop.criterion.procedure_occurrence import ProcedureOccurrence -from execution_engine.task.process import get_processing_module +from execution_engine.task.process import Interval, get_processing_module from execution_engine.util import logic +from execution_engine.util.interval import IntervalType from execution_engine.util.types import Dosage from execution_engine.util.types.timerange import TimeRange from execution_engine.util.value import ValueNumber @@ -191,6 +192,7 @@ def run_criteria_test( base_criterion, observation_window, persons, + result_mode: str = "full_day", ): c = sympy.parse_expr(combination) @@ -222,6 +224,8 @@ def run_criteria_test( cls = lambda *args: logic.AllOrNone(*args) elif c.func.name == "ConditionalFilter": cls = lambda *args: logic.ConditionalFilter(*args) + elif c.func.name == "LeftDependentToggle": + cls = lambda *args: logic.LeftDependentToggle(*args) else: raise ValueError(f"Unknown operator {c.func}") else: @@ -254,6 +258,11 @@ def run_criteria_test( right=symbols[str(c.args[1])], ) + elif hasattr(c.func, "name") and c.func.name == "LeftDependentToggle": + comb = logic.LeftDependentToggle( + left=symbols[str(c.args[0])], + right=symbols[str(c.args[1])], + ) else: comb = cls( *[symbols[str(symbol)] for symbol in c.atoms() if not symbol.is_number] @@ -275,18 +284,47 @@ def run_criteria_test( observation_window=observation_window, ) - df = self.fetch_full_day_result( - db_session, - pi_pair_id=self.pi_pair_id, - criterion_id=None, - category=CohortCategory.POPULATION, - ) - - for person in persons: - result = df.query(f"person_id=={person.person_id}") - expected_result = date_set(expected[person.person_id]) + match result_mode: + case "full_day": + df = self.fetch_full_day_result( + db_session, + pi_pair_id=self.pi_pair_id, + criterion_id=None, + category=CohortCategory.POPULATION, + ) + case "partial_day": + df = self.fetch_partial_day_result( + db_session, + pi_pair_id=self.pi_pair_id, + criterion_id=None, + category=CohortCategory.POPULATION, + ) + case "exact": + df = self.fetch_interval_result( + db_session, + pi_pair_id=self.pi_pair_id, + criterion_id=None, + category=CohortCategory.POPULATION, + ) + + if result_mode in ["full_day", "partial_day"]: + for person in persons: + result = df.query(f"person_id=={person.person_id}") + expected_result = date_set(expected[person.person_id]) + + assert ( + set(pd.to_datetime(result["valid_date"]).dt.date) == expected_result + ) + else: + for person in persons: + result = df.query(f"person_id=={person.person_id}") + result_intervals = [ + Interval(row.interval_start, row.interval_end, row.interval_type) + for _, row in result.iterrows() + ] + expected_result = expected[person.person_id] - assert set(pd.to_datetime(result["valid_date"]).dt.date) == expected_result + assert result_intervals == expected_result class TestCriterionCombinationResultShortObservationWindow( @@ -936,3 +974,428 @@ def test_combination_on_database( observation_window, persons, ) + + +class TestCriterionCombinationLeftDependentToggle(TestCriterionCombinationDatabase): + """ + Test class for testing criterion combinations on the database. + """ + + @pytest.fixture + def observation_window(self) -> TimeRange: + return TimeRange( + start="2023-03-01 04:00:00Z", end="2023-03-04 18:00:00Z", name="observation" + ) + + @pytest.fixture + def criteria(self, db_session): + c1 = ConditionOccurrence( + concept=concept_covid19, + ) + + c2 = ProcedureOccurrence(concept=concept_artificial_respiration) + + c3 = Measurement( + concept=concept_body_weight, + value=ValueNumber(value_min=70, unit=concept_unit_kg), + ) + + c1.set_id(1) + c2.set_id(2) + c3.set_id(3) + + self.register_criterion(c1, db_session) + self.register_criterion(c2, db_session) + self.register_criterion(c3, db_session) + + return [c1, c2, c3] + + @pytest.fixture + def patient_events(self, db_session, person_visit): + _, visit_occurrence = person_visit[0] + + e1 = create_procedure( + vo=visit_occurrence, + procedure_concept_id=concept_artificial_respiration.concept_id, + start_datetime=pendulum.parse("2023-03-02 12:00:00+01:00"), + end_datetime=pendulum.parse("2023-03-02 12:00:00+01:00"), + ) + + e2 = create_condition( + vo=visit_occurrence, + condition_concept_id=concept_covid19.concept_id, + condition_start_datetime=pendulum.parse("2023-03-02 06:00:00+01:00"), + condition_end_datetime=pendulum.parse("2023-03-03 18:00:00+01:00"), + ) + + db_session.add_all([e1, e2]) + + db_session.commit() + + @pytest.mark.parametrize( + "combination,expected", + [ + ( + "LeftDependentToggle(c1, c2)", + { + 1: [ + Interval( + pendulum.parse("2023-03-01 10:36:24+0100", tz="CET"), + pendulum.parse("2023-03-02 05:59:59+0100", tz="CET"), + type=IntervalType.NOT_APPLICABLE, + ), + Interval( + pendulum.parse("2023-03-02 06:00:00+0100", tz="CET"), + pendulum.parse("2023-03-02 11:59:59+0100", tz="CET"), + type=IntervalType.NEGATIVE, + ), + Interval( + pendulum.parse("2023-03-02 12:00:00+0100", tz="CET"), + pendulum.parse("2023-03-02 12:00:00+0100", tz="CET"), + type=IntervalType.POSITIVE, + ), + Interval( + pendulum.parse("2023-03-02 12:00:01+0100", tz="CET"), + pendulum.parse("2023-03-03 18:00:00+0100", tz="CET"), + type=IntervalType.NEGATIVE, + ), + Interval( + pendulum.parse("2023-03-03 18:00:01+0100", tz="CET"), + pendulum.parse("2023-03-04 19:00:00+0100", tz="CET"), + type=IntervalType.NOT_APPLICABLE, + ), + ], + 2: [ + Interval( + lower=pendulum.parse("2023-03-02 09:36:24+0000", tz="UTC"), + upper=pendulum.parse("2023-03-04 18:00:00+0000", tz="UTC"), + type=IntervalType.NEGATIVE, + ) + ], + 3: [ + Interval( + lower=pendulum.parse("2023-03-03 09:36:24+0000", tz="UTC"), + upper=pendulum.parse("2023-03-04 18:00:00+0000", tz="UTC"), + type=IntervalType.NEGATIVE, + ) + ], + }, + ), + ], + ) + def test_combination_on_database( + self, + person_visit, + db_session, + base_criterion, + patient_events, + criteria, + combination, + expected, + observation_window, + ): + persons = [pv[0] for pv in person_visit] + self.run_criteria_test( + combination, + expected, + db_session, + criteria, + base_criterion, + observation_window, + persons, + result_mode="exact", + ) + + +class TestCriterionCombinationLeftDependentToggleMultipleSameTime( + TestCriterionCombinationDatabase +): + """ + Test class for testing criterion combinations on the database. + """ + + @pytest.fixture + def observation_window(self) -> TimeRange: + return TimeRange( + start="2023-03-01 04:00:00Z", end="2023-03-04 18:00:00Z", name="observation" + ) + + @pytest.fixture + def criteria(self, db_session): + c1 = ConditionOccurrence( + concept=concept_covid19, + ) + + c2 = ProcedureOccurrence(concept=concept_artificial_respiration) + + c3 = Measurement( + concept=concept_body_weight, + value=ValueNumber(value_min=70, unit=concept_unit_kg), + ) + + c1.set_id(1) + c2.set_id(2) + c3.set_id(3) + + self.register_criterion(c1, db_session) + self.register_criterion(c2, db_session) + self.register_criterion(c3, db_session) + + return [c1, c2, c3] + + @pytest.fixture + def patient_events(self, db_session, person_visit): + _, visit_occurrence = person_visit[0] + + e1 = create_procedure( + vo=visit_occurrence, + procedure_concept_id=concept_artificial_respiration.concept_id, + start_datetime=pendulum.parse("2023-03-02 12:00:00+01:00"), + end_datetime=pendulum.parse("2023-03-02 12:00:00+01:00"), + ) + + e2 = create_procedure( + vo=visit_occurrence, + procedure_concept_id=concept_artificial_respiration.concept_id, + start_datetime=pendulum.parse("2023-03-02 12:00:00+01:00"), + end_datetime=pendulum.parse("2023-03-02 12:00:00+01:00"), + ) + + e3 = create_procedure( + vo=visit_occurrence, + procedure_concept_id=concept_artificial_respiration.concept_id, + start_datetime=pendulum.parse("2023-03-02 12:00:00+01:00"), + end_datetime=pendulum.parse("2023-03-02 12:00:00+01:00"), + ) + + e4 = create_condition( + vo=visit_occurrence, + condition_concept_id=concept_covid19.concept_id, + condition_start_datetime=pendulum.parse("2023-03-02 06:00:00+01:00"), + condition_end_datetime=pendulum.parse("2023-03-03 18:00:00+01:00"), + ) + + db_session.add_all([e1, e2, e3, e4]) + + db_session.commit() + + @pytest.mark.parametrize( + "combination,expected", + [ + ( + "LeftDependentToggle(c1, c2)", + { + 1: [ + Interval( + pendulum.parse("2023-03-01 10:36:24+0100", tz="CET"), + pendulum.parse("2023-03-02 05:59:59+0100", tz="CET"), + type=IntervalType.NOT_APPLICABLE, + ), + Interval( + pendulum.parse("2023-03-02 06:00:00+0100", tz="CET"), + pendulum.parse("2023-03-02 11:59:59+0100", tz="CET"), + type=IntervalType.NEGATIVE, + ), + Interval( + pendulum.parse("2023-03-02 12:00:00+0100", tz="CET"), + pendulum.parse("2023-03-02 12:00:00+0100", tz="CET"), + type=IntervalType.POSITIVE, + ), + Interval( + pendulum.parse("2023-03-02 12:00:01+0100", tz="CET"), + pendulum.parse("2023-03-03 18:00:00+0100", tz="CET"), + type=IntervalType.NEGATIVE, + ), + Interval( + pendulum.parse("2023-03-03 18:00:01+0100", tz="CET"), + pendulum.parse("2023-03-04 19:00:00+0100", tz="CET"), + type=IntervalType.NOT_APPLICABLE, + ), + ], + 2: [ + Interval( + lower=pendulum.parse("2023-03-02 09:36:24+0000", tz="UTC"), + upper=pendulum.parse("2023-03-04 18:00:00+0000", tz="UTC"), + type=IntervalType.NEGATIVE, + ) + ], + 3: [ + Interval( + lower=pendulum.parse("2023-03-03 09:36:24+0000", tz="UTC"), + upper=pendulum.parse("2023-03-04 18:00:00+0000", tz="UTC"), + type=IntervalType.NEGATIVE, + ) + ], + }, + ), + ], + ) + def test_combination_on_database( + self, + person_visit, + db_session, + base_criterion, + patient_events, + criteria, + combination, + expected, + observation_window, + ): + persons = [pv[0] for pv in person_visit] + self.run_criteria_test( + combination, + expected, + db_session, + criteria, + base_criterion, + observation_window, + persons, + result_mode="exact", + ) + + +class TestCriterionCombinationLeftDependentToggleMultipleOverlapping( + TestCriterionCombinationDatabase +): + """ + Test class for testing criterion combinations on the database. + """ + + @pytest.fixture + def observation_window(self) -> TimeRange: + return TimeRange( + start="2023-03-01 04:00:00Z", end="2023-03-04 18:00:00Z", name="observation" + ) + + @pytest.fixture + def criteria(self, db_session): + c1 = ConditionOccurrence( + concept=concept_covid19, + ) + + c2 = ProcedureOccurrence(concept=concept_artificial_respiration) + + c3 = Measurement( + concept=concept_body_weight, + value=ValueNumber(value_min=70, unit=concept_unit_kg), + ) + + c1.set_id(1) + c2.set_id(2) + c3.set_id(3) + + self.register_criterion(c1, db_session) + self.register_criterion(c2, db_session) + self.register_criterion(c3, db_session) + + return [c1, c2, c3] + + @pytest.fixture + def patient_events(self, db_session, person_visit): + _, visit_occurrence = person_visit[0] + + e1 = create_procedure( + vo=visit_occurrence, + procedure_concept_id=concept_artificial_respiration.concept_id, + start_datetime=pendulum.parse("2023-03-02 12:00:00+01:00"), + end_datetime=pendulum.parse("2023-03-02 12:00:00+01:00"), + ) + + e2 = create_procedure( + vo=visit_occurrence, + procedure_concept_id=concept_artificial_respiration.concept_id, + start_datetime=pendulum.parse("2023-03-02 12:00:00+01:00"), + end_datetime=pendulum.parse("2023-03-02 12:00:01+01:00"), + ) + + e3 = create_procedure( + vo=visit_occurrence, + procedure_concept_id=concept_artificial_respiration.concept_id, + start_datetime=pendulum.parse("2023-03-02 12:00:00+01:00"), + end_datetime=pendulum.parse("2023-03-02 12:00:02+01:00"), + ) + + e4 = create_condition( + vo=visit_occurrence, + condition_concept_id=concept_covid19.concept_id, + condition_start_datetime=pendulum.parse("2023-03-02 06:00:00+01:00"), + condition_end_datetime=pendulum.parse("2023-03-03 18:00:00+01:00"), + ) + + db_session.add_all([e1, e2, e3, e4]) + + db_session.commit() + + @pytest.mark.parametrize( + "combination,expected", + [ + ( + "LeftDependentToggle(c1, c2)", + { + 1: [ + Interval( + pendulum.parse("2023-03-01 10:36:24+0100", tz="CET"), + pendulum.parse("2023-03-02 05:59:59+0100", tz="CET"), + type=IntervalType.NOT_APPLICABLE, + ), + Interval( + pendulum.parse("2023-03-02 06:00:00+0100", tz="CET"), + pendulum.parse("2023-03-02 11:59:59+0100", tz="CET"), + type=IntervalType.NEGATIVE, + ), + Interval( + pendulum.parse("2023-03-02 12:00:00+0100", tz="CET"), + pendulum.parse("2023-03-02 12:00:02+0100", tz="CET"), + type=IntervalType.POSITIVE, + ), + Interval( + pendulum.parse("2023-03-02 12:00:03+0100", tz="CET"), + pendulum.parse("2023-03-03 18:00:00+0100", tz="CET"), + type=IntervalType.NEGATIVE, + ), + Interval( + pendulum.parse("2023-03-03 18:00:01+0100", tz="CET"), + pendulum.parse("2023-03-04 19:00:00+0100", tz="CET"), + type=IntervalType.NOT_APPLICABLE, + ), + ], + 2: [ + Interval( + lower=pendulum.parse("2023-03-02 09:36:24+0000", tz="UTC"), + upper=pendulum.parse("2023-03-04 18:00:00+0000", tz="UTC"), + type=IntervalType.NEGATIVE, + ) + ], + 3: [ + Interval( + lower=pendulum.parse("2023-03-03 09:36:24+0000", tz="UTC"), + upper=pendulum.parse("2023-03-04 18:00:00+0000", tz="UTC"), + type=IntervalType.NEGATIVE, + ) + ], + }, + ), + ], + ) + def test_combination_on_database( + self, + person_visit, + db_session, + base_criterion, + patient_events, + criteria, + combination, + expected, + observation_window, + ): + persons = [pv[0] for pv in person_visit] + self.run_criteria_test( + combination, + expected, + db_session, + criteria, + base_criterion, + observation_window, + persons, + result_mode="exact", + ) diff --git a/tests/execution_engine/task/process/test_rectangle.py b/tests/execution_engine/task/process/test_rectangle.py index 961afd44..5203cf21 100644 --- a/tests/execution_engine/task/process/test_rectangle.py +++ b/tests/execution_engine/task/process/test_rectangle.py @@ -7,7 +7,9 @@ Interval, IntervalWithCount, get_processing_module, + interval_like, ) +from execution_engine.util.interval import IntervalType from execution_engine.util.interval import IntervalType as T from execution_engine.util.types import PersonIntervals from execution_engine.util.types.timerange import TimeRange @@ -1770,3 +1772,75 @@ def test_union_intervals_no_data_negative(self): result = self.intervals_to_df(result, ["person_id"]) pd.testing.assert_frame_equal(result, expected_df) + + +class TestFindRectangles(ProcessTest): + + @pytest.mark.parametrize( + "right_intervals, expected", + ( + ( + [Interval(4, 4, IntervalType.POSITIVE)], + Interval(4, 4, IntervalType.POSITIVE), + ), + ( + [Interval(4, 5, IntervalType.POSITIVE)], + Interval(4, 5, IntervalType.POSITIVE), + ), + ( + [Interval(4, 6, IntervalType.POSITIVE)], + Interval(4, 6, IntervalType.POSITIVE), + ), + ( + [ + Interval(4, 4, IntervalType.POSITIVE), + Interval(4, 4, IntervalType.POSITIVE), + ], + Interval(4, 4, IntervalType.POSITIVE), + ), + # the following test currently (2025-03-28) fails, but this input is artificial, as only criteria could yield + # multiple overlapping intervals "on the same track", but these are always union-ed before propagation anyway, + # so find_rectangles with the new_interval function below should never receive such intervals. + # ([Interval(4, 5, IntervalType.POSITIVE), Interval(4, 6, IntervalType.POSITIVE)], Interval(4, 6, IntervalType.POSITIVE)), + ( + [ + Interval(4, 5, IntervalType.POSITIVE), + Interval(5, 6, IntervalType.POSITIVE), + ], + Interval(4, 6, IntervalType.POSITIVE), + ), + ( + [ + Interval(4, 5, IntervalType.POSITIVE), + Interval(6, 6, IntervalType.POSITIVE), + ], + Interval(4, 6, IntervalType.POSITIVE), + ), + ), + ) + def test_counting(self, right_intervals, expected): + def new_interval(start: int, end: int, intervals): + left_interval, right_interval, observation_window_ = intervals + if (left_interval is None) or left_interval.type != IntervalType.POSITIVE: + # no left_interval or not positive -> use fill type + return Interval(start, end, IntervalType.NOT_APPLICABLE) + elif right_interval is not None: + return interval_like(right_interval, start, end) + else: # left_interval but not right_interval -> implicit negative + return None + + left = {1: [Interval(2, 8, IntervalType.POSITIVE)]} + right = {1: right_intervals} + window_intervals = {1: [Interval(0, 10, IntervalType.POSITIVE)]} + + result = self.process.find_rectangles( + [left, right, window_intervals], new_interval + ) + + assert result == { + 1: [ + Interval(lower=0, upper=1, type=IntervalType.NOT_APPLICABLE), + expected, + Interval(lower=9, upper=10, type=IntervalType.NOT_APPLICABLE), + ] + } From 65e815c7e11a4980b3f1eaf229b6849cf7f896c2 Mon Sep 17 00:00:00 2001 From: Gregor Lichtner Date: Fri, 28 Mar 2025 23:01:15 +0100 Subject: [PATCH 43/43] feat: add custom count handling for logic.And --- execution_engine/task/task.py | 30 ++++++++++++++++++++++++++++-- 1 file changed, 28 insertions(+), 2 deletions(-) diff --git a/execution_engine/task/task.py b/execution_engine/task/task.py index 082d5813..78d541c4 100644 --- a/execution_engine/task/task.py +++ b/execution_engine/task/task.py @@ -73,6 +73,15 @@ def default_interval_union_with_count( return IntervalWithCount(start, end, result_type, result_count) +def default_interval_intersect_with_count( + start: int, end: int, intervals: List[GeneralizedInterval] +) -> IntervalWithCount: + """ + Default interval counting function to be used in logic.Or + """ + raise NotImplementedError() + + def get_engine() -> OMOPSQLClient: """ Returns a OMOPSQLClient object. @@ -382,9 +391,26 @@ def handle_binary_logical_operator( return data[0] if isinstance(self.expr, (logic.And, logic.NonSimplifiableAnd)): - result = process.intersect_intervals(data) + + if self.receives_only_count_inputs() and hasattr( + self.expr, "count_intervals" + ): + # we check if there are custom data preparatory and interval counting functions and use these + prepare_func = getattr(self.expr, "prepare_data", None) + if prepare_func: + data = prepare_func(self, data) + func = getattr( + self.expr, "count_intervals", default_interval_intersect_with_count + ) + result = process.find_rectangles(data, func) + + else: + result = process.intersect_intervals(data) + elif isinstance(self.expr, (logic.Or, logic.NonSimplifiableOr)): - if self.receives_only_count_inputs(): + if self.receives_only_count_inputs() and hasattr( + self.expr, "count_intervals" + ): # we check if there are custom data preparatory and interval counting functions and use these prepare_func = getattr(self.expr, "prepare_data", None) if prepare_func: